diff --git a/.gitattributes b/.gitattributes
index 9ca4c9045e8b6ec32c3ba6519c9b01976ec0d99f..f67edbf743e0c002d40727db1eae7818ae9278f4 100644
--- a/.gitattributes
+++ b/.gitattributes
@@ -43,3 +43,5 @@ Metaworld/zarr_path:[[:space:]]data/metaworld_disassemble_expert.zarr/data/point
Metaworld/zarr_path:[[:space:]]/data/haojun/datasets/3d-dp/metaworld_hammer_expert.zarr/data/depth/0.0.0 filter=lfs diff=lfs merge=lfs -text
Metaworld/zarr_path:[[:space:]]data/metaworld_door-lock_expert.zarr/data/point_cloud/16.0.0 filter=lfs diff=lfs merge=lfs -text
Metaworld/zarr_path:[[:space:]]/data/haojun/datasets/3d-dp/metaworld_drawer-open_expert.zarr/data/point_cloud/5.0.0 filter=lfs diff=lfs merge=lfs -text
+Metaworld/zarr_path:[[:space:]]data/metaworld_door-close_expert.zarr/data/point_cloud/12.0.0 filter=lfs diff=lfs merge=lfs -text
+Metaworld/zarr_path:[[:space:]]data/metaworld_door-lock_expert.zarr/data/point_cloud/7.0.0 filter=lfs diff=lfs merge=lfs -text
diff --git a/Metaworld/metaworld/envs/assets_v2/sawyer_xyz/sawyer_dial.xml b/Metaworld/metaworld/envs/assets_v2/sawyer_xyz/sawyer_dial.xml
new file mode 100644
index 0000000000000000000000000000000000000000..3b2d3dc3aa3e52aca2882f7b213dd3753732e15e
--- /dev/null
+++ b/Metaworld/metaworld/envs/assets_v2/sawyer_xyz/sawyer_dial.xml
@@ -0,0 +1,25 @@
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+ -->
+
+
+
+
+
+
+
+
+
diff --git a/Metaworld/metaworld/envs/assets_v2/sawyer_xyz/sawyer_door_lock.xml b/Metaworld/metaworld/envs/assets_v2/sawyer_xyz/sawyer_door_lock.xml
new file mode 100644
index 0000000000000000000000000000000000000000..baa86c3eb73294751e4efdf803ab280399cbedf2
--- /dev/null
+++ b/Metaworld/metaworld/envs/assets_v2/sawyer_xyz/sawyer_door_lock.xml
@@ -0,0 +1,26 @@
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
diff --git a/Metaworld/metaworld/envs/assets_v2/sawyer_xyz/sawyer_door_pull.xml b/Metaworld/metaworld/envs/assets_v2/sawyer_xyz/sawyer_door_pull.xml
new file mode 100644
index 0000000000000000000000000000000000000000..8c026691e0346e844de4aca455741d075e8e287e
--- /dev/null
+++ b/Metaworld/metaworld/envs/assets_v2/sawyer_xyz/sawyer_door_pull.xml
@@ -0,0 +1,23 @@
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
diff --git a/Metaworld/metaworld/envs/assets_v2/sawyer_xyz/sawyer_faucet.xml b/Metaworld/metaworld/envs/assets_v2/sawyer_xyz/sawyer_faucet.xml
new file mode 100644
index 0000000000000000000000000000000000000000..a4c70d22783957612024bbe4e694618a04d28b47
--- /dev/null
+++ b/Metaworld/metaworld/envs/assets_v2/sawyer_xyz/sawyer_faucet.xml
@@ -0,0 +1,35 @@
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
diff --git a/Metaworld/metaworld/envs/assets_v2/sawyer_xyz/sawyer_laptop.xml b/Metaworld/metaworld/envs/assets_v2/sawyer_xyz/sawyer_laptop.xml
new file mode 100644
index 0000000000000000000000000000000000000000..4c12990f7bffeaf40b50e45aa90efca67a0ba0db
--- /dev/null
+++ b/Metaworld/metaworld/envs/assets_v2/sawyer_xyz/sawyer_laptop.xml
@@ -0,0 +1,22 @@
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
diff --git a/Metaworld/zarr_path: data/metaworld_door-close_expert.zarr/data/point_cloud/12.0.0 b/Metaworld/zarr_path: data/metaworld_door-close_expert.zarr/data/point_cloud/12.0.0
new file mode 100644
index 0000000000000000000000000000000000000000..a36a25cd9059075eb53d5862b6557fbfca81a929
--- /dev/null
+++ b/Metaworld/zarr_path: data/metaworld_door-close_expert.zarr/data/point_cloud/12.0.0
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:902da15cc5607a827c4a6cf9b7c396bacd2bd244963aa4e36f400f543b23497b
+size 1231019
diff --git a/Metaworld/zarr_path: data/metaworld_door-lock_expert.zarr/data/point_cloud/7.0.0 b/Metaworld/zarr_path: data/metaworld_door-lock_expert.zarr/data/point_cloud/7.0.0
new file mode 100644
index 0000000000000000000000000000000000000000..145c8b9a6b9a5dbae2337d32a8762931535418d6
--- /dev/null
+++ b/Metaworld/zarr_path: data/metaworld_door-lock_expert.zarr/data/point_cloud/7.0.0
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:32e76433f75e66cccd7b4db7962d3fc1ae4fe8915409efc427235456347323c4
+size 1213234
diff --git a/dexart-release/assets/sapien/102697/cues.txt b/dexart-release/assets/sapien/102697/cues.txt
new file mode 100644
index 0000000000000000000000000000000000000000..fc91228deb45f4f0076d8ee2a82dbf3490ff8e4c
--- /dev/null
+++ b/dexart-release/assets/sapien/102697/cues.txt
@@ -0,0 +1,5 @@
+link_0 hinge lever button
+link_1 slider pump_lid lid
+link_2 hinge lid lid
+link_3 hinge seat seat
+link_4 static base_body base_body
diff --git a/dexart-release/assets/sapien/102697/meta.json b/dexart-release/assets/sapien/102697/meta.json
new file mode 100644
index 0000000000000000000000000000000000000000..73d99eeda0b2619fb673d5d2a2fac81a7ab6066b
--- /dev/null
+++ b/dexart-release/assets/sapien/102697/meta.json
@@ -0,0 +1 @@
+{"user_id": "haozhu", "model_cat": "Toilet", "model_id": "db252ecd6286a334733badcb2e574996-0", "version": "1", "anno_id": "2697", "time_in_sec": "31"}
\ No newline at end of file
diff --git a/dexart-release/assets/sapien/102697/mobility.urdf b/dexart-release/assets/sapien/102697/mobility.urdf
new file mode 100644
index 0000000000000000000000000000000000000000..66322627aea3fc2cef1f786c1aff0a34a0b14508
--- /dev/null
+++ b/dexart-release/assets/sapien/102697/mobility.urdf
@@ -0,0 +1,502 @@
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
\ No newline at end of file
diff --git a/dexart-release/assets/sapien/102697/mobility_v2.json b/dexart-release/assets/sapien/102697/mobility_v2.json
new file mode 100644
index 0000000000000000000000000000000000000000..425a87c4126d0dc8c85dca23d592be3dae58ba10
--- /dev/null
+++ b/dexart-release/assets/sapien/102697/mobility_v2.json
@@ -0,0 +1 @@
+[{"id":0,"parent":4,"joint":"hinge","name":"lever","parts":[{"id":1,"name":"button"}],"jointData":{"axis":{"origin":[-0.387179130380233,0.5229989412652956,-0.2586584187924479],"direction":[-0.91844856752952,-0.00006467369864908537,0.39554042096893877]},"limit":{"a":0,"b":30,"noLimit":false}}},{"id":1,"parent":4,"joint":"slider","name":"pump_lid","parts":[{"id":2,"name":"lid"}],"jointData":{"axis":{"origin":[0,0,0],"direction":[0,1,0]},"limit":{"a":0,"b":0.02,"noLimit":false,"rotates":false,"noRotationLimit":false,"rotationLimit":0}}},{"id":2,"parent":4,"joint":"hinge","name":"lid","parts":[{"id":3,"name":"lid"}],"jointData":{"axis":{"origin":[0,0.05794601786221419,-0.06205458525801076],"direction":[1,0,0]},"limit":{"a":0,"b":-94,"noLimit":false}}},{"id":3,"parent":4,"joint":"hinge","name":"seat","parts":[{"id":4,"name":"seat"}],"jointData":{"axis":{"origin":[0,0.05794601786221419,-0.06205458525801076],"direction":[1,0,0]},"limit":{"a":0,"b":-94,"noLimit":false}}},{"id":4,"parent":-1,"joint":"static","name":"base_body","parts":[{"id":5,"name":"base_body"}],"jointData":{}}]
\ No newline at end of file
diff --git a/dexart-release/assets/sapien/102697/new_objs/102697_link_1_12.mtl b/dexart-release/assets/sapien/102697/new_objs/102697_link_1_12.mtl
new file mode 100644
index 0000000000000000000000000000000000000000..0a8d7de97ba5b99f083fcf82402f12b2d42e820b
--- /dev/null
+++ b/dexart-release/assets/sapien/102697/new_objs/102697_link_1_12.mtl
@@ -0,0 +1,12 @@
+# Blender MTL File: 'None'
+# Material Count: 1
+
+newmtl Shape.14067
+Ns 400.000000
+Ka 0.400000 0.400000 0.400000
+Kd 0.336000 0.232000 0.584000
+Ks 0.250000 0.250000 0.250000
+Ke 0.000000 0.000000 0.000000
+Ni 1.000000
+d 0.500000
+illum 2
diff --git a/dexart-release/assets/sapien/102697/new_objs/102697_link_1_14.mtl b/dexart-release/assets/sapien/102697/new_objs/102697_link_1_14.mtl
new file mode 100644
index 0000000000000000000000000000000000000000..9b787bc663cb512a8d4bca7d408e8e384c3b8619
--- /dev/null
+++ b/dexart-release/assets/sapien/102697/new_objs/102697_link_1_14.mtl
@@ -0,0 +1,12 @@
+# Blender MTL File: 'None'
+# Material Count: 1
+
+newmtl Shape.14069
+Ns 400.000000
+Ka 0.400000 0.400000 0.400000
+Kd 0.296000 0.784000 0.192000
+Ks 0.250000 0.250000 0.250000
+Ke 0.000000 0.000000 0.000000
+Ni 1.000000
+d 0.500000
+illum 2
diff --git a/dexart-release/assets/sapien/102697/new_objs/102697_link_1_14.obj b/dexart-release/assets/sapien/102697/new_objs/102697_link_1_14.obj
new file mode 100644
index 0000000000000000000000000000000000000000..5815a3adc15fdcd79d4d8e4f3f64b8d93d911c3a
--- /dev/null
+++ b/dexart-release/assets/sapien/102697/new_objs/102697_link_1_14.obj
@@ -0,0 +1,109 @@
+# Blender v2.79 (sub 0) OBJ File: ''
+# www.blender.org
+mtllib 102697_link_1_14.mtl
+o Shape_IndexedFaceSet.014_Shape_IndexedFaceSet.14061
+v 0.381286 0.716318 -0.278277
+v 0.467366 0.669121 -0.458755
+v 0.467366 0.666342 -0.458755
+v 0.314641 0.669121 -0.486500
+v 0.314641 0.666342 -0.275488
+v 0.314641 0.716318 -0.486500
+v 0.442356 0.716318 -0.486520
+v 0.406252 0.666342 -0.275488
+v 0.314641 0.716318 -0.275488
+v 0.453478 0.702433 -0.433757
+v 0.464367 0.666416 -0.486207
+v 0.397925 0.702433 -0.275488
+v 0.386833 0.666342 -0.486500
+v 0.447917 0.713539 -0.453197
+v 0.433456 0.670746 -0.350149
+v 0.411813 0.705207 -0.311600
+v 0.453478 0.710760 -0.486520
+v 0.459025 0.702433 -0.455986
+v 0.397925 0.713539 -0.300506
+v 0.409033 0.674673 -0.283855
+v 0.430634 0.693883 -0.355457
+v 0.461805 0.669121 -0.436525
+v 0.447917 0.710760 -0.439314
+vn 0.6201 0.7564 0.2082
+vn -1.0000 0.0000 0.0000
+vn 0.0000 1.0000 0.0000
+vn -0.0002 0.0000 -1.0000
+vn 0.0000 -1.0000 0.0000
+vn 0.0000 0.0000 1.0000
+vn 0.9941 0.0000 -0.1086
+vn 0.0406 0.2433 0.9691
+vn -0.0385 -0.9992 -0.0132
+vn 0.0010 -1.0000 -0.0028
+vn 0.1781 0.9826 0.0522
+vn -0.0001 -0.0002 -1.0000
+vn 0.9633 0.2356 -0.1285
+vn -0.0000 -0.0004 -1.0000
+vn 0.0038 -0.0061 -1.0000
+vn 0.4470 0.8945 -0.0000
+vn 0.7133 0.6982 0.0608
+vn 0.9470 0.2175 0.2363
+vn 0.9624 0.2498 -0.1067
+vn 0.5705 0.7506 0.3332
+vn 0.2834 0.9545 0.0928
+vn 0.6574 0.6887 0.3057
+vn 0.8546 0.1972 0.4804
+vn 0.9385 0.0321 0.3438
+vn 0.8973 0.2493 0.3642
+vn 0.9155 0.2217 0.3357
+vn 0.8943 0.3344 0.2974
+vn 0.9251 0.1885 0.3296
+vn 0.9350 0.1837 0.3034
+vn 0.9701 0.0000 0.2427
+vn 0.8018 -0.5344 0.2674
+vn 0.9470 0.2170 0.2369
+vn 0.8672 -0.4031 0.2922
+vn 0.9325 0.2086 0.2948
+vn 0.7291 0.6431 0.2341
+vn 0.7174 0.6831 0.1368
+vn 0.7541 0.6292 0.1882
+vn 0.5142 0.8410 0.1683
+usemtl Shape.14069
+s off
+f 19//1 16//1 23//1
+f 4//2 5//2 6//2
+f 6//3 1//3 7//3
+f 4//4 6//4 7//4
+f 5//5 3//5 8//5
+f 5//6 8//6 9//6
+f 1//3 6//3 9//3
+f 6//2 5//2 9//2
+f 2//7 3//7 11//7
+f 9//6 8//6 12//6
+f 1//8 9//8 12//8
+f 5//9 4//9 13//9
+f 3//5 5//5 13//5
+f 11//10 3//10 13//10
+f 7//11 1//11 14//11
+f 4//12 7//12 17//12
+f 2//13 11//13 17//13
+f 13//14 4//14 17//14
+f 11//15 13//15 17//15
+f 7//16 14//16 17//16
+f 17//17 14//17 18//17
+f 10//18 2//18 18//18
+f 2//19 17//19 18//19
+f 1//20 12//20 19//20
+f 14//21 1//21 19//21
+f 12//22 16//22 19//22
+f 12//23 8//23 20//23
+f 8//24 15//24 20//24
+f 16//25 12//25 20//25
+f 16//26 20//26 21//26
+f 10//27 16//27 21//27
+f 20//28 15//28 21//28
+f 21//29 15//29 22//29
+f 3//30 2//30 22//30
+f 8//31 3//31 22//31
+f 2//32 10//32 22//32
+f 15//33 8//33 22//33
+f 10//34 21//34 22//34
+f 16//35 10//35 23//35
+f 18//36 14//36 23//36
+f 10//37 18//37 23//37
+f 14//38 19//38 23//38
diff --git a/dexart-release/assets/sapien/102697/new_objs/102697_link_1_5.mtl b/dexart-release/assets/sapien/102697/new_objs/102697_link_1_5.mtl
new file mode 100644
index 0000000000000000000000000000000000000000..1d543567c45ae9dae1065722296e1ef27ad60a92
--- /dev/null
+++ b/dexart-release/assets/sapien/102697/new_objs/102697_link_1_5.mtl
@@ -0,0 +1,12 @@
+# Blender MTL File: 'None'
+# Material Count: 1
+
+newmtl Shape.14060
+Ns 400.000000
+Ka 0.400000 0.400000 0.400000
+Kd 0.576000 0.288000 0.088000
+Ks 0.250000 0.250000 0.250000
+Ke 0.000000 0.000000 0.000000
+Ni 1.000000
+d 0.500000
+illum 2
diff --git a/dexart-release/assets/sapien/102697/new_objs/102697_link_3_0.obj b/dexart-release/assets/sapien/102697/new_objs/102697_link_3_0.obj
new file mode 100644
index 0000000000000000000000000000000000000000..02ce7be22a24a10eeb20da679572050bce75f8d6
--- /dev/null
+++ b/dexart-release/assets/sapien/102697/new_objs/102697_link_3_0.obj
@@ -0,0 +1,217 @@
+# Blender v2.79 (sub 0) OBJ File: ''
+# www.blender.org
+mtllib 102697_link_3_0.mtl
+o Shape_IndexedFaceSet_Shape_IndexedFaceSet.16309
+v -0.318776 0.060655 0.107693
+v -0.063870 0.047615 -0.076335
+v -0.006959 0.042434 0.076198
+v -0.006959 0.088979 -0.019441
+v -0.309458 0.088979 0.114981
+v -0.182789 0.086387 -0.042715
+v -0.257765 0.042434 0.006414
+v -0.312052 0.042434 0.114981
+v -0.167256 0.073461 0.114981
+v -0.006959 0.042434 -0.068570
+v -0.006959 0.073461 0.060669
+v -0.006959 0.088979 -0.065989
+v -0.262922 0.086387 0.008996
+v -0.151785 0.042434 0.114981
+v -0.154809 0.063235 -0.062488
+v -0.255171 0.088979 0.114981
+v -0.265516 0.065707 -0.001350
+v -0.058683 0.083799 -0.073734
+v -0.155045 0.046362 -0.055734
+v -0.006959 0.055369 0.078780
+v -0.317239 0.052781 0.091707
+v -0.309458 0.083799 0.086544
+v -0.069027 0.088979 -0.065989
+v -0.006959 0.070873 -0.078916
+v -0.159536 0.083799 -0.058243
+v -0.107812 0.042434 -0.063407
+v -0.257765 0.045027 -0.001350
+v -0.159811 0.051020 -0.059069
+v -0.306895 0.088979 0.096890
+v -0.149191 0.060536 0.114981
+v -0.056120 0.078632 -0.076335
+v -0.314646 0.045027 0.096890
+v -0.006959 0.047615 -0.076335
+v -0.319802 0.076049 0.114981
+v -0.260328 0.078632 -0.001350
+v -0.312052 0.068285 0.081361
+v -0.250915 0.087266 0.009661
+v -0.006959 0.070873 0.065852
+v -0.304301 0.050193 0.068434
+v -0.268110 0.070873 0.003833
+v -0.006959 0.060536 0.076198
+v -0.006959 0.086387 -0.071152
+v -0.006959 0.078632 0.034833
+v -0.167107 0.053534 -0.056233
+v -0.162264 0.065707 -0.059140
+v -0.317239 0.045027 0.114981
+v -0.198290 0.081220 -0.037551
+v -0.247420 0.042434 -0.001350
+vn -0.2189 -0.8734 -0.4350
+vn 0.0000 -1.0000 0.0000
+vn 0.0000 0.0000 1.0000
+vn 1.0000 0.0000 0.0000
+vn -0.0000 1.0000 0.0000
+vn 0.2540 -0.1893 0.9485
+vn -0.0785 0.9576 -0.2772
+vn -0.0984 0.7614 -0.6407
+vn -0.0223 -0.8995 -0.4364
+vn -0.1804 -0.1959 -0.9639
+vn -0.1633 -0.5893 -0.7912
+vn -0.1721 -0.6838 -0.7091
+vn -0.3065 -0.7412 -0.5972
+vn -0.5117 0.8123 -0.2799
+vn 0.2435 0.3404 0.9082
+vn 0.2453 -0.0352 0.9688
+vn -0.1444 0.0361 -0.9889
+vn -0.0504 0.0126 -0.9986
+vn -0.1621 0.1636 -0.9731
+vn -0.1399 0.3890 -0.9106
+vn -0.2162 -0.9704 -0.1081
+vn -0.4800 -0.8321 -0.2779
+vn 0.0000 -0.8318 -0.5550
+vn 0.0000 -0.1103 -0.9939
+vn -0.9962 -0.0274 -0.0823
+vn -0.7761 0.6209 -0.1100
+vn -0.7791 0.6162 -0.1155
+vn -0.3952 0.6848 -0.6123
+vn -0.9554 0.1496 -0.2548
+vn -0.9305 0.2462 -0.2713
+vn -0.0669 0.9924 -0.1037
+vn -0.0356 0.9974 -0.0631
+vn -0.0693 0.9955 -0.0640
+vn -0.0240 0.9991 -0.0350
+vn 0.1545 0.8754 0.4580
+vn 0.1519 0.8843 0.4415
+vn -0.8035 -0.3012 -0.5135
+vn -0.7750 -0.5093 -0.3742
+vn -0.6037 -0.7165 -0.3495
+vn -0.8716 -0.0235 -0.4897
+vn -0.8750 -0.0297 -0.4832
+vn -0.7861 0.4152 -0.4579
+vn -0.7546 0.4201 -0.5041
+vn -0.7121 0.2858 -0.6413
+vn -0.8694 0.0560 -0.4909
+vn -0.8355 0.2946 -0.4637
+vn 0.2448 0.4334 0.8673
+vn 0.2223 0.6897 0.6891
+vn 0.0000 0.8937 -0.4487
+vn -0.0128 0.8231 -0.5677
+vn 0.0237 0.4474 -0.8940
+vn 0.0214 0.4580 -0.8887
+vn 0.1009 0.9773 0.1863
+vn 0.1037 0.9753 0.1952
+vn -0.4960 -0.1859 -0.8482
+vn -0.3894 -0.0969 -0.9159
+vn -0.4480 -0.3961 -0.8015
+vn -0.3790 0.1027 -0.9197
+vn -0.4224 -0.0481 -0.9051
+vn -0.4884 -0.0141 -0.8725
+vn -0.9926 -0.1156 -0.0385
+vn -0.4462 -0.8926 -0.0640
+vn -0.9107 -0.3918 -0.1305
+vn -0.9960 -0.0823 0.0336
+vn -0.4229 0.5449 -0.7240
+vn -0.4253 0.5896 -0.6867
+vn -0.5000 0.2007 -0.8425
+vn -0.4868 0.0799 -0.8698
+vn -0.4738 0.1147 -0.8731
+vn -0.1247 -0.9517 -0.2806
+vn -0.2313 -0.9228 -0.3082
+usemtl Shape.16319
+s off
+f 27//1 19//1 48//1
+f 7//2 3//2 8//2
+f 5//3 8//3 9//3
+f 4//4 3//4 10//4
+f 3//2 7//2 10//2
+f 3//4 4//4 11//4
+f 5//5 4//5 12//5
+f 4//4 10//4 12//4
+f 8//2 3//2 14//2
+f 9//3 8//3 14//3
+f 4//5 5//5 16//5
+f 5//3 9//3 16//3
+f 3//4 11//4 20//4
+f 14//6 3//6 20//6
+f 5//5 12//5 23//5
+f 12//4 10//4 24//4
+f 6//7 23//7 25//7
+f 23//8 18//8 25//8
+f 2//9 10//9 26//9
+f 10//2 7//2 26//2
+f 15//10 2//10 28//10
+f 2//11 26//11 28//11
+f 26//12 19//12 28//12
+f 19//13 27//13 28//13
+f 13//14 22//14 29//14
+f 5//5 23//5 29//5
+f 9//3 14//3 30//3
+f 20//15 9//15 30//15
+f 14//16 20//16 30//16
+f 2//17 15//17 31//17
+f 24//18 2//18 31//18
+f 15//19 25//19 31//19
+f 25//20 18//20 31//20
+f 7//21 8//21 32//21
+f 27//22 7//22 32//22
+f 10//23 2//23 33//23
+f 2//24 24//24 33//24
+f 24//4 10//4 33//4
+f 8//3 5//3 34//3
+f 21//25 1//25 34//25
+f 5//26 29//26 34//26
+f 29//27 22//27 34//27
+f 13//28 6//28 35//28
+f 21//29 34//29 36//29
+f 34//30 22//30 36//30
+f 6//31 13//31 37//31
+f 23//32 6//32 37//32
+f 13//33 29//33 37//33
+f 29//34 23//34 37//34
+f 16//35 9//35 38//35
+f 11//36 16//36 38//36
+f 20//4 11//4 38//4
+f 17//37 27//37 39//37
+f 32//38 21//38 39//38
+f 27//39 32//39 39//39
+f 36//40 17//40 39//40
+f 21//41 36//41 39//41
+f 22//42 13//42 40//42
+f 13//43 35//43 40//43
+f 35//44 17//44 40//44
+f 17//45 36//45 40//45
+f 36//46 22//46 40//46
+f 9//47 20//47 41//47
+f 38//48 9//48 41//48
+f 20//4 38//4 41//4
+f 23//49 12//49 42//49
+f 18//50 23//50 42//50
+f 12//4 24//4 42//4
+f 24//51 31//51 42//51
+f 31//52 18//52 42//52
+f 11//4 4//4 43//4
+f 4//53 16//53 43//53
+f 16//54 11//54 43//54
+f 27//55 17//55 44//55
+f 15//56 28//56 44//56
+f 28//57 27//57 44//57
+f 25//58 15//58 45//58
+f 15//59 44//59 45//59
+f 44//60 17//60 45//60
+f 1//61 21//61 46//61
+f 32//62 8//62 46//62
+f 21//63 32//63 46//63
+f 34//64 1//64 46//64
+f 8//3 34//3 46//3
+f 6//65 25//65 47//65
+f 35//66 6//66 47//66
+f 17//67 35//67 47//67
+f 45//68 17//68 47//68
+f 25//69 45//69 47//69
+f 26//2 7//2 48//2
+f 19//70 26//70 48//70
+f 7//71 27//71 48//71
diff --git a/dexart-release/assets/sapien/102697/new_objs/102697_link_3_5.obj b/dexart-release/assets/sapien/102697/new_objs/102697_link_3_5.obj
new file mode 100644
index 0000000000000000000000000000000000000000..074a59b319d2140a08c5e2802a1d71a0ca212920
--- /dev/null
+++ b/dexart-release/assets/sapien/102697/new_objs/102697_link_3_5.obj
@@ -0,0 +1,250 @@
+# Blender v2.79 (sub 0) OBJ File: ''
+# www.blender.org
+mtllib 102697_link_3_5.mtl
+o Shape_IndexedFaceSet.005_Shape_IndexedFaceSet.16314
+v -0.006959 0.055369 0.663225
+v -0.035413 0.042434 0.789889
+v -0.006985 0.088979 0.766613
+v -0.224145 0.088979 0.575274
+v -0.112985 0.042434 0.575296
+v -0.244834 0.042434 0.624424
+v -0.216380 0.083799 0.676129
+v -0.128515 0.070873 0.575274
+v -0.154338 0.047615 0.743337
+v -0.061288 0.088979 0.779539
+v -0.006959 0.042434 0.792466
+v -0.006985 0.073461 0.683903
+v -0.006959 0.042434 0.668378
+v -0.260364 0.088979 0.590820
+v -0.030254 0.068285 0.800239
+v -0.221565 0.065707 0.678706
+v -0.268103 0.042434 0.575274
+v -0.146599 0.081220 0.745913
+v -0.081951 0.055369 0.784714
+v -0.006985 0.088979 0.789889
+v -0.270683 0.060536 0.598550
+v -0.102641 0.042434 0.764037
+v -0.206061 0.055369 0.696785
+v -0.150775 0.086767 0.728688
+v -0.087137 0.076049 0.779539
+v -0.154338 0.065707 0.745913
+v -0.262944 0.081220 0.603725
+v -0.273262 0.081220 0.575274
+v -0.262944 0.047615 0.606323
+v -0.032834 0.047615 0.797640
+v -0.006985 0.068294 0.800239
+v -0.244834 0.088979 0.619249
+v -0.030254 0.083799 0.795064
+v -0.006985 0.070878 0.676129
+v -0.087137 0.050193 0.782138
+v -0.006985 0.081215 0.722659
+v -0.213800 0.078632 0.683881
+v -0.242254 0.052781 0.645102
+v -0.180186 0.081220 0.575274
+v -0.143993 0.083799 0.745913
+v -0.079372 0.083799 0.779539
+v -0.209848 0.046828 0.679494
+v -0.262944 0.088979 0.575274
+v -0.009565 0.047615 0.660627
+v -0.092296 0.068285 0.779539
+v -0.159523 0.070873 0.740739
+v -0.275868 0.068285 0.583047
+v -0.273262 0.045027 0.575274
+v -0.032834 0.055369 0.800239
+v -0.125910 0.068285 0.575274
+v -0.006985 0.063123 0.668378
+v -0.074186 0.045027 0.782138
+v -0.204443 0.044694 0.679584
+vn -0.3528 -0.8802 0.3176
+vn 0.0000 -1.0000 -0.0000
+vn -0.0000 1.0000 -0.0000
+vn 1.0000 0.0000 0.0000
+vn 0.0000 0.0000 -1.0000
+vn 1.0000 0.0008 0.0000
+vn 1.0000 0.0006 0.0001
+vn -0.6739 -0.1041 0.7314
+vn -0.8426 0.1891 0.5042
+vn -0.7059 0.6605 0.2560
+vn -0.0513 -0.8223 0.5667
+vn 1.0000 0.0006 0.0013
+vn -0.1903 0.9645 0.1830
+vn -0.0704 0.9943 0.0806
+vn -0.5935 0.7366 0.3242
+vn -0.6160 0.6947 0.3714
+vn -0.0901 0.8767 0.4726
+vn 0.0988 0.4453 0.8899
+vn -0.0001 0.3164 0.9486
+vn 1.0000 0.0023 -0.0008
+vn 0.2531 0.9181 -0.3050
+vn -0.3097 -0.7486 0.5862
+vn -0.4927 -0.1227 0.8615
+vn 1.0000 0.0017 -0.0003
+vn 0.1515 0.9734 -0.1719
+vn 1.0000 0.0019 -0.0004
+vn 0.1659 0.9670 -0.1935
+vn -0.7073 0.1481 0.6912
+vn -0.7933 0.3506 0.4977
+vn -0.8130 0.2852 0.5077
+vn -0.8528 0.0077 0.5221
+vn -0.7975 -0.2017 0.5686
+vn -0.6081 -0.6770 0.4146
+vn -0.8486 -0.2184 0.4819
+vn 0.1702 0.9644 -0.2025
+vn 0.1908 0.9528 -0.2362
+vn -0.2498 0.9330 0.2591
+vn -0.1520 0.9623 0.2257
+vn -0.5596 0.5915 0.5805
+vn -0.5675 0.5734 0.5909
+vn -0.2916 0.2921 0.9108
+vn -0.4189 0.4197 0.8052
+vn -0.2076 0.7248 0.6569
+vn -0.2872 0.3031 0.9086
+vn -0.2434 0.8497 0.4677
+vn -0.4183 0.4227 0.8039
+vn -0.5264 -0.7109 0.4664
+vn -0.5764 -0.6997 0.4220
+vn -0.6036 -0.6544 0.4554
+vn -0.5981 0.7953 0.0993
+vn 0.6344 0.0453 -0.7717
+vn 0.8968 -0.1637 -0.4110
+vn 0.5171 -0.6211 -0.5890
+vn -0.3140 0.1256 0.9411
+vn -0.3097 0.2058 0.9283
+vn -0.4499 0.2990 0.8416
+vn -0.4707 0.2348 0.8505
+vn -0.4460 0.0014 0.8950
+vn -0.4762 -0.0095 0.8793
+vn -0.5346 0.2667 0.8020
+vn -0.6914 0.0290 0.7219
+vn -0.6165 0.4454 0.6493
+vn -0.7048 0.1501 0.6933
+vn -0.8834 0.2281 0.4094
+vn -0.8746 0.3668 0.3172
+vn -0.9481 0.0000 -0.3179
+vn -0.4393 -0.8740 0.2080
+vn -0.4713 -0.8521 0.2276
+vn -0.8850 -0.3363 0.3221
+vn -0.9562 -0.1834 0.2281
+vn -0.3008 0.0601 0.9518
+vn 0.1250 -0.3153 0.9407
+vn 0.1424 -0.2848 0.9480
+vn 0.0000 0.0000 1.0000
+vn -0.2970 -0.1701 0.9396
+vn -0.2748 -0.3056 0.9117
+vn 0.5887 0.2937 -0.7531
+vn 0.0001 -0.0008 -1.0000
+vn 0.9999 0.0100 -0.0100
+vn 0.5062 0.6097 -0.6100
+vn 0.5596 0.4600 -0.6894
+vn 0.5341 0.5376 -0.6524
+vn -0.1298 -0.9323 0.3377
+vn -0.1701 -0.7922 0.5861
+vn -0.3013 -0.7554 0.5819
+vn -0.2439 -0.6115 0.7527
+vn -0.1402 -0.9798 0.1428
+vn -0.1678 -0.9699 0.1763
+vn -0.3550 -0.8867 0.2963
+usemtl Shape.16324
+s off
+f 9//1 42//1 53//1
+f 5//2 2//2 6//2
+f 3//3 4//3 10//3
+f 2//2 5//2 11//2
+f 1//4 11//4 13//4
+f 11//2 5//2 13//2
+f 10//3 4//3 14//3
+f 5//2 6//2 17//2
+f 4//5 8//5 17//5
+f 1//6 3//6 20//6
+f 3//3 10//3 20//3
+f 11//7 1//7 20//7
+f 6//2 2//2 22//2
+f 23//8 9//8 26//8
+f 21//9 16//9 27//9
+f 4//5 17//5 28//5
+f 27//10 14//10 28//10
+f 2//11 11//11 30//11
+f 11//12 20//12 31//12
+f 10//3 14//3 32//3
+f 7//13 24//13 32//13
+f 24//14 10//14 32//14
+f 14//15 27//15 32//15
+f 27//16 7//16 32//16
+f 20//17 10//17 33//17
+f 31//18 20//18 33//18
+f 15//19 31//19 33//19
+f 12//20 1//20 34//20
+f 8//21 12//21 34//21
+f 9//22 22//22 35//22
+f 26//23 9//23 35//23
+f 3//24 1//24 36//24
+f 4//25 3//25 36//25
+f 1//26 12//26 36//26
+f 12//27 4//27 36//27
+f 16//28 23//28 37//28
+f 7//29 27//29 37//29
+f 27//30 16//30 37//30
+f 16//31 21//31 38//31
+f 23//32 16//32 38//32
+f 29//33 6//33 38//33
+f 21//34 29//34 38//34
+f 8//5 4//5 39//5
+f 4//35 12//35 39//35
+f 12//36 8//36 39//36
+f 24//37 7//37 40//37
+f 10//38 24//38 40//38
+f 7//39 37//39 40//39
+f 37//40 18//40 40//40
+f 25//41 15//41 41//41
+f 18//42 25//42 41//42
+f 33//43 10//43 41//43
+f 15//44 33//44 41//44
+f 10//45 40//45 41//45
+f 40//46 18//46 41//46
+f 9//47 23//47 42//47
+f 38//48 6//48 42//48
+f 23//49 38//49 42//49
+f 14//3 4//3 43//3
+f 4//5 28//5 43//5
+f 28//50 14//50 43//50
+f 5//51 1//51 44//51
+f 1//52 13//52 44//52
+f 13//53 5//53 44//53
+f 19//54 15//54 45//54
+f 15//55 25//55 45//55
+f 25//56 18//56 45//56
+f 18//57 26//57 45//57
+f 35//58 19//58 45//58
+f 26//59 35//59 45//59
+f 26//60 18//60 46//60
+f 23//61 26//61 46//61
+f 18//62 37//62 46//62
+f 37//63 23//63 46//63
+f 21//64 27//64 47//64
+f 27//65 28//65 47//65
+f 47//66 28//66 48//66
+f 17//67 6//67 48//67
+f 28//5 17//5 48//5
+f 6//68 29//68 48//68
+f 29//69 21//69 48//69
+f 21//70 47//70 48//70
+f 15//71 19//71 49//71
+f 30//72 11//72 49//72
+f 11//73 31//73 49//73
+f 31//74 15//74 49//74
+f 19//75 35//75 49//75
+f 35//76 30//76 49//76
+f 1//77 5//77 50//77
+f 5//78 17//78 50//78
+f 17//5 8//5 50//5
+f 34//79 1//79 51//79
+f 8//80 34//80 51//80
+f 1//81 50//81 51//81
+f 50//82 8//82 51//82
+f 22//83 2//83 52//83
+f 2//84 30//84 52//84
+f 35//85 22//85 52//85
+f 30//86 35//86 52//86
+f 6//87 22//87 53//87
+f 22//88 9//88 53//88
+f 42//89 6//89 53//89
diff --git a/dexart-release/assets/sapien/102697/new_objs/102697_link_3_6.obj b/dexart-release/assets/sapien/102697/new_objs/102697_link_3_6.obj
new file mode 100644
index 0000000000000000000000000000000000000000..cf34a50e2465905761098e362a44b1b53d414349
--- /dev/null
+++ b/dexart-release/assets/sapien/102697/new_objs/102697_link_3_6.obj
@@ -0,0 +1,169 @@
+# Blender v2.79 (sub 0) OBJ File: ''
+# www.blender.org
+mtllib 102697_link_3_6.mtl
+o Shape_IndexedFaceSet.006_Shape_IndexedFaceSet.16315
+v 0.267139 0.042434 0.557173
+v 0.174056 0.055369 0.316720
+v 0.220597 0.088979 0.557173
+v 0.117181 0.047615 0.557173
+v 0.321419 0.088979 0.316720
+v 0.176650 0.042434 0.316720
+v 0.192149 0.073461 0.316720
+v 0.272305 0.086387 0.557173
+v 0.324013 0.042434 0.316720
+v 0.132681 0.070873 0.549382
+v 0.311086 0.052781 0.466641
+v 0.272305 0.088979 0.316720
+v 0.124942 0.042434 0.557173
+v 0.308492 0.081220 0.461463
+v 0.137869 0.073461 0.557173
+v 0.282638 0.068285 0.549382
+v 0.298159 0.088979 0.469231
+v 0.334346 0.068290 0.319286
+v 0.308492 0.045027 0.458873
+v 0.124942 0.065707 0.554560
+v 0.334346 0.052781 0.316697
+v 0.184389 0.070873 0.321875
+v 0.285232 0.050193 0.539048
+v 0.324013 0.083799 0.358079
+v 0.267139 0.088979 0.557173
+v 0.137869 0.073461 0.551971
+v 0.119775 0.057957 0.554560
+v 0.174056 0.045027 0.316720
+v 0.316253 0.042434 0.383948
+v 0.331774 0.047610 0.321875
+v 0.329180 0.083803 0.316697
+v 0.272305 0.045027 0.557173
+v 0.220597 0.088979 0.551971
+v 0.119775 0.045027 0.551971
+v 0.313680 0.088979 0.386514
+v 0.282638 0.042434 0.520946
+v 0.179222 0.063128 0.316720
+vn -0.0002 0.0002 -1.0000
+vn 0.0000 0.0000 1.0000
+vn 0.0000 -1.0000 -0.0000
+vn 0.0000 1.0000 0.0000
+vn 0.8767 0.3663 0.3117
+vn 0.9441 0.1404 0.2983
+vn 0.6868 0.6920 0.2223
+vn 0.9787 0.1197 0.1671
+vn -0.3480 0.2786 0.8951
+vn -0.5230 0.8497 0.0660
+vn 0.0000 -0.0022 -1.0000
+vn 0.9879 -0.0256 0.1532
+vn -0.6125 0.7781 -0.1392
+vn 0.9409 -0.0559 0.3340
+vn 0.8012 -0.5355 0.2670
+vn 0.9537 0.2610 0.1497
+vn 0.4429 0.8828 0.1562
+vn -0.1899 0.9808 -0.0438
+vn -0.1844 0.9829 0.0000
+vn -0.4464 0.8948 0.0000
+vn -0.3652 0.9271 -0.0843
+vn -0.4071 0.9087 -0.0925
+vn -0.9578 0.1845 -0.2206
+vn -0.4906 0.3271 0.8076
+vn -0.9731 0.0000 -0.2302
+vn -0.0001 -0.0001 -1.0000
+vn 0.8869 -0.4394 0.1424
+vn 0.6665 -0.6663 -0.3344
+vn 0.9361 -0.3202 0.1452
+vn 0.5256 -0.8486 0.0607
+vn 0.6247 -0.7755 0.0915
+vn 0.0000 0.0044 -1.0000
+vn -0.0003 0.0014 -1.0000
+vn 0.7031 0.1171 -0.7014
+vn 0.9363 0.3313 0.1169
+vn -0.0001 0.0000 -1.0000
+vn 0.6020 0.0000 0.7985
+vn 0.8242 -0.1871 0.5345
+vn -0.1842 0.9821 -0.0405
+vn -0.5501 -0.8240 0.1356
+vn -0.3801 -0.9213 -0.0817
+vn -0.8628 -0.4647 -0.1991
+vn -0.6977 -0.6980 -0.1610
+vn 0.6532 0.7472 0.1226
+vn 0.6915 0.7120 0.1216
+vn 0.5539 0.8303 0.0614
+vn 0.6054 0.7923 0.0757
+vn 0.6261 -0.7451 0.2297
+vn 0.2367 -0.9698 0.0581
+vn 0.4406 -0.8777 0.1885
+vn 0.6228 -0.7475 0.2311
+vn -0.5476 0.6851 -0.4804
+vn -0.7588 0.6260 -0.1800
+vn -0.8168 0.5439 -0.1923
+vn -0.8165 0.5444 -0.1922
+vn -0.0002 0.0001 -1.0000
+usemtl Shape.16325
+s off
+f 7//1 31//1 37//1
+f 1//2 3//2 4//2
+f 3//2 1//2 8//2
+f 1//3 6//3 9//3
+f 3//4 5//4 12//4
+f 1//2 4//2 13//2
+f 6//3 1//3 13//3
+f 4//2 3//2 15//2
+f 14//5 8//5 16//5
+f 11//6 14//6 16//6
+f 5//4 3//4 17//4
+f 8//7 14//7 17//7
+f 14//8 11//8 18//8
+f 4//9 15//9 20//9
+f 15//10 10//10 20//10
+f 9//11 6//11 21//11
+f 18//12 11//12 21//12
+f 20//13 10//13 22//13
+f 11//14 16//14 23//14
+f 19//15 11//15 23//15
+f 14//16 18//16 24//16
+f 3//2 8//2 25//2
+f 17//4 3//4 25//4
+f 8//17 17//17 25//17
+f 12//18 7//18 26//18
+f 15//19 3//19 26//19
+f 10//20 15//20 26//20
+f 7//21 22//21 26//21
+f 22//22 10//22 26//22
+f 2//23 4//23 27//23
+f 4//24 20//24 27//24
+f 4//25 2//25 28//25
+f 21//26 6//26 28//26
+f 1//3 9//3 29//3
+f 11//27 19//27 30//27
+f 9//28 21//28 30//28
+f 21//29 11//29 30//29
+f 29//30 9//30 30//30
+f 19//31 29//31 30//31
+f 12//32 5//32 31//32
+f 7//33 12//33 31//33
+f 18//34 21//34 31//34
+f 24//35 18//35 31//35
+f 28//36 2//36 31//36
+f 21//36 28//36 31//36
+f 8//2 1//2 32//2
+f 16//37 8//37 32//37
+f 23//38 16//38 32//38
+f 3//4 12//4 33//4
+f 26//19 3//19 33//19
+f 12//39 26//39 33//39
+f 13//40 4//40 34//40
+f 6//41 13//41 34//41
+f 4//42 28//42 34//42
+f 28//43 6//43 34//43
+f 5//4 17//4 35//4
+f 17//44 14//44 35//44
+f 14//45 24//45 35//45
+f 31//46 5//46 35//46
+f 24//47 31//47 35//47
+f 19//48 23//48 36//48
+f 1//3 29//3 36//3
+f 29//49 19//49 36//49
+f 32//50 1//50 36//50
+f 23//51 32//51 36//51
+f 22//52 7//52 37//52
+f 20//53 22//53 37//53
+f 2//54 27//54 37//54
+f 27//55 20//55 37//55
+f 31//56 2//56 37//56
diff --git a/dexart-release/assets/sapien/102697/new_objs/102697_link_4_11.obj b/dexart-release/assets/sapien/102697/new_objs/102697_link_4_11.obj
new file mode 100644
index 0000000000000000000000000000000000000000..240ae654bc988106c6fe7646f940f1bd7dfc5537
--- /dev/null
+++ b/dexart-release/assets/sapien/102697/new_objs/102697_link_4_11.obj
@@ -0,0 +1,113 @@
+# Blender v2.79 (sub 0) OBJ File: ''
+# www.blender.org
+mtllib 102697_link_4_11.mtl
+o Shape_IndexedFaceSet.011_Shape_IndexedFaceSet.7497
+v -0.323117 -0.045453 0.089563
+v -0.106232 0.036714 0.031262
+v -0.106232 -0.014976 0.031262
+v -0.207356 0.064630 -0.188472
+v -0.342601 0.058882 0.201139
+v -0.106232 -0.022362 -0.168161
+v -0.269912 -0.053693 0.229603
+v -0.099529 0.061840 -0.178611
+v -0.268728 0.058882 0.223327
+v -0.217054 -0.014976 -0.168161
+v -0.114215 -0.048395 -0.086531
+v -0.354324 -0.050570 0.218569
+v -0.106232 -0.051916 0.009150
+v -0.349976 0.051485 0.178989
+v -0.239154 0.021942 0.215931
+v -0.224454 0.036714 -0.168161
+v -0.170368 -0.047738 -0.078248
+v -0.331607 -0.019798 0.104076
+v -0.246554 -0.022362 0.223327
+v -0.202304 -0.022362 -0.168161
+v -0.126343 0.050010 0.012102
+v -0.246554 0.044099 0.223327
+v -0.224454 -0.000205 -0.168161
+vn -0.9305 0.0000 -0.3663
+vn 0.9995 0.0000 0.0319
+vn 0.0870 -0.1295 -0.9878
+vn -0.0040 0.9999 0.0134
+vn 0.0243 0.9996 0.0176
+vn -0.2829 0.1803 0.9420
+vn -0.1273 0.0565 0.9902
+vn 0.9995 -0.0157 0.0262
+vn 0.9966 -0.0810 -0.0135
+vn 0.8380 -0.5383 -0.0897
+vn -0.4515 0.8806 -0.1437
+vn -0.9523 0.1448 0.2687
+vn 0.8116 0.0000 0.5842
+vn -0.8922 0.3024 -0.3355
+vn -0.0793 -0.9951 -0.0587
+vn -0.0330 -0.9990 -0.0300
+vn -0.0269 -0.9992 -0.0280
+vn -0.0169 -0.9992 -0.0354
+vn -0.9542 -0.1811 -0.2380
+vn -0.9782 -0.0375 -0.2042
+vn -0.9321 0.1193 -0.3421
+vn 0.7252 -0.4335 0.5349
+vn 0.7686 -0.3286 0.5489
+vn 0.8076 -0.0366 0.5886
+vn -0.0000 -0.2274 -0.9738
+vn -0.4300 -0.8588 -0.2785
+vn -0.1163 -0.2322 -0.9657
+vn 0.0000 -0.9527 -0.3038
+vn -0.2235 -0.9559 -0.1904
+vn -0.0489 -0.9657 -0.2552
+vn 0.4655 0.8769 0.1198
+vn 0.3830 0.8970 0.2205
+vn 0.1907 0.9778 0.0875
+vn 0.5197 0.7795 0.3497
+vn 0.0368 0.0552 0.9978
+vn 0.8066 0.0736 0.5865
+vn 0.2595 0.0000 0.9657
+vn 0.7069 0.0000 0.7073
+vn -0.8241 -0.4128 -0.3879
+vn -0.3724 -0.1866 -0.9091
+vn -0.7650 0.0000 -0.6440
+vn -0.9238 -0.0961 -0.3705
+usemtl Shape.7503
+s off
+f 18//1 16//1 23//1
+f 2//2 3//2 8//2
+f 6//3 4//3 8//3
+f 4//4 5//4 9//4
+f 8//5 4//5 9//5
+f 9//6 5//6 12//6
+f 7//7 9//7 12//7
+f 8//8 3//8 13//8
+f 6//9 8//9 13//9
+f 11//10 6//10 13//10
+f 5//11 4//11 14//11
+f 12//12 5//12 14//12
+f 3//13 2//13 15//13
+f 14//14 4//14 16//14
+f 12//15 1//15 17//15
+f 7//16 12//16 17//16
+f 13//17 7//17 17//17
+f 11//18 13//18 17//18
+f 1//19 12//19 18//19
+f 12//20 14//20 18//20
+f 14//21 16//21 18//21
+f 7//22 13//22 19//22
+f 13//23 3//23 19//23
+f 3//24 15//24 19//24
+f 4//25 6//25 20//25
+f 1//26 10//26 20//26
+f 10//27 4//27 20//27
+f 6//28 11//28 20//28
+f 17//29 1//29 20//29
+f 11//30 17//30 20//30
+f 2//31 8//31 21//31
+f 9//32 2//32 21//32
+f 8//33 9//33 21//33
+f 2//34 9//34 22//34
+f 9//35 7//35 22//35
+f 15//36 2//36 22//36
+f 7//37 19//37 22//37
+f 19//38 15//38 22//38
+f 10//39 1//39 23//39
+f 4//40 10//40 23//40
+f 16//41 4//41 23//41
+f 1//42 18//42 23//42
diff --git a/dexart-release/assets/sapien/102697/new_objs/102697_link_4_19.obj b/dexart-release/assets/sapien/102697/new_objs/102697_link_4_19.obj
new file mode 100644
index 0000000000000000000000000000000000000000..59dbb889a91f621e683a90a6ac304eb46adb8789
--- /dev/null
+++ b/dexart-release/assets/sapien/102697/new_objs/102697_link_4_19.obj
@@ -0,0 +1,105 @@
+# Blender v2.79 (sub 0) OBJ File: ''
+# www.blender.org
+mtllib 102697_link_4_19.mtl
+o Shape_IndexedFaceSet.019_Shape_IndexedFaceSet.7505
+v 0.418781 0.627924 -0.508767
+v 0.428328 0.622978 -0.508489
+v -0.056475 0.145332 -0.517066
+v -0.076686 0.620168 -0.567029
+v 0.366468 0.620168 -0.567029
+v -0.076686 0.125411 -0.567029
+v -0.054501 0.599577 -0.517235
+v 0.320187 0.115689 -0.504922
+v 0.322183 0.125411 -0.567029
+v 0.418075 0.269311 -0.508853
+v 0.411683 0.362817 -0.537486
+v -0.067659 0.626136 -0.530364
+v 0.388636 0.635034 -0.507943
+v 0.389069 0.133686 -0.508902
+v 0.351689 0.214056 -0.567029
+v 0.314606 0.110688 -0.506236
+v 0.413343 0.608120 -0.537486
+v 0.390810 0.209582 -0.537486
+v 0.432230 0.376647 -0.508872
+v 0.373857 0.605402 -0.567029
+v 0.359079 0.620168 -0.507943
+v -0.067415 0.625895 -0.551722
+vn 0.0041 0.9694 -0.2455
+vn 0.4154 0.7751 -0.4762
+vn 0.0000 0.0000 -1.0000
+vn -0.0322 0.0005 0.9995
+vn -0.9254 -0.0111 0.3788
+vn -0.9710 -0.0000 0.2391
+vn -0.7030 0.0033 0.7112
+vn -0.0527 0.4214 0.9053
+vn 0.0729 0.1956 0.9780
+vn 0.2100 0.9266 -0.3119
+vn 0.0149 0.0038 0.9999
+vn 0.0612 -0.0134 0.9980
+vn -0.0967 -0.9105 0.4021
+vn -0.0476 -0.2038 0.9778
+vn 0.0000 -0.9719 -0.2354
+vn 0.2833 -0.9396 -0.1922
+vn 0.1599 -0.4139 0.8961
+vn 0.5278 0.6137 -0.5872
+vn 0.8262 -0.1125 -0.5520
+vn 0.6491 -0.2811 -0.7069
+vn 0.8763 -0.1873 -0.4438
+vn 0.5947 -0.0810 -0.7998
+vn 0.5773 -0.1922 -0.7936
+vn 0.0375 -0.0010 0.9993
+vn 0.0502 -0.0064 0.9987
+vn 0.8317 -0.1098 -0.5442
+vn 0.8852 0.0147 -0.4650
+vn 0.8135 -0.0055 -0.5815
+vn 0.4969 -0.0281 -0.8673
+vn 0.5624 0.2814 -0.7775
+vn 0.5992 -0.0041 -0.8006
+vn -0.0228 0.0077 0.9997
+vn -0.0249 0.0495 0.9985
+vn -0.0031 0.0062 1.0000
+vn 0.0000 0.9366 -0.3504
+vn -0.5068 0.8619 -0.0155
+vn -0.0189 0.9998 -0.0115
+usemtl Shape.7511
+s off
+f 13//1 5//1 22//1
+f 1//2 2//2 5//2
+f 4//3 5//3 6//3
+f 7//4 3//4 8//4
+f 6//3 5//3 9//3
+f 6//5 3//5 12//5
+f 4//6 6//6 12//6
+f 3//7 7//7 12//7
+f 12//8 7//8 13//8
+f 2//9 1//9 13//9
+f 1//10 5//10 13//10
+f 8//11 2//11 13//11
+f 10//12 8//12 14//12
+f 9//3 5//3 15//3
+f 3//13 6//13 16//13
+f 8//14 3//14 16//14
+f 6//15 9//15 16//15
+f 9//16 14//16 16//16
+f 14//17 8//17 16//17
+f 5//18 2//18 17//18
+f 11//19 10//19 18//19
+f 14//20 9//20 18//20
+f 10//21 14//21 18//21
+f 15//22 11//22 18//22
+f 9//23 15//23 18//23
+f 2//24 8//24 19//24
+f 8//25 10//25 19//25
+f 10//26 11//26 19//26
+f 17//27 2//27 19//27
+f 11//28 17//28 19//28
+f 11//29 15//29 20//29
+f 15//3 5//3 20//3
+f 5//30 17//30 20//30
+f 17//31 11//31 20//31
+f 7//32 8//32 21//32
+f 13//33 7//33 21//33
+f 8//34 13//34 21//34
+f 5//35 4//35 22//35
+f 4//36 12//36 22//36
+f 12//37 13//37 22//37
diff --git a/dexart-release/assets/sapien/102697/new_objs/102697_link_4_3.mtl b/dexart-release/assets/sapien/102697/new_objs/102697_link_4_3.mtl
new file mode 100644
index 0000000000000000000000000000000000000000..4813824b1d555853c436a16aa53832e48462157e
--- /dev/null
+++ b/dexart-release/assets/sapien/102697/new_objs/102697_link_4_3.mtl
@@ -0,0 +1,12 @@
+# Blender MTL File: 'None'
+# Material Count: 1
+
+newmtl Shape.7495
+Ns 400.000000
+Ka 0.400000 0.400000 0.400000
+Kd 0.168000 0.496000 0.216000
+Ks 0.250000 0.250000 0.250000
+Ke 0.000000 0.000000 0.000000
+Ni 1.000000
+d 0.500000
+illum 2
diff --git a/dexart-release/assets/sapien/102697/new_objs/102697_link_4_33.mtl b/dexart-release/assets/sapien/102697/new_objs/102697_link_4_33.mtl
new file mode 100644
index 0000000000000000000000000000000000000000..eae730d995bafc1fc4d40dae64eacd7c06c8a02f
--- /dev/null
+++ b/dexart-release/assets/sapien/102697/new_objs/102697_link_4_33.mtl
@@ -0,0 +1,12 @@
+# Blender MTL File: 'None'
+# Material Count: 1
+
+newmtl Shape.7525
+Ns 400.000000
+Ka 0.400000 0.400000 0.400000
+Kd 0.272000 0.624000 0.536000
+Ks 0.250000 0.250000 0.250000
+Ke 0.000000 0.000000 0.000000
+Ni 1.000000
+d 0.500000
+illum 2
diff --git a/dexart-release/assets/sapien/102697/new_objs/102697_link_4_4.mtl b/dexart-release/assets/sapien/102697/new_objs/102697_link_4_4.mtl
new file mode 100644
index 0000000000000000000000000000000000000000..396e5bcf8d40c1142017984f4aad04c4a616b409
--- /dev/null
+++ b/dexart-release/assets/sapien/102697/new_objs/102697_link_4_4.mtl
@@ -0,0 +1,12 @@
+# Blender MTL File: 'None'
+# Material Count: 1
+
+newmtl Shape.7496
+Ns 400.000000
+Ka 0.400000 0.400000 0.400000
+Kd 0.720000 0.472000 0.504000
+Ks 0.250000 0.250000 0.250000
+Ke 0.000000 0.000000 0.000000
+Ni 1.000000
+d 0.500000
+illum 2
diff --git a/dexart-release/assets/sapien/102697/new_objs/102697_link_4_8.mtl b/dexart-release/assets/sapien/102697/new_objs/102697_link_4_8.mtl
new file mode 100644
index 0000000000000000000000000000000000000000..26d5f949ad7bc15b162251a226af5e1cd2305870
--- /dev/null
+++ b/dexart-release/assets/sapien/102697/new_objs/102697_link_4_8.mtl
@@ -0,0 +1,12 @@
+# Blender MTL File: 'None'
+# Material Count: 1
+
+newmtl Shape.7500
+Ns 400.000000
+Ka 0.400000 0.400000 0.400000
+Kd 0.184000 0.536000 0.280000
+Ks 0.250000 0.250000 0.250000
+Ke 0.000000 0.000000 0.000000
+Ni 1.000000
+d 0.500000
+illum 2
diff --git a/dexart-release/assets/sapien/102697/result.json b/dexart-release/assets/sapien/102697/result.json
new file mode 100644
index 0000000000000000000000000000000000000000..176894da5199288fa5953af8e86e86f548d373fc
--- /dev/null
+++ b/dexart-release/assets/sapien/102697/result.json
@@ -0,0 +1 @@
+[{"text": "Toilet", "name": "Toilet", "id": 0, "children": [{"text": "button", "name": "button", "id": 1, "objs": ["original-8"]}, {"text": "lid", "name": "lid", "id": 2, "objs": ["original-4", "original-3"]}, {"text": "lid", "name": "lid", "id": 3, "objs": ["original-21", "original-20"]}, {"text": "seat", "name": "seat", "id": 4, "objs": ["original-18", "original-17"]}, {"text": "base_body", "name": "base_body", "id": 5, "objs": ["original-13", "original-6", "original-15", "original-11", "original-7", "original-10", "original-14"]}]}]
\ No newline at end of file
diff --git a/dexart-release/assets/sapien/102697/result_original.json b/dexart-release/assets/sapien/102697/result_original.json
new file mode 100644
index 0000000000000000000000000000000000000000..ab433e563f0bc74457d8b63bac00218605d84ed8
--- /dev/null
+++ b/dexart-release/assets/sapien/102697/result_original.json
@@ -0,0 +1 @@
+[{"id": 0, "name": "Toilet", "text": "Toilet", "children": [{"id": 1, "name": "button", "text": "button", "objs": ["original-8"]}, {"id": 2, "name": "lid", "text": "lid", "objs": ["original-4", "original-3"]}, {"id": 3, "name": "lid", "text": "lid", "objs": ["original-21", "original-20"]}, {"id": 4, "name": "seat", "text": "seat", "objs": ["original-18", "original-17"]}]}]
\ No newline at end of file
diff --git a/dexart-release/assets/sapien/102697/semantics.txt b/dexart-release/assets/sapien/102697/semantics.txt
new file mode 100644
index 0000000000000000000000000000000000000000..78e04a0f6d97e6d0c1c1436663e811d6b6d19b6a
--- /dev/null
+++ b/dexart-release/assets/sapien/102697/semantics.txt
@@ -0,0 +1,5 @@
+link_0 hinge lever
+link_1 slider pump_lid
+link_2 hinge lid
+link_3 hinge seat
+link_4 static toilet_body
diff --git a/dexart-release/dexart.egg-info/PKG-INFO b/dexart-release/dexart.egg-info/PKG-INFO
new file mode 100644
index 0000000000000000000000000000000000000000..71b72913fe316b9a164c464ffc857ae04b7fdb21
--- /dev/null
+++ b/dexart-release/dexart.egg-info/PKG-INFO
@@ -0,0 +1,12 @@
+Metadata-Version: 2.1
+Name: dexart
+Version: 0.1.0
+Summary: UNKNOWN
+Home-page: https://github.com/Kami-code/dexart-release
+Author: Xiaolong Wang's Lab
+License: UNKNOWN
+Platform: UNKNOWN
+License-File: LICENSE
+
+UNKNOWN
+
diff --git a/dexart-release/dexart.egg-info/SOURCES.txt b/dexart-release/dexart.egg-info/SOURCES.txt
new file mode 100644
index 0000000000000000000000000000000000000000..688f57b67ad7ec4a33139851809bf05ee3492d42
--- /dev/null
+++ b/dexart-release/dexart.egg-info/SOURCES.txt
@@ -0,0 +1,51 @@
+LICENSE
+README.md
+setup.py
+dexart.egg-info/PKG-INFO
+dexart.egg-info/SOURCES.txt
+dexart.egg-info/dependency_links.txt
+dexart.egg-info/requires.txt
+dexart.egg-info/top_level.txt
+stable_baselines3/__init__.py
+stable_baselines3/pickle_utils.py
+stable_baselines3/a2c/__init__.py
+stable_baselines3/a2c/a2c.py
+stable_baselines3/a2c/policies.py
+stable_baselines3/common/__init__.py
+stable_baselines3/common/base_class.py
+stable_baselines3/common/buffers.py
+stable_baselines3/common/callbacks.py
+stable_baselines3/common/distributions.py
+stable_baselines3/common/env_util.py
+stable_baselines3/common/evaluation.py
+stable_baselines3/common/logger.py
+stable_baselines3/common/monitor.py
+stable_baselines3/common/noise.py
+stable_baselines3/common/on_policy_algorithm.py
+stable_baselines3/common/policies.py
+stable_baselines3/common/preprocessing.py
+stable_baselines3/common/running_mean_std.py
+stable_baselines3/common/save_util.py
+stable_baselines3/common/torch_layers.py
+stable_baselines3/common/type_aliases.py
+stable_baselines3/common/utils.py
+stable_baselines3/common/vec_env/__init__.py
+stable_baselines3/common/vec_env/base_vec_env.py
+stable_baselines3/common/vec_env/dummy_vec_env.py
+stable_baselines3/common/vec_env/maniskill2_utils_common.py
+stable_baselines3/common/vec_env/maniskill2_utils_wrappers_obs.py
+stable_baselines3/common/vec_env/maniskill2_vec_env.py
+stable_baselines3/common/vec_env/maniskill2_wrapper_obs.py
+stable_baselines3/common/vec_env/stacked_observations.py
+stable_baselines3/common/vec_env/subproc_vec_env.py
+stable_baselines3/common/vec_env/util.py
+stable_baselines3/common/vec_env/vec_check_nan.py
+stable_baselines3/common/vec_env/vec_extract_dict_obs.py
+stable_baselines3/common/vec_env/vec_frame_stack.py
+stable_baselines3/common/vec_env/vec_monitor.py
+stable_baselines3/common/vec_env/vec_normalize.py
+stable_baselines3/common/vec_env/vec_transpose.py
+stable_baselines3/common/vec_env/vec_video_recorder.py
+stable_baselines3/ppo/__init__.py
+stable_baselines3/ppo/policies.py
+stable_baselines3/ppo/ppo.py
\ No newline at end of file
diff --git a/dexart-release/dexart.egg-info/dependency_links.txt b/dexart-release/dexart.egg-info/dependency_links.txt
new file mode 100644
index 0000000000000000000000000000000000000000..8b137891791fe96927ad78e64b0aad7bded08bdc
--- /dev/null
+++ b/dexart-release/dexart.egg-info/dependency_links.txt
@@ -0,0 +1 @@
+
diff --git a/dexart-release/dexart.egg-info/requires.txt b/dexart-release/dexart.egg-info/requires.txt
new file mode 100644
index 0000000000000000000000000000000000000000..3df94ffddf65c396377019387eec478fce9480df
--- /dev/null
+++ b/dexart-release/dexart.egg-info/requires.txt
@@ -0,0 +1,4 @@
+transforms3d
+sapien==2.2.1
+numpy
+open3d>=0.15.1
diff --git a/dexart-release/dexart.egg-info/top_level.txt b/dexart-release/dexart.egg-info/top_level.txt
new file mode 100644
index 0000000000000000000000000000000000000000..108131853cb412de9a5f163e52796658bcef9fb8
--- /dev/null
+++ b/dexart-release/dexart.egg-info/top_level.txt
@@ -0,0 +1 @@
+stable_baselines3
diff --git a/dexart-release/examples/gen_demonstration_expert.py b/dexart-release/examples/gen_demonstration_expert.py
new file mode 100644
index 0000000000000000000000000000000000000000..6ae7c8e49ddbd765f3572d19c0ce210096e50488
--- /dev/null
+++ b/dexart-release/examples/gen_demonstration_expert.py
@@ -0,0 +1,238 @@
+import os
+import sys
+
+sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
+
+import argparse
+import zarr
+import torch
+import numpy as np
+import torch.nn.functional as F
+import pytorch3d.ops as torch3d_ops
+from dexart.env.task_setting import TRAIN_CONFIG, RANDOM_CONFIG
+from dexart.env.create_env import create_env
+from stable_baselines3 import PPO
+# from examples.train import get_3d_policy_kwargs
+from train import get_3d_policy_kwargs
+from tqdm import tqdm
+from termcolor import cprint
+
+
+def parse_args():
+ parser = argparse.ArgumentParser()
+ parser.add_argument('--task_name', type=str, required=True)
+ parser.add_argument('--checkpoint_path', type=str, required=True)
+ parser.add_argument('--num_episodes', type=int, default=10, help='number of total episodes')
+ parser.add_argument('--use_test_set', dest='use_test_set', action='store_true', default=False)
+ parser.add_argument('--root_dir', type=str, default='data', help='directory to save data')
+ parser.add_argument('--img_size', type=int, default=84, help='image size')
+ parser.add_argument('--num_points', type=int, default=1024, help='number of points in point cloud')
+ args = parser.parse_args()
+ return args
+
+def downsample_with_fps(points: np.ndarray, num_points: int = 512):
+ # fast point cloud sampling using torch3d
+ points = torch.from_numpy(points).unsqueeze(0).cuda()
+ num_points = torch.tensor([num_points]).cuda()
+ # remember to only use coord to sample
+ _, sampled_indices = torch3d_ops.sample_farthest_points(points=points[...,:3], K=num_points)
+ points = points.squeeze(0).cpu().numpy()
+ points = points[sampled_indices.squeeze(0).cpu().numpy()]
+ return points
+
+def main():
+ args = parse_args()
+ task_name = args.task_name
+ use_test_set = args.use_test_set
+ checkpoint_path = args.checkpoint_path
+
+
+ save_dir = os.path.join(args.root_dir, 'dexart_'+args.task_name+'_expert.zarr')
+ if os.path.exists(save_dir):
+ cprint('Data already exists at {}'.format(save_dir), 'red')
+ cprint("If you want to overwrite, delete the existing directory first.", "red")
+ cprint("Do you want to overwrite? (y/n)", "red")
+ # user_input = input()
+ user_input = 'y'
+ if user_input == 'y':
+ cprint('Overwriting {}'.format(save_dir), 'red')
+ os.system('rm -rf {}'.format(save_dir))
+ else:
+ cprint('Exiting', 'red')
+ return
+ os.makedirs(save_dir, exist_ok=True)
+
+
+ if use_test_set:
+ indeces = TRAIN_CONFIG[task_name]['unseen']
+ cprint(f"using unseen instances {indeces}", 'yellow')
+ else:
+ indeces = TRAIN_CONFIG[task_name]['seen']
+ cprint(f"using seen instances {indeces}", 'yellow')
+
+ rand_pos = RANDOM_CONFIG[task_name]['rand_pos']
+ rand_degree = RANDOM_CONFIG[task_name]['rand_degree']
+ env = create_env(task_name=task_name,
+ use_visual_obs=True,
+ use_gui=False,
+ is_eval=True,
+ pc_noise=True,
+ pc_seg=True,
+ index=indeces,
+ img_type='robot',
+ rand_pos=rand_pos,
+ rand_degree=rand_degree)
+
+ policy = PPO.load(checkpoint_path, env, 'cuda:0',
+ policy_kwargs=get_3d_policy_kwargs(extractor_name='smallpn'),
+ check_obs_space=False, force_load=True)
+
+ eval_instances = len(env.instance_list)
+ num_episodes = args.num_episodes
+ cprint(f"generate {num_episodes} episodes in total", 'yellow')
+
+ success_list = []
+ reward_list = []
+
+ total_count = 0
+ img_arrays = []
+ point_cloud_arrays = []
+ depth_arrays = []
+ state_arrays = []
+ imagin_robot_arrays = []
+ action_arrays = []
+ episode_ends_arrays = []
+
+
+ with tqdm(total=num_episodes) as pbar:
+ num_success = 0
+ while num_success < num_episodes:
+
+ # obs dict keys: 'instance_1-seg_gt', 'instance_1-point_cloud',
+ # 'instance_1-rgb', 'imagination_robot', 'state', 'oracle_state'
+ obs = env.reset()
+ eval_success = False
+ reward_sum = 0
+
+ img_arrays_sub = []
+ point_cloud_arrays_sub = []
+ depth_arrays_sub = []
+ state_arrays_sub = []
+ imagin_robot_arrays_sub = []
+ action_arrays_sub = []
+ total_count_sub = 0
+ for j in range(env.horizon):
+
+ if isinstance(obs, dict):
+ for key, value in obs.items():
+ obs[key] = value[np.newaxis, :]
+ else:
+ obs = obs[np.newaxis, :]
+ action = policy.predict(observation=obs, deterministic=True)[0]
+
+ # fetch data
+ total_count_sub += 1
+ obs_state = obs['state'][0] # (32)
+ obs_imagin_robot = obs['imagination_robot'][0] # (96,7)
+ obs_point_cloud = obs['instance_1-point_cloud'][0] # (1024,3)
+ obs_depth = obs['instance_1-depth'][0] # (84,84)
+
+ if obs_point_cloud.shape[0] > args.num_points:
+ obs_point_cloud = downsample_with_fps(obs_point_cloud, num_points=args.num_points)
+ obs_image = obs['instance_1-rgb'][0] # (84,84,3), [0,1]
+
+
+ # to 0-255
+ obs_image = (obs_image*255).astype(np.uint8)
+
+ # interpolate to target image size
+ if obs_image.shape[0] != args.img_size:
+ obs_image = F.interpolate(torch.from_numpy(obs_image).permute(2,0,1).unsqueeze(0),
+ size=args.img_size).squeeze().permute(1,2,0).numpy()
+ # save data
+ img_arrays_sub.append(obs_image)
+ imagin_robot_arrays_sub.append(obs_imagin_robot)
+ point_cloud_arrays_sub.append(obs_point_cloud)
+ depth_arrays_sub.append(obs_depth)
+ state_arrays_sub.append(obs_state)
+ action_arrays_sub.append(action)
+
+ # step
+ obs, reward, done, _ = env.step(action)
+ reward_sum += reward
+ if env.is_eval_done:
+ eval_success = True
+ if done:
+ break
+
+ if eval_success:
+ total_count += total_count_sub
+ episode_ends_arrays.append(total_count) # the index of the last step of the episode
+ reward_list.append(reward_sum)
+ success_list.append(int(eval_success))
+
+ img_arrays.extend(img_arrays_sub)
+ imagin_robot_arrays.extend(imagin_robot_arrays_sub)
+ point_cloud_arrays.extend(point_cloud_arrays_sub)
+ depth_arrays.extend(depth_arrays_sub)
+ state_arrays.extend(state_arrays_sub)
+ action_arrays.extend(action_arrays_sub)
+
+ num_success += 1
+
+ pbar.update(1)
+ pbar.set_description(f"reward = {reward_sum}, success = {eval_success}")
+ else:
+ print("episode failed. continue.")
+ continue
+
+
+ cprint(f"reward_mean = {np.mean(reward_list)}, success rate = {np.mean(success_list)}", 'yellow')
+
+ ###############################
+ # save data
+ ###############################
+ # create zarr file
+ zarr_root = zarr.group(save_dir)
+ zarr_data = zarr_root.create_group('data')
+ zarr_meta = zarr_root.create_group('meta')
+ # save img, state, action arrays into data, and episode ends arrays into meta
+ img_arrays = np.stack(img_arrays, axis=0)
+ if img_arrays.shape[1] == 3: # make channel last
+ img_arrays = np.transpose(img_arrays, (0,2,3,1))
+ state_arrays = np.stack(state_arrays, axis=0)
+ imagin_robot_arrays = np.stack(imagin_robot_arrays, axis=0)
+ point_cloud_arrays = np.stack(point_cloud_arrays, axis=0)
+ depth_arrays = np.stack(depth_arrays, axis=0)
+ action_arrays = np.stack(action_arrays, axis=0)
+ episode_ends_arrays = np.array(episode_ends_arrays)
+
+ compressor = zarr.Blosc(cname='zstd', clevel=3, shuffle=1)
+ img_chunk_size = (env.horizon, img_arrays.shape[1], img_arrays.shape[2], img_arrays.shape[3])
+ imagin_robot_chunk_size = (env.horizon, imagin_robot_arrays.shape[1], imagin_robot_arrays.shape[2])
+ point_cloud_chunk_size = (env.horizon, point_cloud_arrays.shape[1], point_cloud_arrays.shape[2])
+ depth_chunk_size = (env.horizon, depth_arrays.shape[1], depth_arrays.shape[2])
+ state_chunk_size = (env.horizon, state_arrays.shape[1])
+ action_chunk_size = (env.horizon, action_arrays.shape[1])
+ zarr_data.create_dataset('img', data=img_arrays, chunks=img_chunk_size, dtype='uint8', overwrite=True, compressor=compressor)
+ zarr_data.create_dataset('state', data=state_arrays, chunks=state_chunk_size, dtype='float32', overwrite=True, compressor=compressor)
+ zarr_data.create_dataset('imagin_robot', data=imagin_robot_arrays, chunks=imagin_robot_chunk_size, dtype='float32', overwrite=True, compressor=compressor)
+ zarr_data.create_dataset('point_cloud', data=point_cloud_arrays, chunks=point_cloud_chunk_size, dtype='float32', overwrite=True, compressor=compressor)
+ zarr_data.create_dataset('depth', data=depth_arrays, chunks=depth_chunk_size, dtype='float32', overwrite=True, compressor=compressor)
+ zarr_data.create_dataset('action', data=action_arrays, chunks=action_chunk_size, dtype='float32', overwrite=True, compressor=compressor)
+ zarr_meta.create_dataset('episode_ends', data=episode_ends_arrays, dtype='int64', overwrite=True, compressor=compressor)
+
+ # print shape
+ cprint(f'img shape: {img_arrays.shape}, range: [{np.min(img_arrays)}, {np.max(img_arrays)}]', 'green')
+ cprint(f'imagin_robot shape: {imagin_robot_arrays.shape}, range: [{np.min(imagin_robot_arrays)}, {np.max(imagin_robot_arrays)}]', 'green')
+ cprint(f'point_cloud shape: {point_cloud_arrays.shape}, range: [{np.min(point_cloud_arrays)}, {np.max(point_cloud_arrays)}]', 'green')
+ cprint(f'depth shape: {depth_arrays.shape}, range: [{np.min(depth_arrays)}, {np.max(depth_arrays)}]', 'green')
+ cprint(f'state shape: {state_arrays.shape}, range: [{np.min(state_arrays)}, {np.max(state_arrays)}]', 'green')
+ cprint(f'action shape: {action_arrays.shape}, range: [{np.min(action_arrays)}, {np.max(action_arrays)}]', 'green')
+ cprint(f'Saved zarr file to {save_dir}', 'green')
+
+
+if __name__ == "__main__":
+ main()
+
+
diff --git a/dexart-release/examples/train.py b/dexart-release/examples/train.py
new file mode 100644
index 0000000000000000000000000000000000000000..2f78b55bcf88c0353b79e4024e4ce2a86c57fdda
--- /dev/null
+++ b/dexart-release/examples/train.py
@@ -0,0 +1,124 @@
+import os
+import sys
+sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
+import random
+from collections import OrderedDict
+import torch.nn as nn
+import argparse
+from dexart.env.create_env import create_env
+from dexart.env.task_setting import TRAIN_CONFIG, IMG_CONFIG, RANDOM_CONFIG
+from stable_baselines3.common.torch_layers import PointNetImaginationExtractorGP
+from stable_baselines3.common.vec_env.subproc_vec_env import SubprocVecEnv
+from stable_baselines3.ppo import PPO
+import torch
+
+BASE_DIR = os.path.dirname(os.path.abspath(__file__))
+
+
+def get_3d_policy_kwargs(extractor_name):
+ feature_extractor_class = PointNetImaginationExtractorGP
+ feature_extractor_kwargs = {"pc_key": "instance_1-point_cloud", "gt_key": "instance_1-seg_gt",
+ "extractor_name": extractor_name,
+ "imagination_keys": [f'imagination_{key}' for key in IMG_CONFIG['robot'].keys()],
+ "state_key": "state"}
+
+ policy_kwargs = {
+ "features_extractor_class": feature_extractor_class,
+ "features_extractor_kwargs": feature_extractor_kwargs,
+ "net_arch": [dict(pi=[64, 64], vf=[64, 64])],
+ "activation_fn": nn.ReLU,
+ }
+ return policy_kwargs
+
+
+if __name__ == '__main__':
+ parser = argparse.ArgumentParser()
+ parser.add_argument('--n', type=int, default=10)
+ parser.add_argument('--workers', type=int, default=1)
+ parser.add_argument('--lr', type=float, default=3e-4)
+ parser.add_argument('--ep', type=int, default=10)
+ parser.add_argument('--bs', type=int, default=10)
+ parser.add_argument('--seed', type=int, default=100)
+ parser.add_argument('--iter', type=int, default=1000)
+ parser.add_argument('--freeze', dest='freeze', action='store_true', default=False)
+ parser.add_argument('--task_name', type=str, default="laptop")
+ parser.add_argument('--extractor_name', type=str, default="smallpn")
+ parser.add_argument('--pretrain_path', type=str, default=None)
+ args = parser.parse_args()
+
+ task_name = args.task_name
+ extractor_name = args.extractor_name
+ seed = args.seed if args.seed >= 0 else random.randint(0, 100000)
+ pretrain_path = args.pretrain_path
+ horizon = 200
+ env_iter = args.iter * horizon * args.n
+ print(f"freeze: {args.freeze}")
+
+ rand_pos = RANDOM_CONFIG[task_name]['rand_pos']
+ rand_degree = RANDOM_CONFIG[task_name]['rand_degree']
+
+
+ def create_env_fn():
+ seen_indeces = TRAIN_CONFIG[task_name]['seen']
+ environment = create_env(task_name=task_name,
+ use_visual_obs=True,
+ use_gui=False,
+ is_eval=False,
+ pc_noise=True,
+ index=seen_indeces,
+ img_type='robot',
+ rand_pos=rand_pos,
+ rand_degree=rand_degree
+ )
+ return environment
+
+
+ def create_eval_env_fn():
+ unseen_indeces = TRAIN_CONFIG[task_name]['unseen']
+ environment = create_env(task_name=task_name,
+ use_visual_obs=True,
+ use_gui=False,
+ is_eval=True,
+ pc_noise=True,
+ index=unseen_indeces,
+ img_type='robot',
+ rand_pos=rand_pos,
+ rand_degree=rand_degree)
+ return environment
+
+
+ env = SubprocVecEnv([create_env_fn] * args.workers, "spawn") # train on a list of envs.
+
+ model = PPO("PointCloudPolicy", env, verbose=1,
+ n_epochs=args.ep,
+ n_steps=(args.n // args.workers) * horizon,
+ learning_rate=args.lr,
+ batch_size=args.bs,
+ seed=seed,
+ policy_kwargs=get_3d_policy_kwargs(extractor_name=extractor_name),
+ min_lr=args.lr,
+ max_lr=args.lr,
+ adaptive_kl=0.02,
+ target_kl=0.2,
+ )
+
+ if pretrain_path is not None:
+ state_dict: OrderedDict = torch.load(pretrain_path)
+ model.policy.features_extractor.extractor.load_state_dict(state_dict, strict=False)
+ print("load pretrained model: ", pretrain_path)
+
+ rollout = int(model.num_timesteps / (horizon * args.n))
+
+ # after loading or init the model, then freeze it if needed
+ if args.freeze:
+ model.policy.features_extractor.extractor.eval()
+ for param in model.policy.features_extractor.extractor.parameters():
+ param.requires_grad = False
+ print("freeze model!")
+
+ model.learn(
+ total_timesteps=int(env_iter),
+ reset_num_timesteps=False,
+ iter_start=rollout,
+ callback=None
+ )
diff --git a/dexart-release/examples/utils.py b/dexart-release/examples/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..cbe23a9a962a904ca60660165af003f226fb8b09
--- /dev/null
+++ b/dexart-release/examples/utils.py
@@ -0,0 +1,66 @@
+#!/usr/bin/env python
+# -*- coding: UTF-8 -*-
+import numpy as np
+from dexart.env.task_setting import ROBUSTNESS_INIT_CAMERA_CONFIG
+import open3d as o3d
+
+def visualize_observation(obs, use_seg=False, img_type=None):
+ def visualize_pc_with_seg_label(cloud):
+ pc = o3d.geometry.PointCloud(o3d.utility.Vector3dVector(cloud[:, :3]))
+
+ def map(feature):
+ color = np.zeros((feature.shape[0], 3))
+ COLOR20 = np.array(
+ [[230, 25, 75], [60, 180, 75], [255, 225, 25], [0, 130, 200], [245, 130, 48],
+ [145, 30, 180], [70, 240, 240], [240, 50, 230], [210, 245, 60], [250, 190, 190],
+ [0, 128, 128], [230, 190, 255], [170, 110, 40], [255, 250, 200], [128, 0, 0],
+ [170, 255, 195], [128, 128, 0], [255, 215, 180], [0, 0, 128], [128, 128, 128]]) / 255
+ for i in range(feature.shape[0]):
+ for j in range(feature.shape[1]):
+ if feature[i, j] == 1:
+ color[i, :] = COLOR20[j, :]
+ return color
+
+ color = map(cloud[:, 3:])
+ pc.colors = o3d.utility.Vector3dVector(color)
+ return pc
+
+ pc = obs["instance_1-point_cloud"]
+ if use_seg:
+ gt_seg = obs["instance_1-seg_gt"]
+ pc = np.concatenate([pc, gt_seg], axis=1)
+ pc = visualize_pc_with_seg_label(pc)
+ if img_type == "robot":
+ robot_pc = obs["imagination_robot"]
+ pc += visualize_pc_with_seg_label(robot_pc)
+ else:
+ raise NotImplementedError
+ return pc
+
+
+def get_viewpoint_camera_parameter():
+ robustness_init_camera_config = ROBUSTNESS_INIT_CAMERA_CONFIG['laptop']
+ r = robustness_init_camera_config['r']
+ phi = robustness_init_camera_config['phi']
+ theta = robustness_init_camera_config['theta']
+ center = robustness_init_camera_config['center']
+
+ x0, y0, z0 = center
+ # phi in [0, pi/2]
+ # theta in [0, 2 * pi]
+ x = x0 + r * np.sin(phi) * np.cos(theta)
+ y = y0 + r * np.sin(phi) * np.sin(theta)
+ z = z0 + r * np.cos(phi)
+
+ cam_pos = np.array([x, y, z])
+ forward = np.array([x0 - x, y0 - y, z0 - z])
+ forward /= np.linalg.norm(forward)
+
+ left = np.cross([0, 0, 1], forward)
+ left = left / np.linalg.norm(left)
+
+ up = np.cross(forward, left)
+ mat44 = np.eye(4)
+ mat44[:3, :3] = np.stack([forward, left, up], axis=1)
+ mat44[:3, 3] = cam_pos
+ return cam_pos, center, up, mat44
\ No newline at end of file
diff --git a/dexart-release/stable_baselines3/a2c/__init__.py b/dexart-release/stable_baselines3/a2c/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..7e9996494b658e9bef6efccfc9cdad269da4609d
--- /dev/null
+++ b/dexart-release/stable_baselines3/a2c/__init__.py
@@ -0,0 +1,2 @@
+from stable_baselines3.a2c.a2c import A2C
+from stable_baselines3.a2c.policies import CnnPolicy, MlpPolicy, MultiInputPolicy
diff --git a/dexart-release/stable_baselines3/a2c/a2c.py b/dexart-release/stable_baselines3/a2c/a2c.py
new file mode 100644
index 0000000000000000000000000000000000000000..13adf680001deaaae21659d09b22cdeaffd7a775
--- /dev/null
+++ b/dexart-release/stable_baselines3/a2c/a2c.py
@@ -0,0 +1,207 @@
+from typing import Any, Dict, Optional, Type, Union
+
+import torch as th
+from gym import spaces
+from torch.nn import functional as F
+
+from stable_baselines3.common.on_policy_algorithm import OnPolicyAlgorithm
+from stable_baselines3.common.policies import ActorCriticCnnPolicy, ActorCriticPolicy, BasePolicy, MultiInputActorCriticPolicy
+from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, Schedule
+from stable_baselines3.common.utils import explained_variance
+
+
+class A2C(OnPolicyAlgorithm):
+ """
+ Advantage Actor Critic (A2C)
+
+ Paper: https://arxiv.org/abs/1602.01783
+ Code: This implementation borrows code from https://github.com/ikostrikov/pytorch-a2c-ppo-acktr-gail and
+ and Stable Baselines (https://github.com/hill-a/stable-baselines)
+
+ Introduction to A2C: https://hackernoon.com/intuitive-rl-intro-to-advantage-actor-critic-a2c-4ff545978752
+
+ :param policy: The policy model to use (MlpPolicy, CnnPolicy, ...)
+ :param env: The environment to learn from (if registered in Gym, can be str)
+ :param learning_rate: The learning rate, it can be a function
+ of the current progress remaining (from 1 to 0)
+ :param n_steps: The number of steps to run for each environment per update
+ (i.e. batch size is n_steps * n_env where n_env is number of environment copies running in parallel)
+ :param gamma: Discount factor
+ :param gae_lambda: Factor for trade-off of bias vs variance for Generalized Advantage Estimator
+ Equivalent to classic advantage when set to 1.
+ :param ent_coef: Entropy coefficient for the loss calculation
+ :param vf_coef: Value function coefficient for the loss calculation
+ :param max_grad_norm: The maximum value for the gradient clipping
+ :param rms_prop_eps: RMSProp epsilon. It stabilizes square root computation in denominator
+ of RMSProp update
+ :param use_rms_prop: Whether to use RMSprop (default) or Adam as optimizer
+ :param use_sde: Whether to use generalized State Dependent Exploration (gSDE)
+ instead of action noise exploration (default: False)
+ :param sde_sample_freq: Sample a new noise matrix every n steps when using gSDE
+ Default: -1 (only sample at the beginning of the rollout)
+ :param normalize_advantage: Whether to normalize or not the advantage
+ :param tensorboard_log: the log location for tensorboard (if None, no logging)
+ :param create_eval_env: Whether to create a second environment that will be
+ used for evaluating the agent periodically. (Only available when passing string for the environment)
+ :param policy_kwargs: additional arguments to be passed to the policy on creation
+ :param verbose: the verbosity level: 0 no output, 1 info, 2 debug
+ :param seed: Seed for the pseudo random generators
+ :param device: Device (cpu, cuda, ...) on which the code should be run.
+ Setting it to auto, the code will be run on the GPU if possible.
+ :param _init_setup_model: Whether or not to build the network at the creation of the instance
+ """
+
+ policy_aliases: Dict[str, Type[BasePolicy]] = {
+ "MlpPolicy": ActorCriticPolicy,
+ "CnnPolicy": ActorCriticCnnPolicy,
+ "MultiInputPolicy": MultiInputActorCriticPolicy,
+ }
+
+ def __init__(
+ self,
+ policy: Union[str, Type[ActorCriticPolicy]],
+ env: Union[GymEnv, str],
+ learning_rate: Union[float, Schedule] = 7e-4,
+ n_steps: int = 5,
+ gamma: float = 0.99,
+ gae_lambda: float = 1.0,
+ ent_coef: float = 0.0,
+ vf_coef: float = 0.5,
+ max_grad_norm: float = 0.5,
+ rms_prop_eps: float = 1e-5,
+ use_rms_prop: bool = True,
+ use_sde: bool = False,
+ sde_sample_freq: int = -1,
+ normalize_advantage: bool = False,
+ tensorboard_log: Optional[str] = None,
+ create_eval_env: bool = False,
+ policy_kwargs: Optional[Dict[str, Any]] = None,
+ verbose: int = 0,
+ seed: Optional[int] = None,
+ device: Union[th.device, str] = "auto",
+ _init_setup_model: bool = True,
+ ):
+
+ super().__init__(
+ policy,
+ env,
+ learning_rate=learning_rate,
+ n_steps=n_steps,
+ gamma=gamma,
+ gae_lambda=gae_lambda,
+ ent_coef=ent_coef,
+ vf_coef=vf_coef,
+ max_grad_norm=max_grad_norm,
+ use_sde=use_sde,
+ sde_sample_freq=sde_sample_freq,
+ tensorboard_log=tensorboard_log,
+ policy_kwargs=policy_kwargs,
+ verbose=verbose,
+ device=device,
+ create_eval_env=create_eval_env,
+ seed=seed,
+ _init_setup_model=False,
+ supported_action_spaces=(
+ spaces.Box,
+ spaces.Discrete,
+ spaces.MultiDiscrete,
+ spaces.MultiBinary,
+ ),
+ )
+
+ self.normalize_advantage = normalize_advantage
+
+ # Update optimizer inside the policy if we want to use RMSProp
+ # (original implementation) rather than Adam
+ if use_rms_prop and "optimizer_class" not in self.policy_kwargs:
+ self.policy_kwargs["optimizer_class"] = th.optim.RMSprop
+ self.policy_kwargs["optimizer_kwargs"] = dict(alpha=0.99, eps=rms_prop_eps, weight_decay=0)
+
+ if _init_setup_model:
+ self._setup_model()
+
+ def train(self) -> None:
+ """
+ Update policy using the currently gathered
+ rollout buffer (one gradient step over whole data).
+ """
+ # Switch to train mode (this affects batch norm / dropout)
+ self.policy.set_training_mode(True)
+
+ # Update optimizer learning rate
+ self._update_learning_rate(self.policy.optimizer)
+
+ # This will only loop once (get all data in one go)
+ for rollout_data in self.rollout_buffer.get(batch_size=None):
+
+ actions = rollout_data.actions
+ if isinstance(self.action_space, spaces.Discrete):
+ # Convert discrete action from float to long
+ actions = actions.long().flatten()
+
+ values, log_prob, entropy = self.policy.evaluate_actions(rollout_data.observations, actions)
+ values = values.flatten()
+
+ # Normalize advantage (not present in the original implementation)
+ advantages = rollout_data.advantages
+ if self.normalize_advantage:
+ advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8)
+
+ # Policy gradient loss
+ policy_loss = -(advantages * log_prob).mean()
+
+ # Value loss using the TD(gae_lambda) target
+ value_loss = F.mse_loss(rollout_data.returns, values)
+
+ # Entropy loss favor exploration
+ if entropy is None:
+ # Approximate entropy when no analytical form
+ entropy_loss = -th.mean(-log_prob)
+ else:
+ entropy_loss = -th.mean(entropy)
+
+ loss = policy_loss + self.ent_coef * entropy_loss + self.vf_coef * value_loss
+
+ # Optimization step
+ self.policy.optimizer.zero_grad()
+ loss.backward()
+
+ # Clip grad norm
+ th.nn.utils.clip_grad_norm_(self.policy.parameters(), self.max_grad_norm)
+ self.policy.optimizer.step()
+
+ explained_var = explained_variance(self.rollout_buffer.values.flatten(), self.rollout_buffer.returns.flatten())
+
+ self._n_updates += 1
+ self.logger.record("train/n_updates", self._n_updates, exclude="tensorboard")
+ self.logger.record("train/explained_variance", explained_var)
+ self.logger.record("train/entropy_loss", entropy_loss.item())
+ self.logger.record("train/policy_loss", policy_loss.item())
+ self.logger.record("train/value_loss", value_loss.item())
+ if hasattr(self.policy, "log_std"):
+ self.logger.record("train/std", th.exp(self.policy.log_std).mean().item())
+
+ def learn(
+ self,
+ total_timesteps: int,
+ callback: MaybeCallback = None,
+ log_interval: int = 100,
+ eval_env: Optional[GymEnv] = None,
+ eval_freq: int = -1,
+ n_eval_episodes: int = 5,
+ tb_log_name: str = "A2C",
+ eval_log_path: Optional[str] = None,
+ reset_num_timesteps: bool = True,
+ ) -> "A2C":
+
+ return super().learn(
+ total_timesteps=total_timesteps,
+ callback=callback,
+ log_interval=log_interval,
+ eval_env=eval_env,
+ eval_freq=eval_freq,
+ n_eval_episodes=n_eval_episodes,
+ tb_log_name=tb_log_name,
+ eval_log_path=eval_log_path,
+ reset_num_timesteps=reset_num_timesteps,
+ )
diff --git a/dexart-release/stable_baselines3/a2c/policies.py b/dexart-release/stable_baselines3/a2c/policies.py
new file mode 100644
index 0000000000000000000000000000000000000000..7299b34df478f6392cf5182842dacd83373aed0f
--- /dev/null
+++ b/dexart-release/stable_baselines3/a2c/policies.py
@@ -0,0 +1,7 @@
+# This file is here just to define MlpPolicy/CnnPolicy
+# that work for A2C
+from stable_baselines3.common.policies import ActorCriticCnnPolicy, ActorCriticPolicy, MultiInputActorCriticPolicy
+
+MlpPolicy = ActorCriticPolicy
+CnnPolicy = ActorCriticCnnPolicy
+MultiInputPolicy = MultiInputActorCriticPolicy
diff --git a/dexart-release/stable_baselines3/common/__init__.py b/dexart-release/stable_baselines3/common/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/dexart-release/stable_baselines3/common/base_class.py b/dexart-release/stable_baselines3/common/base_class.py
new file mode 100644
index 0000000000000000000000000000000000000000..31266e76b9f00b3c5c2e52e3add465aec4f3e9e4
--- /dev/null
+++ b/dexart-release/stable_baselines3/common/base_class.py
@@ -0,0 +1,835 @@
+"""Abstract base classes for RL algorithms."""
+
+import io
+import pathlib
+import time
+from abc import ABC, abstractmethod
+from collections import deque
+from typing import Any, Dict, Iterable, List, Optional, Tuple, Type, Union
+
+import gym
+import numpy as np
+import torch as th
+
+from stable_baselines3.common import utils
+from stable_baselines3.common.callbacks import BaseCallback, CallbackList, ConvertCallback, EvalCallback
+from stable_baselines3.common.env_util import is_wrapped
+from stable_baselines3.common.logger import Logger
+from stable_baselines3.common.monitor import Monitor
+from stable_baselines3.common.noise import ActionNoise
+from stable_baselines3.common.policies import BasePolicy
+from stable_baselines3.common.preprocessing import check_for_nested_spaces, is_image_space, is_image_space_channels_first
+from stable_baselines3.common.save_util import load_from_zip_file, recursive_getattr, recursive_setattr, save_to_zip_file
+from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, Schedule
+from stable_baselines3.common.utils import (
+ check_for_correct_spaces,
+ get_device,
+ get_schedule_fn,
+ get_system_info,
+ set_random_seed,
+ update_learning_rate,
+)
+from stable_baselines3.common.vec_env import (
+ DummyVecEnv,
+ VecEnv,
+ VecNormalize,
+ VecTransposeImage,
+ is_vecenv_wrapped,
+ unwrap_vec_normalize,
+)
+
+
+def maybe_make_env(env: Union[GymEnv, str, None], verbose: int) -> Optional[GymEnv]:
+ """If env is a string, make the environment; otherwise, return env.
+
+ :param env: The environment to learn from.
+ :param verbose: logging verbosity
+ :return A Gym (vector) environment.
+ """
+ if isinstance(env, str):
+ if verbose >= 1:
+ print(f"Creating environment from the given name '{env}'")
+ env = gym.make(env)
+ return env
+
+
+class BaseAlgorithm(ABC):
+ """
+ The base of RL algorithms
+
+ :param policy: Policy object
+ :param env: The environment to learn from
+ (if registered in Gym, can be str. Can be None for loading trained models)
+ :param learning_rate: learning rate for the optimizer,
+ it can be a function of the current progress remaining (from 1 to 0)
+ :param policy_kwargs: Additional arguments to be passed to the policy on creation
+ :param tensorboard_log: the log location for tensorboard (if None, no logging)
+ :param verbose: The verbosity level: 0 none, 1 training information, 2 debug
+ :param device: Device on which the code should run.
+ By default, it will try to use a Cuda compatible device and fallback to cpu
+ if it is not possible.
+ :param support_multi_env: Whether the algorithm supports training
+ with multiple environments (as in A2C)
+ :param create_eval_env: Whether to create a second environment that will be
+ used for evaluating the agent periodically. (Only available when passing string for the environment)
+ :param monitor_wrapper: When creating an environment, whether to wrap it
+ or not in a Monitor wrapper.
+ :param seed: Seed for the pseudo random generators
+ :param use_sde: Whether to use generalized State Dependent Exploration (gSDE)
+ instead of action noise exploration (default: False)
+ :param sde_sample_freq: Sample a new noise matrix every n steps when using gSDE
+ Default: -1 (only sample at the beginning of the rollout)
+ :param supported_action_spaces: The action spaces supported by the algorithm.
+ """
+
+ # Policy aliases (see _get_policy_from_name())
+ policy_aliases: Dict[str, Type[BasePolicy]] = {}
+
+ def __init__(
+ self,
+ policy: Type[BasePolicy],
+ env: Union[GymEnv, str, None],
+ learning_rate: Union[float, Schedule],
+ policy_kwargs: Optional[Dict[str, Any]] = None,
+ tensorboard_log: Optional[str] = None,
+ verbose: int = 0,
+ device: Union[th.device, str] = "auto",
+ support_multi_env: bool = False,
+ create_eval_env: bool = False,
+ monitor_wrapper: bool = True,
+ seed: Optional[int] = None,
+ use_sde: bool = False,
+ sde_sample_freq: int = -1,
+ supported_action_spaces: Optional[Tuple[gym.spaces.Space, ...]] = None,
+ ):
+ if isinstance(policy, str):
+ self.policy_class = self._get_policy_from_name(policy)
+ else:
+ self.policy_class = policy
+
+ self.device = get_device(device)
+ if verbose > 0:
+ print(f"Using {self.device} device")
+
+ self.env = None # type: Optional[GymEnv]
+ # get VecNormalize object if needed
+ self._vec_normalize_env = unwrap_vec_normalize(env)
+ self.verbose = verbose
+ self.policy_kwargs = {} if policy_kwargs is None else policy_kwargs
+ self.observation_space = None # type: Optional[gym.spaces.Space]
+ self.action_space = None # type: Optional[gym.spaces.Space]
+ self.n_envs = None
+ self.num_timesteps = 0
+ # Used for updating schedules
+ self._total_timesteps = 0
+ # Used for computing fps, it is updated at each call of learn()
+ self._num_timesteps_at_start = 0
+ self.eval_env = None
+ self.seed = seed
+ self.action_noise = None # type: Optional[ActionNoise]
+ self.start_time = None
+ self.policy = None
+ self.learning_rate = learning_rate
+ self.tensorboard_log = tensorboard_log
+ self.lr_schedule = None # type: Optional[Schedule]
+ self._last_obs = None # type: Optional[Union[np.ndarray, Dict[str, np.ndarray]]]
+ self._last_episode_starts = None # type: Optional[np.ndarray]
+ # When using VecNormalize:
+ self._last_original_obs = None # type: Optional[Union[np.ndarray, Dict[str, np.ndarray]]]
+ self._episode_num = 0
+ # Used for gSDE only
+ self.use_sde = use_sde
+ self.sde_sample_freq = sde_sample_freq
+ # Track the training progress remaining (from 1 to 0)
+ # this is used to update the learning rate
+ self._current_progress_remaining = 1
+ # Buffers for logging
+ self.ep_info_buffer = None # type: Optional[deque]
+ self.ep_success_buffer = None # type: Optional[deque]
+ # For logging (and TD3 delayed updates)
+ self._n_updates = 0 # type: int
+ # The logger object
+ self._logger = None # type: Logger
+ # Whether the user passed a custom logger or not
+ self._custom_logger = False
+
+ # Create and wrap the env if needed
+ if env is not None:
+ if isinstance(env, str):
+ if create_eval_env:
+ self.eval_env = maybe_make_env(env, self.verbose)
+
+ env = maybe_make_env(env, self.verbose)
+ env = self._wrap_env(env, self.verbose, monitor_wrapper)
+
+ self.observation_space = env.observation_space
+ self.action_space = env.action_space
+ self.n_envs = env.num_envs
+ self.env = env
+
+ if supported_action_spaces is not None:
+ assert isinstance(self.action_space, supported_action_spaces), (
+ f"The algorithm only supports {supported_action_spaces} as action spaces "
+ f"but {self.action_space} was provided"
+ )
+
+ if not support_multi_env and self.n_envs > 1:
+ raise ValueError(
+ "Error: the model does not support multiple envs; it requires " "a single vectorized environment."
+ )
+
+ # Catch common mistake: using MlpPolicy/CnnPolicy instead of MultiInputPolicy
+ if policy in ["MlpPolicy", "CnnPolicy"] and isinstance(self.observation_space, gym.spaces.Dict):
+ raise ValueError(f"You must use `MultiInputPolicy` when working with dict observation space, not {policy}")
+
+ if self.use_sde and not isinstance(self.action_space, gym.spaces.Box):
+ raise ValueError("generalized State-Dependent Exploration (gSDE) can only be used with continuous actions.")
+
+ @staticmethod
+ def _wrap_env(env: GymEnv, verbose: int = 0, monitor_wrapper: bool = True) -> VecEnv:
+ """ "
+ Wrap environment with the appropriate wrappers if needed.
+ For instance, to have a vectorized environment
+ or to re-order the image channels.
+
+ :param env:
+ :param verbose:
+ :param monitor_wrapper: Whether to wrap the env in a ``Monitor`` when possible.
+ :return: The wrapped environment.
+ """
+ if not isinstance(env, VecEnv):
+ if not is_wrapped(env, Monitor) and monitor_wrapper:
+ if verbose >= 1:
+ print("Wrapping the env with a `Monitor` wrapper")
+ env = Monitor(env)
+ if verbose >= 1:
+ print("Wrapping the env in a DummyVecEnv.")
+ env = DummyVecEnv([lambda: env])
+
+ # Make sure that dict-spaces are not nested (not supported)
+ check_for_nested_spaces(env.observation_space)
+
+ if not is_vecenv_wrapped(env, VecTransposeImage):
+ wrap_with_vectranspose = False
+ if isinstance(env.observation_space, gym.spaces.Dict):
+ # If even one of the keys is a image-space in need of transpose, apply transpose
+ # If the image spaces are not consistent (for instance one is channel first,
+ # the other channel last), VecTransposeImage will throw an error
+ for space in env.observation_space.spaces.values():
+ wrap_with_vectranspose = wrap_with_vectranspose or (
+ is_image_space(space) and not is_image_space_channels_first(space)
+ )
+ else:
+ wrap_with_vectranspose = is_image_space(env.observation_space) and not is_image_space_channels_first(
+ env.observation_space
+ )
+
+ if wrap_with_vectranspose:
+ if verbose >= 1:
+ print("Wrapping the env in a VecTransposeImage.")
+ env = VecTransposeImage(env)
+
+ return env
+
+ @abstractmethod
+ def _setup_model(self) -> None:
+ """Create networks, buffer and optimizers."""
+
+ def set_logger(self, logger: Logger) -> None:
+ """
+ Setter for for logger object.
+
+ .. warning::
+
+ When passing a custom logger object,
+ this will overwrite ``tensorboard_log`` and ``verbose`` settings
+ passed to the constructor.
+ """
+ self._logger = logger
+ # User defined logger
+ self._custom_logger = True
+
+ @property
+ def logger(self) -> Logger:
+ """Getter for the logger object."""
+ return self._logger
+
+ def _get_eval_env(self, eval_env: Optional[GymEnv]) -> Optional[GymEnv]:
+ """
+ Return the environment that will be used for evaluation.
+
+ :param eval_env:)
+ :return:
+ """
+ if eval_env is None:
+ eval_env = self.eval_env
+
+ if eval_env is not None:
+ eval_env = self._wrap_env(eval_env, self.verbose)
+ assert eval_env.num_envs == 1
+ return eval_env
+
+ def _setup_lr_schedule(self) -> None:
+ """Transform to callable if needed."""
+ self.lr_schedule = get_schedule_fn(self.learning_rate)
+
+ def _update_current_progress_remaining(self, num_timesteps: int, total_timesteps: int) -> None:
+ """
+ Compute current progress remaining (starts from 1 and ends to 0)
+
+ :param num_timesteps: current number of timesteps
+ :param total_timesteps:
+ """
+ self._current_progress_remaining = 1.0 - float(num_timesteps) / float(total_timesteps)
+
+ def _update_learning_rate(self, optimizers: Union[List[th.optim.Optimizer], th.optim.Optimizer]) -> None:
+ """
+ Update the optimizers learning rate using the current learning rate schedule
+ and the current progress remaining (from 1 to 0).
+
+ :param optimizers:
+ An optimizer or a list of optimizers.
+ """
+ # Log the current learning rate
+ self.logger.record("train/learning_rate", self.lr_schedule(self._current_progress_remaining))
+
+ if not isinstance(optimizers, list):
+ optimizers = [optimizers]
+ for optimizer in optimizers:
+ update_learning_rate(optimizer, self.lr_schedule(self._current_progress_remaining))
+
+ def _excluded_save_params(self) -> List[str]:
+ """
+ Returns the names of the parameters that should be excluded from being
+ saved by pickling. E.g. replay buffers are skipped by default
+ as they take up a lot of space. PyTorch variables should be excluded
+ with this so they can be stored with ``th.save``.
+
+ :return: List of parameters that should be excluded from being saved with pickle.
+ """
+ return [
+ "policy",
+ "device",
+ "env",
+ "eval_env",
+ "replay_buffer",
+ "rollout_buffer",
+ "_vec_normalize_env",
+ "_episode_storage",
+ "_logger",
+ "_custom_logger",
+ ]
+
+ def _get_policy_from_name(self, policy_name: str) -> Type[BasePolicy]:
+ """
+ Get a policy class from its name representation.
+
+ The goal here is to standardize policy naming, e.g.
+ all algorithms can call upon "MlpPolicy" or "CnnPolicy",
+ and they receive respective policies that work for them.
+
+ :param policy_name: Alias of the policy
+ :return: A policy class (type)
+ """
+
+ if policy_name in self.policy_aliases:
+ return self.policy_aliases[policy_name]
+ else:
+ raise ValueError(f"Policy {policy_name} unknown")
+
+ def _get_torch_save_params(self) -> Tuple[List[str], List[str]]:
+ """
+ Get the name of the torch variables that will be saved with
+ PyTorch ``th.save``, ``th.load`` and ``state_dicts`` instead of the default
+ pickling strategy. This is to handle device placement correctly.
+
+ Names can point to specific variables under classes, e.g.
+ "policy.optimizer" would point to ``optimizer`` object of ``self.policy``
+ if this object.
+
+ :return:
+ List of Torch variables whose state dicts to save (e.g. th.nn.Modules),
+ and list of other Torch variables to store with ``th.save``.
+ """
+ state_dicts = ["policy"]
+
+ return state_dicts, []
+
+ def _init_callback(
+ self,
+ callback: MaybeCallback,
+ eval_env: Optional[VecEnv] = None,
+ eval_freq: int = 10000,
+ n_eval_episodes: int = 5,
+ log_path: Optional[str] = None,
+ ) -> BaseCallback:
+ """
+ :param callback: Callback(s) called at every step with state of the algorithm.
+ :param eval_freq: How many steps between evaluations; if None, do not evaluate.
+ :param n_eval_episodes: How many episodes to play per evaluation
+ :param n_eval_episodes: Number of episodes to rollout during evaluation.
+ :param log_path: Path to a folder where the evaluations will be saved
+ :return: A hybrid callback calling `callback` and performing evaluation.
+ """
+ # Convert a list of callbacks into a callback
+ if isinstance(callback, list):
+ callback = CallbackList(callback)
+
+ # Convert functional callback to object
+ if not isinstance(callback, BaseCallback):
+ callback = ConvertCallback(callback)
+
+ # Create eval callback in charge of the evaluation
+ if eval_env is not None:
+ eval_callback = EvalCallback(
+ eval_env,
+ best_model_save_path=log_path,
+ log_path=log_path,
+ eval_freq=eval_freq,
+ n_eval_episodes=n_eval_episodes,
+ )
+ callback = CallbackList([callback, eval_callback])
+
+ callback.init_callback(self)
+ return callback
+
+ def _setup_learn(
+ self,
+ total_timesteps: int,
+ eval_env: Optional[GymEnv],
+ callback: MaybeCallback = None,
+ eval_freq: int = 10000,
+ n_eval_episodes: int = 5,
+ log_path: Optional[str] = None,
+ reset_num_timesteps: bool = True,
+ tb_log_name: str = "run",
+ ) -> Tuple[int, BaseCallback]:
+ """
+ Initialize different variables needed for training.
+
+ :param total_timesteps: The total number of samples (env steps) to train on
+ :param eval_env: Environment to use for evaluation.
+ :param callback: Callback(s) called at every step with state of the algorithm.
+ :param eval_freq: How many steps between evaluations
+ :param n_eval_episodes: How many episodes to play per evaluation
+ :param log_path: Path to a folder where the evaluations will be saved
+ :param reset_num_timesteps: Whether to reset or not the ``num_timesteps`` attribute
+ :param tb_log_name: the name of the run for tensorboard log
+ :return:
+ """
+ self.start_time = time.time()
+
+ if self.ep_info_buffer is None or reset_num_timesteps:
+ # Initialize buffers if they don't exist, or reinitialize if resetting counters
+ self.ep_info_buffer = deque(maxlen=100)
+ self.ep_success_buffer = deque(maxlen=100)
+
+ if self.action_noise is not None:
+ self.action_noise.reset()
+
+ if reset_num_timesteps:
+ self.num_timesteps = 0
+ self._episode_num = 0
+ else:
+ # Make sure training timesteps are ahead of the internal counter
+ total_timesteps += self.num_timesteps
+ self._total_timesteps = total_timesteps
+ self._num_timesteps_at_start = self.num_timesteps
+
+ # Avoid resetting the environment when calling ``.learn()`` consecutive times
+ if reset_num_timesteps or self._last_obs is None:
+ self._last_obs = self.env.reset() # pytype: disable=annotation-type-mismatch
+ self._last_episode_starts = np.ones((self.env.num_envs,), dtype=bool)
+ # Retrieve unnormalized observation for saving into the buffer
+ if self._vec_normalize_env is not None:
+ self._last_original_obs = self._vec_normalize_env.get_original_obs()
+
+ if eval_env is not None and self.seed is not None:
+ eval_env.seed(self.seed)
+
+ eval_env = self._get_eval_env(eval_env)
+
+ # Configure logger's outputs if no logger was passed
+ if not self._custom_logger:
+ self._logger = utils.configure_logger(self.verbose, self.tensorboard_log, tb_log_name, reset_num_timesteps)
+
+ # Create eval callback if needed
+ callback = self._init_callback(callback, eval_env, eval_freq, n_eval_episodes, log_path)
+
+ return total_timesteps, callback
+
+ def _update_info_buffer(self, infos: List[Dict[str, Any]], dones: Optional[np.ndarray] = None) -> None:
+ """
+ Retrieve reward, episode length, episode success and update the buffer
+ if using Monitor wrapper or a GoalEnv.
+
+ :param infos: List of additional information about the transition.
+ :param dones: Termination signals
+ """
+ if dones is None:
+ dones = np.array([False] * len(infos))
+ for idx, info in enumerate(infos):
+ maybe_ep_info = info.get("episode")
+ maybe_is_success = info.get("is_success")
+ if maybe_ep_info is not None:
+ self.ep_info_buffer.extend([maybe_ep_info])
+ if maybe_is_success is not None and dones[idx]:
+ self.ep_success_buffer.append(maybe_is_success)
+
+ def get_env(self) -> Optional[VecEnv]:
+ """
+ Returns the current environment (can be None if not defined).
+
+ :return: The current environment
+ """
+ return self.env
+
+ def get_vec_normalize_env(self) -> Optional[VecNormalize]:
+ """
+ Return the ``VecNormalize`` wrapper of the training env
+ if it exists.
+
+ :return: The ``VecNormalize`` env.
+ """
+ return self._vec_normalize_env
+
+ def set_env(self, env: GymEnv, force_reset: bool = True) -> None:
+ """
+ Checks the validity of the environment, and if it is coherent, set it as the current environment.
+ Furthermore wrap any non vectorized env into a vectorized
+ checked parameters:
+ - observation_space
+ - action_space
+
+ :param env: The environment for learning a policy
+ :param force_reset: Force call to ``reset()`` before training
+ to avoid unexpected behavior.
+ See issue https://github.com/DLR-RM/stable-baselines3/issues/597
+ """
+ # if it is not a VecEnv, make it a VecEnv
+ # and do other transformations (dict obs, image transpose) if needed
+ env = self._wrap_env(env, self.verbose)
+ # Check that the observation spaces match
+ check_for_correct_spaces(env, self.observation_space, self.action_space)
+ # Update VecNormalize object
+ # otherwise the wrong env may be used, see https://github.com/DLR-RM/stable-baselines3/issues/637
+ self._vec_normalize_env = unwrap_vec_normalize(env)
+
+ # Discard `_last_obs`, this will force the env to reset before training
+ # See issue https://github.com/DLR-RM/stable-baselines3/issues/597
+ if force_reset:
+ self._last_obs = None
+
+ self.n_envs = env.num_envs
+ self.env = env
+
+ @abstractmethod
+ def learn(
+ self,
+ total_timesteps: int,
+ callback: MaybeCallback = None,
+ log_interval: int = 100,
+ tb_log_name: str = "run",
+ eval_env: Optional[GymEnv] = None,
+ eval_freq: int = -1,
+ n_eval_episodes: int = 5,
+ eval_log_path: Optional[str] = None,
+ reset_num_timesteps: bool = True,
+ ) -> "BaseAlgorithm":
+ """
+ Return a trained model.
+
+ :param total_timesteps: The total number of samples (env steps) to train on
+ :param callback: callback(s) called at every step with state of the algorithm.
+ :param log_interval: The number of timesteps before logging.
+ :param tb_log_name: the name of the run for TensorBoard logging
+ :param eval_env: Environment that will be used to evaluate the agent
+ :param eval_freq: Evaluate the agent every ``eval_freq`` timesteps (this may vary a little)
+ :param n_eval_episodes: Number of episode to evaluate the agent
+ :param eval_log_path: Path to a folder where the evaluations will be saved
+ :param reset_num_timesteps: whether or not to reset the current timestep number (used in logging)
+ :return: the trained model
+ """
+
+ def predict(
+ self,
+ observation: np.ndarray,
+ state: Optional[Tuple[np.ndarray, ...]] = None,
+ episode_start: Optional[np.ndarray] = None,
+ deterministic: bool = False,
+ ) -> Tuple[np.ndarray, Optional[Tuple[np.ndarray, ...]]]:
+ """
+ Get the policy action from an observation (and optional hidden state).
+ Includes sugar-coating to handle different observations (e.g. normalizing images).
+
+ :param observation: the input observation
+ :param state: The last hidden states (can be None, used in recurrent policies)
+ :param episode_start: The last masks (can be None, used in recurrent policies)
+ this correspond to beginning of episodes,
+ where the hidden states of the RNN must be reset.
+ :param deterministic: Whether or not to return deterministic actions.
+ :return: the model's action and the next hidden state
+ (used in recurrent policies)
+ """
+ return self.policy.predict(observation, state, episode_start, deterministic)
+
+ def set_random_seed(self, seed: Optional[int] = None) -> None:
+ """
+ Set the seed of the pseudo-random generators
+ (python, numpy, pytorch, gym, action_space)
+
+ :param seed:
+ """
+ if seed is None:
+ return
+ set_random_seed(seed, using_cuda=self.device.type == th.device("cuda").type)
+ self.action_space.seed(seed)
+ if self.env is not None:
+ self.env.seed(seed)
+ if self.eval_env is not None:
+ self.eval_env.seed(seed)
+
+ def set_parameters(
+ self,
+ load_path_or_dict: Union[str, Dict[str, Dict]],
+ exact_match: bool = True,
+ device: Union[th.device, str] = "auto",
+ ) -> None:
+ """
+ Load parameters from a given zip-file or a nested dictionary containing parameters for
+ different modules (see ``get_parameters``).
+
+ :param load_path_or_iter: Location of the saved data (path or file-like, see ``save``), or a nested
+ dictionary containing nn.Module parameters used by the policy. The dictionary maps
+ object names to a state-dictionary returned by ``torch.nn.Module.state_dict()``.
+ :param exact_match: If True, the given parameters should include parameters for each
+ module and each of their parameters, otherwise raises an Exception. If set to False, this
+ can be used to update only specific parameters.
+ :param device: Device on which the code should run.
+ """
+ params = None
+ if isinstance(load_path_or_dict, dict):
+ params = load_path_or_dict
+ else:
+ _, params, _ = load_from_zip_file(load_path_or_dict, device=device)
+
+ # Keep track which objects were updated.
+ # `_get_torch_save_params` returns [params, other_pytorch_variables].
+ # We are only interested in former here.
+ objects_needing_update = set(self._get_torch_save_params()[0])
+ updated_objects = set()
+
+ for name in params:
+ attr = None
+ try:
+ attr = recursive_getattr(self, name)
+ except Exception:
+ # What errors recursive_getattr could throw? KeyError, but
+ # possible something else too (e.g. if key is an int?).
+ # Catch anything for now.
+ raise ValueError(f"Key {name} is an invalid object name.")
+
+ if isinstance(attr, th.optim.Optimizer):
+ # Optimizers do not support "strict" keyword...
+ # Seems like they will just replace the whole
+ # optimizer state with the given one.
+ # On top of this, optimizer state-dict
+ # seems to change (e.g. first ``optim.step()``),
+ # which makes comparing state dictionary keys
+ # invalid (there is also a nesting of dictionaries
+ # with lists with dictionaries with ...), adding to the
+ # mess.
+ #
+ # TL;DR: We might not be able to reliably say
+ # if given state-dict is missing keys.
+ #
+ # Solution: Just load the state-dict as is, and trust
+ # the user has provided a sensible state dictionary.
+ attr.load_state_dict(params[name])
+ else:
+ # Assume attr is th.nn.Module
+ attr.load_state_dict(params[name], strict=exact_match)
+ updated_objects.add(name)
+
+ if exact_match and updated_objects != objects_needing_update:
+ raise ValueError(
+ "Names of parameters do not match agents' parameters: "
+ f"expected {objects_needing_update}, got {updated_objects}"
+ )
+
+ @classmethod
+ def load(
+ cls,
+ path: Union[str, pathlib.Path, io.BufferedIOBase],
+ env: Optional[GymEnv] = None,
+ device: Union[th.device, str] = "auto",
+ custom_objects: Optional[Dict[str, Any]] = None,
+ print_system_info: bool = False,
+ force_reset: bool = True,
+ check_obs_space: bool = True,
+ force_load: bool = False,
+ **kwargs,
+ ) -> "BaseAlgorithm":
+ """
+ Load the model from a zip-file.
+ Warning: ``load`` re-creates the model from scratch, it does not update it in-place!
+ For an in-place load use ``set_parameters`` instead.
+
+ :param path: path to the file (or a file-like) where to
+ load the agent from
+ :param env: the new environment to run the loaded model on
+ (can be None if you only need prediction from a trained model) has priority over any saved environment
+ :param device: Device on which the code should run.
+ :param custom_objects: Dictionary of objects to replace
+ upon loading. If a variable is present in this dictionary as a
+ key, it will not be deserialized and the corresponding item
+ will be used instead. Similar to custom_objects in
+ ``keras.models.load_model``. Useful when you have an object in
+ file that can not be deserialized.
+ :param print_system_info: Whether to print system info from the saved model
+ and the current system info (useful to debug loading issues)
+ :param force_reset: Force call to ``reset()`` before training
+ to avoid unexpected behavior.
+ See https://github.com/DLR-RM/stable-baselines3/issues/597
+ :param kwargs: extra arguments to change the model when loading
+ :return: new model instance with loaded parameters
+ """
+ if print_system_info:
+ print("== CURRENT SYSTEM INFO ==")
+ get_system_info()
+
+ data, params, pytorch_variables = load_from_zip_file(
+ path, device=device, custom_objects=custom_objects, print_system_info=print_system_info
+ )
+
+ # Remove stored device information and replace with ours
+ if "policy_kwargs" in data:
+ if "device" in data["policy_kwargs"]:
+ del data["policy_kwargs"]["device"]
+
+ if not force_load:
+ if "policy_kwargs" in kwargs and kwargs["policy_kwargs"] != data["policy_kwargs"]:
+ raise ValueError(
+ f"The specified policy kwargs do not equal the stored policy kwargs."
+ f"Stored kwargs: {data['policy_kwargs']}, specified kwargs: {kwargs['policy_kwargs']}"
+ )
+
+ if "observation_space" not in data or "action_space" not in data:
+ raise KeyError("The observation_space and action_space were not given, can't verify new environments")
+
+ if env is not None:
+ # Wrap first if needed
+ env = cls._wrap_env(env, data["verbose"])
+ # Check if given env is valid
+ if check_obs_space:
+ check_for_correct_spaces(env, data["observation_space"], data["action_space"])
+ # Discard `_last_obs`, this will force the env to reset before training
+ # See issue https://github.com/DLR-RM/stable-baselines3/issues/597
+ if force_reset and data is not None:
+ data["_last_obs"] = None
+ else:
+ # Use stored env, if one exists. If not, continue as is (can be used for predict)
+ if "env" in data:
+ env = data["env"]
+
+ # noinspection PyArgumentList
+ model = cls( # pytype: disable=not-instantiable,wrong-keyword-args
+ policy=data["policy_class"],
+ env=env,
+ device=device,
+ _init_setup_model=False, # pytype: disable=not-instantiable,wrong-keyword-args
+ )
+
+ # load parameters
+ if not force_load:
+ model.__dict__.update(data) # TODO: Helin cancelled this
+ model.__dict__.update(kwargs)
+ model._setup_model()
+
+ # put state_dicts back in place
+ model.set_parameters(params, exact_match=True, device=device)
+
+ # put other pytorch variables back in place
+ if pytorch_variables is not None:
+ for name in pytorch_variables:
+ # Skip if PyTorch variable was not defined (to ensure backward compatibility).
+ # This happens when using SAC/TQC.
+ # SAC has an entropy coefficient which can be fixed or optimized.
+ # If it is optimized, an additional PyTorch variable `log_ent_coef` is defined,
+ # otherwise it is initialized to `None`.
+ if pytorch_variables[name] is None:
+ continue
+ # Set the data attribute directly to avoid issue when using optimizers
+ # See https://github.com/DLR-RM/stable-baselines3/issues/391
+ recursive_setattr(model, name + ".data", pytorch_variables[name].data)
+
+ # Sample gSDE exploration matrix, so it uses the right device
+ # see issue #44
+ if model.use_sde:
+ model.policy.reset_noise() # pytype: disable=attribute-error
+ return model
+
+ def get_parameters(self) -> Dict[str, Dict]:
+ """
+ Return the parameters of the agent. This includes parameters from different networks, e.g.
+ critics (value functions) and policies (pi functions).
+
+ :return: Mapping of from names of the objects to PyTorch state-dicts.
+ """
+ state_dicts_names, _ = self._get_torch_save_params()
+ params = {}
+ for name in state_dicts_names:
+ attr = recursive_getattr(self, name)
+ # Retrieve state dict
+ params[name] = attr.state_dict()
+ return params
+
+ def save(
+ self,
+ path: Union[str, pathlib.Path, io.BufferedIOBase],
+ exclude: Optional[Iterable[str]] = None,
+ include: Optional[Iterable[str]] = None,
+ ) -> None:
+ """
+ Save all the attributes of the object and the model parameters in a zip-file.
+
+ :param path: path to the file where the rl agent should be saved
+ :param exclude: name of parameters that should be excluded in addition to the default ones
+ :param include: name of parameters that might be excluded but should be included anyway
+ """
+ # Copy parameter list so we don't mutate the original dict
+ data = self.__dict__.copy()
+
+ # Exclude is union of specified parameters (if any) and standard exclusions
+ if exclude is None:
+ exclude = []
+ exclude = set(exclude).union(self._excluded_save_params())
+
+ # Do not exclude params if they are specifically included
+ if include is not None:
+ exclude = exclude.difference(include)
+
+ state_dicts_names, torch_variable_names = self._get_torch_save_params()
+ all_pytorch_variables = state_dicts_names + torch_variable_names
+ for torch_var in all_pytorch_variables:
+ # We need to get only the name of the top most module as we'll remove that
+ var_name = torch_var.split(".")[0]
+ # Any params that are in the save vars must not be saved by data
+ exclude.add(var_name)
+
+ # Remove parameter entries of parameters which are to be excluded
+ for param_name in exclude:
+ data.pop(param_name, None)
+
+ # Build dict of torch variables
+ pytorch_variables = None
+ if torch_variable_names is not None:
+ pytorch_variables = {}
+ for name in torch_variable_names:
+ attr = recursive_getattr(self, name)
+ pytorch_variables[name] = attr
+
+ # Build dict of state_dicts
+ params_to_save = self.get_parameters()
+
+ save_to_zip_file(path, data=data, params=params_to_save, pytorch_variables=pytorch_variables)
diff --git a/dexart-release/stable_baselines3/common/buffers.py b/dexart-release/stable_baselines3/common/buffers.py
new file mode 100644
index 0000000000000000000000000000000000000000..ff420f04a986cf7178d38d3b493125db9a3764e6
--- /dev/null
+++ b/dexart-release/stable_baselines3/common/buffers.py
@@ -0,0 +1,1010 @@
+import warnings
+from abc import ABC, abstractmethod
+from typing import Any, Dict, Generator, List, Optional, Union
+import stable_baselines3.pickle_utils as pickle_utils
+import numpy as np
+import torch as th
+from gym import spaces
+
+from stable_baselines3.common.preprocessing import get_action_dim, get_obs_shape
+from stable_baselines3.common.type_aliases import (
+ DictReplayBufferSamples,
+ DictRolloutBufferSamples,
+ DictSSLRolloutBufferSamples,
+ ReplayBufferSamples,
+ RolloutBufferSamples,
+)
+
+from stable_baselines3.common.vec_env import VecNormalize
+
+try:
+ # Check memory used by replay buffer when possible
+ import psutil
+except ImportError:
+ psutil = None
+
+
+class BaseBuffer(ABC):
+ """
+ Base class that represent a buffer (rollout or replay)
+
+ :param buffer_size: Max number of element in the buffer
+ :param observation_space: Observation space
+ :param action_space: Action space
+ :param device: PyTorch device
+ to which the values will be converted
+ :param n_envs: Number of parallel environments
+ """
+
+ def __init__(
+ self,
+ buffer_size: int,
+ observation_space: spaces.Space,
+ action_space: spaces.Space,
+ device: Union[th.device, str] = "cpu",
+ n_envs: int = 1,
+ ):
+ super().__init__()
+ self.buffer_size = buffer_size
+ self.observation_space = observation_space
+ self.action_space = action_space
+ self.obs_shape = get_obs_shape(observation_space)
+
+ self.action_dim = get_action_dim(action_space)
+ self.pos = 0
+ self.full = False
+ self.device = device
+ self.n_envs = n_envs
+
+ @staticmethod
+ def swap_and_flatten(arr: np.ndarray) -> np.ndarray:
+ """
+ Swap and then flatten axes 0 (buffer_size) and 1 (n_envs)
+ to convert shape from [n_steps, n_envs, ...] (when ... is the shape of the features)
+ to [n_steps * n_envs, ...] (which maintain the order)
+
+ :param arr:
+ :return:
+ """
+ shape = arr.shape
+ if len(shape) < 3:
+ shape = shape + (1,)
+ return arr.swapaxes(0, 1).reshape(shape[0] * shape[1], *shape[2:])
+
+ def size(self) -> int:
+ """
+ :return: The current size of the buffer
+ """
+ if self.full:
+ return self.buffer_size
+ return self.pos
+
+ def add(self, *args, **kwargs) -> None:
+ """
+ Add elements to the buffer.
+ """
+ raise NotImplementedError()
+
+ def extend(self, *args, **kwargs) -> None:
+ """
+ Add a new batch of transitions to the buffer
+ """
+ # Do a for loop along the batch axis
+ for data in zip(*args):
+ self.add(*data)
+
+ def reset(self) -> None:
+ """
+ Reset the buffer.
+ """
+ self.pos = 0
+ self.full = False
+
+ def sample(self, batch_size: int, env: Optional[VecNormalize] = None):
+ """
+ :param batch_size: Number of element to sample
+ :param env: associated gym VecEnv
+ to normalize the observations/rewards when sampling
+ :return:
+ """
+ upper_bound = self.buffer_size if self.full else self.pos
+ batch_inds = np.random.randint(0, upper_bound, size=batch_size)
+ return self._get_samples(batch_inds, env=env)
+
+ @abstractmethod
+ def _get_samples(
+ self, batch_inds: np.ndarray, env: Optional[VecNormalize] = None
+ ) -> Union[ReplayBufferSamples, RolloutBufferSamples]:
+ """
+ :param batch_inds:
+ :param env:
+ :return:
+ """
+ raise NotImplementedError()
+
+ def to_torch(self, array: np.ndarray, copy: bool = True) -> th.Tensor:
+ """
+ Convert a numpy array to a PyTorch tensor.
+ Note: it copies the data by default
+
+ :param array:
+ :param copy: Whether to copy or not the data
+ (may be useful to avoid changing things be reference)
+ :return:
+ """
+ if copy:
+ return th.tensor(array).to(self.device)
+ return th.as_tensor(array).to(self.device)
+
+ @staticmethod
+ def _normalize_obs(
+ obs: Union[np.ndarray, Dict[str, np.ndarray]],
+ env: Optional[VecNormalize] = None,
+ ) -> Union[np.ndarray, Dict[str, np.ndarray]]:
+ if env is not None:
+ return env.normalize_obs(obs)
+ return obs
+
+ @staticmethod
+ def _normalize_reward(reward: np.ndarray, env: Optional[VecNormalize] = None) -> np.ndarray:
+ if env is not None:
+ return env.normalize_reward(reward).astype(np.float32)
+ return reward
+
+
+class ExpertBuffer(BaseBuffer):
+ def __init__(self,
+ buffer_size: int,
+ observation_space: spaces.Space,
+ action_space: spaces.Space,
+ device: Union[th.device, str] = "cpu",
+ n_envs: int = 1,
+ dataset_path=''
+ ):
+ super(ExpertBuffer, self).__init__(buffer_size, observation_space, action_space, device, n_envs=n_envs)
+ data = pickle_utils.load_data(dataset_path)
+
+ data_obs = []
+ data_action = []
+
+ self.optimize_memory_usage = False
+
+ # print(data[0])
+
+ for trajectory in data:
+ print(trajectory.keys())
+ # for k, v in trajectory.items():
+ data_obs.append(trajectory['observations'])
+ data_action.append(trajectory['actions'])
+ self.observations = np.concatenate(data_obs, axis=0)
+ self.actions = np.concatenate(data_action, axis=0)
+
+ assert len(self.observations) == len(self.actions), "Demo Dataset Error: Obs num does not match Action num."
+ print('Expert buffer info:', self.observations.shape, self.actions.shape)
+ self.buffer_size = len(self.observations)
+ self.full = True
+
+ def add(
+ self,
+ obs: np.ndarray,
+ next_obs: np.ndarray,
+ action: np.ndarray,
+ reward: np.ndarray,
+ done: np.ndarray,
+ infos: List[Dict[str, Any]],
+ ) -> None:
+ assert False, "We do not expect user to use this method."
+
+ def _get_samples(self, batch_inds: np.ndarray, env: Optional[VecNormalize] = None) -> ReplayBufferSamples:
+ # Sample randomly the env idx
+ data = (
+ self._normalize_obs(self.observations[batch_inds, :], env),
+ self.actions[batch_inds, :],
+ )
+ return ReplayBufferSamples(*tuple(map(self.to_torch, data)))
+
+ def sample(self, batch_size: int, env: Optional[VecNormalize] = None) -> ReplayBufferSamples:
+ """
+ Sample elements from the replay buffer.
+ Custom sampling when using memory efficient variant,
+ as we should not sample the element with index `self.pos`
+ See https://github.com/DLR-RM/stable-baselines3/pull/28#issuecomment-637559274
+
+ :param batch_size: Number of element to sample
+ :param env: associated gym VecEnv
+ to normalize the observations/rewards when sampling
+ :return:
+ """
+ if not self.optimize_memory_usage:
+ return super().sample(batch_size=batch_size, env=env)
+ # Do not sample the element with index `self.pos` as the transitions is invalid
+ # (we use only one array to store `obs` and `next_obs`)
+ if self.full:
+ batch_inds = (np.random.randint(1, self.buffer_size, size=batch_size) + self.pos) % self.buffer_size
+ else:
+ batch_inds = np.random.randint(0, self.pos, size=batch_size)
+ return self._get_samples(batch_inds, env=env)
+
+ def get_all_samples(self, env=None):
+ data = (
+ self._normalize_obs(self.observations, env),
+ self.actions,
+ )
+ return ReplayBufferSamples(*tuple(map(self.to_torch, data)), None, None, None)
+
+
+class ReplayBuffer(BaseBuffer):
+ """
+ Replay buffer used in off-policy algorithms like SAC/TD3.
+
+ :param buffer_size: Max number of element in the buffer
+ :param observation_space: Observation space
+ :param action_space: Action space
+ :param device:
+ :param n_envs: Number of parallel environments
+ :param optimize_memory_usage: Enable a memory efficient variant
+ of the replay buffer which reduces by almost a factor two the memory used,
+ at a cost of more complexity.
+ See https://github.com/DLR-RM/stable-baselines3/issues/37#issuecomment-637501195
+ and https://github.com/DLR-RM/stable-baselines3/pull/28#issuecomment-637559274
+ :param handle_timeout_termination: Handle timeout termination (due to timelimit)
+ separately and treat the task as infinite horizon task.
+ https://github.com/DLR-RM/stable-baselines3/issues/284
+ """
+
+ def __init__(
+ self,
+ buffer_size: int,
+ observation_space: spaces.Space,
+ action_space: spaces.Space,
+ device: Union[th.device, str] = "cpu",
+ n_envs: int = 1,
+ optimize_memory_usage: bool = False,
+ handle_timeout_termination: bool = True,
+ ):
+ super().__init__(buffer_size, observation_space, action_space, device, n_envs=n_envs)
+
+ # Adjust buffer size
+ self.buffer_size = max(buffer_size // n_envs, 1)
+
+ # Check that the replay buffer can fit into the memory
+ if psutil is not None:
+ mem_available = psutil.virtual_memory().available
+
+ self.optimize_memory_usage = optimize_memory_usage
+
+ self.observations = np.zeros((self.buffer_size, self.n_envs) + self.obs_shape, dtype=observation_space.dtype)
+
+ if optimize_memory_usage:
+ # `observations` contains also the next observation
+ self.next_observations = None
+ else:
+ self.next_observations = np.zeros((self.buffer_size, self.n_envs) + self.obs_shape, dtype=observation_space.dtype)
+
+ self.actions = np.zeros((self.buffer_size, self.n_envs, self.action_dim), dtype=action_space.dtype)
+
+ self.rewards = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32)
+ self.dones = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32)
+ # Handle timeouts termination properly if needed
+ # see https://github.com/DLR-RM/stable-baselines3/issues/284
+ self.handle_timeout_termination = handle_timeout_termination
+ self.timeouts = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32)
+
+ if psutil is not None:
+ total_memory_usage = self.observations.nbytes + self.actions.nbytes + self.rewards.nbytes + self.dones.nbytes
+
+ if self.next_observations is not None:
+ total_memory_usage += self.next_observations.nbytes
+
+ if total_memory_usage > mem_available:
+ # Convert to GB
+ total_memory_usage /= 1e9
+ mem_available /= 1e9
+ warnings.warn(
+ "This system does not have apparently enough memory to store the complete "
+ f"replay buffer {total_memory_usage:.2f}GB > {mem_available:.2f}GB"
+ )
+
+ def add(
+ self,
+ obs: np.ndarray,
+ next_obs: np.ndarray,
+ action: np.ndarray,
+ reward: np.ndarray,
+ done: np.ndarray,
+ infos: List[Dict[str, Any]],
+ ) -> None:
+
+ # Reshape needed when using multiple envs with discrete observations
+ # as numpy cannot broadcast (n_discrete,) to (n_discrete, 1)
+ if isinstance(self.observation_space, spaces.Discrete):
+ obs = obs.reshape((self.n_envs,) + self.obs_shape)
+ next_obs = next_obs.reshape((self.n_envs,) + self.obs_shape)
+
+ # Same, for actions
+ if isinstance(self.action_space, spaces.Discrete):
+ action = action.reshape((self.n_envs, self.action_dim))
+
+ # Copy to avoid modification by reference
+ self.observations[self.pos] = np.array(obs).copy()
+
+ if self.optimize_memory_usage:
+ self.observations[(self.pos + 1) % self.buffer_size] = np.array(next_obs).copy()
+ else:
+ self.next_observations[self.pos] = np.array(next_obs).copy()
+
+ self.actions[self.pos] = np.array(action).copy()
+ self.rewards[self.pos] = np.array(reward).copy()
+ self.dones[self.pos] = np.array(done).copy()
+
+ if self.handle_timeout_termination:
+ self.timeouts[self.pos] = np.array([info.get("TimeLimit.truncated", False) for info in infos])
+
+ self.pos += 1
+ if self.pos == self.buffer_size:
+ self.full = True
+ self.pos = 0
+
+ def sample(self, batch_size: int, env: Optional[VecNormalize] = None) -> ReplayBufferSamples:
+ """
+ Sample elements from the replay buffer.
+ Custom sampling when using memory efficient variant,
+ as we should not sample the element with index `self.pos`
+ See https://github.com/DLR-RM/stable-baselines3/pull/28#issuecomment-637559274
+
+ :param batch_size: Number of element to sample
+ :param env: associated gym VecEnv
+ to normalize the observations/rewards when sampling
+ :return:
+ """
+ if not self.optimize_memory_usage:
+ return super().sample(batch_size=batch_size, env=env)
+ # Do not sample the element with index `self.pos` as the transitions is invalid
+ # (we use only one array to store `obs` and `next_obs`)
+ if self.full:
+ batch_inds = (np.random.randint(1, self.buffer_size, size=batch_size) + self.pos) % self.buffer_size
+ else:
+ batch_inds = np.random.randint(0, self.pos, size=batch_size)
+ return self._get_samples(batch_inds, env=env)
+
+ def _get_samples(self, batch_inds: np.ndarray, env: Optional[VecNormalize] = None) -> ReplayBufferSamples:
+ # Sample randomly the env idx
+ env_indices = np.random.randint(0, high=self.n_envs, size=(len(batch_inds),))
+
+ if self.optimize_memory_usage:
+ next_obs = self._normalize_obs(self.observations[(batch_inds + 1) % self.buffer_size, env_indices, :], env)
+ else:
+ next_obs = self._normalize_obs(self.next_observations[batch_inds, env_indices, :], env)
+
+ data = (
+ self._normalize_obs(self.observations[batch_inds, env_indices, :], env),
+ self.actions[batch_inds, env_indices, :],
+ next_obs,
+ # Only use dones that are not due to timeouts
+ # deactivated by default (timeouts is initialized as an array of False)
+ (self.dones[batch_inds, env_indices] * (1 - self.timeouts[batch_inds, env_indices])).reshape(-1, 1),
+ self._normalize_reward(self.rewards[batch_inds, env_indices].reshape(-1, 1), env),
+ )
+ return ReplayBufferSamples(*tuple(map(self.to_torch, data)))
+
+
+class RolloutBuffer(BaseBuffer):
+ """
+ Rollout buffer used in on-policy algorithms like A2C/PPO.
+ It corresponds to ``buffer_size`` transitions collected
+ using the current policy.
+ This experience will be discarded after the policy update.
+ In order to use PPO objective, we also store the current value of each state
+ and the log probability of each taken action.
+
+ The term rollout here refers to the model-free notion and should not
+ be used with the concept of rollout used in model-based RL or planning.
+ Hence, it is only involved in policy and value function training but not action selection.
+
+ :param buffer_size: Max number of element in the buffer
+ :param observation_space: Observation space
+ :param action_space: Action space
+ :param device:
+ :param gae_lambda: Factor for trade-off of bias vs variance for Generalized Advantage Estimator
+ Equivalent to classic advantage when set to 1.
+ :param gamma: Discount factor
+ :param n_envs: Number of parallel environments
+ """
+
+ def __init__(
+ self,
+ buffer_size: int,
+ observation_space: spaces.Space,
+ action_space: spaces.Space,
+ device: Union[th.device, str] = "cpu",
+ gae_lambda: float = 1,
+ gamma: float = 0.99,
+ n_envs: int = 1,
+ ):
+
+ super().__init__(buffer_size, observation_space, action_space, device, n_envs=n_envs)
+ self.gae_lambda = gae_lambda
+ self.gamma = gamma
+ self.observations, self.actions, self.rewards, self.advantages = None, None, None, None
+ self.returns, self.episode_starts, self.values, self.log_probs = None, None, None, None
+ self.generator_ready = False
+ self.reset()
+
+ def reset(self) -> None:
+
+ self.observations = np.zeros((self.buffer_size, self.n_envs) + self.obs_shape, dtype=np.float32)
+ self.actions = np.zeros((self.buffer_size, self.n_envs, self.action_dim), dtype=np.float32)
+ self.rewards = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32)
+ self.returns = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32)
+ self.episode_starts = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32)
+ self.values = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32)
+ self.log_probs = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32)
+ self.advantages = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32)
+ self.generator_ready = False
+ super().reset()
+
+ def compute_returns_and_advantage(self, last_values: th.Tensor, dones: np.ndarray) -> None:
+ """
+ Post-processing step: compute the lambda-return (TD(lambda) estimate)
+ and GAE(lambda) advantage.
+
+ Uses Generalized Advantage Estimation (https://arxiv.org/abs/1506.02438)
+ to compute the advantage. To obtain Monte-Carlo advantage estimate (A(s) = R - V(S))
+ where R is the sum of discounted reward with value bootstrap
+ (because we don't always have full episode), set ``gae_lambda=1.0`` during initialization.
+
+ The TD(lambda) estimator has also two special cases:
+ - TD(1) is Monte-Carlo estimate (sum of discounted rewards)
+ - TD(0) is one-step estimate with bootstrapping (r_t + gamma * v(s_{t+1}))
+
+ For more information, see discussion in https://github.com/DLR-RM/stable-baselines3/pull/375.
+
+ :param last_values: state value estimation for the last step (one for each env)
+ :param dones: if the last step was a terminal step (one bool for each env).
+ """
+ # Convert to numpy
+ last_values = last_values.clone().cpu().numpy().flatten()
+
+ last_gae_lam = 0
+ for step in reversed(range(self.buffer_size)):
+ if step == self.buffer_size - 1:
+ next_non_terminal = 1.0 - dones
+ next_values = last_values
+ else:
+ next_non_terminal = 1.0 - self.episode_starts[step + 1]
+ next_values = self.values[step + 1]
+ delta = self.rewards[step] + self.gamma * next_values * next_non_terminal - self.values[step]
+ last_gae_lam = delta + self.gamma * self.gae_lambda * next_non_terminal * last_gae_lam
+ self.advantages[step] = last_gae_lam
+ # TD(lambda) estimator, see Github PR #375 or "Telescoping in TD(lambda)"
+ # in David Silver Lecture 4: https://www.youtube.com/watch?v=PnHCvfgC_ZA
+ self.returns = self.advantages + self.values
+
+ def add(
+ self,
+ obs: np.ndarray,
+ action: np.ndarray,
+ reward: np.ndarray,
+ episode_start: np.ndarray,
+ value: th.Tensor,
+ log_prob: th.Tensor,
+ ) -> None:
+ """
+ :param obs: Observation
+ :param action: Action
+ :param reward:
+ :param episode_start: Start of episode signal.
+ :param value: estimated value of the current state
+ following the current policy.
+ :param log_prob: log probability of the action
+ following the current policy.
+ """
+ if len(log_prob.shape) == 0:
+ # Reshape 0-d tensor to avoid error
+ log_prob = log_prob.reshape(-1, 1)
+
+ # Reshape needed when using multiple envs with discrete observations
+ # as numpy cannot broadcast (n_discrete,) to (n_discrete, 1)
+ if isinstance(self.observation_space, spaces.Discrete):
+ obs = obs.reshape((self.n_envs,) + self.obs_shape)
+
+ self.observations[self.pos] = np.array(obs).copy()
+ self.actions[self.pos] = np.array(action).copy()
+ self.rewards[self.pos] = np.array(reward).copy()
+ self.episode_starts[self.pos] = np.array(episode_start).copy()
+ self.values[self.pos] = value.clone().cpu().numpy().flatten()
+ self.log_probs[self.pos] = log_prob.clone().cpu().numpy()
+ self.pos += 1
+ if self.pos == self.buffer_size:
+ self.full = True
+
+ def get(self, batch_size: Optional[int] = None) -> Generator[RolloutBufferSamples, None, None]:
+ assert self.full, ""
+ indices = np.random.permutation(self.buffer_size * self.n_envs)
+ # Prepare the data
+ if not self.generator_ready:
+
+ _tensor_names = [
+ "observations",
+ "actions",
+ "values",
+ "log_probs",
+ "advantages",
+ "returns",
+ ]
+
+ for tensor in _tensor_names:
+ self.__dict__[tensor] = self.swap_and_flatten(self.__dict__[tensor])
+ self.generator_ready = True
+
+ # Return everything, don't create minibatches
+ if batch_size is None:
+ batch_size = self.buffer_size * self.n_envs
+
+ start_idx = 0
+ while start_idx < self.buffer_size * self.n_envs:
+ yield self._get_samples(indices[start_idx : start_idx + batch_size])
+ start_idx += batch_size
+
+ def _get_samples(self, batch_inds: np.ndarray, env: Optional[VecNormalize] = None) -> RolloutBufferSamples:
+ data = (
+ self.observations[batch_inds],
+ self.actions[batch_inds],
+ self.values[batch_inds].flatten(),
+ self.log_probs[batch_inds].flatten(),
+ self.advantages[batch_inds].flatten(),
+ self.returns[batch_inds].flatten(),
+ )
+ return RolloutBufferSamples(*tuple(map(self.to_torch, data)))
+
+
+class DictReplayBuffer(ReplayBuffer):
+ """
+ Dict Replay buffer used in off-policy algorithms like SAC/TD3.
+ Extends the ReplayBuffer to use dictionary observations
+
+ :param buffer_size: Max number of element in the buffer
+ :param observation_space: Observation space
+ :param action_space: Action space
+ :param device:
+ :param n_envs: Number of parallel environments
+ :param optimize_memory_usage: Enable a memory efficient variant
+ Disabled for now (see https://github.com/DLR-RM/stable-baselines3/pull/243#discussion_r531535702)
+ :param handle_timeout_termination: Handle timeout termination (due to timelimit)
+ separately and treat the task as infinite horizon task.
+ https://github.com/DLR-RM/stable-baselines3/issues/284
+ """
+
+ def __init__(
+ self,
+ buffer_size: int,
+ observation_space: spaces.Space,
+ action_space: spaces.Space,
+ device: Union[th.device, str] = "cpu",
+ n_envs: int = 1,
+ optimize_memory_usage: bool = False,
+ handle_timeout_termination: bool = True,
+ ):
+ super(ReplayBuffer, self).__init__(buffer_size, observation_space, action_space, device, n_envs=n_envs)
+
+ assert isinstance(self.obs_shape, dict), "DictReplayBuffer must be used with Dict obs space only"
+ self.buffer_size = max(buffer_size // n_envs, 1)
+
+ # Check that the replay buffer can fit into the memory
+ if psutil is not None:
+ mem_available = psutil.virtual_memory().available
+
+ assert optimize_memory_usage is False, "DictReplayBuffer does not support optimize_memory_usage"
+ # disabling as this adds quite a bit of complexity
+ # https://github.com/DLR-RM/stable-baselines3/pull/243#discussion_r531535702
+ self.optimize_memory_usage = optimize_memory_usage
+
+ self.observations = {
+ key: np.zeros((self.buffer_size, self.n_envs) + _obs_shape, dtype=observation_space[key].dtype)
+ for key, _obs_shape in self.obs_shape.items()
+ }
+ self.next_observations = {
+ key: np.zeros((self.buffer_size, self.n_envs) + _obs_shape, dtype=observation_space[key].dtype)
+ for key, _obs_shape in self.obs_shape.items()
+ }
+
+ self.actions = np.zeros((self.buffer_size, self.n_envs, self.action_dim), dtype=action_space.dtype)
+ self.rewards = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32)
+ self.dones = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32)
+
+ # Handle timeouts termination properly if needed
+ # see https://github.com/DLR-RM/stable-baselines3/issues/284
+ self.handle_timeout_termination = handle_timeout_termination
+ self.timeouts = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32)
+
+ if psutil is not None:
+ obs_nbytes = 0
+ for _, obs in self.observations.items():
+ obs_nbytes += obs.nbytes
+
+ total_memory_usage = obs_nbytes + self.actions.nbytes + self.rewards.nbytes + self.dones.nbytes
+ if self.next_observations is not None:
+ next_obs_nbytes = 0
+ for _, obs in self.observations.items():
+ next_obs_nbytes += obs.nbytes
+ total_memory_usage += next_obs_nbytes
+
+ if total_memory_usage > mem_available:
+ # Convert to GB
+ total_memory_usage /= 1e9
+ mem_available /= 1e9
+ warnings.warn(
+ "This system does not have apparently enough memory to store the complete "
+ f"replay buffer {total_memory_usage:.2f}GB > {mem_available:.2f}GB"
+ )
+
+ def add(
+ self,
+ obs: Dict[str, np.ndarray],
+ next_obs: Dict[str, np.ndarray],
+ action: np.ndarray,
+ reward: np.ndarray,
+ done: np.ndarray,
+ infos: List[Dict[str, Any]],
+ ) -> None:
+ # Copy to avoid modification by reference
+ for key in self.observations.keys():
+ # Reshape needed when using multiple envs with discrete observations
+ # as numpy cannot broadcast (n_discrete,) to (n_discrete, 1)
+ if isinstance(self.observation_space.spaces[key], spaces.Discrete):
+ obs[key] = obs[key].reshape((self.n_envs,) + self.obs_shape[key])
+ self.observations[key][self.pos] = np.array(obs[key])
+
+ for key in self.next_observations.keys():
+ if isinstance(self.observation_space.spaces[key], spaces.Discrete):
+ next_obs[key] = next_obs[key].reshape((self.n_envs,) + self.obs_shape[key])
+ self.next_observations[key][self.pos] = np.array(next_obs[key]).copy()
+
+ # Same reshape, for actions
+ if isinstance(self.action_space, spaces.Discrete):
+ action = action.reshape((self.n_envs, self.action_dim))
+
+ self.actions[self.pos] = np.array(action).copy()
+ self.rewards[self.pos] = np.array(reward).copy()
+ self.dones[self.pos] = np.array(done).copy()
+
+ if self.handle_timeout_termination:
+ self.timeouts[self.pos] = np.array([info.get("TimeLimit.truncated", False) for info in infos])
+
+ self.pos += 1
+ if self.pos == self.buffer_size:
+ self.full = True
+ self.pos = 0
+
+ def sample(self, batch_size: int, env: Optional[VecNormalize] = None) -> DictReplayBufferSamples:
+ """
+ Sample elements from the replay buffer.
+
+ :param batch_size: Number of element to sample
+ :param env: associated gym VecEnv
+ to normalize the observations/rewards when sampling
+ :return:
+ """
+ return super(ReplayBuffer, self).sample(batch_size=batch_size, env=env)
+
+ def _get_samples(self, batch_inds: np.ndarray, env: Optional[VecNormalize] = None) -> DictReplayBufferSamples:
+ # Sample randomly the env idx
+ env_indices = np.random.randint(0, high=self.n_envs, size=(len(batch_inds),))
+
+ # Normalize if needed and remove extra dimension (we are using only one env for now)
+ obs_ = self._normalize_obs({key: obs[batch_inds, env_indices, :] for key, obs in self.observations.items()}, env)
+ next_obs_ = self._normalize_obs(
+ {key: obs[batch_inds, env_indices, :] for key, obs in self.next_observations.items()}, env
+ )
+
+ # Convert to torch tensor
+ observations = {key: self.to_torch(obs) for key, obs in obs_.items()}
+ next_observations = {key: self.to_torch(obs) for key, obs in next_obs_.items()}
+
+ return DictReplayBufferSamples(
+ observations=observations,
+ actions=self.to_torch(self.actions[batch_inds, env_indices]),
+ next_observations=next_observations,
+ # Only use dones that are not due to timeouts
+ # deactivated by default (timeouts is initialized as an array of False)
+ dones=self.to_torch(self.dones[batch_inds, env_indices] * (1 - self.timeouts[batch_inds, env_indices])).reshape(
+ -1, 1
+ ),
+ rewards=self.to_torch(self._normalize_reward(self.rewards[batch_inds, env_indices].reshape(-1, 1), env)),
+ )
+
+
+class DictRolloutBuffer(RolloutBuffer):
+ """
+ Dict Rollout buffer used in on-policy algorithms like A2C/PPO.
+ Extends the RolloutBuffer to use dictionary observations
+
+ It corresponds to ``buffer_size`` transitions collected
+ using the current policy.
+ This experience will be discarded after the policy update.
+ In order to use PPO objective, we also store the current value of each state
+ and the log probability of each taken action.
+
+ The term rollout here refers to the model-free notion and should not
+ be used with the concept of rollout used in model-based RL or planning.
+ Hence, it is only involved in policy and value function training but not action selection.
+
+ :param buffer_size: Max number of element in the buffer
+ :param observation_space: Observation space
+ :param action_space: Action space
+ :param device:
+ :param gae_lambda: Factor for trade-off of bias vs variance for Generalized Advantage Estimator
+ Equivalent to Monte-Carlo advantage estimate when set to 1.
+ :param gamma: Discount factor
+ :param n_envs: Number of parallel environments
+ """
+
+ def __init__(
+ self,
+ buffer_size: int,
+ observation_space: spaces.Space,
+ action_space: spaces.Space,
+ device: Union[th.device, str] = "cpu",
+ gae_lambda: float = 1,
+ gamma: float = 0.99,
+ n_envs: int = 1,
+ ):
+
+ super(RolloutBuffer, self).__init__(buffer_size, observation_space, action_space, device, n_envs=n_envs)
+
+ assert isinstance(self.obs_shape, dict), "DictRolloutBuffer must be used with Dict obs space only"
+
+ self.gae_lambda = gae_lambda
+ self.gamma = gamma
+ self.observations, self.actions, self.rewards, self.advantages = None, None, None, None
+ self.returns, self.episode_starts, self.values, self.log_probs = None, None, None, None
+ self.generator_ready = False
+ self.reset()
+
+ def reset(self) -> None:
+ assert isinstance(self.obs_shape, dict), "DictRolloutBuffer must be used with Dict obs space only"
+ self.observations = {}
+ for key, obs_input_shape in self.obs_shape.items():
+ self.observations[key] = np.zeros((self.buffer_size, self.n_envs) + obs_input_shape, dtype=np.float32)
+ self.actions = np.zeros((self.buffer_size, self.n_envs, self.action_dim), dtype=np.float32)
+ self.rewards = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32)
+ self.returns = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32)
+ self.episode_starts = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32)
+ self.values = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32)
+ self.log_probs = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32)
+ self.advantages = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32)
+ self.generator_ready = False
+ super(RolloutBuffer, self).reset()
+
+ def add(
+ self,
+ obs: Dict[str, np.ndarray],
+ action: np.ndarray,
+ reward: np.ndarray,
+ episode_start: np.ndarray,
+ value: th.Tensor,
+ log_prob: th.Tensor,
+ ) -> None:
+ """
+ :param obs: Observation
+ :param action: Action
+ :param reward:
+ :param episode_start: Start of episode signal.
+ :param value: estimated value of the current state
+ following the current policy.
+ :param log_prob: log probability of the action
+ following the current policy.
+ """
+ if len(log_prob.shape) == 0:
+ # Reshape 0-d tensor to avoid error
+ log_prob = log_prob.reshape(-1, 1)
+
+ for key in self.observations.keys():
+ obs_ = np.array(obs[key]).copy()
+ # Reshape needed when using multiple envs with discrete observations
+ # as numpy cannot broadcast (n_discrete,) to (n_discrete, 1)
+ if isinstance(self.observation_space.spaces[key], spaces.Discrete):
+ obs_ = obs_.reshape((self.n_envs,) + self.obs_shape[key])
+ self.observations[key][self.pos] = obs_
+
+ self.actions[self.pos] = np.array(action).copy()
+ self.rewards[self.pos] = np.array(reward).copy()
+ self.episode_starts[self.pos] = np.array(episode_start).copy()
+ self.values[self.pos] = value.clone().cpu().numpy().flatten()
+ self.log_probs[self.pos] = log_prob.clone().cpu().numpy()
+ self.pos += 1
+ if self.pos == self.buffer_size:
+ self.full = True
+
+ def get(self, batch_size: Optional[int] = None) -> Generator[DictRolloutBufferSamples, None, None]:
+ assert self.full, ""
+ indices = np.random.permutation(self.buffer_size * self.n_envs)
+ # Prepare the data
+ if not self.generator_ready:
+
+ for key, obs in self.observations.items():
+ self.observations[key] = self.swap_and_flatten(obs)
+
+ _tensor_names = ["actions", "values", "log_probs", "advantages", "returns"]
+
+ for tensor in _tensor_names:
+ self.__dict__[tensor] = self.swap_and_flatten(self.__dict__[tensor])
+ self.generator_ready = True
+
+ # Return everything, don't create minibatches
+ if batch_size is None:
+ batch_size = self.buffer_size * self.n_envs
+
+ start_idx = 0
+ while start_idx < self.buffer_size * self.n_envs:
+ yield self._get_samples(indices[start_idx : start_idx + batch_size])
+ start_idx += batch_size
+
+ def _get_samples(self, batch_inds: np.ndarray, env: Optional[VecNormalize] = None) -> DictRolloutBufferSamples:
+
+ return DictRolloutBufferSamples(
+ observations={key: self.to_torch(obs[batch_inds]) for (key, obs) in self.observations.items()},
+ actions=self.to_torch(self.actions[batch_inds]),
+ old_values=self.to_torch(self.values[batch_inds].flatten()),
+ old_log_prob=self.to_torch(self.log_probs[batch_inds].flatten()),
+ advantages=self.to_torch(self.advantages[batch_inds].flatten()),
+ returns=self.to_torch(self.returns[batch_inds].flatten()),
+ )
+
+
+class DictSSLRolloutBuffer(RolloutBuffer):
+ """
+ Dict Rollout buffer used in on-policy algorithms like A2C/PPO.
+ Extends the RolloutBuffer to use dictionary observations
+
+ It corresponds to ``buffer_size`` transitions collected
+ using the current policy.
+ This experience will be discarded after the policy update.
+ In order to use PPO objective, we also store the current value of each state
+ and the log probability of each taken action.
+
+ The term rollout here refers to the model-free notion and should not
+ be used with the concept of rollout used in model-based RL or planning.
+ Hence, it is only involved in policy and value function training but not action selection.
+
+ :param buffer_size: Max number of element in the buffer
+ :param observation_space: Observation space
+ :param action_space: Action space
+ :param device:
+ :param gae_lambda: Factor for trade-off of bias vs variance for Generalized Advantage Estimator
+ Equivalent to Monte-Carlo advantage estimate when set to 1.
+ :param gamma: Discount factor
+ :param n_envs: Number of parallel environments
+ """
+
+ def __init__(
+ self,
+ buffer_size: int,
+ observation_space: spaces.Space,
+ action_space: spaces.Space,
+ device: Union[th.device, str] = "cpu",
+ gae_lambda: float = 1,
+ gamma: float = 0.99,
+ n_envs: int = 1,
+ ):
+
+ super(RolloutBuffer, self).__init__(buffer_size, observation_space, action_space, device, n_envs=n_envs)
+
+ assert isinstance(self.obs_shape, dict), "DictRolloutBuffer must be used with Dict obs space only"
+
+ self.gae_lambda = gae_lambda
+ self.gamma = gamma
+ self.observations, self.actions, self.rewards, self.advantages = None, None, None, None
+ self.returns, self.episode_starts, self.values, self.log_probs = None, None, None, None
+ self.generator_ready = False
+ self.reset()
+
+ def reset(self) -> None:
+ assert isinstance(self.obs_shape, dict), "DictRolloutBuffer must be used with Dict obs space only"
+ self.observations = {}
+ for key, obs_input_shape in self.obs_shape.items():
+ self.observations[key] = np.zeros((self.buffer_size, self.n_envs) + obs_input_shape, dtype=np.float32)
+ self.actions = np.zeros((self.buffer_size, self.n_envs, self.action_dim), dtype=np.float32)
+ self.rewards = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32)
+ self.returns = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32)
+ self.episode_starts = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32)
+ self.values = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32)
+ self.log_probs = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32)
+ self.advantages = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32)
+ self.generator_ready = False
+ super(RolloutBuffer, self).reset()
+
+ def add(
+ self,
+ obs: Dict[str, np.ndarray],
+ action: np.ndarray,
+ reward: np.ndarray,
+ episode_start: np.ndarray,
+ value: th.Tensor,
+ log_prob: th.Tensor,
+ ) -> None:
+ """
+ :param obs: Observation
+ :param action: Action
+ :param reward:
+ :param episode_start: Start of episode signal.
+ :param value: estimated value of the current state
+ following the current policy.
+ :param log_prob: log probability of the action
+ following the current policy.
+ """
+ if len(log_prob.shape) == 0:
+ # Reshape 0-d tensor to avoid error
+ log_prob = log_prob.reshape(-1, 1)
+
+ for key in self.observations.keys():
+ obs_ = np.array(obs[key]).copy()
+ # Reshape needed when using multiple envs with discrete observations
+ # as numpy cannot broadcast (n_discrete,) to (n_discrete, 1)
+ if isinstance(self.observation_space.spaces[key], spaces.Discrete):
+ obs_ = obs_.reshape((self.n_envs,) + self.obs_shape[key])
+ self.observations[key][self.pos] = obs_
+
+ self.actions[self.pos] = np.array(action).copy()
+ self.rewards[self.pos] = np.array(reward).copy()
+ self.episode_starts[self.pos] = np.array(episode_start).copy()
+ self.values[self.pos] = value.clone().cpu().numpy().flatten()
+ self.log_probs[self.pos] = log_prob.clone().cpu().numpy()
+ self.pos += 1
+ if self.pos == self.buffer_size:
+ self.full = True
+
+ def get(self, batch_size: Optional[int] = None) -> Generator[DictRolloutBufferSamples, None, None]:
+ assert self.full, ""
+ indices = np.random.permutation(self.buffer_size * self.n_envs)
+ # Prepare the data
+ if not self.generator_ready:
+
+ for key, obs in self.observations.items():
+ self.observations[key] = self.swap_and_flatten(obs)
+
+ _tensor_names = ["actions", "values", "log_probs", "advantages", "returns"]
+
+ for tensor in _tensor_names:
+ self.__dict__[tensor] = self.swap_and_flatten(self.__dict__[tensor])
+ self.generator_ready = True
+
+ # Return everything, don't create minibatches
+ if batch_size is None:
+ batch_size = self.buffer_size * self.n_envs
+
+ start_idx = 0
+ while start_idx < self.buffer_size * self.n_envs:
+ yield self._get_samples(indices[start_idx : start_idx + batch_size])
+ start_idx += batch_size
+
+ def _get_samples(self, batch_inds: np.ndarray, env: Optional[VecNormalize] = None) -> DictSSLRolloutBufferSamples:
+
+ def get_next_obs(obs, ind, i):
+ result = {}
+ for key, obs in obs.items():
+ future_batch_inds = np.clip(ind + i, 0, len(obs)-1)
+ next_obs = self.to_torch(obs[future_batch_inds])
+ result[key] = next_obs
+
+ return result
+
+ def get_next_action(actions, ind, i):
+ future_batch_inds = np.clip(ind + i, 0, len(actions) - 1)
+ return self.to_torch(actions[future_batch_inds])
+
+
+ next_observations = [get_next_obs(self.observations, batch_inds, i) for i in range(4)]
+ next_actions = [get_next_action(self.actions, batch_inds, i) for i in range(4)]
+
+ return DictSSLRolloutBufferSamples(
+ observations={key: self.to_torch(obs[batch_inds]) for (key, obs) in self.observations.items()},
+ next_observations=next_observations,
+ actions=self.to_torch(self.actions[batch_inds]),
+ next_actions=next_actions,
+ old_values=self.to_torch(self.values[batch_inds].flatten()),
+ old_log_prob=self.to_torch(self.log_probs[batch_inds].flatten()),
+ advantages=self.to_torch(self.advantages[batch_inds].flatten()),
+ returns=self.to_torch(self.returns[batch_inds].flatten()),
+ )
+
diff --git a/dexart-release/stable_baselines3/common/callbacks.py b/dexart-release/stable_baselines3/common/callbacks.py
new file mode 100644
index 0000000000000000000000000000000000000000..c5f297c7c7b11e00e4be1e03ac7ae4ede0b6ee45
--- /dev/null
+++ b/dexart-release/stable_baselines3/common/callbacks.py
@@ -0,0 +1,602 @@
+import os
+import warnings
+from abc import ABC, abstractmethod
+from typing import Any, Callable, Dict, List, Optional, Union
+
+import gym
+import numpy as np
+
+from stable_baselines3.common import base_class # pytype: disable=pyi-error
+from stable_baselines3.common.evaluation import evaluate_policy
+from stable_baselines3.common.vec_env import DummyVecEnv, VecEnv, sync_envs_normalization
+
+
+class BaseCallback(ABC):
+ """
+ Base class for callback.
+
+ :param verbose:
+ """
+
+ def __init__(self, verbose: int = 0):
+ super().__init__()
+ # The RL model
+ self.model = None # type: Optional[base_class.BaseAlgorithm]
+ # An alias for self.model.get_env(), the environment used for training
+ self.training_env = None # type: Union[gym.Env, VecEnv, None]
+ # Number of time the callback was called
+ self.n_calls = 0 # type: int
+ # n_envs * n times env.step() was called
+ self.num_timesteps = 0 # type: int
+ self.verbose = verbose
+ self.locals: Dict[str, Any] = {}
+ self.globals: Dict[str, Any] = {}
+ self.logger = None
+ # Sometimes, for event callback, it is useful
+ # to have access to the parent object
+ self.parent = None # type: Optional[BaseCallback]
+
+ # Type hint as string to avoid circular import
+ def init_callback(self, model: "base_class.BaseAlgorithm") -> None:
+ """
+ Initialize the callback by saving references to the
+ RL model and the training environment for convenience.
+ """
+ self.model = model
+ self.training_env = model.get_env()
+ self.logger = model.logger
+ self._init_callback()
+
+ def _init_callback(self) -> None:
+ pass
+
+ def on_training_start(self, locals_: Dict[str, Any], globals_: Dict[str, Any]) -> None:
+ # Those are reference and will be updated automatically
+ self.locals = locals_
+ self.globals = globals_
+ self._on_training_start()
+
+ def _on_training_start(self) -> None:
+ pass
+
+ def on_rollout_start(self) -> None:
+ self._on_rollout_start()
+
+ def _on_rollout_start(self) -> None:
+ pass
+
+ @abstractmethod
+ def _on_step(self) -> bool:
+ """
+ :return: If the callback returns False, training is aborted early.
+ """
+ return True
+
+ def on_step(self) -> bool:
+ """
+ This method will be called by the model after each call to ``env.step()``.
+
+ For child callback (of an ``EventCallback``), this will be called
+ when the event is triggered.
+
+ :return: If the callback returns False, training is aborted early.
+ """
+ self.n_calls += 1
+ # timesteps start at zero
+ self.num_timesteps = self.model.num_timesteps
+
+ return self._on_step()
+
+ def on_training_end(self) -> None:
+ self._on_training_end()
+
+ def _on_training_end(self) -> None:
+ pass
+
+ def on_rollout_end(self) -> None:
+ self._on_rollout_end()
+
+ def _on_rollout_end(self) -> None:
+ pass
+
+ def update_locals(self, locals_: Dict[str, Any]) -> None:
+ """
+ Update the references to the local variables.
+
+ :param locals_: the local variables during rollout collection
+ """
+ self.locals.update(locals_)
+ self.update_child_locals(locals_)
+
+ def update_child_locals(self, locals_: Dict[str, Any]) -> None:
+ """
+ Update the references to the local variables on sub callbacks.
+
+ :param locals_: the local variables during rollout collection
+ """
+ pass
+
+
+class EventCallback(BaseCallback):
+ """
+ Base class for triggering callback on event.
+
+ :param callback: Callback that will be called
+ when an event is triggered.
+ :param verbose:
+ """
+
+ def __init__(self, callback: Optional[BaseCallback] = None, verbose: int = 0):
+ super().__init__(verbose=verbose)
+ self.callback = callback
+ # Give access to the parent
+ if callback is not None:
+ self.callback.parent = self
+
+ def init_callback(self, model: "base_class.BaseAlgorithm") -> None:
+ super().init_callback(model)
+ if self.callback is not None:
+ self.callback.init_callback(self.model)
+
+ def _on_training_start(self) -> None:
+ if self.callback is not None:
+ self.callback.on_training_start(self.locals, self.globals)
+
+ def _on_event(self) -> bool:
+ if self.callback is not None:
+ return self.callback.on_step()
+ return True
+
+ def _on_step(self) -> bool:
+ return True
+
+ def update_child_locals(self, locals_: Dict[str, Any]) -> None:
+ """
+ Update the references to the local variables.
+
+ :param locals_: the local variables during rollout collection
+ """
+ if self.callback is not None:
+ self.callback.update_locals(locals_)
+
+
+class CallbackList(BaseCallback):
+ """
+ Class for chaining callbacks.
+
+ :param callbacks: A list of callbacks that will be called
+ sequentially.
+ """
+
+ def __init__(self, callbacks: List[BaseCallback]):
+ super().__init__()
+ assert isinstance(callbacks, list)
+ self.callbacks = callbacks
+
+ def _init_callback(self) -> None:
+ for callback in self.callbacks:
+ callback.init_callback(self.model)
+
+ def _on_training_start(self) -> None:
+ for callback in self.callbacks:
+ callback.on_training_start(self.locals, self.globals)
+
+ def _on_rollout_start(self) -> None:
+ for callback in self.callbacks:
+ callback.on_rollout_start()
+
+ def _on_step(self) -> bool:
+ continue_training = True
+ for callback in self.callbacks:
+ # Return False (stop training) if at least one callback returns False
+ continue_training = callback.on_step() and continue_training
+ return continue_training
+
+ def _on_rollout_end(self) -> None:
+ for callback in self.callbacks:
+ callback.on_rollout_end()
+
+ def _on_training_end(self) -> None:
+ for callback in self.callbacks:
+ callback.on_training_end()
+
+ def update_child_locals(self, locals_: Dict[str, Any]) -> None:
+ """
+ Update the references to the local variables.
+
+ :param locals_: the local variables during rollout collection
+ """
+ for callback in self.callbacks:
+ callback.update_locals(locals_)
+
+
+class CheckpointCallback(BaseCallback):
+ """
+ Callback for saving a model every ``save_freq`` calls
+ to ``env.step()``.
+
+ .. warning::
+
+ When using multiple environments, each call to ``env.step()``
+ will effectively correspond to ``n_envs`` steps.
+ To account for that, you can use ``save_freq = max(save_freq // n_envs, 1)``
+
+ :param save_freq:
+ :param save_path: Path to the folder where the model will be saved.
+ :param name_prefix: Common prefix to the saved models
+ :param verbose:
+ """
+
+ def __init__(self, save_freq: int, save_path: str, name_prefix: str = "rl_model", verbose: int = 0):
+ super().__init__(verbose)
+ self.save_freq = save_freq
+ self.save_path = save_path
+ self.name_prefix = name_prefix
+
+ def _init_callback(self) -> None:
+ # Create folder if needed
+ if self.save_path is not None:
+ os.makedirs(self.save_path, exist_ok=True)
+
+ def _on_step(self) -> bool:
+ if self.n_calls % self.save_freq == 0:
+ path = os.path.join(self.save_path, f"{self.name_prefix}_{self.num_timesteps}_steps")
+ self.model.save(path)
+ if self.verbose > 1:
+ print(f"Saving model checkpoint to {path}")
+ return True
+
+
+class ConvertCallback(BaseCallback):
+ """
+ Convert functional callback (old-style) to object.
+
+ :param callback:
+ :param verbose:
+ """
+
+ def __init__(self, callback: Callable[[Dict[str, Any], Dict[str, Any]], bool], verbose: int = 0):
+ super().__init__(verbose)
+ self.callback = callback
+
+ def _on_step(self) -> bool:
+ if self.callback is not None:
+ return self.callback(self.locals, self.globals)
+ return True
+
+
+class EvalCallback(EventCallback):
+ """
+ Callback for evaluating an agent.
+
+ .. warning::
+
+ When using multiple environments, each call to ``env.step()``
+ will effectively correspond to ``n_envs`` steps.
+ To account for that, you can use ``eval_freq = max(eval_freq // n_envs, 1)``
+
+ :param eval_env: The environment used for initialization
+ :param callback_on_new_best: Callback to trigger
+ when there is a new best model according to the ``mean_reward``
+ :param callback_after_eval: Callback to trigger after every evaluation
+ :param n_eval_episodes: The number of episodes to test the agent
+ :param eval_freq: Evaluate the agent every ``eval_freq`` call of the callback.
+ :param log_path: Path to a folder where the evaluations (``evaluations.npz``)
+ will be saved. It will be updated at each evaluation.
+ :param best_model_save_path: Path to a folder where the best model
+ according to performance on the eval env will be saved.
+ :param deterministic: Whether the evaluation should
+ use a stochastic or deterministic actions.
+ :param render: Whether to render or not the environment during evaluation
+ :param verbose:
+ :param warn: Passed to ``evaluate_policy`` (warns if ``eval_env`` has not been
+ wrapped with a Monitor wrapper)
+ """
+
+ def __init__(
+ self,
+ eval_env: Union[gym.Env, VecEnv],
+ callback_on_new_best: Optional[BaseCallback] = None,
+ callback_after_eval: Optional[BaseCallback] = None,
+ n_eval_episodes: int = 5,
+ eval_freq: int = 10000,
+ log_path: Optional[str] = None,
+ best_model_save_path: Optional[str] = None,
+ deterministic: bool = True,
+ render: bool = False,
+ verbose: int = 1,
+ warn: bool = True,
+ ):
+ super().__init__(callback_after_eval, verbose=verbose)
+
+ self.callback_on_new_best = callback_on_new_best
+ if self.callback_on_new_best is not None:
+ # Give access to the parent
+ self.callback_on_new_best.parent = self
+
+ self.n_eval_episodes = n_eval_episodes
+ self.eval_freq = eval_freq
+ self.best_mean_reward = -np.inf
+ self.last_mean_reward = -np.inf
+ self.deterministic = deterministic
+ self.render = render
+ self.warn = warn
+
+ # Convert to VecEnv for consistency
+ if not isinstance(eval_env, VecEnv):
+ eval_env = DummyVecEnv([lambda: eval_env])
+
+ self.eval_env = eval_env
+ self.best_model_save_path = best_model_save_path
+ # Logs will be written in ``evaluations.npz``
+ if log_path is not None:
+ log_path = os.path.join(log_path, "evaluations")
+ self.log_path = log_path
+ self.evaluations_results = []
+ self.evaluations_timesteps = []
+ self.evaluations_length = []
+ # For computing success rate
+ self._is_success_buffer = []
+ self.evaluations_successes = []
+
+ def _init_callback(self) -> None:
+ # Does not work in some corner cases, where the wrapper is not the same
+ if not isinstance(self.training_env, type(self.eval_env)):
+ warnings.warn("Training and eval env are not of the same type" f"{self.training_env} != {self.eval_env}")
+
+ # Create folders if needed
+ if self.best_model_save_path is not None:
+ os.makedirs(self.best_model_save_path, exist_ok=True)
+ if self.log_path is not None:
+ os.makedirs(os.path.dirname(self.log_path), exist_ok=True)
+
+ # Init callback called on new best model
+ if self.callback_on_new_best is not None:
+ self.callback_on_new_best.init_callback(self.model)
+
+ def _log_success_callback(self, locals_: Dict[str, Any], globals_: Dict[str, Any]) -> None:
+ """
+ Callback passed to the ``evaluate_policy`` function
+ in order to log the success rate (when applicable),
+ for instance when using HER.
+
+ :param locals_:
+ :param globals_:
+ """
+ info = locals_["info"]
+
+ if locals_["done"]:
+ maybe_is_success = info.get("is_success")
+ if maybe_is_success is not None:
+ self._is_success_buffer.append(maybe_is_success)
+
+ def _on_step(self) -> bool:
+
+ continue_training = True
+
+ if self.eval_freq > 0 and self.n_calls % self.eval_freq == 0:
+
+ # Sync training and eval env if there is VecNormalize
+ if self.model.get_vec_normalize_env() is not None:
+ try:
+ sync_envs_normalization(self.training_env, self.eval_env)
+ except AttributeError:
+ raise AssertionError(
+ "Training and eval env are not wrapped the same way, "
+ "see https://stable-baselines3.readthedocs.io/en/master/guide/callbacks.html#evalcallback "
+ "and warning above."
+ )
+
+ # Reset success rate buffer
+ self._is_success_buffer = []
+
+ episode_rewards, episode_lengths = evaluate_policy(
+ self.model,
+ self.eval_env,
+ n_eval_episodes=self.n_eval_episodes,
+ render=self.render,
+ deterministic=self.deterministic,
+ return_episode_rewards=True,
+ warn=self.warn,
+ callback=self._log_success_callback,
+ )
+
+ if self.log_path is not None:
+ self.evaluations_timesteps.append(self.num_timesteps)
+ self.evaluations_results.append(episode_rewards)
+ self.evaluations_length.append(episode_lengths)
+
+ kwargs = {}
+ # Save success log if present
+ if len(self._is_success_buffer) > 0:
+ self.evaluations_successes.append(self._is_success_buffer)
+ kwargs = dict(successes=self.evaluations_successes)
+
+ np.savez(
+ self.log_path,
+ timesteps=self.evaluations_timesteps,
+ results=self.evaluations_results,
+ ep_lengths=self.evaluations_length,
+ **kwargs,
+ )
+
+ mean_reward, std_reward = np.mean(episode_rewards), np.std(episode_rewards)
+ mean_ep_length, std_ep_length = np.mean(episode_lengths), np.std(episode_lengths)
+ self.last_mean_reward = mean_reward
+
+ if self.verbose > 0:
+ print(f"Eval num_timesteps={self.num_timesteps}, " f"episode_reward={mean_reward:.2f} +/- {std_reward:.2f}")
+ print(f"Episode length: {mean_ep_length:.2f} +/- {std_ep_length:.2f}")
+ # Add to current Logger
+ self.logger.record("eval/mean_reward", float(mean_reward))
+ self.logger.record("eval/mean_ep_length", mean_ep_length)
+
+ if len(self._is_success_buffer) > 0:
+ success_rate = np.mean(self._is_success_buffer)
+ if self.verbose > 0:
+ print(f"Success rate: {100 * success_rate:.2f}%")
+ self.logger.record("eval/success_rate", success_rate)
+
+ # Dump log so the evaluation results are printed with the correct timestep
+ self.logger.record("time/total_timesteps", self.num_timesteps, exclude="tensorboard")
+ self.logger.dump(self.num_timesteps)
+
+ if mean_reward > self.best_mean_reward:
+ if self.verbose > 0:
+ print("New best mean reward!")
+ if self.best_model_save_path is not None:
+ self.model.save(os.path.join(self.best_model_save_path, "best_model"))
+ self.best_mean_reward = mean_reward
+ # Trigger callback on new best model, if needed
+ if self.callback_on_new_best is not None:
+ continue_training = self.callback_on_new_best.on_step()
+
+ # Trigger callback after every evaluation, if needed
+ if self.callback is not None:
+ continue_training = continue_training and self._on_event()
+
+ return continue_training
+
+ def update_child_locals(self, locals_: Dict[str, Any]) -> None:
+ """
+ Update the references to the local variables.
+
+ :param locals_: the local variables during rollout collection
+ """
+ if self.callback:
+ self.callback.update_locals(locals_)
+
+
+class StopTrainingOnRewardThreshold(BaseCallback):
+ """
+ Stop the training once a threshold in episodic reward
+ has been reached (i.e. when the model is good enough).
+
+ It must be used with the ``EvalCallback``.
+
+ :param reward_threshold: Minimum expected reward per episode
+ to stop training.
+ :param verbose:
+ """
+
+ def __init__(self, reward_threshold: float, verbose: int = 0):
+ super().__init__(verbose=verbose)
+ self.reward_threshold = reward_threshold
+
+ def _on_step(self) -> bool:
+ assert self.parent is not None, "``StopTrainingOnMinimumReward`` callback must be used " "with an ``EvalCallback``"
+ # Convert np.bool_ to bool, otherwise callback() is False won't work
+ continue_training = bool(self.parent.best_mean_reward < self.reward_threshold)
+ if self.verbose > 0 and not continue_training:
+ print(
+ f"Stopping training because the mean reward {self.parent.best_mean_reward:.2f} "
+ f" is above the threshold {self.reward_threshold}"
+ )
+ return continue_training
+
+
+class EveryNTimesteps(EventCallback):
+ """
+ Trigger a callback every ``n_steps``Â timesteps
+
+ :param n_steps: Number of timesteps between two trigger.
+ :param callback: Callback that will be called
+ when the event is triggered.
+ """
+
+ def __init__(self, n_steps: int, callback: BaseCallback):
+ super().__init__(callback)
+ self.n_steps = n_steps
+ self.last_time_trigger = 0
+
+ def _on_step(self) -> bool:
+ if (self.num_timesteps - self.last_time_trigger) >= self.n_steps:
+ self.last_time_trigger = self.num_timesteps
+ return self._on_event()
+ return True
+
+
+class StopTrainingOnMaxEpisodes(BaseCallback):
+ """
+ Stop the training once a maximum number of episodes are played.
+
+ For multiple environments presumes that, the desired behavior is that the agent trains on each env for ``max_episodes``
+ and in total for ``max_episodes * n_envs`` episodes.
+
+ :param max_episodes: Maximum number of episodes to stop training.
+ :param verbose: Select whether to print information about when training ended by reaching ``max_episodes``
+ """
+
+ def __init__(self, max_episodes: int, verbose: int = 0):
+ super().__init__(verbose=verbose)
+ self.max_episodes = max_episodes
+ self._total_max_episodes = max_episodes
+ self.n_episodes = 0
+
+ def _init_callback(self) -> None:
+ # At start set total max according to number of envirnments
+ self._total_max_episodes = self.max_episodes * self.training_env.num_envs
+
+ def _on_step(self) -> bool:
+ # Check that the `dones` local variable is defined
+ assert "dones" in self.locals, "`dones` variable is not defined, please check your code next to `callback.on_step()`"
+ self.n_episodes += np.sum(self.locals["dones"]).item()
+
+ continue_training = self.n_episodes < self._total_max_episodes
+
+ if self.verbose > 0 and not continue_training:
+ mean_episodes_per_env = self.n_episodes / self.training_env.num_envs
+ mean_ep_str = (
+ f"with an average of {mean_episodes_per_env:.2f} episodes per env" if self.training_env.num_envs > 1 else ""
+ )
+
+ print(
+ f"Stopping training with a total of {self.num_timesteps} steps because the "
+ f"{self.locals.get('tb_log_name')} model reached max_episodes={self.max_episodes}, "
+ f"by playing for {self.n_episodes} episodes "
+ f"{mean_ep_str}"
+ )
+ return continue_training
+
+
+class StopTrainingOnNoModelImprovement(BaseCallback):
+ """
+ Stop the training early if there is no new best model (new best mean reward) after more than N consecutive evaluations.
+
+ It is possible to define a minimum number of evaluations before start to count evaluations without improvement.
+
+ It must be used with the ``EvalCallback``.
+
+ :param max_no_improvement_evals: Maximum number of consecutive evaluations without a new best model.
+ :param min_evals: Number of evaluations before start to count evaluations without improvements.
+ :param verbose: Verbosity of the output (set to 1 for info messages)
+ """
+
+ def __init__(self, max_no_improvement_evals: int, min_evals: int = 0, verbose: int = 0):
+ super().__init__(verbose=verbose)
+ self.max_no_improvement_evals = max_no_improvement_evals
+ self.min_evals = min_evals
+ self.last_best_mean_reward = -np.inf
+ self.no_improvement_evals = 0
+
+ def _on_step(self) -> bool:
+ assert self.parent is not None, "``StopTrainingOnNoModelImprovement`` callback must be used with an ``EvalCallback``"
+
+ continue_training = True
+
+ if self.n_calls > self.min_evals:
+ if self.parent.best_mean_reward > self.last_best_mean_reward:
+ self.no_improvement_evals = 0
+ else:
+ self.no_improvement_evals += 1
+ if self.no_improvement_evals > self.max_no_improvement_evals:
+ continue_training = False
+
+ self.last_best_mean_reward = self.parent.best_mean_reward
+
+ if self.verbose > 0 and not continue_training:
+ print(
+ f"Stopping training because there was no new best model in the last {self.no_improvement_evals:d} evaluations"
+ )
+
+ return continue_training
diff --git a/dexart-release/stable_baselines3/common/distributions.py b/dexart-release/stable_baselines3/common/distributions.py
new file mode 100644
index 0000000000000000000000000000000000000000..3d1ff5aa0bb1730b24f460f4c479d30058365448
--- /dev/null
+++ b/dexart-release/stable_baselines3/common/distributions.py
@@ -0,0 +1,699 @@
+"""Probability distributions."""
+
+from abc import ABC, abstractmethod
+from typing import Any, Dict, List, Optional, Tuple, Union
+
+import gym
+import torch as th
+from gym import spaces
+from torch import nn
+from torch.distributions import Bernoulli, Categorical, Normal
+
+from stable_baselines3.common.preprocessing import get_action_dim
+
+
+class Distribution(ABC):
+ """Abstract base class for distributions."""
+
+ def __init__(self):
+ super().__init__()
+ self.distribution = None
+
+ @abstractmethod
+ def proba_distribution_net(self, *args, **kwargs) -> Union[nn.Module, Tuple[nn.Module, nn.Parameter]]:
+ """Create the layers and parameters that represent the distribution.
+
+ Subclasses must define this, but the arguments and return type vary between
+ concrete classes."""
+
+ @abstractmethod
+ def proba_distribution(self, *args, **kwargs) -> "Distribution":
+ """Set parameters of the distribution.
+
+ :return: self
+ """
+
+ @abstractmethod
+ def log_prob(self, x: th.Tensor) -> th.Tensor:
+ """
+ Returns the log likelihood
+
+ :param x: the taken action
+ :return: The log likelihood of the distribution
+ """
+
+ @abstractmethod
+ def entropy(self) -> Optional[th.Tensor]:
+ """
+ Returns Shannon's entropy of the probability
+
+ :return: the entropy, or None if no analytical form is known
+ """
+
+ @abstractmethod
+ def sample(self) -> th.Tensor:
+ """
+ Returns a sample from the probability distribution
+
+ :return: the stochastic action
+ """
+
+ @abstractmethod
+ def mode(self) -> th.Tensor:
+ """
+ Returns the most likely action (deterministic output)
+ from the probability distribution
+
+ :return: the stochastic action
+ """
+
+ def get_actions(self, deterministic: bool = False) -> th.Tensor:
+ """
+ Return actions according to the probability distribution.
+
+ :param deterministic:
+ :return:
+ """
+ if deterministic:
+ return self.mode()
+ return self.sample()
+
+ @abstractmethod
+ def actions_from_params(self, *args, **kwargs) -> th.Tensor:
+ """
+ Returns samples from the probability distribution
+ given its parameters.
+
+ :return: actions
+ """
+
+ @abstractmethod
+ def log_prob_from_params(self, *args, **kwargs) -> Tuple[th.Tensor, th.Tensor]:
+ """
+ Returns samples and the associated log probabilities
+ from the probability distribution given its parameters.
+
+ :return: actions and log prob
+ """
+
+
+def sum_independent_dims(tensor: th.Tensor) -> th.Tensor:
+ """
+ Continuous actions are usually considered to be independent,
+ so we can sum components of the ``log_prob`` or the entropy.
+
+ :param tensor: shape: (n_batch, n_actions) or (n_batch,)
+ :return: shape: (n_batch,)
+ """
+ if len(tensor.shape) > 1:
+ tensor = tensor.sum(dim=1)
+ else:
+ tensor = tensor.sum()
+ return tensor
+
+
+class DiagGaussianDistribution(Distribution):
+ """
+ Gaussian distribution with diagonal covariance matrix, for continuous actions.
+
+ :param action_dim: Dimension of the action space.
+ """
+
+ def __init__(self, action_dim: int):
+ super().__init__()
+ self.action_dim = action_dim
+ self.mean_actions = None
+ self.log_std = None
+
+ def proba_distribution_net(self, latent_dim: int, log_std_init: float = 0.0) -> Tuple[nn.Module, nn.Parameter]:
+ """
+ Create the layers and parameter that represent the distribution:
+ one output will be the mean of the Gaussian, the other parameter will be the
+ standard deviation (log std in fact to allow negative values)
+
+ :param latent_dim: Dimension of the last layer of the policy (before the action layer)
+ :param log_std_init: Initial value for the log standard deviation
+ :return:
+ """
+ mean_actions = nn.Linear(latent_dim, self.action_dim)
+ # TODO: allow action dependent std
+ log_std = nn.Parameter(th.ones(self.action_dim) * log_std_init, requires_grad=True)
+ return mean_actions, log_std
+
+ def proba_distribution(self, mean_actions: th.Tensor, log_std: th.Tensor) -> "DiagGaussianDistribution":
+ """
+ Create the distribution given its parameters (mean, std)
+
+ :param mean_actions:
+ :param log_std:
+ :return:
+ """
+ action_std = th.ones_like(mean_actions) * log_std.exp()
+ self.distribution = Normal(mean_actions, action_std)
+ return self
+
+ def log_prob(self, actions: th.Tensor) -> th.Tensor:
+ """
+ Get the log probabilities of actions according to the distribution.
+ Note that you must first call the ``proba_distribution()`` method.
+
+ :param actions:
+ :return:
+ """
+ log_prob = self.distribution.log_prob(actions)
+ return sum_independent_dims(log_prob)
+
+ def entropy(self) -> th.Tensor:
+ return sum_independent_dims(self.distribution.entropy())
+
+ def sample(self) -> th.Tensor:
+ # Reparametrization trick to pass gradients
+ return self.distribution.rsample()
+
+ def mode(self) -> th.Tensor:
+ return self.distribution.mean
+
+ def actions_from_params(self, mean_actions: th.Tensor, log_std: th.Tensor, deterministic: bool = False) -> th.Tensor:
+ # Update the proba distribution
+ self.proba_distribution(mean_actions, log_std)
+ return self.get_actions(deterministic=deterministic)
+
+ def log_prob_from_params(self, mean_actions: th.Tensor, log_std: th.Tensor) -> Tuple[th.Tensor, th.Tensor]:
+ """
+ Compute the log probability of taking an action
+ given the distribution parameters.
+
+ :param mean_actions:
+ :param log_std:
+ :return:
+ """
+ actions = self.actions_from_params(mean_actions, log_std)
+ log_prob = self.log_prob(actions)
+ return actions, log_prob
+
+
+class SquashedDiagGaussianDistribution(DiagGaussianDistribution):
+ """
+ Gaussian distribution with diagonal covariance matrix, followed by a squashing function (tanh) to ensure bounds.
+
+ :param action_dim: Dimension of the action space.
+ :param epsilon: small value to avoid NaN due to numerical imprecision.
+ """
+
+ def __init__(self, action_dim: int, epsilon: float = 1e-6):
+ super().__init__(action_dim)
+ # Avoid NaN (prevents division by zero or log of zero)
+ self.epsilon = epsilon
+ self.gaussian_actions = None
+
+ def proba_distribution(self, mean_actions: th.Tensor, log_std: th.Tensor) -> "SquashedDiagGaussianDistribution":
+ super().proba_distribution(mean_actions, log_std)
+ return self
+
+ def log_prob(self, actions: th.Tensor, gaussian_actions: Optional[th.Tensor] = None) -> th.Tensor:
+ # Inverse tanh
+ # Naive implementation (not stable): 0.5 * torch.log((1 + x) / (1 - x))
+ # We use numpy to avoid numerical instability
+ if gaussian_actions is None:
+ # It will be clipped to avoid NaN when inversing tanh
+ gaussian_actions = TanhBijector.inverse(actions)
+
+ # Log likelihood for a Gaussian distribution
+ log_prob = super().log_prob(gaussian_actions)
+ # Squash correction (from original SAC implementation)
+ # this comes from the fact that tanh is bijective and differentiable
+ log_prob -= th.sum(th.log(1 - actions**2 + self.epsilon), dim=1)
+ return log_prob
+
+ def entropy(self) -> Optional[th.Tensor]:
+ # No analytical form,
+ # entropy needs to be estimated using -log_prob.mean()
+ return None
+
+ def sample(self) -> th.Tensor:
+ # Reparametrization trick to pass gradients
+ self.gaussian_actions = super().sample()
+ return th.tanh(self.gaussian_actions)
+
+ def mode(self) -> th.Tensor:
+ self.gaussian_actions = super().mode()
+ # Squash the output
+ return th.tanh(self.gaussian_actions)
+
+ def log_prob_from_params(self, mean_actions: th.Tensor, log_std: th.Tensor) -> Tuple[th.Tensor, th.Tensor]:
+ action = self.actions_from_params(mean_actions, log_std)
+ log_prob = self.log_prob(action, self.gaussian_actions)
+ return action, log_prob
+
+
+class CategoricalDistribution(Distribution):
+ """
+ Categorical distribution for discrete actions.
+
+ :param action_dim: Number of discrete actions
+ """
+
+ def __init__(self, action_dim: int):
+ super().__init__()
+ self.action_dim = action_dim
+
+ def proba_distribution_net(self, latent_dim: int) -> nn.Module:
+ """
+ Create the layer that represents the distribution:
+ it will be the logits of the Categorical distribution.
+ You can then get probabilities using a softmax.
+
+ :param latent_dim: Dimension of the last layer
+ of the policy network (before the action layer)
+ :return:
+ """
+ action_logits = nn.Linear(latent_dim, self.action_dim)
+ return action_logits
+
+ def proba_distribution(self, action_logits: th.Tensor) -> "CategoricalDistribution":
+ self.distribution = Categorical(logits=action_logits)
+ return self
+
+ def log_prob(self, actions: th.Tensor) -> th.Tensor:
+ return self.distribution.log_prob(actions)
+
+ def entropy(self) -> th.Tensor:
+ return self.distribution.entropy()
+
+ def sample(self) -> th.Tensor:
+ return self.distribution.sample()
+
+ def mode(self) -> th.Tensor:
+ return th.argmax(self.distribution.probs, dim=1)
+
+ def actions_from_params(self, action_logits: th.Tensor, deterministic: bool = False) -> th.Tensor:
+ # Update the proba distribution
+ self.proba_distribution(action_logits)
+ return self.get_actions(deterministic=deterministic)
+
+ def log_prob_from_params(self, action_logits: th.Tensor) -> Tuple[th.Tensor, th.Tensor]:
+ actions = self.actions_from_params(action_logits)
+ log_prob = self.log_prob(actions)
+ return actions, log_prob
+
+
+class MultiCategoricalDistribution(Distribution):
+ """
+ MultiCategorical distribution for multi discrete actions.
+
+ :param action_dims: List of sizes of discrete action spaces
+ """
+
+ def __init__(self, action_dims: List[int]):
+ super().__init__()
+ self.action_dims = action_dims
+
+ def proba_distribution_net(self, latent_dim: int) -> nn.Module:
+ """
+ Create the layer that represents the distribution:
+ it will be the logits (flattened) of the MultiCategorical distribution.
+ You can then get probabilities using a softmax on each sub-space.
+
+ :param latent_dim: Dimension of the last layer
+ of the policy network (before the action layer)
+ :return:
+ """
+
+ action_logits = nn.Linear(latent_dim, sum(self.action_dims))
+ return action_logits
+
+ def proba_distribution(self, action_logits: th.Tensor) -> "MultiCategoricalDistribution":
+ self.distribution = [Categorical(logits=split) for split in th.split(action_logits, tuple(self.action_dims), dim=1)]
+ return self
+
+ def log_prob(self, actions: th.Tensor) -> th.Tensor:
+ # Extract each discrete action and compute log prob for their respective distributions
+ return th.stack(
+ [dist.log_prob(action) for dist, action in zip(self.distribution, th.unbind(actions, dim=1))], dim=1
+ ).sum(dim=1)
+
+ def entropy(self) -> th.Tensor:
+ return th.stack([dist.entropy() for dist in self.distribution], dim=1).sum(dim=1)
+
+ def sample(self) -> th.Tensor:
+ return th.stack([dist.sample() for dist in self.distribution], dim=1)
+
+ def mode(self) -> th.Tensor:
+ return th.stack([th.argmax(dist.probs, dim=1) for dist in self.distribution], dim=1)
+
+ def actions_from_params(self, action_logits: th.Tensor, deterministic: bool = False) -> th.Tensor:
+ # Update the proba distribution
+ self.proba_distribution(action_logits)
+ return self.get_actions(deterministic=deterministic)
+
+ def log_prob_from_params(self, action_logits: th.Tensor) -> Tuple[th.Tensor, th.Tensor]:
+ actions = self.actions_from_params(action_logits)
+ log_prob = self.log_prob(actions)
+ return actions, log_prob
+
+
+class BernoulliDistribution(Distribution):
+ """
+ Bernoulli distribution for MultiBinary action spaces.
+
+ :param action_dim: Number of binary actions
+ """
+
+ def __init__(self, action_dims: int):
+ super().__init__()
+ self.action_dims = action_dims
+
+ def proba_distribution_net(self, latent_dim: int) -> nn.Module:
+ """
+ Create the layer that represents the distribution:
+ it will be the logits of the Bernoulli distribution.
+
+ :param latent_dim: Dimension of the last layer
+ of the policy network (before the action layer)
+ :return:
+ """
+ action_logits = nn.Linear(latent_dim, self.action_dims)
+ return action_logits
+
+ def proba_distribution(self, action_logits: th.Tensor) -> "BernoulliDistribution":
+ self.distribution = Bernoulli(logits=action_logits)
+ return self
+
+ def log_prob(self, actions: th.Tensor) -> th.Tensor:
+ return self.distribution.log_prob(actions).sum(dim=1)
+
+ def entropy(self) -> th.Tensor:
+ return self.distribution.entropy().sum(dim=1)
+
+ def sample(self) -> th.Tensor:
+ return self.distribution.sample()
+
+ def mode(self) -> th.Tensor:
+ return th.round(self.distribution.probs)
+
+ def actions_from_params(self, action_logits: th.Tensor, deterministic: bool = False) -> th.Tensor:
+ # Update the proba distribution
+ self.proba_distribution(action_logits)
+ return self.get_actions(deterministic=deterministic)
+
+ def log_prob_from_params(self, action_logits: th.Tensor) -> Tuple[th.Tensor, th.Tensor]:
+ actions = self.actions_from_params(action_logits)
+ log_prob = self.log_prob(actions)
+ return actions, log_prob
+
+
+class StateDependentNoiseDistribution(Distribution):
+ """
+ Distribution class for using generalized State Dependent Exploration (gSDE).
+ Paper: https://arxiv.org/abs/2005.05719
+
+ It is used to create the noise exploration matrix and
+ compute the log probability of an action with that noise.
+
+ :param action_dim: Dimension of the action space.
+ :param full_std: Whether to use (n_features x n_actions) parameters
+ for the std instead of only (n_features,)
+ :param use_expln: Use ``expln()`` function instead of ``exp()`` to ensure
+ a positive standard deviation (cf paper). It allows to keep variance
+ above zero and prevent it from growing too fast. In practice, ``exp()`` is usually enough.
+ :param squash_output: Whether to squash the output using a tanh function,
+ this ensures bounds are satisfied.
+ :param learn_features: Whether to learn features for gSDE or not.
+ This will enable gradients to be backpropagated through the features
+ ``latent_sde`` in the code.
+ :param epsilon: small value to avoid NaN due to numerical imprecision.
+ """
+
+ def __init__(
+ self,
+ action_dim: int,
+ full_std: bool = True,
+ use_expln: bool = False,
+ squash_output: bool = False,
+ learn_features: bool = False,
+ epsilon: float = 1e-6,
+ ):
+ super().__init__()
+ self.action_dim = action_dim
+ self.latent_sde_dim = None
+ self.mean_actions = None
+ self.log_std = None
+ self.weights_dist = None
+ self.exploration_mat = None
+ self.exploration_matrices = None
+ self._latent_sde = None
+ self.use_expln = use_expln
+ self.full_std = full_std
+ self.epsilon = epsilon
+ self.learn_features = learn_features
+ if squash_output:
+ self.bijector = TanhBijector(epsilon)
+ else:
+ self.bijector = None
+
+ def get_std(self, log_std: th.Tensor) -> th.Tensor:
+ """
+ Get the standard deviation from the learned parameter
+ (log of it by default). This ensures that the std is positive.
+
+ :param log_std:
+ :return:
+ """
+ if self.use_expln:
+ # From gSDE paper, it allows to keep variance
+ # above zero and prevent it from growing too fast
+ below_threshold = th.exp(log_std) * (log_std <= 0)
+ # Avoid NaN: zeros values that are below zero
+ safe_log_std = log_std * (log_std > 0) + self.epsilon
+ above_threshold = (th.log1p(safe_log_std) + 1.0) * (log_std > 0)
+ std = below_threshold + above_threshold
+ else:
+ # Use normal exponential
+ std = th.exp(log_std)
+
+ if self.full_std:
+ return std
+ # Reduce the number of parameters:
+ return th.ones(self.latent_sde_dim, self.action_dim).to(log_std.device) * std
+
+ def sample_weights(self, log_std: th.Tensor, batch_size: int = 1) -> None:
+ """
+ Sample weights for the noise exploration matrix,
+ using a centered Gaussian distribution.
+
+ :param log_std:
+ :param batch_size:
+ """
+ std = self.get_std(log_std)
+ self.weights_dist = Normal(th.zeros_like(std), std)
+ # Reparametrization trick to pass gradients
+ self.exploration_mat = self.weights_dist.rsample()
+ # Pre-compute matrices in case of parallel exploration
+ self.exploration_matrices = self.weights_dist.rsample((batch_size,))
+
+ def proba_distribution_net(
+ self, latent_dim: int, log_std_init: float = -2.0, latent_sde_dim: Optional[int] = None
+ ) -> Tuple[nn.Module, nn.Parameter]:
+ """
+ Create the layers and parameter that represent the distribution:
+ one output will be the deterministic action, the other parameter will be the
+ standard deviation of the distribution that control the weights of the noise matrix.
+
+ :param latent_dim: Dimension of the last layer of the policy (before the action layer)
+ :param log_std_init: Initial value for the log standard deviation
+ :param latent_sde_dim: Dimension of the last layer of the features extractor
+ for gSDE. By default, it is shared with the policy network.
+ :return:
+ """
+ # Network for the deterministic action, it represents the mean of the distribution
+ mean_actions_net = nn.Linear(latent_dim, self.action_dim)
+ # When we learn features for the noise, the feature dimension
+ # can be different between the policy and the noise network
+ self.latent_sde_dim = latent_dim if latent_sde_dim is None else latent_sde_dim
+ # Reduce the number of parameters if needed
+ log_std = th.ones(self.latent_sde_dim, self.action_dim) if self.full_std else th.ones(self.latent_sde_dim, 1)
+ # Transform it to a parameter so it can be optimized
+ log_std = nn.Parameter(log_std * log_std_init, requires_grad=True)
+ # Sample an exploration matrix
+ self.sample_weights(log_std)
+ return mean_actions_net, log_std
+
+ def proba_distribution(
+ self, mean_actions: th.Tensor, log_std: th.Tensor, latent_sde: th.Tensor
+ ) -> "StateDependentNoiseDistribution":
+ """
+ Create the distribution given its parameters (mean, std)
+
+ :param mean_actions:
+ :param log_std:
+ :param latent_sde:
+ :return:
+ """
+ # Stop gradient if we don't want to influence the features
+ self._latent_sde = latent_sde if self.learn_features else latent_sde.detach()
+ variance = th.mm(self._latent_sde**2, self.get_std(log_std) ** 2)
+ self.distribution = Normal(mean_actions, th.sqrt(variance + self.epsilon))
+ return self
+
+ def log_prob(self, actions: th.Tensor) -> th.Tensor:
+ if self.bijector is not None:
+ gaussian_actions = self.bijector.inverse(actions)
+ else:
+ gaussian_actions = actions
+ # log likelihood for a gaussian
+ log_prob = self.distribution.log_prob(gaussian_actions)
+ # Sum along action dim
+ log_prob = sum_independent_dims(log_prob)
+
+ if self.bijector is not None:
+ # Squash correction (from original SAC implementation)
+ log_prob -= th.sum(self.bijector.log_prob_correction(gaussian_actions), dim=1)
+ return log_prob
+
+ def entropy(self) -> Optional[th.Tensor]:
+ if self.bijector is not None:
+ # No analytical form,
+ # entropy needs to be estimated using -log_prob.mean()
+ return None
+ return sum_independent_dims(self.distribution.entropy())
+
+ def sample(self) -> th.Tensor:
+ noise = self.get_noise(self._latent_sde)
+ actions = self.distribution.mean + noise
+ if self.bijector is not None:
+ return self.bijector.forward(actions)
+ return actions
+
+ def mode(self) -> th.Tensor:
+ actions = self.distribution.mean
+ if self.bijector is not None:
+ return self.bijector.forward(actions)
+ return actions
+
+ def get_noise(self, latent_sde: th.Tensor) -> th.Tensor:
+ latent_sde = latent_sde if self.learn_features else latent_sde.detach()
+ # Default case: only one exploration matrix
+ if len(latent_sde) == 1 or len(latent_sde) != len(self.exploration_matrices):
+ return th.mm(latent_sde, self.exploration_mat)
+ # Use batch matrix multiplication for efficient computation
+ # (batch_size, n_features) -> (batch_size, 1, n_features)
+ latent_sde = latent_sde.unsqueeze(1)
+ # (batch_size, 1, n_actions)
+ noise = th.bmm(latent_sde, self.exploration_matrices)
+ return noise.squeeze(1)
+
+ def actions_from_params(
+ self, mean_actions: th.Tensor, log_std: th.Tensor, latent_sde: th.Tensor, deterministic: bool = False
+ ) -> th.Tensor:
+ # Update the proba distribution
+ self.proba_distribution(mean_actions, log_std, latent_sde)
+ return self.get_actions(deterministic=deterministic)
+
+ def log_prob_from_params(
+ self, mean_actions: th.Tensor, log_std: th.Tensor, latent_sde: th.Tensor
+ ) -> Tuple[th.Tensor, th.Tensor]:
+ actions = self.actions_from_params(mean_actions, log_std, latent_sde)
+ log_prob = self.log_prob(actions)
+ return actions, log_prob
+
+
+class TanhBijector:
+ """
+ Bijective transformation of a probability distribution
+ using a squashing function (tanh)
+ TODO: use Pyro instead (https://pyro.ai/)
+
+ :param epsilon: small value to avoid NaN due to numerical imprecision.
+ """
+
+ def __init__(self, epsilon: float = 1e-6):
+ super().__init__()
+ self.epsilon = epsilon
+
+ @staticmethod
+ def forward(x: th.Tensor) -> th.Tensor:
+ return th.tanh(x)
+
+ @staticmethod
+ def atanh(x: th.Tensor) -> th.Tensor:
+ """
+ Inverse of Tanh
+
+ Taken from Pyro: https://github.com/pyro-ppl/pyro
+ 0.5 * torch.log((1 + x ) / (1 - x))
+ """
+ return 0.5 * (x.log1p() - (-x).log1p())
+
+ @staticmethod
+ def inverse(y: th.Tensor) -> th.Tensor:
+ """
+ Inverse tanh.
+
+ :param y:
+ :return:
+ """
+ eps = th.finfo(y.dtype).eps
+ # Clip the action to avoid NaN
+ return TanhBijector.atanh(y.clamp(min=-1.0 + eps, max=1.0 - eps))
+
+ def log_prob_correction(self, x: th.Tensor) -> th.Tensor:
+ # Squash correction (from original SAC implementation)
+ return th.log(1.0 - th.tanh(x) ** 2 + self.epsilon)
+
+
+def make_proba_distribution(
+ action_space: gym.spaces.Space, use_sde: bool = False, dist_kwargs: Optional[Dict[str, Any]] = None
+) -> Distribution:
+ """
+ Return an instance of Distribution for the correct type of action space
+
+ :param action_space: the input action space
+ :param use_sde: Force the use of StateDependentNoiseDistribution
+ instead of DiagGaussianDistribution
+ :param dist_kwargs: Keyword arguments to pass to the probability distribution
+ :return: the appropriate Distribution object
+ """
+ if dist_kwargs is None:
+ dist_kwargs = {}
+
+ if isinstance(action_space, spaces.Box):
+ assert len(action_space.shape) == 1, "Error: the action space must be a vector"
+ cls = StateDependentNoiseDistribution if use_sde else DiagGaussianDistribution
+ return cls(get_action_dim(action_space), **dist_kwargs)
+ elif isinstance(action_space, spaces.Discrete):
+ return CategoricalDistribution(action_space.n, **dist_kwargs)
+ elif isinstance(action_space, spaces.MultiDiscrete):
+ return MultiCategoricalDistribution(action_space.nvec, **dist_kwargs)
+ elif isinstance(action_space, spaces.MultiBinary):
+ return BernoulliDistribution(action_space.n, **dist_kwargs)
+ else:
+ raise NotImplementedError(
+ "Error: probability distribution, not implemented for action space"
+ f"of type {type(action_space)}."
+ " Must be of type Gym Spaces: Box, Discrete, MultiDiscrete or MultiBinary."
+ )
+
+
+def kl_divergence(dist_true: Distribution, dist_pred: Distribution) -> th.Tensor:
+ """
+ Wrapper for the PyTorch implementation of the full form KL Divergence
+
+ :param dist_true: the p distribution
+ :param dist_pred: the q distribution
+ :return: KL(dist_true||dist_pred)
+ """
+ # KL Divergence for different distribution types is out of scope
+ assert dist_true.__class__ == dist_pred.__class__, "Error: input distributions should be the same type"
+
+ # MultiCategoricalDistribution is not a PyTorch Distribution subclass
+ # so we need to implement it ourselves!
+ if isinstance(dist_pred, MultiCategoricalDistribution):
+ assert dist_pred.action_dims == dist_true.action_dims, "Error: distributions must have the same input space"
+ return th.stack(
+ [th.distributions.kl_divergence(p, q) for p, q in zip(dist_true.distribution, dist_pred.distribution)],
+ dim=1,
+ ).sum(dim=1)
+
+ # Use the PyTorch kl_divergence implementation
+ else:
+ return th.distributions.kl_divergence(dist_true.distribution, dist_pred.distribution)
diff --git a/dexart-release/stable_baselines3/common/env_util.py b/dexart-release/stable_baselines3/common/env_util.py
new file mode 100644
index 0000000000000000000000000000000000000000..297cce30404fc8811ca3371a01a7ad6824c5e274
--- /dev/null
+++ b/dexart-release/stable_baselines3/common/env_util.py
@@ -0,0 +1,104 @@
+import os
+from typing import Any, Callable, Dict, Optional, Type, Union
+
+import gym
+
+from stable_baselines3.common.monitor import Monitor
+from stable_baselines3.common.vec_env import DummyVecEnv, SubprocVecEnv, VecEnv
+
+
+def unwrap_wrapper(env: gym.Env, wrapper_class: Type[gym.Wrapper]) -> Optional[gym.Wrapper]:
+ """
+ Retrieve a ``VecEnvWrapper`` object by recursively searching.
+
+ :param env: Environment to unwrap
+ :param wrapper_class: Wrapper to look for
+ :return: Environment unwrapped till ``wrapper_class`` if it has been wrapped with it
+ """
+ env_tmp = env
+ while isinstance(env_tmp, gym.Wrapper):
+ if isinstance(env_tmp, wrapper_class):
+ return env_tmp
+ env_tmp = env_tmp.env
+ return None
+
+
+def is_wrapped(env: Type[gym.Env], wrapper_class: Type[gym.Wrapper]) -> bool:
+ """
+ Check if a given environment has been wrapped with a given wrapper.
+
+ :param env: Environment to check
+ :param wrapper_class: Wrapper class to look for
+ :return: True if environment has been wrapped with ``wrapper_class``.
+ """
+ return unwrap_wrapper(env, wrapper_class) is not None
+
+
+def make_vec_env(
+ env_id: Union[str, Type[gym.Env]],
+ n_envs: int = 1,
+ seed: Optional[int] = None,
+ start_index: int = 0,
+ monitor_dir: Optional[str] = None,
+ wrapper_class: Optional[Callable[[gym.Env], gym.Env]] = None,
+ env_kwargs: Optional[Dict[str, Any]] = None,
+ vec_env_cls: Optional[Type[Union[DummyVecEnv, SubprocVecEnv]]] = None,
+ vec_env_kwargs: Optional[Dict[str, Any]] = None,
+ monitor_kwargs: Optional[Dict[str, Any]] = None,
+ wrapper_kwargs: Optional[Dict[str, Any]] = None,
+) -> VecEnv:
+ """
+ Create a wrapped, monitored ``VecEnv``.
+ By default it uses a ``DummyVecEnv`` which is usually faster
+ than a ``SubprocVecEnv``.
+
+ :param env_id: the environment ID or the environment class
+ :param n_envs: the number of environments you wish to have in parallel
+ :param seed: the initial seed for the random number generator
+ :param start_index: start rank index
+ :param monitor_dir: Path to a folder where the monitor files will be saved.
+ If None, no file will be written, however, the env will still be wrapped
+ in a Monitor wrapper to provide additional information about training.
+ :param wrapper_class: Additional wrapper to use on the environment.
+ This can also be a function with single argument that wraps the environment in many things.
+ :param env_kwargs: Optional keyword argument to pass to the env constructor
+ :param vec_env_cls: A custom ``VecEnv`` class constructor. Default: None.
+ :param vec_env_kwargs: Keyword arguments to pass to the ``VecEnv`` class constructor.
+ :param monitor_kwargs: Keyword arguments to pass to the ``Monitor`` class constructor.
+ :param wrapper_kwargs: Keyword arguments to pass to the ``Wrapper`` class constructor.
+ :return: The wrapped environment
+ """
+ env_kwargs = {} if env_kwargs is None else env_kwargs
+ vec_env_kwargs = {} if vec_env_kwargs is None else vec_env_kwargs
+ monitor_kwargs = {} if monitor_kwargs is None else monitor_kwargs
+ wrapper_kwargs = {} if wrapper_kwargs is None else wrapper_kwargs
+
+ def make_env(rank):
+ def _init():
+ if isinstance(env_id, str):
+ env = gym.make(env_id, **env_kwargs)
+ else:
+ env = env_id(**env_kwargs)
+ if seed is not None:
+ env.seed(seed + rank)
+ env.action_space.seed(seed + rank)
+ # Wrap the env in a Monitor wrapper
+ # to have additional training information
+ monitor_path = os.path.join(monitor_dir, str(rank)) if monitor_dir is not None else None
+ # Create the monitor folder if needed
+ if monitor_path is not None:
+ os.makedirs(monitor_dir, exist_ok=True)
+ env = Monitor(env, filename=monitor_path, **monitor_kwargs)
+ # Optionally, wrap the environment with the provided wrapper
+ if wrapper_class is not None:
+ env = wrapper_class(env, **wrapper_kwargs)
+ return env
+
+ return _init
+
+ # No custom VecEnv is passed
+ if vec_env_cls is None:
+ # Default: use a DummyVecEnv
+ vec_env_cls = DummyVecEnv
+
+ return vec_env_cls([make_env(i + start_index) for i in range(n_envs)], **vec_env_kwargs)
\ No newline at end of file
diff --git a/dexart-release/stable_baselines3/common/evaluation.py b/dexart-release/stable_baselines3/common/evaluation.py
new file mode 100644
index 0000000000000000000000000000000000000000..e3f14d3f8083c0e4b98efb784a629c0b75141470
--- /dev/null
+++ b/dexart-release/stable_baselines3/common/evaluation.py
@@ -0,0 +1,131 @@
+import warnings
+from typing import Any, Callable, Dict, List, Optional, Tuple, Union
+
+import gym
+import numpy as np
+
+from stable_baselines3.common import base_class
+from stable_baselines3.common.vec_env import DummyVecEnv, VecEnv, VecMonitor, is_vecenv_wrapped
+
+
+def evaluate_policy(
+ model: "base_class.BaseAlgorithm",
+ env: Union[gym.Env, VecEnv],
+ n_eval_episodes: int = 10,
+ deterministic: bool = True,
+ render: bool = False,
+ callback: Optional[Callable[[Dict[str, Any], Dict[str, Any]], None]] = None,
+ reward_threshold: Optional[float] = None,
+ return_episode_rewards: bool = False,
+ warn: bool = True,
+) -> Union[Tuple[float, float], Tuple[List[float], List[int]]]:
+ """
+ Runs policy for ``n_eval_episodes`` episodes and returns average reward.
+ If a vector env is passed in, this divides the episodes to evaluate onto the
+ different elements of the vector env. This static division of work is done to
+ remove bias. See https://github.com/DLR-RM/stable-baselines3/issues/402 for more
+ details and discussion.
+
+ .. note::
+ If environment has not been wrapped with ``Monitor`` wrapper, reward and
+ episode lengths are counted as it appears with ``env.step`` calls. If
+ the environment contains wrappers that modify rewards or episode lengths
+ (e.g. reward scaling, early episode reset), these will affect the evaluation
+ results as well. You can avoid this by wrapping environment with ``Monitor``
+ wrapper before anything else.
+
+ :param model: The RL agent you want to evaluate.
+ :param env: The gym environment or ``VecEnv`` environment.
+ :param n_eval_episodes: Number of episode to evaluate the agent
+ :param deterministic: Whether to use deterministic or stochastic actions
+ :param render: Whether to render the environment or not
+ :param callback: callback function to do additional checks,
+ called after each step. Gets locals() and globals() passed as parameters.
+ :param reward_threshold: Minimum expected reward per episode,
+ this will raise an error if the performance is not met
+ :param return_episode_rewards: If True, a list of rewards and episode lengths
+ per episode will be returned instead of the mean.
+ :param warn: If True (default), warns user about lack of a Monitor wrapper in the
+ evaluation environment.
+ :return: Mean reward per episode, std of reward per episode.
+ Returns ([float], [int]) when ``return_episode_rewards`` is True, first
+ list containing per-episode rewards and second containing per-episode lengths
+ (in number of steps).
+ """
+ is_monitor_wrapped = False
+ # Avoid circular import
+ from stable_baselines3.common.monitor import Monitor
+
+ if not isinstance(env, VecEnv):
+ env = DummyVecEnv([lambda: env])
+
+ is_monitor_wrapped = is_vecenv_wrapped(env, VecMonitor) or env.env_is_wrapped(Monitor)[0]
+
+ if not is_monitor_wrapped and warn:
+ warnings.warn(
+ "Evaluation environment is not wrapped with a ``Monitor`` wrapper. "
+ "This may result in reporting modified episode lengths and rewards, if other wrappers happen to modify these. "
+ "Consider wrapping environment first with ``Monitor`` wrapper.",
+ UserWarning,
+ )
+
+ n_envs = env.num_envs
+ episode_rewards = []
+ episode_lengths = []
+
+ episode_counts = np.zeros(n_envs, dtype="int")
+ # Divides episodes among different sub environments in the vector as evenly as possible
+ episode_count_targets = np.array([(n_eval_episodes + i) // n_envs for i in range(n_envs)], dtype="int")
+
+ current_rewards = np.zeros(n_envs)
+ current_lengths = np.zeros(n_envs, dtype="int")
+ observations = env.reset()
+ states = None
+ episode_starts = np.ones((env.num_envs,), dtype=bool)
+ while (episode_counts < episode_count_targets).any():
+ actions, states = model.predict(observations, state=states, episode_start=episode_starts, deterministic=deterministic)
+ observations, rewards, dones, infos = env.step(actions)
+ current_rewards += rewards
+ current_lengths += 1
+ for i in range(n_envs):
+ if episode_counts[i] < episode_count_targets[i]:
+
+ # unpack values so that the callback can access the local variables
+ reward = rewards[i]
+ done = dones[i]
+ info = infos[i]
+ episode_starts[i] = done
+
+ if callback is not None:
+ callback(locals(), globals())
+
+ if dones[i]:
+ if is_monitor_wrapped:
+ # Atari wrapper can send a "done" signal when
+ # the agent loses a life, but it does not correspond
+ # to the true end of episode
+ if "episode" in info.keys():
+ # Do not trust "done" with episode endings.
+ # Monitor wrapper includes "episode" key in info if environment
+ # has been wrapped with it. Use those rewards instead.
+ episode_rewards.append(info["episode"]["r"])
+ episode_lengths.append(info["episode"]["l"])
+ # Only increment at the real end of an episode
+ episode_counts[i] += 1
+ else:
+ episode_rewards.append(current_rewards[i])
+ episode_lengths.append(current_lengths[i])
+ episode_counts[i] += 1
+ current_rewards[i] = 0
+ current_lengths[i] = 0
+
+ if render:
+ env.render()
+
+ mean_reward = np.mean(episode_rewards)
+ std_reward = np.std(episode_rewards)
+ if reward_threshold is not None:
+ assert mean_reward > reward_threshold, "Mean reward below threshold: " f"{mean_reward:.2f} < {reward_threshold:.2f}"
+ if return_episode_rewards:
+ return episode_rewards, episode_lengths
+ return mean_reward, std_reward
diff --git a/dexart-release/stable_baselines3/common/logger.py b/dexart-release/stable_baselines3/common/logger.py
new file mode 100644
index 0000000000000000000000000000000000000000..1be6945c5bbfc1ef5905c5cdec1851182faea7bb
--- /dev/null
+++ b/dexart-release/stable_baselines3/common/logger.py
@@ -0,0 +1,644 @@
+import datetime
+import json
+import os
+import sys
+import tempfile
+import warnings
+from collections import defaultdict
+from typing import Any, Dict, List, Optional, Sequence, TextIO, Tuple, Union
+
+import numpy as np
+import pandas
+import torch as th
+from matplotlib import pyplot as plt
+
+try:
+ from torch.utils.tensorboard import SummaryWriter
+except ImportError:
+ SummaryWriter = None
+
+DEBUG = 10
+INFO = 20
+WARN = 30
+ERROR = 40
+DISABLED = 50
+
+
+class Video:
+ """
+ Video data class storing the video frames and the frame per seconds
+
+ :param frames: frames to create the video from
+ :param fps: frames per second
+ """
+
+ def __init__(self, frames: th.Tensor, fps: Union[float, int]):
+ self.frames = frames
+ self.fps = fps
+
+
+class Figure:
+ """
+ Figure data class storing a matplotlib figure and whether to close the figure after logging it
+
+ :param figure: figure to log
+ :param close: if true, close the figure after logging it
+ """
+
+ def __init__(self, figure: plt.figure, close: bool):
+ self.figure = figure
+ self.close = close
+
+
+class Image:
+ """
+ Image data class storing an image and data format
+
+ :param image: image to log
+ :param dataformats: Image data format specification of the form NCHW, NHWC, CHW, HWC, HW, WH, etc.
+ More info in add_image method doc at https://pytorch.org/docs/stable/tensorboard.html
+ Gym envs normally use 'HWC' (channel last)
+ """
+
+ def __init__(self, image: Union[th.Tensor, np.ndarray, str], dataformats: str):
+ self.image = image
+ self.dataformats = dataformats
+
+
+class FormatUnsupportedError(NotImplementedError):
+ """
+ Custom error to display informative message when
+ a value is not supported by some formats.
+
+ :param unsupported_formats: A sequence of unsupported formats,
+ for instance ``["stdout"]``.
+ :param value_description: Description of the value that cannot be logged by this format.
+ """
+
+ def __init__(self, unsupported_formats: Sequence[str], value_description: str):
+ if len(unsupported_formats) > 1:
+ format_str = f"formats {', '.join(unsupported_formats)} are"
+ else:
+ format_str = f"format {unsupported_formats[0]} is"
+ super().__init__(
+ f"The {format_str} not supported for the {value_description} value logged.\n"
+ f"You can exclude formats via the `exclude` parameter of the logger's `record` function."
+ )
+
+
+class KVWriter:
+ """
+ Key Value writer
+ """
+
+ def write(self, key_values: Dict[str, Any], key_excluded: Dict[str, Union[str, Tuple[str, ...]]],
+ step: int = 0) -> None:
+ """
+ Write a dictionary to file
+
+ :param key_values:
+ :param key_excluded:
+ :param step:
+ """
+ raise NotImplementedError
+
+ def close(self) -> None:
+ """
+ Close owned resources
+ """
+ raise NotImplementedError
+
+
+class SeqWriter:
+ """
+ sequence writer
+ """
+
+ def write_sequence(self, sequence: List) -> None:
+ """
+ write_sequence an array to file
+
+ :param sequence:
+ """
+ raise NotImplementedError
+
+
+class HumanOutputFormat(KVWriter, SeqWriter):
+ """A human-readable output format producing ASCII tables of key-value pairs.
+
+ Set attribute ``max_length`` to change the maximum length of keys and values
+ to write to output (or specify it when calling ``__init__``).
+
+ :param filename_or_file: the file to write the log to
+ :param max_length: the maximum length of keys and values to write to output.
+ Outputs longer than this will be truncated. An error will be raised
+ if multiple keys are truncated to the same value. The maximum output
+ width will be ``2*max_length + 7``. The default of 36 produces output
+ no longer than 79 characters wide.
+ """
+
+ def __init__(self, filename_or_file: Union[str, TextIO], max_length: int = 36):
+ self.max_length = max_length
+ if isinstance(filename_or_file, str):
+ self.file = open(filename_or_file, "wt")
+ self.own_file = True
+ else:
+ assert hasattr(filename_or_file, "write"), f"Expected file or str, got {filename_or_file}"
+ self.file = filename_or_file
+ self.own_file = False
+
+ def write(self, key_values: Dict, key_excluded: Dict, step: int = 0) -> None:
+ # Create strings for printing
+ key2str = {}
+ tag = None
+ for (key, value), (_, excluded) in zip(sorted(key_values.items()), sorted(key_excluded.items())):
+
+ if excluded is not None and ("stdout" in excluded or "log" in excluded):
+ continue
+
+ elif isinstance(value, Video):
+ raise FormatUnsupportedError(["stdout", "log"], "video")
+
+ elif isinstance(value, Figure):
+ raise FormatUnsupportedError(["stdout", "log"], "figure")
+
+ elif isinstance(value, Image):
+ raise FormatUnsupportedError(["stdout", "log"], "image")
+
+ elif isinstance(value, float):
+ # Align left
+ value_str = f"{value:<8.3g}"
+ else:
+ value_str = str(value)
+
+ if key.find("/") > 0: # Find tag and add it to the dict
+ tag = key[: key.find("/") + 1]
+ key2str[self._truncate(tag)] = ""
+ # Remove tag from key
+ if tag is not None and tag in key:
+ key = str(" " + key[len(tag):])
+
+ truncated_key = self._truncate(key)
+ if truncated_key in key2str:
+ raise ValueError(
+ f"Key '{key}' truncated to '{truncated_key}' that already exists. Consider increasing `max_length`."
+ )
+ key2str[truncated_key] = self._truncate(value_str)
+
+ # Find max widths
+ if len(key2str) == 0:
+ warnings.warn("Tried to write empty key-value dict")
+ return
+ else:
+ key_width = max(map(len, key2str.keys()))
+ val_width = max(map(len, key2str.values()))
+
+ # Write out the data
+ dashes = "-" * (key_width + val_width + 7)
+ lines = [dashes]
+ for key, value in key2str.items():
+ key_space = " " * (key_width - len(key))
+ val_space = " " * (val_width - len(value))
+ lines.append(f"| {key}{key_space} | {value}{val_space} |")
+ lines.append(dashes)
+ self.file.write("\n".join(lines) + "\n")
+
+ # Flush the output to the file
+ self.file.flush()
+
+ def _truncate(self, string: str) -> str:
+ if len(string) > self.max_length:
+ string = string[: self.max_length - 3] + "..."
+ return string
+
+ def write_sequence(self, sequence: List) -> None:
+ sequence = list(sequence)
+ for i, elem in enumerate(sequence):
+ self.file.write(elem)
+ if i < len(sequence) - 1: # add space unless this is the last one
+ self.file.write(" ")
+ self.file.write("\n")
+ self.file.flush()
+
+ def close(self) -> None:
+ """
+ closes the file
+ """
+ if self.own_file:
+ self.file.close()
+
+
+def filter_excluded_keys(
+ key_values: Dict[str, Any], key_excluded: Dict[str, Union[str, Tuple[str, ...]]], _format: str
+) -> Dict[str, Any]:
+ """
+ Filters the keys specified by ``key_exclude`` for the specified format
+
+ :param key_values: log dictionary to be filtered
+ :param key_excluded: keys to be excluded per format
+ :param _format: format for which this filter is run
+ :return: dict without the excluded keys
+ """
+
+ def is_excluded(key: str) -> bool:
+ return key in key_excluded and key_excluded[key] is not None and _format in key_excluded[key]
+
+ return {key: value for key, value in key_values.items() if not is_excluded(key)}
+
+
+class JSONOutputFormat(KVWriter):
+ """
+ Log to a file, in the JSON format
+
+ :param filename: the file to write the log to
+ """
+
+ def __init__(self, filename: str):
+ self.file = open(filename, "wt")
+
+ def write(self, key_values: Dict[str, Any], key_excluded: Dict[str, Union[str, Tuple[str, ...]]],
+ step: int = 0) -> None:
+ def cast_to_json_serializable(value: Any):
+ if isinstance(value, Video):
+ raise FormatUnsupportedError(["json"], "video")
+ if isinstance(value, Figure):
+ raise FormatUnsupportedError(["json"], "figure")
+ if isinstance(value, Image):
+ raise FormatUnsupportedError(["json"], "image")
+ if hasattr(value, "dtype"):
+ if value.shape == () or len(value) == 1:
+ # if value is a dimensionless numpy array or of length 1, serialize as a float
+ return float(value)
+ else:
+ # otherwise, a value is a numpy array, serialize as a list or nested lists
+ return value.tolist()
+ return value
+
+ key_values = {
+ key: cast_to_json_serializable(value)
+ for key, value in filter_excluded_keys(key_values, key_excluded, "json").items()
+ }
+ self.file.write(json.dumps(key_values) + "\n")
+ self.file.flush()
+
+ def close(self) -> None:
+ """
+ closes the file
+ """
+
+ self.file.close()
+
+
+class CSVOutputFormat(KVWriter):
+ """
+ Log to a file, in a CSV format
+
+ :param filename: the file to write the log to
+ """
+
+ def __init__(self, filename: str):
+ self.file = open(filename, "w+t")
+ self.keys = []
+ self.separator = ","
+ self.quotechar = '"'
+
+ def write(self, key_values: Dict[str, Any], key_excluded: Dict[str, Union[str, Tuple[str, ...]]],
+ step: int = 0) -> None:
+ # Add our current row to the history
+ key_values = filter_excluded_keys(key_values, key_excluded, "csv")
+ extra_keys = key_values.keys() - self.keys
+ if extra_keys:
+ self.keys.extend(extra_keys)
+ self.file.seek(0)
+ lines = self.file.readlines()
+ self.file.seek(0)
+ for (i, key) in enumerate(self.keys):
+ if i > 0:
+ self.file.write(",")
+ self.file.write(key)
+ self.file.write("\n")
+ for line in lines[1:]:
+ self.file.write(line[:-1])
+ self.file.write(self.separator * len(extra_keys))
+ self.file.write("\n")
+ for i, key in enumerate(self.keys):
+ if i > 0:
+ self.file.write(",")
+ value = key_values.get(key)
+
+ if isinstance(value, Video):
+ raise FormatUnsupportedError(["csv"], "video")
+
+ elif isinstance(value, Figure):
+ raise FormatUnsupportedError(["csv"], "figure")
+
+ elif isinstance(value, Image):
+ raise FormatUnsupportedError(["csv"], "image")
+
+ elif isinstance(value, str):
+ # escape quotechars by prepending them with another quotechar
+ value = value.replace(self.quotechar, self.quotechar + self.quotechar)
+
+ # additionally wrap text with quotechars so that any delimiters in the text are ignored by csv readers
+ self.file.write(self.quotechar + value + self.quotechar)
+
+ elif value is not None:
+ self.file.write(str(value))
+ self.file.write("\n")
+ self.file.flush()
+
+ def close(self) -> None:
+ """
+ closes the file
+ """
+ self.file.close()
+
+
+class TensorBoardOutputFormat(KVWriter):
+ """
+ Dumps key/value pairs into TensorBoard's numeric format.
+
+ :param folder: the folder to write the log to
+ """
+
+ def __init__(self, folder: str):
+ assert SummaryWriter is not None, "tensorboard is not installed, you can use " "pip install tensorboard to do so"
+ self.writer = SummaryWriter(log_dir=folder)
+
+ def write(self, key_values: Dict[str, Any], key_excluded: Dict[str, Union[str, Tuple[str, ...]]],
+ step: int = 0) -> None:
+ for (key, value), (_, excluded) in zip(sorted(key_values.items()), sorted(key_excluded.items())):
+
+ if excluded is not None and "tensorboard" in excluded:
+ continue
+
+ if isinstance(value, np.ScalarType):
+ if isinstance(value, str):
+ # str is considered a np.ScalarType
+ self.writer.add_text(key, value, step)
+ else:
+ self.writer.add_scalar(key, value, step)
+
+ if isinstance(value, th.Tensor):
+ self.writer.add_histogram(key, value, step)
+
+ if isinstance(value, Video):
+ self.writer.add_video(key, value.frames, step, value.fps)
+
+ if isinstance(value, Figure):
+ self.writer.add_figure(key, value.figure, step, close=value.close)
+
+ if isinstance(value, Image):
+ self.writer.add_image(key, value.image, step, dataformats=value.dataformats)
+
+ # Flush the output to the file
+ self.writer.flush()
+
+ def close(self) -> None:
+ """
+ closes the file
+ """
+ if self.writer:
+ self.writer.close()
+ self.writer = None
+
+
+def make_output_format(_format: str, log_dir: str, log_suffix: str = "") -> KVWriter:
+ """
+ return a logger for the requested format
+
+ :param _format: the requested format to log to ('stdout', 'log', 'json' or 'csv' or 'tensorboard')
+ :param log_dir: the logging directory
+ :param log_suffix: the suffix for the log file
+ :return: the logger
+ """
+ os.makedirs(log_dir, exist_ok=True)
+ if _format == "stdout":
+ return HumanOutputFormat(sys.stdout)
+ elif _format == "log":
+ return HumanOutputFormat(os.path.join(log_dir, f"log{log_suffix}.txt"))
+ elif _format == "json":
+ return JSONOutputFormat(os.path.join(log_dir, f"progress{log_suffix}.json"))
+ elif _format == "csv":
+ return CSVOutputFormat(os.path.join(log_dir, f"progress{log_suffix}.csv"))
+ elif _format == "tensorboard":
+ return TensorBoardOutputFormat(log_dir)
+ else:
+ raise ValueError(f"Unknown format specified: {_format}")
+
+
+# ================================================================
+# Backend
+# ================================================================
+
+
+class Logger:
+ """
+ The logger class.
+
+ :param folder: the logging location
+ :param output_formats: the list of output formats
+ """
+
+ def __init__(self, folder: Optional[str], output_formats: List[KVWriter]):
+ self.name_to_value = defaultdict(float) # values this iteration
+ self.name_to_count = defaultdict(int)
+ self.name_to_excluded = defaultdict(str)
+ self.level = INFO
+ self.dir = folder
+ self.output_formats = output_formats
+
+ def record(self, key: str, value: Any, exclude: Optional[Union[str, Tuple[str, ...]]] = None) -> None:
+ """
+ Log a value of some diagnostic
+ Call this once for each diagnostic quantity, each iteration
+ If called many times, last value will be used.
+
+ :param key: save to log this key
+ :param value: save to log this value
+ :param exclude: outputs to be excluded
+ """
+ self.name_to_value[key] = value
+ self.name_to_excluded[key] = exclude
+
+ def record_mean(self, key: str, value: Any, exclude: Optional[Union[str, Tuple[str, ...]]] = None) -> None:
+ """
+ The same as record(), but if called many times, values averaged.
+
+ :param key: save to log this key
+ :param value: save to log this value
+ :param exclude: outputs to be excluded
+ """
+ if value is None:
+ self.name_to_value[key] = None
+ return
+ old_val, count = self.name_to_value[key], self.name_to_count[key]
+ self.name_to_value[key] = old_val * count / (count + 1) + value / (count + 1)
+ self.name_to_count[key] = count + 1
+ self.name_to_excluded[key] = exclude
+
+ def dump(self, step: int = 0) -> None:
+ """
+ Write all of the diagnostics from the current iteration
+ """
+ if self.level == DISABLED:
+ return
+ for _format in self.output_formats:
+ if isinstance(_format, KVWriter):
+ _format.write(self.name_to_value, self.name_to_excluded, step)
+
+ self.name_to_value.clear()
+ self.name_to_count.clear()
+ self.name_to_excluded.clear()
+
+ def log(self, *args, level: int = INFO) -> None:
+ """
+ Write the sequence of args, with no separators,
+ to the console and output files (if you've configured an output file).
+
+ level: int. (see logger.py docs) If the global logger level is higher than
+ the level argument here, don't print to stdout.
+
+ :param args: log the arguments
+ :param level: the logging level (can be DEBUG=10, INFO=20, WARN=30, ERROR=40, DISABLED=50)
+ """
+ if self.level <= level:
+ self._do_log(args)
+
+ def debug(self, *args) -> None:
+ """
+ Write the sequence of args, with no separators,
+ to the console and output files (if you've configured an output file).
+ Using the DEBUG level.
+
+ :param args: log the arguments
+ """
+ self.log(*args, level=DEBUG)
+
+ def info(self, *args) -> None:
+ """
+ Write the sequence of args, with no separators,
+ to the console and output files (if you've configured an output file).
+ Using the INFO level.
+
+ :param args: log the arguments
+ """
+ self.log(*args, level=INFO)
+
+ def warn(self, *args) -> None:
+ """
+ Write the sequence of args, with no separators,
+ to the console and output files (if you've configured an output file).
+ Using the WARN level.
+
+ :param args: log the arguments
+ """
+ self.log(*args, level=WARN)
+
+ def error(self, *args) -> None:
+ """
+ Write the sequence of args, with no separators,
+ to the console and output files (if you've configured an output file).
+ Using the ERROR level.
+
+ :param args: log the arguments
+ """
+ self.log(*args, level=ERROR)
+
+ # Configuration
+ # ----------------------------------------
+ def set_level(self, level: int) -> None:
+ """
+ Set logging threshold on current logger.
+
+ :param level: the logging level (can be DEBUG=10, INFO=20, WARN=30, ERROR=40, DISABLED=50)
+ """
+ self.level = level
+
+ def get_dir(self) -> str:
+ """
+ Get directory that log files are being written to.
+ will be None if there is no output directory (i.e., if you didn't call start)
+
+ :return: the logging directory
+ """
+ return self.dir
+
+ def close(self) -> None:
+ """
+ closes the file
+ """
+ for _format in self.output_formats:
+ _format.close()
+
+ # Misc
+ # ----------------------------------------
+ def _do_log(self, args) -> None:
+ """
+ log to the requested format outputs
+
+ :param args: the arguments to log
+ """
+ for _format in self.output_formats:
+ if isinstance(_format, SeqWriter):
+ _format.write_sequence(map(str, args))
+
+
+def configure(folder: Optional[str] = None, format_strings: Optional[List[str]] = None) -> Logger:
+ """
+ Configure the current logger.
+
+ :param folder: the save location
+ (if None, $SB3_LOGDIR, if still None, tempdir/SB3-[date & time])
+ :param format_strings: the output logging format
+ (if None, $SB3_LOG_FORMAT, if still None, ['stdout', 'log', 'csv'])
+ :return: The logger object.
+ """
+ if folder is None:
+ folder = os.getenv("SB3_LOGDIR")
+ if folder is None:
+ folder = os.path.join(tempfile.gettempdir(), datetime.datetime.now().strftime("SB3-%Y-%m-%d-%H-%M-%S-%f"))
+ assert isinstance(folder, str)
+ os.makedirs(folder, exist_ok=True)
+
+ log_suffix = ""
+ if format_strings is None:
+ format_strings = os.getenv("SB3_LOG_FORMAT", "stdout,log,csv").split(",")
+
+ format_strings = list(filter(None, format_strings))
+ output_formats = [make_output_format(f, folder, log_suffix) for f in format_strings]
+
+ logger = Logger(folder=folder, output_formats=output_formats)
+ # Only print when some files will be saved
+ if len(format_strings) > 0 and format_strings != ["stdout"]:
+ logger.log(f"Logging to {folder}")
+ return logger
+
+
+# ================================================================
+# Readers
+# ================================================================
+
+
+def read_json(filename: str) -> pandas.DataFrame:
+ """
+ read a json file using pandas
+
+ :param filename: the file path to read
+ :return: the data in the json
+ """
+ data = []
+ with open(filename) as file_handler:
+ for line in file_handler:
+ data.append(json.loads(line))
+ return pandas.DataFrame(data)
+
+
+def read_csv(filename: str) -> pandas.DataFrame:
+ """
+ read a csv file using pandas
+
+ :param filename: the file path to read
+ :return: the data in the csv
+ """
+ return pandas.read_csv(filename, index_col=None, comment="#")
diff --git a/dexart-release/stable_baselines3/common/monitor.py b/dexart-release/stable_baselines3/common/monitor.py
new file mode 100644
index 0000000000000000000000000000000000000000..a482b72be66f58f5f9776b74d7ce595b400ad6eb
--- /dev/null
+++ b/dexart-release/stable_baselines3/common/monitor.py
@@ -0,0 +1,239 @@
+__all__ = ["Monitor", "ResultsWriter", "get_monitor_files", "load_results"]
+
+import csv
+import json
+import os
+import time
+from glob import glob
+from typing import Dict, List, Optional, Tuple, Union
+
+import gym
+import numpy as np
+import pandas
+
+from stable_baselines3.common.type_aliases import GymObs, GymStepReturn
+
+
+class Monitor(gym.Wrapper):
+ """
+ A monitor wrapper for Gym environments, it is used to know the episode reward, length, time and other data.
+
+ :param env: The environment
+ :param filename: the location to save a log file, can be None for no log
+ :param allow_early_resets: allows the reset of the environment before it is done
+ :param reset_keywords: extra keywords for the reset call,
+ if extra parameters are needed at reset
+ :param info_keywords: extra information to log, from the information return of env.step()
+ """
+
+ EXT = "monitor.csv"
+
+ def __init__(
+ self,
+ env: gym.Env,
+ filename: Optional[str] = None,
+ allow_early_resets: bool = True,
+ reset_keywords: Tuple[str, ...] = (),
+ info_keywords: Tuple[str, ...] = (),
+ ):
+ super().__init__(env=env)
+ self.t_start = time.time()
+ if filename is not None:
+ self.results_writer = ResultsWriter(
+ filename,
+ header={"t_start": self.t_start, "env_id": env.spec and env.spec.id},
+ extra_keys=reset_keywords + info_keywords,
+ )
+ else:
+ self.results_writer = None
+ self.reset_keywords = reset_keywords
+ self.info_keywords = info_keywords
+ self.allow_early_resets = allow_early_resets
+ self.rewards = None
+ self.needs_reset = True
+ self.episode_returns = []
+ self.episode_lengths = []
+ self.episode_times = []
+ self.total_steps = 0
+ self.current_reset_info = {} # extra info about the current episode, that was passed in during reset()
+
+ def reset(self, **kwargs) -> GymObs:
+ """
+ Calls the Gym environment reset. Can only be called if the environment is over, or if allow_early_resets is True
+
+ :param kwargs: Extra keywords saved for the next episode. only if defined by reset_keywords
+ :return: the first observation of the environment
+ """
+ if not self.allow_early_resets and not self.needs_reset:
+ raise RuntimeError(
+ "Tried to reset an environment before done. If you want to allow early resets, "
+ "wrap your env with Monitor(env, path, allow_early_resets=True)"
+ )
+ self.rewards = []
+ self.needs_reset = False
+ for key in self.reset_keywords:
+ value = kwargs.get(key)
+ if value is None:
+ raise ValueError(f"Expected you to pass keyword argument {key} into reset")
+ self.current_reset_info[key] = value
+ return self.env.reset(**kwargs)
+
+ def step(self, action: Union[np.ndarray, int]) -> GymStepReturn:
+ """
+ Step the environment with the given action
+
+ :param action: the action
+ :return: observation, reward, done, information
+ """
+ if self.needs_reset:
+ raise RuntimeError("Tried to step environment that needs reset")
+ observation, reward, done, info = self.env.step(action)
+ self.rewards.append(reward)
+ if done:
+ self.needs_reset = True
+ ep_rew = sum(self.rewards)
+ ep_len = len(self.rewards)
+ ep_info = {"r": round(ep_rew, 6), "l": ep_len, "t": round(time.time() - self.t_start, 6)}
+ for key in self.info_keywords:
+ ep_info[key] = info[key]
+ self.episode_returns.append(ep_rew)
+ self.episode_lengths.append(ep_len)
+ self.episode_times.append(time.time() - self.t_start)
+ ep_info.update(self.current_reset_info)
+ if self.results_writer:
+ self.results_writer.write_row(ep_info)
+ info["episode"] = ep_info
+ self.total_steps += 1
+ return observation, reward, done, info
+
+ def close(self) -> None:
+ """
+ Closes the environment
+ """
+ super().close()
+ if self.results_writer is not None:
+ self.results_writer.close()
+
+ def get_total_steps(self) -> int:
+ """
+ Returns the total number of timesteps
+
+ :return:
+ """
+ return self.total_steps
+
+ def get_episode_rewards(self) -> List[float]:
+ """
+ Returns the rewards of all the episodes
+
+ :return:
+ """
+ return self.episode_returns
+
+ def get_episode_lengths(self) -> List[int]:
+ """
+ Returns the number of timesteps of all the episodes
+
+ :return:
+ """
+ return self.episode_lengths
+
+ def get_episode_times(self) -> List[float]:
+ """
+ Returns the runtime in seconds of all the episodes
+
+ :return:
+ """
+ return self.episode_times
+
+
+class LoadMonitorResultsError(Exception):
+ """
+ Raised when loading the monitor log fails.
+ """
+
+ pass
+
+
+class ResultsWriter:
+ """
+ A result writer that saves the data from the `Monitor` class
+
+ :param filename: the location to save a log file, can be None for no log
+ :param header: the header dictionary object of the saved csv
+ :param reset_keywords: the extra information to log, typically is composed of
+ ``reset_keywords`` and ``info_keywords``
+ """
+
+ def __init__(
+ self,
+ filename: str = "",
+ header: Optional[Dict[str, Union[float, str]]] = None,
+ extra_keys: Tuple[str, ...] = (),
+ ):
+ if header is None:
+ header = {}
+ if not filename.endswith(Monitor.EXT):
+ if os.path.isdir(filename):
+ filename = os.path.join(filename, Monitor.EXT)
+ else:
+ filename = filename + "." + Monitor.EXT
+ # Prevent newline issue on Windows, see GH issue #692
+ self.file_handler = open(filename, "wt", newline="\n")
+ self.file_handler.write("#%s\n" % json.dumps(header))
+ self.logger = csv.DictWriter(self.file_handler, fieldnames=("r", "l", "t") + extra_keys)
+ self.logger.writeheader()
+ self.file_handler.flush()
+
+ def write_row(self, epinfo: Dict[str, Union[float, int]]) -> None:
+ """
+ Close the file handler
+
+ :param epinfo: the information on episodic return, length, and time
+ """
+ if self.logger:
+ self.logger.writerow(epinfo)
+ self.file_handler.flush()
+
+ def close(self) -> None:
+ """
+ Close the file handler
+ """
+ self.file_handler.close()
+
+
+def get_monitor_files(path: str) -> List[str]:
+ """
+ get all the monitor files in the given path
+
+ :param path: the logging folder
+ :return: the log files
+ """
+ return glob(os.path.join(path, "*" + Monitor.EXT))
+
+
+def load_results(path: str) -> pandas.DataFrame:
+ """
+ Load all Monitor logs from a given directory path matching ``*monitor.csv``
+
+ :param path: the directory path containing the log file(s)
+ :return: the logged data
+ """
+ monitor_files = get_monitor_files(path)
+ if len(monitor_files) == 0:
+ raise LoadMonitorResultsError(f"No monitor files of the form *{Monitor.EXT} found in {path}")
+ data_frames, headers = [], []
+ for file_name in monitor_files:
+ with open(file_name) as file_handler:
+ first_line = file_handler.readline()
+ assert first_line[0] == "#"
+ header = json.loads(first_line[1:])
+ data_frame = pandas.read_csv(file_handler, index_col=None)
+ headers.append(header)
+ data_frame["t"] += header["t_start"]
+ data_frames.append(data_frame)
+ data_frame = pandas.concat(data_frames)
+ data_frame.sort_values("t", inplace=True)
+ data_frame.reset_index(inplace=True)
+ data_frame["t"] -= min(header["t_start"] for header in headers)
+ return data_frame
diff --git a/dexart-release/stable_baselines3/common/noise.py b/dexart-release/stable_baselines3/common/noise.py
new file mode 100644
index 0000000000000000000000000000000000000000..119ed362ec47363fcfe91343c8d823581b925e17
--- /dev/null
+++ b/dexart-release/stable_baselines3/common/noise.py
@@ -0,0 +1,167 @@
+import copy
+from abc import ABC, abstractmethod
+from typing import Iterable, List, Optional
+
+import numpy as np
+
+
+class ActionNoise(ABC):
+ """
+ The action noise base class
+ """
+
+ def __init__(self):
+ super().__init__()
+
+ def reset(self) -> None:
+ """
+ call end of episode reset for the noise
+ """
+ pass
+
+ @abstractmethod
+ def __call__(self) -> np.ndarray:
+ raise NotImplementedError()
+
+
+class NormalActionNoise(ActionNoise):
+ """
+ A Gaussian action noise
+
+ :param mean: the mean value of the noise
+ :param sigma: the scale of the noise (std here)
+ """
+
+ def __init__(self, mean: np.ndarray, sigma: np.ndarray):
+ self._mu = mean
+ self._sigma = sigma
+ super().__init__()
+
+ def __call__(self) -> np.ndarray:
+ return np.random.normal(self._mu, self._sigma)
+
+ def __repr__(self) -> str:
+ return f"NormalActionNoise(mu={self._mu}, sigma={self._sigma})"
+
+
+class OrnsteinUhlenbeckActionNoise(ActionNoise):
+ """
+ An Ornstein Uhlenbeck action noise, this is designed to approximate Brownian motion with friction.
+
+ Based on http://math.stackexchange.com/questions/1287634/implementing-ornstein-uhlenbeck-in-matlab
+
+ :param mean: the mean of the noise
+ :param sigma: the scale of the noise
+ :param theta: the rate of mean reversion
+ :param dt: the timestep for the noise
+ :param initial_noise: the initial value for the noise output, (if None: 0)
+ """
+
+ def __init__(
+ self,
+ mean: np.ndarray,
+ sigma: np.ndarray,
+ theta: float = 0.15,
+ dt: float = 1e-2,
+ initial_noise: Optional[np.ndarray] = None,
+ ):
+ self._theta = theta
+ self._mu = mean
+ self._sigma = sigma
+ self._dt = dt
+ self.initial_noise = initial_noise
+ self.noise_prev = np.zeros_like(self._mu)
+ self.reset()
+ super().__init__()
+
+ def __call__(self) -> np.ndarray:
+ noise = (
+ self.noise_prev
+ + self._theta * (self._mu - self.noise_prev) * self._dt
+ + self._sigma * np.sqrt(self._dt) * np.random.normal(size=self._mu.shape)
+ )
+ self.noise_prev = noise
+ return noise
+
+ def reset(self) -> None:
+ """
+ reset the Ornstein Uhlenbeck noise, to the initial position
+ """
+ self.noise_prev = self.initial_noise if self.initial_noise is not None else np.zeros_like(self._mu)
+
+ def __repr__(self) -> str:
+ return f"OrnsteinUhlenbeckActionNoise(mu={self._mu}, sigma={self._sigma})"
+
+
+class VectorizedActionNoise(ActionNoise):
+ """
+ A Vectorized action noise for parallel environments.
+
+ :param base_noise: ActionNoise The noise generator to use
+ :param n_envs: The number of parallel environments
+ """
+
+ def __init__(self, base_noise: ActionNoise, n_envs: int):
+ try:
+ self.n_envs = int(n_envs)
+ assert self.n_envs > 0
+ except (TypeError, AssertionError):
+ raise ValueError(f"Expected n_envs={n_envs} to be positive integer greater than 0")
+
+ self.base_noise = base_noise
+ self.noises = [copy.deepcopy(self.base_noise) for _ in range(n_envs)]
+
+ def reset(self, indices: Optional[Iterable[int]] = None) -> None:
+ """
+ Reset all the noise processes, or those listed in indices
+
+ :param indices: Optional[Iterable[int]] The indices to reset. Default: None.
+ If the parameter is None, then all processes are reset to their initial position.
+ """
+ if indices is None:
+ indices = range(len(self.noises))
+
+ for index in indices:
+ self.noises[index].reset()
+
+ def __repr__(self) -> str:
+ return f"VecNoise(BaseNoise={repr(self.base_noise)}), n_envs={len(self.noises)})"
+
+ def __call__(self) -> np.ndarray:
+ """
+ Generate and stack the action noise from each noise object
+ """
+ noise = np.stack([noise() for noise in self.noises])
+ return noise
+
+ @property
+ def base_noise(self) -> ActionNoise:
+ return self._base_noise
+
+ @base_noise.setter
+ def base_noise(self, base_noise: ActionNoise) -> None:
+ if base_noise is None:
+ raise ValueError("Expected base_noise to be an instance of ActionNoise, not None", ActionNoise)
+ if not isinstance(base_noise, ActionNoise):
+ raise TypeError("Expected base_noise to be an instance of type ActionNoise", ActionNoise)
+ self._base_noise = base_noise
+
+ @property
+ def noises(self) -> List[ActionNoise]:
+ return self._noises
+
+ @noises.setter
+ def noises(self, noises: List[ActionNoise]) -> None:
+ noises = list(noises) # raises TypeError if not iterable
+ assert len(noises) == self.n_envs, f"Expected a list of {self.n_envs} ActionNoises, found {len(noises)}."
+
+ different_types = [i for i, noise in enumerate(noises) if not isinstance(noise, type(self.base_noise))]
+
+ if len(different_types):
+ raise ValueError(
+ f"Noise instances at indices {different_types} don't match the type of base_noise", type(self.base_noise)
+ )
+
+ self._noises = noises
+ for noise in noises:
+ noise.reset()
diff --git a/dexart-release/stable_baselines3/common/on_policy_algorithm.py b/dexart-release/stable_baselines3/common/on_policy_algorithm.py
new file mode 100644
index 0000000000000000000000000000000000000000..8b870fc2af6b607af01c04b6dad067fc0d67fbe3
--- /dev/null
+++ b/dexart-release/stable_baselines3/common/on_policy_algorithm.py
@@ -0,0 +1,320 @@
+import time
+from typing import Any, Dict, List, Optional, Tuple, Type, Union
+
+import gym
+import numpy as np
+import torch as th
+
+from stable_baselines3.common.base_class import BaseAlgorithm
+from stable_baselines3.common.buffers import DictRolloutBuffer, RolloutBuffer
+from stable_baselines3.common.callbacks import BaseCallback
+from stable_baselines3.common.policies import ActorCriticPolicy
+from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, Schedule
+from stable_baselines3.common.utils import obs_as_tensor, safe_mean
+from stable_baselines3.common.vec_env import VecEnv
+
+
+class OnPolicyAlgorithm(BaseAlgorithm):
+ """
+ The base for On-Policy algorithms (ex: A2C/PPO).
+
+ :param policy: The policy model to use (MlpPolicy, CnnPolicy, ...)
+ :param env: The environment to learn from (if registered in Gym, can be str)
+ :param learning_rate: The learning rate, it can be a function
+ of the current progress remaining (from 1 to 0)
+ :param n_steps: The number of steps to run for each environment per update
+ (i.e. batch size is n_steps * n_env where n_env is number of environment copies running in parallel)
+ :param gamma: Discount factor
+ :param gae_lambda: Factor for trade-off of bias vs variance for Generalized Advantage Estimator.
+ Equivalent to classic advantage when set to 1.
+ :param ent_coef: Entropy coefficient for the loss calculation
+ :param vf_coef: Value function coefficient for the loss calculation
+ :param max_grad_norm: The maximum value for the gradient clipping
+ :param use_sde: Whether to use generalized State Dependent Exploration (gSDE)
+ instead of action noise exploration (default: False)
+ :param sde_sample_freq: Sample a new noise matrix every n steps when using gSDE
+ Default: -1 (only sample at the beginning of the rollout)
+ :param tensorboard_log: the log location for tensorboard (if None, no logging)
+ :param create_eval_env: Whether to create a second environment that will be
+ used for evaluating the agent periodically. (Only available when passing string for the environment)
+ :param monitor_wrapper: When creating an environment, whether to wrap it
+ or not in a Monitor wrapper.
+ :param policy_kwargs: additional arguments to be passed to the policy on creation
+ :param verbose: the verbosity level: 0 no output, 1 info, 2 debug
+ :param seed: Seed for the pseudo random generators
+ :param device: Device (cpu, cuda, ...) on which the code should be run.
+ Setting it to auto, the code will be run on the GPU if possible.
+ :param _init_setup_model: Whether or not to build the network at the creation of the instance
+ :param supported_action_spaces: The action spaces supported by the algorithm.
+ """
+
+ def __init__(
+ self,
+ policy: Union[str, Type[ActorCriticPolicy]],
+ env: Union[GymEnv, str],
+ learning_rate: Union[float, Schedule],
+ n_steps: int,
+ gamma: float,
+ gae_lambda: float,
+ ent_coef: float,
+ vf_coef: float,
+ max_grad_norm: float,
+ use_sde: bool,
+ sde_sample_freq: int,
+ tensorboard_log: Optional[str] = None,
+ create_eval_env: bool = False,
+ monitor_wrapper: bool = True,
+ policy_kwargs: Optional[Dict[str, Any]] = None,
+ verbose: int = 0,
+ seed: Optional[int] = None,
+ device: Union[th.device, str] = "auto",
+ _init_setup_model: bool = True,
+ supported_action_spaces: Optional[Tuple[gym.spaces.Space, ...]] = None,
+ ):
+
+ super().__init__(
+ policy=policy,
+ env=env,
+ learning_rate=learning_rate,
+ policy_kwargs=policy_kwargs,
+ verbose=verbose,
+ device=device,
+ use_sde=use_sde,
+ sde_sample_freq=sde_sample_freq,
+ create_eval_env=create_eval_env,
+ support_multi_env=True,
+ seed=seed,
+ tensorboard_log=tensorboard_log,
+ supported_action_spaces=supported_action_spaces,
+ )
+
+ self.n_steps = n_steps
+ self.gamma = gamma
+ self.gae_lambda = gae_lambda
+ self.ent_coef = ent_coef
+ self.vf_coef = vf_coef
+ self.max_grad_norm = max_grad_norm
+ self.rollout_buffer = None
+
+ self.last_rollout_reward = -np.inf
+ self.need_restore = False
+ self.last_policy_saved: List[Dict] = [{}, {}]
+ self.current_restore_step = 0
+
+ if _init_setup_model:
+ self._setup_model()
+
+ def _setup_model(self) -> None:
+ self._setup_lr_schedule()
+ self.set_random_seed(self.seed)
+
+ buffer_cls = DictRolloutBuffer if isinstance(self.observation_space, gym.spaces.Dict) else RolloutBuffer
+
+ self.rollout_buffer = buffer_cls(
+ self.n_steps,
+ self.observation_space,
+ self.action_space,
+ device=self.device,
+ gamma=self.gamma,
+ gae_lambda=self.gae_lambda,
+ n_envs=self.n_envs,
+ )
+ self.policy = self.policy_class( # pytype:disable=not-instantiable
+ self.observation_space,
+ self.action_space,
+ self.lr_schedule,
+ use_sde=self.use_sde,
+ **self.policy_kwargs # pytype:disable=not-instantiable
+ )
+ self.policy = self.policy.to(self.device)
+
+ def collect_rollouts(
+ self,
+ env: VecEnv,
+ callback: BaseCallback,
+ rollout_buffer: RolloutBuffer,
+ n_rollout_steps: int,
+ ) -> bool:
+ """
+ Collect experiences using the current policy and fill a ``RolloutBuffer``.
+ The term rollout here refers to the model-free notion and should not
+ be used with the concept of rollout used in model-based RL or planning.
+
+ :param env: The training environment
+ :param callback: Callback that will be called at each step
+ (and at the beginning and end of the rollout)
+ :param rollout_buffer: Buffer to fill with rollouts
+ :param n_steps: Number of experiences to collect per environment
+ :return: True if function returned with at least `n_rollout_steps`
+ collected, False if callback terminated rollout prematurely.
+ """
+ assert self._last_obs is not None, "No previous observation was provided"
+ # Switch to eval mode (this affects batch norm / dropout)
+ self.policy.set_training_mode(False)
+ last_episode_reward = self.last_rollout_reward
+ self.last_rollout_reward = 0
+ num_rollouts = 0
+ n_steps = 0
+ rollout_buffer.reset()
+ # Sample new weights for the state dependent exploration
+ if self.use_sde:
+ self.policy.reset_noise(env.num_envs)
+ callback.on_rollout_start()
+ while n_steps < n_rollout_steps:
+ if self.use_sde and self.sde_sample_freq > 0 and n_steps % self.sde_sample_freq == 0:
+ # Sample a new noise matrix
+ self.policy.reset_noise(env.num_envs)
+
+ with th.no_grad():
+ # Convert to pytorch tensor or to TensorDict
+ obs_tensor = obs_as_tensor(self._last_obs, self.device)
+ actions, values, log_probs = self.policy(obs_tensor)
+ actions = actions.cpu().numpy()
+
+ # Rescale and perform action
+ clipped_actions = actions
+ # Clip the actions to avoid out of bound error
+ if isinstance(self.action_space, gym.spaces.Box):
+ clipped_actions = np.clip(actions, self.action_space.low, self.action_space.high)
+
+ new_obs, rewards, dones, infos = env.step(clipped_actions)
+
+ self.num_timesteps += env.num_envs
+
+ # Give access to local variables
+ callback.update_locals(locals())
+ if callback.on_step() is False:
+ return False
+
+ self._update_info_buffer(infos)
+ n_steps += 1
+
+ if isinstance(self.action_space, gym.spaces.Discrete):
+ # Reshape in case of discrete action
+ actions = actions.reshape(-1, 1)
+
+ # Handle timeout by bootstraping with value function
+ # see GitHub issue #633
+ for idx, done in enumerate(dones):
+ if (
+ done
+ and infos[idx].get("terminal_observation") is not None
+ and infos[idx].get("TimeLimit.truncated", False)
+ ):
+ terminal_obs = self.policy.obs_to_tensor(infos[idx]["terminal_observation"])[0]
+ with th.no_grad():
+ terminal_value = self.policy.predict_values(terminal_obs)[0]
+ rewards[idx] += self.gamma * terminal_value
+
+ if done:
+ num_rollouts += 1
+
+ self.last_rollout_reward += rewards.sum()
+
+ rollout_buffer.add(self._last_obs, actions, rewards, self._last_episode_starts, values, log_probs)
+ self._last_obs = new_obs
+ self._last_episode_starts = dones
+ with th.no_grad():
+ # Compute value for the last timestep
+ values = self.policy.predict_values(obs_as_tensor(new_obs, self.device))
+
+ rollout_buffer.compute_returns_and_advantage(last_values=values, dones=dones)
+
+ self.last_rollout_reward /= num_rollouts
+ reward_gap = last_episode_reward - self.last_rollout_reward
+ self.need_restore = False
+ self.current_restore_step = 0
+
+ callback.on_rollout_end()
+
+ return True
+
+ def train(self) -> None:
+ """
+ Consume current rollout data and update policy parameters.
+ Implemented by individual algorithms.
+ """
+ raise NotImplementedError
+
+ def learn(
+ self,
+ total_timesteps: int,
+ callback: MaybeCallback = None,
+ log_interval: int = 1,
+ eval_env: Optional[GymEnv] = None,
+ eval_freq: int = -1,
+ n_eval_episodes: int = 5,
+ tb_log_name: str = "OnPolicyAlgorithm",
+ eval_log_path: Optional[str] = None,
+ reset_num_timesteps: bool = True,
+ iter_start=0,
+ ) -> "OnPolicyAlgorithm":
+ iteration = iter_start
+
+ total_timesteps, callback = self._setup_learn(
+ total_timesteps, eval_env, callback, eval_freq, n_eval_episodes, eval_log_path, reset_num_timesteps,
+ tb_log_name
+ )
+
+ callback.on_training_start(locals(), globals())
+
+ while self.num_timesteps < total_timesteps:
+
+ x = time.time()
+ continue_training = self.collect_rollouts(self.env, callback, self.rollout_buffer,
+ n_rollout_steps=self.n_steps)
+ print("Rollout time:", time.time() - x)
+
+ if continue_training is False:
+ break
+
+ if self.need_restore and self.current_restore_step < 5:
+ print(f"Large performance drop detected. Restore previous model.")
+ self.set_parameters(self.last_policy_saved[0], exact_match=True, device=self.device)
+ continue
+ else:
+ self.last_policy_saved[0] = self.last_policy_saved[1]
+ self.last_policy_saved[1] = self.get_parameters()
+
+ iteration += 1
+ self._update_current_progress_remaining(self.num_timesteps, total_timesteps)
+
+ # Display training infos
+
+ if log_interval is not None and iteration % log_interval == 0:
+ fps = int((self.num_timesteps - self._num_timesteps_at_start) / (time.time() - self.start_time))
+ self.logger.record("time/iterations", iteration, exclude="wandb")
+ self.logger.record("rollout/rollout_rew_mean", self.last_rollout_reward)
+ if len(self.ep_info_buffer) > 0 and len(self.ep_info_buffer[0]) > 0:
+ self.logger.record("rollout/ep_rew_mean",
+ safe_mean([ep_info["r"] for ep_info in self.ep_info_buffer]))
+ self.logger.record("rollout/ep_len_mean",
+ safe_mean([ep_info["l"] for ep_info in self.ep_info_buffer]))
+ self.logger.record("time/fps", fps)
+ self.logger.record("time/time_elapsed", int(time.time() - self.start_time), exclude="tensorboard")
+ self.logger.record("time/total_timesteps", self.num_timesteps, exclude="tensorboard")
+ self.logger.dump(step=iteration)
+
+ x = time.time()
+ self.train()
+ print("Train time:", time.time() - x)
+
+ callback.on_training_end()
+
+ return self
+
+ def _get_torch_save_params(self) -> Tuple[List[str], List[str]]:
+ state_dicts = ["policy", "policy.optimizer"]
+
+ return state_dicts, []
+
+ def _excluded_save_params(self) -> List[str]:
+ """
+ Returns the names of the parameters that should be excluded from being
+ saved by pickling. E.g. replay buffers are skipped by default
+ as they take up a lot of space. PyTorch variables should be excluded
+ with this so they can be stored with ``th.save``.
+
+ :return: List of parameters that should be excluded from being saved with pickle.
+ """
+ return super()._excluded_save_params() + ["last_policy_saved"]
diff --git a/dexart-release/stable_baselines3/common/policies.py b/dexart-release/stable_baselines3/common/policies.py
new file mode 100644
index 0000000000000000000000000000000000000000..051adc2ed19eab9e91a3fac2c3162e2f6699d16e
--- /dev/null
+++ b/dexart-release/stable_baselines3/common/policies.py
@@ -0,0 +1,898 @@
+"""Policies: abstract base class and concrete implementations."""
+
+import collections
+import copy
+import warnings
+from abc import ABC, abstractmethod
+from functools import partial
+from typing import Any, Dict, List, Optional, Tuple, Type, Union
+
+import gym
+import numpy as np
+import torch as th
+from torch import nn
+
+from stable_baselines3.common.distributions import (
+ BernoulliDistribution,
+ CategoricalDistribution,
+ DiagGaussianDistribution,
+ Distribution,
+ MultiCategoricalDistribution,
+ StateDependentNoiseDistribution,
+ make_proba_distribution,
+)
+from stable_baselines3.common.preprocessing import get_action_dim, is_image_space, maybe_transpose, preprocess_obs
+from stable_baselines3.common.torch_layers import (
+ BaseFeaturesExtractor,
+ CombinedExtractor,
+ FlattenExtractor,
+ MlpExtractor,
+ NatureCNN,
+ create_mlp,
+)
+from stable_baselines3.common.type_aliases import Schedule
+from stable_baselines3.common.utils import get_device, is_vectorized_observation, obs_as_tensor
+
+
+class BaseModel(nn.Module, ABC):
+ """
+ The base model object: makes predictions in response to observations.
+
+ In the case of policies, the prediction is an action. In the case of critics, it is the
+ estimated value of the observation.
+
+ :param observation_space: The observation space of the environment
+ :param action_space: The action space of the environment
+ :param features_extractor_class: Features extractor to use.
+ :param features_extractor_kwargs: Keyword arguments
+ to pass to the features extractor.
+ :param features_extractor: Network to extract features
+ (a CNN when using images, a nn.Flatten() layer otherwise)
+ :param normalize_images: Whether to normalize images or not,
+ dividing by 255.0 (True by default)
+ :param optimizer_class: The optimizer to use,
+ ``th.optim.Adam`` by default
+ :param optimizer_kwargs: Additional keyword arguments,
+ excluding the learning rate, to pass to the optimizer
+ """
+
+ def __init__(
+ self,
+ observation_space: gym.spaces.Space,
+ action_space: gym.spaces.Space,
+ features_extractor_class: Type[BaseFeaturesExtractor] = FlattenExtractor,
+ features_extractor_kwargs: Optional[Dict[str, Any]] = None,
+ features_extractor: Optional[nn.Module] = None,
+ normalize_images: bool = True,
+ optimizer_class: Type[th.optim.Optimizer] = th.optim.Adam,
+ optimizer_kwargs: Optional[Dict[str, Any]] = None,
+ ):
+ super().__init__()
+
+ if optimizer_kwargs is None:
+ optimizer_kwargs = {}
+
+ if features_extractor_kwargs is None:
+ features_extractor_kwargs = {}
+
+ self.observation_space = observation_space
+ self.action_space = action_space
+ self.features_extractor = features_extractor
+ self.normalize_images = normalize_images
+
+ self.optimizer_class = optimizer_class
+ self.optimizer_kwargs = optimizer_kwargs
+ self.optimizer = None # type: Optional[th.optim.Optimizer]
+
+ self.features_extractor_class = features_extractor_class
+ self.features_extractor_kwargs = features_extractor_kwargs
+
+ @abstractmethod
+ def forward(self, *args, **kwargs):
+ pass
+
+ def _update_features_extractor(
+ self,
+ net_kwargs: Dict[str, Any],
+ features_extractor: Optional[BaseFeaturesExtractor] = None,
+ ) -> Dict[str, Any]:
+ """
+ Update the network keyword arguments and create a new features extractor object if needed.
+ If a ``features_extractor`` object is passed, then it will be shared.
+
+ :param net_kwargs: the base network keyword arguments, without the ones
+ related to features extractor
+ :param features_extractor: a features extractor object.
+ If None, a new object will be created.
+ :return: The updated keyword arguments
+ """
+ net_kwargs = net_kwargs.copy()
+ if features_extractor is None:
+ # The features extractor is not shared, create a new one
+ features_extractor = self.make_features_extractor()
+ net_kwargs.update(dict(features_extractor=features_extractor, features_dim=features_extractor.features_dim))
+ return net_kwargs
+
+ def make_features_extractor(self) -> BaseFeaturesExtractor:
+ """Helper method to create a features extractor."""
+ return self.features_extractor_class(self.observation_space, **self.features_extractor_kwargs)
+
+ def extract_features(self, obs: th.Tensor) -> th.Tensor:
+ """
+ Preprocess the observation if needed and extract features.
+
+ :param obs:
+ :return:
+ """
+ assert self.features_extractor is not None, "No features extractor was set"
+ preprocessed_obs = preprocess_obs(obs, self.observation_space, normalize_images=self.normalize_images)
+ return self.features_extractor(preprocessed_obs)
+
+ def _get_constructor_parameters(self) -> Dict[str, Any]:
+ """
+ Get data that need to be saved in order to re-create the model when loading it from disk.
+
+ :return: The dictionary to pass to the as kwargs constructor when reconstruction this model.
+ """
+ return dict(
+ observation_space=self.observation_space,
+ action_space=self.action_space,
+ # Passed to the constructor by child class
+ # squash_output=self.squash_output,
+ # features_extractor=self.features_extractor
+ normalize_images=self.normalize_images,
+ )
+
+ @property
+ def device(self) -> th.device:
+ """Infer which device this policy lives on by inspecting its parameters.
+ If it has no parameters, the 'cpu' device is used as a fallback.
+
+ :return:"""
+ for param in self.parameters():
+ return param.device
+ return get_device("cpu")
+
+ def save(self, path: str) -> None:
+ """
+ Save model to a given location.
+
+ :param path:
+ """
+ th.save({"state_dict": self.state_dict(), "data": self._get_constructor_parameters()}, path)
+
+ @classmethod
+ def load(cls, path: str, device: Union[th.device, str] = "auto") -> "BaseModel":
+ """
+ Load model from path.
+
+ :param path:
+ :param device: Device on which the policy should be loaded.
+ :return:
+ """
+ device = get_device(device)
+ saved_variables = th.load(path, map_location=device)
+
+ # Allow to load policy saved with older version of SB3
+ if "sde_net_arch" in saved_variables["data"]:
+ warnings.warn(
+ "sde_net_arch is deprecated, please downgrade to SB3 v1.2.0 if you need such parameter.",
+ DeprecationWarning,
+ )
+ del saved_variables["data"]["sde_net_arch"]
+
+ # Create policy object
+ model = cls(**saved_variables["data"]) # pytype: disable=not-instantiable
+ # Load weights
+ model.load_state_dict(saved_variables["state_dict"])
+ model.to(device)
+ return model
+
+ def load_from_vector(self, vector: np.ndarray) -> None:
+ """
+ Load parameters from a 1D vector.
+
+ :param vector:
+ """
+ th.nn.utils.vector_to_parameters(th.FloatTensor(vector).to(self.device), self.parameters())
+
+ def parameters_to_vector(self) -> np.ndarray:
+ """
+ Convert the parameters to a 1D vector.
+
+ :return:
+ """
+ return th.nn.utils.parameters_to_vector(self.parameters()).detach().cpu().numpy()
+
+ def set_training_mode(self, mode: bool) -> None:
+ """
+ Put the policy in either training or evaluation mode.
+
+ This affects certain modules, such as batch normalisation and dropout.
+
+ :param mode: if true, set to training mode, else set to evaluation mode
+ """
+ self.train(mode)
+
+ def obs_to_tensor(self, observation: Union[np.ndarray, Dict[str, np.ndarray]]) -> Tuple[th.Tensor, bool]:
+ """
+ Convert an input observation to a PyTorch tensor that can be fed to a model.
+ Includes sugar-coating to handle different observations (e.g. normalizing images).
+
+ :param observation: the input observation
+ :return: The observation as PyTorch tensor
+ and whether the observation is vectorized or not
+ """
+ vectorized_env = False
+ if isinstance(observation, dict):
+ # need to copy the dict as the dict in VecFrameStack will become a torch tensor
+ observation = copy.deepcopy(observation)
+ for key, obs in observation.items():
+ if not self.observation_space.contains(key):
+ continue
+ obs_space = self.observation_space.spaces[key]
+ if is_image_space(obs_space):
+ obs_ = maybe_transpose(obs, obs_space)
+ else:
+ obs_ = np.array(obs)
+ vectorized_env = vectorized_env or is_vectorized_observation(obs_, obs_space)
+ # Add batch dimension if needed
+ observation[key] = obs_.reshape((-1,) + self.observation_space[key].shape)
+
+ elif is_image_space(self.observation_space):
+ # Handle the different cases for images
+ # as PyTorch use channel first format
+ observation = maybe_transpose(observation, self.observation_space)
+
+ else:
+ observation = np.array(observation)
+
+ if not isinstance(observation, dict):
+ # Dict obs need to be handled separately
+ vectorized_env = is_vectorized_observation(observation, self.observation_space)
+ # Add batch dimension if needed
+ observation = observation.reshape((-1,) + self.observation_space.shape)
+
+ observation = obs_as_tensor(observation, self.device)
+ return observation, vectorized_env
+
+
+class BasePolicy(BaseModel):
+ """The base policy object.
+
+ Parameters are mostly the same as `BaseModel`; additions are documented below.
+
+ :param args: positional arguments passed through to `BaseModel`.
+ :param kwargs: keyword arguments passed through to `BaseModel`.
+ :param squash_output: For continuous actions, whether the output is squashed
+ or not using a ``tanh()`` function.
+ """
+
+ def __init__(self, *args, squash_output: bool = False, **kwargs):
+ super().__init__(*args, **kwargs)
+ self._squash_output = squash_output
+
+ @staticmethod
+ def _dummy_schedule(progress_remaining: float) -> float:
+ """(float) Useful for pickling policy."""
+ del progress_remaining
+ return 0.0
+
+ @property
+ def squash_output(self) -> bool:
+ """(bool) Getter for squash_output."""
+ return self._squash_output
+
+ @staticmethod
+ def init_weights(module: nn.Module, gain: float = 1) -> None:
+ """
+ Orthogonal initialization (used in PPO and A2C)
+ """
+ if isinstance(module, (nn.Linear, nn.Conv2d)):
+ nn.init.orthogonal_(module.weight, gain=gain)
+ if module.bias is not None:
+ module.bias.data.fill_(0.0)
+
+ @abstractmethod
+ def _predict(self, observation: th.Tensor, deterministic: bool = False) -> th.Tensor:
+ """
+ Get the action according to the policy for a given observation.
+
+ By default provides a dummy implementation -- not all BasePolicy classes
+ implement this, e.g. if they are a Critic in an Actor-Critic method.
+
+ :param observation:
+ :param deterministic: Whether to use stochastic or deterministic actions
+ :return: Taken action according to the policy
+ """
+
+ def predict(
+ self,
+ observation: Union[np.ndarray, Dict[str, np.ndarray]],
+ state: Optional[Tuple[np.ndarray, ...]] = None,
+ episode_start: Optional[np.ndarray] = None,
+ deterministic: bool = False,
+ ) -> Tuple[np.ndarray, Optional[Tuple[np.ndarray, ...]]]:
+ """
+ Get the policy action from an observation (and optional hidden state).
+ Includes sugar-coating to handle different observations (e.g. normalizing images).
+
+ :param observation: the input observation
+ :param state: The last hidden states (can be None, used in recurrent policies)
+ :param episode_start: The last masks (can be None, used in recurrent policies)
+ this correspond to beginning of episodes,
+ where the hidden states of the RNN must be reset.
+ :param deterministic: Whether or not to return deterministic actions.
+ :return: the model's action and the next hidden state
+ (used in recurrent policies)
+ """
+ # TODO (GH/1): add support for RNN policies
+ # if state is None:
+ # state = self.initial_state
+ # if episode_start is None:
+ # episode_start = [False for _ in range(self.n_envs)]
+ # Switch to eval mode (this affects batch norm / dropout)
+ self.set_training_mode(False)
+
+ observation, vectorized_env = self.obs_to_tensor(observation)
+
+ with th.no_grad():
+ actions = self._predict(observation, deterministic=deterministic)
+ # Convert to numpy
+ actions = actions.cpu().numpy()
+
+ if isinstance(self.action_space, gym.spaces.Box):
+ if self.squash_output:
+ # Rescale to proper domain when using squashing
+ actions = self.unscale_action(actions)
+ else:
+ # Actions could be on arbitrary scale, so clip the actions to avoid
+ # out of bound error (e.g. if sampling from a Gaussian distribution)
+ actions = np.clip(actions, self.action_space.low, self.action_space.high)
+
+ # Remove batch dimension if needed
+ if not vectorized_env:
+ actions = actions[0]
+
+ return actions, state
+
+ def scale_action(self, action: np.ndarray) -> np.ndarray:
+ """
+ Rescale the action from [low, high] to [-1, 1]
+ (no need for symmetric action space)
+
+ :param action: Action to scale
+ :return: Scaled action
+ """
+ low, high = self.action_space.low, self.action_space.high
+ return 2.0 * ((action - low) / (high - low)) - 1.0
+
+ def unscale_action(self, scaled_action: np.ndarray) -> np.ndarray:
+ """
+ Rescale the action from [-1, 1] to [low, high]
+ (no need for symmetric action space)
+
+ :param scaled_action: Action to un-scale
+ """
+ low, high = self.action_space.low, self.action_space.high
+ return low + (0.5 * (scaled_action + 1.0) * (high - low))
+
+
+class ActorCriticPolicy(BasePolicy):
+ """
+ Policy class for actor-critic algorithms (has both policy and value prediction).
+ Used by A2C, PPO and the likes.
+
+ :param observation_space: Observation space
+ :param action_space: Action space
+ :param lr_schedule: Learning rate schedule (could be constant)
+ :param net_arch: The specification of the policy and value networks.
+ :param activation_fn: Activation function
+ :param ortho_init: Whether to use or not orthogonal initialization
+ :param use_sde: Whether to use State Dependent Exploration or not
+ :param log_std_init: Initial value for the log standard deviation
+ :param full_std: Whether to use (n_features x n_actions) parameters
+ for the std instead of only (n_features,) when using gSDE
+ :param sde_net_arch: Network architecture for extracting features
+ when using gSDE. If None, the latent features from the policy will be used.
+ Pass an empty list to use the states as features.
+ :param use_expln: Use ``expln()`` function instead of ``exp()`` to ensure
+ a positive standard deviation (cf paper). It allows to keep variance
+ above zero and prevent it from growing too fast. In practice, ``exp()`` is usually enough.
+ :param squash_output: Whether to squash the output using a tanh function,
+ this allows to ensure boundaries when using gSDE.
+ :param features_extractor_class: Features extractor to use.
+ :param features_extractor_kwargs: Keyword arguments
+ to pass to the features extractor.
+ :param normalize_images: Whether to normalize images or not,
+ dividing by 255.0 (True by default)
+ :param optimizer_class: The optimizer to use,
+ ``th.optim.Adam`` by default
+ :param optimizer_kwargs: Additional keyword arguments,
+ excluding the learning rate, to pass to the optimizer
+ """
+
+ def __init__(
+ self,
+ observation_space: gym.spaces.Space,
+ action_space: gym.spaces.Space,
+ lr_schedule: Schedule,
+ net_arch: Optional[List[Union[int, Dict[str, List[int]]]]] = None,
+ activation_fn: Type[nn.Module] = nn.Tanh,
+ ortho_init: bool = True,
+ use_sde: bool = False,
+ log_std_init: float = 0.0,
+ full_std: bool = True,
+ sde_net_arch: Optional[List[int]] = None,
+ use_expln: bool = False,
+ squash_output: bool = False,
+ features_extractor_class: Type[BaseFeaturesExtractor] = FlattenExtractor,
+ features_extractor_kwargs: Optional[Dict[str, Any]] = None,
+ normalize_images: bool = True,
+ optimizer_class: Type[th.optim.Optimizer] = th.optim.Adam,
+ optimizer_kwargs: Optional[Dict[str, Any]] = None,
+ ):
+
+ if optimizer_kwargs is None:
+ optimizer_kwargs = {}
+ # Small values to avoid NaN in Adam optimizer
+ if optimizer_class == th.optim.Adam:
+ optimizer_kwargs["eps"] = 1e-5
+
+ super().__init__(
+ observation_space,
+ action_space,
+ features_extractor_class,
+ features_extractor_kwargs,
+ optimizer_class=optimizer_class,
+ optimizer_kwargs=optimizer_kwargs,
+ squash_output=squash_output,
+ )
+
+ # Default network architecture, from stable-baselines
+ if net_arch is None:
+ if features_extractor_class == NatureCNN:
+ net_arch = []
+ else:
+ net_arch = [dict(pi=[64, 64], vf=[64, 64])]
+
+ self.net_arch = net_arch
+ self.activation_fn = activation_fn
+ self.ortho_init = ortho_init
+
+ self.features_extractor = features_extractor_class(self.observation_space, **self.features_extractor_kwargs)
+ self.features_dim = self.features_extractor.features_dim
+
+ self.normalize_images = normalize_images
+ self.log_std_init = log_std_init
+ dist_kwargs = None
+ # Keyword arguments for gSDE distribution
+ if use_sde:
+ dist_kwargs = {
+ "full_std": full_std,
+ "squash_output": squash_output,
+ "use_expln": use_expln,
+ "learn_features": False,
+ }
+
+ if sde_net_arch is not None:
+ warnings.warn("sde_net_arch is deprecated and will be removed in SB3 v2.4.0.", DeprecationWarning)
+
+ self.use_sde = use_sde
+ self.dist_kwargs = dist_kwargs
+
+ # Action distribution
+ self.action_dist = make_proba_distribution(action_space, use_sde=use_sde, dist_kwargs=dist_kwargs)
+
+ self._build(lr_schedule)
+
+ def _get_constructor_parameters(self) -> Dict[str, Any]:
+ data = super()._get_constructor_parameters()
+
+ default_none_kwargs = self.dist_kwargs or collections.defaultdict(lambda: None)
+
+ data.update(
+ dict(
+ net_arch=self.net_arch,
+ activation_fn=self.activation_fn,
+ use_sde=self.use_sde,
+ log_std_init=self.log_std_init,
+ squash_output=default_none_kwargs["squash_output"],
+ full_std=default_none_kwargs["full_std"],
+ use_expln=default_none_kwargs["use_expln"],
+ lr_schedule=self._dummy_schedule, # dummy lr schedule, not needed for loading policy alone
+ ortho_init=self.ortho_init,
+ optimizer_class=self.optimizer_class,
+ optimizer_kwargs=self.optimizer_kwargs,
+ features_extractor_class=self.features_extractor_class,
+ features_extractor_kwargs=self.features_extractor_kwargs,
+ )
+ )
+ return data
+
+ def reset_noise(self, n_envs: int = 1) -> None:
+ """
+ Sample new weights for the exploration matrix.
+
+ :param n_envs:
+ """
+ assert isinstance(self.action_dist, StateDependentNoiseDistribution), "reset_noise() is only available when using gSDE"
+ self.action_dist.sample_weights(self.log_std, batch_size=n_envs)
+
+ def _build_mlp_extractor(self) -> None:
+ """
+ Create the policy and value networks.
+ Part of the layers can be shared.
+ """
+ # Note: If net_arch is None and some features extractor is used,
+ # net_arch here is an empty list and mlp_extractor does not
+ # really contain any layers (acts like an identity module).
+ self.mlp_extractor = MlpExtractor(
+ self.features_dim,
+ net_arch=self.net_arch,
+ activation_fn=self.activation_fn,
+ device=self.device,
+ )
+
+ def _build(self, lr_schedule: Schedule) -> None:
+ """
+ Create the networks and the optimizer.
+
+ :param lr_schedule: Learning rate schedule
+ lr_schedule(1) is the initial learning rate
+ """
+ self._build_mlp_extractor()
+
+ latent_dim_pi = self.mlp_extractor.latent_dim_pi
+
+ if isinstance(self.action_dist, DiagGaussianDistribution):
+ self.action_net, self.log_std = self.action_dist.proba_distribution_net(
+ latent_dim=latent_dim_pi, log_std_init=self.log_std_init
+ )
+ elif isinstance(self.action_dist, StateDependentNoiseDistribution):
+ self.action_net, self.log_std = self.action_dist.proba_distribution_net(
+ latent_dim=latent_dim_pi, latent_sde_dim=latent_dim_pi, log_std_init=self.log_std_init
+ )
+ elif isinstance(self.action_dist, (CategoricalDistribution, MultiCategoricalDistribution, BernoulliDistribution)):
+ self.action_net = self.action_dist.proba_distribution_net(latent_dim=latent_dim_pi)
+ else:
+ raise NotImplementedError(f"Unsupported distribution '{self.action_dist}'.")
+
+ self.value_net = nn.Linear(self.mlp_extractor.latent_dim_vf, 1)
+ # Init weights: use orthogonal initialization
+ # with small initial weight for the output
+ if self.ortho_init: # TODO: Helin
+ # TODO: check for features_extractor
+ # Values from stable-baselines.
+ # features_extractor/mlp values are
+ # originally from openai/baselines (default gains/init_scales).
+ module_gains = {
+ self.features_extractor: np.sqrt(2),
+ self.mlp_extractor: np.sqrt(2),
+ self.action_net: 0.01,
+ self.value_net: 1,
+ }
+ for module, gain in module_gains.items():
+ module.apply(partial(self.init_weights, gain=gain))
+
+ # Setup optimizer with initial learning rate
+ self.optimizer = self.optimizer_class(filter(lambda p: p.requires_grad, self.parameters()), lr=lr_schedule(1), **self.optimizer_kwargs)
+
+ def forward(self, obs: th.Tensor, deterministic: bool = False) -> Tuple[th.Tensor, th.Tensor, th.Tensor]:
+ """
+ Forward pass in all the networks (actor and critic)
+
+ :param obs: Observation
+ :param deterministic: Whether to sample or use deterministic actions
+ :return: action, value and log probability of the action
+ """
+ # Preprocess the observation if needed
+ features = self.extract_features(obs)
+ latent_pi, latent_vf = self.mlp_extractor(features)
+ # Evaluate the values for the given observations
+ values = self.value_net(latent_vf)
+ distribution = self._get_action_dist_from_latent(latent_pi)
+ actions = distribution.get_actions(deterministic=deterministic)
+ log_prob = distribution.log_prob(actions)
+ return actions, values, log_prob
+
+ def _get_action_dist_from_latent(self, latent_pi: th.Tensor) -> Distribution:
+ """
+ Retrieve action distribution given the latent codes.
+
+ :param latent_pi: Latent code for the actor
+ :return: Action distribution
+ """
+ mean_actions = self.action_net(latent_pi)
+
+ if isinstance(self.action_dist, DiagGaussianDistribution):
+ return self.action_dist.proba_distribution(mean_actions, self.log_std)
+ elif isinstance(self.action_dist, CategoricalDistribution):
+ # Here mean_actions are the logits before the softmax
+ return self.action_dist.proba_distribution(action_logits=mean_actions)
+ elif isinstance(self.action_dist, MultiCategoricalDistribution):
+ # Here mean_actions are the flattened logits
+ return self.action_dist.proba_distribution(action_logits=mean_actions)
+ elif isinstance(self.action_dist, BernoulliDistribution):
+ # Here mean_actions are the logits (before rounding to get the binary actions)
+ return self.action_dist.proba_distribution(action_logits=mean_actions)
+ elif isinstance(self.action_dist, StateDependentNoiseDistribution):
+ return self.action_dist.proba_distribution(mean_actions, self.log_std, latent_pi)
+ else:
+ raise ValueError("Invalid action distribution")
+
+ def _predict(self, observation: th.Tensor, deterministic: bool = False) -> th.Tensor:
+ """
+ Get the action according to the policy for a given observation.
+
+ :param observation:
+ :param deterministic: Whether to use stochastic or deterministic actions
+ :return: Taken action according to the policy
+ """
+ return self.get_distribution(observation).get_actions(deterministic=deterministic)
+
+ def evaluate_actions(self, obs: th.Tensor, actions: th.Tensor) -> Tuple[th.Tensor, th.Tensor, th.Tensor]:
+ """
+ Evaluate actions according to the current policy,
+ given the observations.
+
+ :param obs:
+ :param actions:
+ :return: estimated value, log likelihood of taking those actions
+ and entropy of the action distribution.
+ """
+ # Preprocess the observation if needed
+ features = self.extract_features(obs)
+ latent_pi, latent_vf = self.mlp_extractor(features)
+ distribution = self._get_action_dist_from_latent(latent_pi)
+ log_prob = distribution.log_prob(actions)
+ values = self.value_net(latent_vf)
+ return values, log_prob, distribution.entropy()
+
+ def get_distribution(self, obs: th.Tensor) -> Distribution:
+ """
+ Get the current policy distribution given the observations.
+
+ :param obs:
+ :return: the action distribution.
+ """
+ features = self.extract_features(obs)
+ latent_pi = self.mlp_extractor.forward_actor(features)
+ return self._get_action_dist_from_latent(latent_pi)
+
+ def predict_values(self, obs: th.Tensor) -> th.Tensor:
+ """
+ Get the estimated values according to the current policy given the observations.
+
+ :param obs:
+ :return: the estimated values.
+ """
+ features = self.extract_features(obs)
+ latent_vf = self.mlp_extractor.forward_critic(features)
+ return self.value_net(latent_vf)
+
+
+class ActorCriticCnnPolicy(ActorCriticPolicy):
+ """
+ CNN policy class for actor-critic algorithms (has both policy and value prediction).
+ Used by A2C, PPO and the likes.
+
+ :param observation_space: Observation space
+ :param action_space: Action space
+ :param lr_schedule: Learning rate schedule (could be constant)
+ :param net_arch: The specification of the policy and value networks.
+ :param activation_fn: Activation function
+ :param ortho_init: Whether to use or not orthogonal initialization
+ :param use_sde: Whether to use State Dependent Exploration or not
+ :param log_std_init: Initial value for the log standard deviation
+ :param full_std: Whether to use (n_features x n_actions) parameters
+ for the std instead of only (n_features,) when using gSDE
+ :param sde_net_arch: Network architecture for extracting features
+ when using gSDE. If None, the latent features from the policy will be used.
+ Pass an empty list to use the states as features.
+ :param use_expln: Use ``expln()`` function instead of ``exp()`` to ensure
+ a positive standard deviation (cf paper). It allows to keep variance
+ above zero and prevent it from growing too fast. In practice, ``exp()`` is usually enough.
+ :param squash_output: Whether to squash the output using a tanh function,
+ this allows to ensure boundaries when using gSDE.
+ :param features_extractor_class: Features extractor to use.
+ :param features_extractor_kwargs: Keyword arguments
+ to pass to the features extractor.
+ :param normalize_images: Whether to normalize images or not,
+ dividing by 255.0 (True by default)
+ :param optimizer_class: The optimizer to use,
+ ``th.optim.Adam`` by default
+ :param optimizer_kwargs: Additional keyword arguments,
+ excluding the learning rate, to pass to the optimizer
+ """
+
+ def __init__(
+ self,
+ observation_space: gym.spaces.Space,
+ action_space: gym.spaces.Space,
+ lr_schedule: Schedule,
+ net_arch: Optional[List[Union[int, Dict[str, List[int]]]]] = None,
+ activation_fn: Type[nn.Module] = nn.Tanh,
+ ortho_init: bool = True,
+ use_sde: bool = False,
+ log_std_init: float = 0.0,
+ full_std: bool = True,
+ sde_net_arch: Optional[List[int]] = None,
+ use_expln: bool = False,
+ squash_output: bool = False,
+ features_extractor_class: Type[BaseFeaturesExtractor] = NatureCNN,
+ features_extractor_kwargs: Optional[Dict[str, Any]] = None,
+ normalize_images: bool = True,
+ optimizer_class: Type[th.optim.Optimizer] = th.optim.Adam,
+ optimizer_kwargs: Optional[Dict[str, Any]] = None,
+ ):
+ super().__init__(
+ observation_space,
+ action_space,
+ lr_schedule,
+ net_arch,
+ activation_fn,
+ ortho_init,
+ use_sde,
+ log_std_init,
+ full_std,
+ sde_net_arch,
+ use_expln,
+ squash_output,
+ features_extractor_class,
+ features_extractor_kwargs,
+ normalize_images,
+ optimizer_class,
+ optimizer_kwargs,
+ )
+
+
+class MultiInputActorCriticPolicy(ActorCriticPolicy):
+ """
+ MultiInputActorClass policy class for actor-critic algorithms (has both policy and value prediction).
+ Used by A2C, PPO and the likes.
+
+ :param observation_space: Observation space (Tuple)
+ :param action_space: Action space
+ :param lr_schedule: Learning rate schedule (could be constant)
+ :param net_arch: The specification of the policy and value networks.
+ :param activation_fn: Activation function
+ :param ortho_init: Whether to use or not orthogonal initialization
+ :param use_sde: Whether to use State Dependent Exploration or not
+ :param log_std_init: Initial value for the log standard deviation
+ :param full_std: Whether to use (n_features x n_actions) parameters
+ for the std instead of only (n_features,) when using gSDE
+ :param sde_net_arch: Network architecture for extracting features
+ when using gSDE. If None, the latent features from the policy will be used.
+ Pass an empty list to use the states as features.
+ :param use_expln: Use ``expln()`` function instead of ``exp()`` to ensure
+ a positive standard deviation (cf paper). It allows to keep variance
+ above zero and prevent it from growing too fast. In practice, ``exp()`` is usually enough.
+ :param squash_output: Whether to squash the output using a tanh function,
+ this allows to ensure boundaries when using gSDE.
+ :param features_extractor_class: Uses the CombinedExtractor
+ :param features_extractor_kwargs: Keyword arguments
+ to pass to the feature extractor.
+ :param normalize_images: Whether to normalize images or not,
+ dividing by 255.0 (True by default)
+ :param optimizer_class: The optimizer to use,
+ ``th.optim.Adam`` by default
+ :param optimizer_kwargs: Additional keyword arguments,
+ excluding the learning rate, to pass to the optimizer
+ """
+
+ def __init__(
+ self,
+ observation_space: gym.spaces.Dict,
+ action_space: gym.spaces.Space,
+ lr_schedule: Schedule,
+ net_arch: Optional[List[Union[int, Dict[str, List[int]]]]] = None,
+ activation_fn: Type[nn.Module] = nn.Tanh,
+ ortho_init: bool = True,
+ use_sde: bool = False,
+ log_std_init: float = 0.0,
+ full_std: bool = True,
+ sde_net_arch: Optional[List[int]] = None,
+ use_expln: bool = False,
+ squash_output: bool = False,
+ features_extractor_class: Type[BaseFeaturesExtractor] = CombinedExtractor,
+ features_extractor_kwargs: Optional[Dict[str, Any]] = None,
+ normalize_images: bool = True,
+ optimizer_class: Type[th.optim.Optimizer] = th.optim.Adam,
+ optimizer_kwargs: Optional[Dict[str, Any]] = None,
+ ):
+ super().__init__(
+ observation_space,
+ action_space,
+ lr_schedule,
+ net_arch,
+ activation_fn,
+ ortho_init,
+ use_sde,
+ log_std_init,
+ full_std,
+ sde_net_arch,
+ use_expln,
+ squash_output,
+ features_extractor_class,
+ features_extractor_kwargs,
+ normalize_images,
+ optimizer_class,
+ optimizer_kwargs,
+ )
+
+
+class ContinuousCritic(BaseModel):
+ """
+ Critic network(s) for DDPG/SAC/TD3.
+ It represents the action-state value function (Q-value function).
+ Compared to A2C/PPO critics, this one represents the Q-value
+ and takes the continuous action as input. It is concatenated with the state
+ and then fed to the network which outputs a single value: Q(s, a).
+ For more recent algorithms like SAC/TD3, multiple networks
+ are created to give different estimates.
+
+ By default, it creates two critic networks used to reduce overestimation
+ thanks to clipped Q-learning (cf TD3 paper).
+
+ :param observation_space: Obervation space
+ :param action_space: Action space
+ :param net_arch: Network architecture
+ :param features_extractor: Network to extract features
+ (a CNN when using images, a nn.Flatten() layer otherwise)
+ :param features_dim: Number of features
+ :param activation_fn: Activation function
+ :param normalize_images: Whether to normalize images or not,
+ dividing by 255.0 (True by default)
+ :param n_critics: Number of critic networks to create.
+ :param share_features_extractor: Whether the features extractor is shared or not
+ between the actor and the critic (this saves computation time)
+ """
+
+ def __init__(
+ self,
+ observation_space: gym.spaces.Space,
+ action_space: gym.spaces.Space,
+ net_arch: List[int],
+ features_extractor: nn.Module,
+ features_dim: int,
+ activation_fn: Type[nn.Module] = nn.ReLU,
+ normalize_images: bool = True,
+ n_critics: int = 2,
+ share_features_extractor: bool = True,
+ ):
+ super().__init__(
+ observation_space,
+ action_space,
+ features_extractor=features_extractor,
+ normalize_images=normalize_images,
+ )
+
+ action_dim = get_action_dim(self.action_space)
+
+ self.share_features_extractor = share_features_extractor
+ self.n_critics = n_critics
+ self.q_networks = []
+ for idx in range(n_critics):
+ q_net = create_mlp(features_dim + action_dim, 1, net_arch, activation_fn)
+ q_net = nn.Sequential(*q_net)
+ self.add_module(f"qf{idx}", q_net)
+ self.q_networks.append(q_net)
+
+ def forward(self, obs: th.Tensor, actions: th.Tensor) -> Tuple[th.Tensor, ...]:
+ # Learn the features extractor using the policy loss only
+ # when the features_extractor is shared with the actor
+ with th.set_grad_enabled(not self.share_features_extractor):
+ features = self.extract_features(obs)
+ qvalue_input = th.cat([features, actions], dim=1)
+ return tuple(q_net(qvalue_input) for q_net in self.q_networks)
+
+ def q1_forward(self, obs: th.Tensor, actions: th.Tensor) -> th.Tensor:
+ """
+ Only predict the Q-value using the first network.
+ This allows to reduce computation when all the estimates are not needed
+ (e.g. when updating the policy in TD3).
+ """
+ with th.no_grad():
+ features = self.extract_features(obs)
+ return self.q_networks[0](th.cat([features, actions], dim=1))
diff --git a/dexart-release/stable_baselines3/common/preprocessing.py b/dexart-release/stable_baselines3/common/preprocessing.py
new file mode 100644
index 0000000000000000000000000000000000000000..a2a2f9466fc26ce3b288bc2fac275aeb3d6b336c
--- /dev/null
+++ b/dexart-release/stable_baselines3/common/preprocessing.py
@@ -0,0 +1,217 @@
+import warnings
+from typing import Dict, Tuple, Union
+
+import numpy as np
+import torch as th
+from gym import spaces
+from torch.nn import functional as F
+
+
+def is_image_space_channels_first(observation_space: spaces.Box) -> bool:
+ """
+ Check if an image observation space (see ``is_image_space``)
+ is channels-first (CxHxW, True) or channels-last (HxWxC, False).
+
+ Use a heuristic that channel dimension is the smallest of the three.
+ If second dimension is smallest, raise an exception (no support).
+
+ :param observation_space:
+ :return: True if observation space is channels-first image, False if channels-last.
+ """
+ smallest_dimension = np.argmin(observation_space.shape).item()
+ if smallest_dimension == 1:
+ warnings.warn("Treating image space as channels-last, while second dimension was smallest of the three.")
+ return smallest_dimension == 0
+
+
+def is_image_space(
+ observation_space: spaces.Space,
+ check_channels: bool = False,
+) -> bool:
+ """
+ Check if a observation space has the shape, limits and dtype
+ of a valid image.
+ The check is conservative, so that it returns False if there is a doubt.
+
+ Valid images: RGB, RGBD, GrayScale with values in [0, 255]
+
+ :param observation_space:
+ :param check_channels: Whether to do or not the check for the number of channels.
+ e.g., with frame-stacking, the observation space may have more channels than expected.
+ :return:
+ """
+ if isinstance(observation_space, spaces.Box) and len(observation_space.shape) == 3:
+ # Check the type
+ if observation_space.dtype != np.uint8:
+ return False
+
+ # Check the value range
+ if np.any(observation_space.low != 0) or np.any(observation_space.high != 255):
+ return False
+
+ # Skip channels check
+ if not check_channels:
+ return True
+ # Check the number of channels
+ if is_image_space_channels_first(observation_space):
+ n_channels = observation_space.shape[0]
+ else:
+ n_channels = observation_space.shape[-1]
+ # RGB, RGBD, GrayScale
+ return n_channels in [1, 3, 4]
+ return False
+
+
+def maybe_transpose(observation: np.ndarray, observation_space: spaces.Space) -> np.ndarray:
+ """
+ Handle the different cases for images as PyTorch use channel first format.
+
+ :param observation:
+ :param observation_space:
+ :return: channel first observation if observation is an image
+ """
+ # Avoid circular import
+ from stable_baselines3.common.vec_env import VecTransposeImage
+
+ if is_image_space(observation_space):
+ if not (observation.shape == observation_space.shape or observation.shape[1:] == observation_space.shape):
+ # Try to re-order the channels
+ transpose_obs = VecTransposeImage.transpose_image(observation)
+ if transpose_obs.shape == observation_space.shape or transpose_obs.shape[1:] == observation_space.shape:
+ observation = transpose_obs
+ return observation
+
+
+def preprocess_obs(
+ obs: th.Tensor,
+ observation_space: spaces.Space,
+ normalize_images: bool = True
+) -> Union[th.Tensor, Dict[str, th.Tensor]]:
+ """
+ Preprocess observation to be to a neural network.
+ For images, it normalizes the values by dividing them by 255 (to have values in [0, 1])
+ For discrete observations, it create a one hot vector.
+
+ :param obs: Observation
+ :param observation_space:
+ :param normalize_images: Whether to normalize images or not
+ (True by default)
+ :return:
+ """
+ if isinstance(observation_space, spaces.Box):
+ if is_image_space(observation_space) and normalize_images:
+ return obs.float() / 255.0
+ return obs.float()
+
+ elif isinstance(observation_space, spaces.Discrete):
+ # One hot encoding and convert to float to avoid errors
+ return F.one_hot(obs.long(), num_classes=observation_space.n).float()
+
+ elif isinstance(observation_space, spaces.MultiDiscrete):
+ # Tensor concatenation of one hot encodings of each Categorical sub-space
+ return th.cat(
+ [
+ F.one_hot(obs_.long(), num_classes=int(observation_space.nvec[idx])).float()
+ for idx, obs_ in enumerate(th.split(obs.long(), 1, dim=1))
+ ],
+ dim=-1,
+ ).view(obs.shape[0], sum(observation_space.nvec))
+
+ elif isinstance(observation_space, spaces.MultiBinary):
+ return obs.float()
+
+ elif isinstance(observation_space, spaces.Dict):
+ # Do not modify by reference the original observation
+ preprocessed_obs = {}
+ for key, _obs in obs.items():
+ if observation_space.spaces.__contains__(key):
+ preprocessed_obs[key] = preprocess_obs(_obs, observation_space[key], normalize_images=normalize_images)
+ return preprocessed_obs
+
+ else:
+ raise NotImplementedError(f"Preprocessing not implemented for {observation_space}")
+
+
+def get_obs_shape(
+ observation_space: spaces.Space,
+) -> Union[Tuple[int, ...], Dict[str, Tuple[int, ...]]]:
+ """
+ Get the shape of the observation (useful for the buffers).
+
+ :param observation_space:
+ :return:
+ """
+ if isinstance(observation_space, spaces.Box):
+ return observation_space.shape
+ elif isinstance(observation_space, spaces.Discrete):
+ # Observation is an int
+ return (1,)
+ elif isinstance(observation_space, spaces.MultiDiscrete):
+ # Number of discrete features
+ return (int(len(observation_space.nvec)),)
+ elif isinstance(observation_space, spaces.MultiBinary):
+ # Number of binary features
+ return (int(observation_space.n),)
+ elif isinstance(observation_space, spaces.Dict):
+ return {key: get_obs_shape(subspace) for (key, subspace) in observation_space.spaces.items()}
+
+ else:
+ raise NotImplementedError(f"{observation_space} observation space is not supported")
+
+
+def get_flattened_obs_dim(observation_space: spaces.Space) -> int:
+ """
+ Get the dimension of the observation space when flattened.
+ It does not apply to image observation space.
+
+ Used by the ``FlattenExtractor`` to compute the input shape.
+
+ :param observation_space:
+ :return:
+ """
+ # See issue https://github.com/openai/gym/issues/1915
+ # it may be a problem for Dict/Tuple spaces too...
+ if isinstance(observation_space, spaces.MultiDiscrete):
+ return sum(observation_space.nvec)
+ else:
+ # Use Gym internal method
+ return spaces.utils.flatdim(observation_space)
+
+
+def get_action_dim(action_space: spaces.Space) -> int:
+ """
+ Get the dimension of the action space.
+
+ :param action_space:
+ :return:
+ """
+ if isinstance(action_space, spaces.Box):
+ return int(np.prod(action_space.shape))
+ elif isinstance(action_space, spaces.Discrete):
+ # Action is an int
+ return 1
+ elif isinstance(action_space, spaces.MultiDiscrete):
+ # Number of discrete actions
+ return int(len(action_space.nvec))
+ elif isinstance(action_space, spaces.MultiBinary):
+ # Number of binary actions
+ return int(action_space.n)
+ else:
+ raise NotImplementedError(f"{action_space} action space is not supported")
+
+
+def check_for_nested_spaces(obs_space: spaces.Space):
+ """
+ Make sure the observation space does not have nested spaces (Dicts/Tuples inside Dicts/Tuples).
+ If so, raise an Exception informing that there is no support for this.
+
+ :param obs_space: an observation space
+ :return:
+ """
+ if isinstance(obs_space, (spaces.Dict, spaces.Tuple)):
+ sub_spaces = obs_space.spaces.values() if isinstance(obs_space, spaces.Dict) else obs_space.spaces
+ for sub_space in sub_spaces:
+ if isinstance(sub_space, (spaces.Dict, spaces.Tuple)):
+ raise NotImplementedError(
+ "Nested observation spaces are not supported (Tuple/Dict space inside Tuple/Dict space)."
+ )
diff --git a/dexart-release/stable_baselines3/common/running_mean_std.py b/dexart-release/stable_baselines3/common/running_mean_std.py
new file mode 100644
index 0000000000000000000000000000000000000000..b48f9223c9a1a6814ab5db0fe416e30522374797
--- /dev/null
+++ b/dexart-release/stable_baselines3/common/running_mean_std.py
@@ -0,0 +1,57 @@
+from typing import Tuple, Union
+
+import numpy as np
+
+
+class RunningMeanStd:
+ def __init__(self, epsilon: float = 1e-4, shape: Tuple[int, ...] = ()):
+ """
+ Calulates the running mean and std of a data stream
+ https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Parallel_algorithm
+
+ :param epsilon: helps with arithmetic issues
+ :param shape: the shape of the data stream's output
+ """
+ self.mean = np.zeros(shape, np.float64)
+ self.var = np.ones(shape, np.float64)
+ self.count = epsilon
+
+ def copy(self) -> "RunningMeanStd":
+ """
+ :return: Return a copy of the current object.
+ """
+ new_object = RunningMeanStd(shape=self.mean.shape)
+ new_object.mean = self.mean.copy()
+ new_object.var = self.var.copy()
+ new_object.count = float(self.count)
+ return new_object
+
+ def combine(self, other: "RunningMeanStd") -> None:
+ """
+ Combine stats from another ``RunningMeanStd`` object.
+
+ :param other: The other object to combine with.
+ """
+ self.update_from_moments(other.mean, other.var, other.count)
+
+ def update(self, arr: np.ndarray) -> None:
+ batch_mean = np.mean(arr, axis=0)
+ batch_var = np.var(arr, axis=0)
+ batch_count = arr.shape[0]
+ self.update_from_moments(batch_mean, batch_var, batch_count)
+
+ def update_from_moments(self, batch_mean: np.ndarray, batch_var: np.ndarray, batch_count: Union[int, float]) -> None:
+ delta = batch_mean - self.mean
+ tot_count = self.count + batch_count
+
+ new_mean = self.mean + delta * batch_count / tot_count
+ m_a = self.var * self.count
+ m_b = batch_var * batch_count
+ m_2 = m_a + m_b + np.square(delta) * self.count * batch_count / (self.count + batch_count)
+ new_var = m_2 / (self.count + batch_count)
+
+ new_count = batch_count + self.count
+
+ self.mean = new_mean
+ self.var = new_var
+ self.count = new_count
diff --git a/dexart-release/stable_baselines3/common/save_util.py b/dexart-release/stable_baselines3/common/save_util.py
new file mode 100644
index 0000000000000000000000000000000000000000..a3e3acd0b84d9427ab1f4cf5ffe0b58ea3885f02
--- /dev/null
+++ b/dexart-release/stable_baselines3/common/save_util.py
@@ -0,0 +1,447 @@
+"""
+Save util taken from stable_baselines
+used to serialize data (class parameters) of model classes
+"""
+import base64
+import functools
+import io
+import json
+import os
+import pathlib
+import pickle
+import warnings
+import zipfile
+from typing import Any, Dict, Optional, Tuple, Union
+
+import cloudpickle
+import torch as th
+
+import stable_baselines3 as sb3
+from stable_baselines3.common.type_aliases import TensorDict
+from stable_baselines3.common.utils import get_device, get_system_info
+
+
+def recursive_getattr(obj: Any, attr: str, *args) -> Any:
+ """
+ Recursive version of getattr
+ taken from https://stackoverflow.com/questions/31174295
+
+ Ex:
+ > MyObject.sub_object = SubObject(name='test')
+ > recursive_getattr(MyObject, 'sub_object.name') # return test
+ :param obj:
+ :param attr: Attribute to retrieve
+ :return: The attribute
+ """
+
+ def _getattr(obj: Any, attr: str) -> Any:
+ return getattr(obj, attr, *args)
+
+ return functools.reduce(_getattr, [obj] + attr.split("."))
+
+
+def recursive_setattr(obj: Any, attr: str, val: Any) -> None:
+ """
+ Recursive version of setattr
+ taken from https://stackoverflow.com/questions/31174295
+
+ Ex:
+ > MyObject.sub_object = SubObject(name='test')
+ > recursive_setattr(MyObject, 'sub_object.name', 'hello')
+ :param obj:
+ :param attr: Attribute to set
+ :param val: New value of the attribute
+ """
+ pre, _, post = attr.rpartition(".")
+ return setattr(recursive_getattr(obj, pre) if pre else obj, post, val)
+
+
+def is_json_serializable(item: Any) -> bool:
+ """
+ Test if an object is serializable into JSON
+
+ :param item: The object to be tested for JSON serialization.
+ :return: True if object is JSON serializable, false otherwise.
+ """
+ # Try with try-except struct.
+ json_serializable = True
+ try:
+ _ = json.dumps(item)
+ except TypeError:
+ json_serializable = False
+ return json_serializable
+
+
+def data_to_json(data: Dict[str, Any]) -> str:
+ """
+ Turn data (class parameters) into a JSON string for storing
+
+ :param data: Dictionary of class parameters to be
+ stored. Items that are not JSON serializable will be
+ pickled with Cloudpickle and stored as bytearray in
+ the JSON file
+ :return: JSON string of the data serialized.
+ """
+ # First, check what elements can not be JSONfied,
+ # and turn them into byte-strings
+ serializable_data = {}
+ for data_key, data_item in data.items():
+ # See if object is JSON serializable
+ if is_json_serializable(data_item):
+ # All good, store as it is
+ serializable_data[data_key] = data_item
+ else:
+ # Not serializable, cloudpickle it into
+ # bytes and convert to base64 string for storing.
+ # Also store type of the class for consumption
+ # from other languages/humans, so we have an
+ # idea what was being stored.
+ base64_encoded = base64.b64encode(cloudpickle.dumps(data_item)).decode()
+
+ # Use ":" to make sure we do
+ # not override these keys
+ # when we include variables of the object later
+ cloudpickle_serialization = {
+ ":type:": str(type(data_item)),
+ ":serialized:": base64_encoded,
+ }
+
+ # Add first-level JSON-serializable items of the
+ # object for further details (but not deeper than this to
+ # avoid deep nesting).
+ # First we check that object has attributes (not all do,
+ # e.g. numpy scalars)
+ if hasattr(data_item, "__dict__") or isinstance(data_item, dict):
+ # Take elements from __dict__ for custom classes
+ item_generator = data_item.items if isinstance(data_item, dict) else data_item.__dict__.items
+ for variable_name, variable_item in item_generator():
+ # Check if serializable. If not, just include the
+ # string-representation of the object.
+ if is_json_serializable(variable_item):
+ cloudpickle_serialization[variable_name] = variable_item
+ else:
+ cloudpickle_serialization[variable_name] = str(variable_item)
+
+ serializable_data[data_key] = cloudpickle_serialization
+ json_string = json.dumps(serializable_data, indent=4)
+ return json_string
+
+
+def json_to_data(json_string: str, custom_objects: Optional[Dict[str, Any]] = None) -> Dict[str, Any]:
+ """
+ Turn JSON serialization of class-parameters back into dictionary.
+
+ :param json_string: JSON serialization of the class-parameters
+ that should be loaded.
+ :param custom_objects: Dictionary of objects to replace
+ upon loading. If a variable is present in this dictionary as a
+ key, it will not be deserialized and the corresponding item
+ will be used instead. Similar to custom_objects in
+ ``keras.models.load_model``. Useful when you have an object in
+ file that can not be deserialized.
+ :return: Loaded class parameters.
+ """
+ if custom_objects is not None and not isinstance(custom_objects, dict):
+ raise ValueError("custom_objects argument must be a dict or None")
+
+ json_dict = json.loads(json_string)
+ # This will be filled with deserialized data
+ return_data = {}
+ for data_key, data_item in json_dict.items():
+ if custom_objects is not None and data_key in custom_objects.keys():
+ # If item is provided in custom_objects, replace
+ # the one from JSON with the one in custom_objects
+ return_data[data_key] = custom_objects[data_key]
+ elif isinstance(data_item, dict) and ":serialized:" in data_item.keys():
+ # If item is dictionary with ":serialized:"
+ # key, this means it is serialized with cloudpickle.
+ serialization = data_item[":serialized:"]
+ # Try-except deserialization in case we run into
+ # errors. If so, we can tell bit more information to
+ # user.
+ try:
+ base64_object = base64.b64decode(serialization.encode())
+ deserialized_object = cloudpickle.loads(base64_object)
+ except (RuntimeError, TypeError):
+ warnings.warn(
+ f"Could not deserialize object {data_key}. "
+ + "Consider using `custom_objects` argument to replace "
+ + "this object."
+ )
+ return_data[data_key] = deserialized_object
+ else:
+ # Read as it is
+ return_data[data_key] = data_item
+ return return_data
+
+
+@functools.singledispatch
+def open_path(path: Union[str, pathlib.Path, io.BufferedIOBase], mode: str, verbose: int = 0, suffix: Optional[str] = None):
+ """
+ Opens a path for reading or writing with a preferred suffix and raises debug information.
+ If the provided path is a derivative of io.BufferedIOBase it ensures that the file
+ matches the provided mode, i.e. If the mode is read ("r", "read") it checks that the path is readable.
+ If the mode is write ("w", "write") it checks that the file is writable.
+
+ If the provided path is a string or a pathlib.Path, it ensures that it exists. If the mode is "read"
+ it checks that it exists, if it doesn't exist it attempts to read path.suffix if a suffix is provided.
+ If the mode is "write" and the path does not exist, it creates all the parent folders. If the path
+ points to a folder, it changes the path to path_2. If the path already exists and verbose == 2,
+ it raises a warning.
+
+ :param path: the path to open.
+ if save_path is a str or pathlib.Path and mode is "w", single dispatch ensures that the
+ path actually exists. If path is a io.BufferedIOBase the path exists.
+ :param mode: how to open the file. "w"|"write" for writing, "r"|"read" for reading.
+ :param verbose: Verbosity level, 0 means only warnings, 2 means debug information.
+ :param suffix: The preferred suffix. If mode is "w" then the opened file has the suffix.
+ If mode is "r" then we attempt to open the path. If an error is raised and the suffix
+ is not None, we attempt to open the path with the suffix.
+ :return:
+ """
+ if not isinstance(path, io.BufferedIOBase):
+ raise TypeError("Path parameter has invalid type.", io.BufferedIOBase)
+ if path.closed:
+ raise ValueError("File stream is closed.")
+ mode = mode.lower()
+ try:
+ mode = {"write": "w", "read": "r", "w": "w", "r": "r"}[mode]
+ except KeyError:
+ raise ValueError("Expected mode to be either 'w' or 'r'.")
+ if ("w" == mode) and not path.writable() or ("r" == mode) and not path.readable():
+ e1 = "writable" if "w" == mode else "readable"
+ raise ValueError(f"Expected a {e1} file.")
+ return path
+
+
+@open_path.register(str)
+def open_path_str(path: str, mode: str, verbose: int = 0, suffix: Optional[str] = None) -> io.BufferedIOBase:
+ """
+ Open a path given by a string. If writing to the path, the function ensures
+ that the path exists.
+
+ :param path: the path to open. If mode is "w" then it ensures that the path exists
+ by creating the necessary folders and renaming path if it points to a folder.
+ :param mode: how to open the file. "w" for writing, "r" for reading.
+ :param verbose: Verbosity level, 0 means only warnings, 2 means debug information.
+ :param suffix: The preferred suffix. If mode is "w" then the opened file has the suffix.
+ If mode is "r" then we attempt to open the path. If an error is raised and the suffix
+ is not None, we attempt to open the path with the suffix.
+ :return:
+ """
+ return open_path(pathlib.Path(path), mode, verbose, suffix)
+
+
+@open_path.register(pathlib.Path)
+def open_path_pathlib(path: pathlib.Path, mode: str, verbose: int = 0, suffix: Optional[str] = None) -> io.BufferedIOBase:
+ """
+ Open a path given by a string. If writing to the path, the function ensures
+ that the path exists.
+
+ :param path: the path to check. If mode is "w" then it
+ ensures that the path exists by creating the necessary folders and
+ renaming path if it points to a folder.
+ :param mode: how to open the file. "w" for writing, "r" for reading.
+ :param verbose: Verbosity level, 0 means only warnings, 2 means debug information.
+ :param suffix: The preferred suffix. If mode is "w" then the opened file has the suffix.
+ If mode is "r" then we attempt to open the path. If an error is raised and the suffix
+ is not None, we attempt to open the path with the suffix.
+ :return:
+ """
+ if mode not in ("w", "r"):
+ raise ValueError("Expected mode to be either 'w' or 'r'.")
+
+ if mode == "r":
+ try:
+ path = path.open("rb")
+ except FileNotFoundError as error:
+ if suffix is not None and suffix != "":
+ newpath = pathlib.Path(f"{path}.{suffix}")
+ if verbose == 2:
+ warnings.warn(f"Path '{path}' not found. Attempting {newpath}.")
+ path, suffix = newpath, None
+ else:
+ raise error
+ else:
+ try:
+ if path.suffix == "" and suffix is not None and suffix != "":
+ path = pathlib.Path(f"{path}.{suffix}")
+ if path.exists() and path.is_file() and verbose == 2:
+ warnings.warn(f"Path '{path}' exists, will overwrite it.")
+ path = path.open("wb")
+ except IsADirectoryError:
+ warnings.warn(f"Path '{path}' is a folder. Will save instead to {path}_2")
+ path = pathlib.Path(f"{path}_2")
+ except FileNotFoundError: # Occurs when the parent folder doesn't exist
+ warnings.warn(f"Path '{path.parent}' does not exist. Will create it.")
+ path.parent.mkdir(exist_ok=True, parents=True)
+
+ # if opening was successful uses the identity function
+ # if opening failed with IsADirectory|FileNotFound, calls open_path_pathlib
+ # with corrections
+ # if reading failed with FileNotFoundError, calls open_path_pathlib with suffix
+
+ return open_path(path, mode, verbose, suffix)
+
+
+def save_to_zip_file(
+ save_path: Union[str, pathlib.Path, io.BufferedIOBase],
+ data: Optional[Dict[str, Any]] = None,
+ params: Optional[Dict[str, Any]] = None,
+ pytorch_variables: Optional[Dict[str, Any]] = None,
+ verbose: int = 0,
+) -> None:
+ """
+ Save model data to a zip archive.
+
+ :param save_path: Where to store the model.
+ if save_path is a str or pathlib.Path ensures that the path actually exists.
+ :param data: Class parameters being stored (non-PyTorch variables)
+ :param params: Model parameters being stored expected to contain an entry for every
+ state_dict with its name and the state_dict.
+ :param pytorch_variables: Other PyTorch variables expected to contain name and value of the variable.
+ :param verbose: Verbosity level, 0 means only warnings, 2 means debug information
+ """
+ save_path = open_path(save_path, "w", verbose=0, suffix="zip")
+ # data/params can be None, so do not
+ # try to serialize them blindly
+ if data is not None:
+ serialized_data = data_to_json(data)
+
+ # Create a zip-archive and write our objects there.
+ with zipfile.ZipFile(save_path, mode="w") as archive:
+ # Do not try to save "None" elements
+ if data is not None:
+ archive.writestr("data", serialized_data)
+ if pytorch_variables is not None:
+ with archive.open("pytorch_variables.pth", mode="w", force_zip64=True) as pytorch_variables_file:
+ th.save(pytorch_variables, pytorch_variables_file)
+ if params is not None:
+ for file_name, dict_ in params.items():
+ with archive.open(file_name + ".pth", mode="w", force_zip64=True) as param_file:
+ th.save(dict_, param_file)
+ # Save metadata: library version when file was saved
+ archive.writestr("_stable_baselines3_version", sb3.__version__)
+ # Save system info about the current python env
+ # archive.writestr("system_info.txt", get_system_info(print_info=False)[1])
+
+
+def save_to_pkl(path: Union[str, pathlib.Path, io.BufferedIOBase], obj: Any, verbose: int = 0) -> None:
+ """
+ Save an object to path creating the necessary folders along the way.
+ If the path exists and is a directory, it will raise a warning and rename the path.
+ If a suffix is provided in the path, it will use that suffix, otherwise, it will use '.pkl'.
+
+ :param path: the path to open.
+ if save_path is a str or pathlib.Path and mode is "w", single dispatch ensures that the
+ path actually exists. If path is a io.BufferedIOBase the path exists.
+ :param obj: The object to save.
+ :param verbose: Verbosity level, 0 means only warnings, 2 means debug information.
+ """
+ with open_path(path, "w", verbose=verbose, suffix="pkl") as file_handler:
+ # Use protocol>=4 to support saving replay buffers >= 4Gb
+ # See https://docs.python.org/3/library/pickle.html
+ pickle.dump(obj, file_handler, protocol=pickle.HIGHEST_PROTOCOL)
+
+
+def load_from_pkl(path: Union[str, pathlib.Path, io.BufferedIOBase], verbose: int = 0) -> Any:
+ """
+ Load an object from the path. If a suffix is provided in the path, it will use that suffix.
+ If the path does not exist, it will attempt to load using the .pkl suffix.
+
+ :param path: the path to open.
+ if save_path is a str or pathlib.Path and mode is "w", single dispatch ensures that the
+ path actually exists. If path is a io.BufferedIOBase the path exists.
+ :param verbose: Verbosity level, 0 means only warnings, 2 means debug information.
+ """
+ with open_path(path, "r", verbose=verbose, suffix="pkl") as file_handler:
+ return pickle.load(file_handler)
+
+
+def load_from_zip_file(
+ load_path: Union[str, pathlib.Path, io.BufferedIOBase],
+ load_data: bool = True,
+ custom_objects: Optional[Dict[str, Any]] = None,
+ device: Union[th.device, str] = "auto",
+ verbose: int = 0,
+ print_system_info: bool = False,
+) -> (Tuple[Optional[Dict[str, Any]], Optional[TensorDict], Optional[TensorDict]]):
+ """
+ Load model data from a .zip archive
+
+ :param load_path: Where to load the model from
+ :param load_data: Whether we should load and return data
+ (class parameters). Mainly used by 'load_parameters' to only load model parameters (weights)
+ :param custom_objects: Dictionary of objects to replace
+ upon loading. If a variable is present in this dictionary as a
+ key, it will not be deserialized and the corresponding item
+ will be used instead. Similar to custom_objects in
+ ``keras.models.load_model``. Useful when you have an object in
+ file that can not be deserialized.
+ :param device: Device on which the code should run.
+ :param verbose: Verbosity level, 0 means only warnings, 2 means debug information.
+ :param print_system_info: Whether to print or not the system info
+ about the saved model.
+ :return: Class parameters, model state_dicts (aka "params", dict of state_dict)
+ and dict of pytorch variables
+ """
+ load_path = open_path(load_path, "r", verbose=verbose, suffix="zip")
+
+ # set device to cpu if cuda is not available
+ device = get_device(device=device)
+
+ # Open the zip archive and load data
+ try:
+ with zipfile.ZipFile(load_path) as archive:
+ namelist = archive.namelist()
+ # If data or parameters is not in the
+ # zip archive, assume they were stored
+ # as None (_save_to_file_zip allows this).
+ data = None
+ pytorch_variables = None
+ params = {}
+
+ # Debug system info first
+ if print_system_info:
+ if "system_info.txt" in namelist:
+ print("== SAVED MODEL SYSTEM INFO ==")
+ print(archive.read("system_info.txt").decode())
+ else:
+ warnings.warn(
+ "The model was saved with SB3 <= 1.2.0 and thus cannot print system information.",
+ UserWarning,
+ )
+
+ if "data" in namelist and load_data:
+ # Load class parameters that are stored
+ # with either JSON or pickle (not PyTorch variables).
+ json_data = archive.read("data").decode()
+ data = json_to_data(json_data, custom_objects=custom_objects)
+
+ # Check for all .pth files and load them using th.load.
+ # "pytorch_variables.pth" stores PyTorch variables, and any other .pth
+ # files store state_dicts of variables with custom names (e.g. policy, policy.optimizer)
+ pth_files = [file_name for file_name in namelist if os.path.splitext(file_name)[1] == ".pth"]
+ for file_path in pth_files:
+ with archive.open(file_path, mode="r") as param_file:
+ # File has to be seekable, but param_file is not, so load in BytesIO first
+ # fixed in python >= 3.7
+ file_content = io.BytesIO()
+ file_content.write(param_file.read())
+ # go to start of file
+ file_content.seek(0)
+ # Load the parameters with the right ``map_location``.
+ # Remove ".pth" ending with splitext
+ th_object = th.load(file_content, map_location=device)
+ # "tensors.pth" was renamed "pytorch_variables.pth" in v0.9.0, see PR #138
+ if file_path == "pytorch_variables.pth" or file_path == "tensors.pth":
+ # PyTorch variables (not state_dicts)
+ pytorch_variables = th_object
+ else:
+ # State dicts. Store into params dictionary
+ # with same name as in .zip file (without .pth)
+ params[os.path.splitext(file_path)[0]] = th_object
+ except zipfile.BadZipFile:
+ # load_path wasn't a zip file
+ raise ValueError(f"Error: the file {load_path} wasn't a zip-file")
+ return data, params, pytorch_variables
diff --git a/dexart-release/stable_baselines3/common/torch_layers.py b/dexart-release/stable_baselines3/common/torch_layers.py
new file mode 100644
index 0000000000000000000000000000000000000000..c558ef63e3def98cdadd1938f0a93a857be96e44
--- /dev/null
+++ b/dexart-release/stable_baselines3/common/torch_layers.py
@@ -0,0 +1,401 @@
+from itertools import zip_longest
+from typing import Dict, List, Tuple, Type, Union, Optional
+
+import gym
+import torch
+import torch as th
+from torch import nn
+
+from stable_baselines3.common.preprocessing import get_flattened_obs_dim, is_image_space
+from stable_baselines3.common.type_aliases import TensorDict
+from stable_baselines3.common.utils import get_device
+
+
+class BaseFeaturesExtractor(nn.Module):
+ """
+ Base class that represents a features extractor.
+
+ :param observation_space:
+ :param features_dim: Number of features extracted.
+ """
+
+ def __init__(self, observation_space: gym.Space, features_dim: int = 0):
+ super().__init__()
+ assert features_dim > 0
+ self._observation_space = observation_space
+ self._features_dim = features_dim
+
+ @property
+ def features_dim(self) -> int:
+ return self._features_dim
+
+ def forward(self, observations: th.Tensor) -> th.Tensor:
+ raise NotImplementedError()
+
+
+class FlattenExtractor(BaseFeaturesExtractor):
+ """
+ Feature extract that flatten the input.
+ Used as a placeholder when feature extraction is not needed.
+
+ :param observation_space:
+ """
+
+ def __init__(self, observation_space: gym.Space):
+ super().__init__(observation_space, get_flattened_obs_dim(observation_space))
+ self.flatten = nn.Flatten()
+
+ def forward(self, observations: th.Tensor) -> th.Tensor:
+ return self.flatten(observations)
+
+
+class NatureCNN(BaseFeaturesExtractor):
+ """
+ CNN from DQN nature paper:
+ Mnih, Volodymyr, et al.
+ "Human-level control through deep reinforcement learning."
+ Nature 518.7540 (2015): 529-533.
+
+ :param observation_space:
+ :param features_dim: Number of features extracted.
+ This corresponds to the number of unit for the last layer.
+ """
+
+ def __init__(self, observation_space: gym.spaces.Box, features_dim: int = 512):
+ super().__init__(observation_space, features_dim)
+ # We assume CxHxW images (channels first)
+ # Re-ordering will be done by pre-preprocessing or wrapper
+ assert is_image_space(observation_space, check_channels=False), (
+ "You should use NatureCNN "
+ f"only with images not with {observation_space}\n"
+ "(you are probably using `CnnPolicy` instead of `MlpPolicy` or `MultiInputPolicy`)\n"
+ "If you are using a custom environment,\n"
+ "please check it using our env checker:\n"
+ "https://stable-baselines3.readthedocs.io/en/master/common/env_checker.html"
+ )
+ n_input_channels = observation_space.shape[0]
+ self.cnn = nn.Sequential(
+ nn.Conv2d(n_input_channels, 32, kernel_size=8, stride=4, padding=0),
+ nn.ReLU(),
+ nn.Conv2d(32, 64, kernel_size=4, stride=2, padding=0),
+ nn.ReLU(),
+ nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=0),
+ nn.ReLU(),
+ nn.Flatten(),
+ )
+
+ # Compute shape by doing one forward pass
+ with th.no_grad():
+ n_flatten = self.cnn(th.as_tensor(observation_space.sample()[None]).float()).shape[1]
+
+ self.linear = nn.Sequential(nn.Linear(n_flatten, features_dim), nn.ReLU())
+
+ def forward(self, observations: th.Tensor) -> th.Tensor:
+ return self.linear(self.cnn(observations))
+
+
+def create_mlp(
+ input_dim: int,
+ output_dim: int,
+ net_arch: List[int],
+ activation_fn: Type[nn.Module] = nn.ReLU,
+ squash_output: bool = False,
+) -> List[nn.Module]:
+ """
+ Create a multi layer perceptron (MLP), which is
+ a collection of fully-connected layers each followed by an activation function.
+
+ :param input_dim: Dimension of the input vector
+ :param output_dim:
+ :param net_arch: Architecture of the neural net
+ It represents the number of units per layer.
+ The length of this list is the number of layers.
+ :param activation_fn: The activation function
+ to use after each layer.
+ :param squash_output: Whether to squash the output using a Tanh
+ activation function
+ :return:
+ """
+
+ if len(net_arch) > 0:
+ modules = [nn.Linear(input_dim, net_arch[0]), activation_fn()]
+ else:
+ modules = []
+
+ for idx in range(len(net_arch) - 1):
+ modules.append(nn.Linear(net_arch[idx], net_arch[idx + 1]))
+ modules.append(activation_fn())
+
+ if output_dim > 0:
+ last_layer_dim = net_arch[-1] if len(net_arch) > 0 else input_dim
+ modules.append(nn.Linear(last_layer_dim, output_dim))
+ if squash_output:
+ modules.append(nn.Tanh())
+ return modules
+
+
+class MlpExtractor(nn.Module):
+ """
+ Constructs an MLP that receives the output from a previous feature extractor (i.e. a CNN) or directly
+ the observations (if no feature extractor is applied) as an input and outputs a latent representation
+ for the policy and a value network.
+ The ``net_arch`` parameter allows to specify the amount and size of the hidden layers and how many
+ of them are shared between the policy network and the value network. It is assumed to be a list with the following
+ structure:
+
+ 1. An arbitrary length (zero allowed) number of integers each specifying the number of units in a shared layer.
+ If the number of ints is zero, there will be no shared layers.
+ 2. An optional dict, to specify the following non-shared layers for the value network and the policy network.
+ It is formatted like ``dict(vf=[], pi=[])``.
+ If it is missing any of the keys (pi or vf), no non-shared layers (empty list) is assumed.
+
+ For example to construct a network with one shared layer of size 55 followed by two non-shared layers for the value
+ network of size 255 and a single non-shared layer of size 128 for the policy network, the following layers_spec
+ would be used: ``[55, dict(vf=[255, 255], pi=[128])]``. A simple shared network topology with two layers of size 128
+ would be specified as [128, 128].
+
+ Adapted from Stable Baselines.
+
+ :param feature_dim: Dimension of the feature vector (can be the output of a CNN)
+ :param net_arch: The specification of the policy and value networks.
+ See above for details on its formatting.
+ :param activation_fn: The activation function to use for the networks.
+ :param device:
+ """
+
+ def __init__(
+ self,
+ feature_dim: int,
+ net_arch: List[Union[int, Dict[str, List[int]]]],
+ activation_fn: Type[nn.Module],
+ device: Union[th.device, str] = "auto",
+ ):
+ super().__init__()
+ device = get_device(device)
+ shared_net, policy_net, value_net = [], [], []
+ policy_only_layers = [] # Layer sizes of the network that only belongs to the policy network
+ value_only_layers = [] # Layer sizes of the network that only belongs to the value network
+ last_layer_dim_shared = feature_dim
+
+ # Iterate through the shared layers and build the shared parts of the network
+ for layer in net_arch:
+ if isinstance(layer, int): # Check that this is a shared layer
+ # TODO: give layer a meaningful name
+ shared_net.append(nn.Linear(last_layer_dim_shared, layer)) # add linear of size layer
+ shared_net.append(activation_fn())
+ last_layer_dim_shared = layer
+ else:
+ assert isinstance(layer, dict), "Error: the net_arch list can only contain ints and dicts"
+ if "pi" in layer:
+ assert isinstance(layer["pi"], list), "Error: net_arch[-1]['pi'] must contain a list of integers."
+ policy_only_layers = layer["pi"]
+
+ if "vf" in layer:
+ assert isinstance(layer["vf"], list), "Error: net_arch[-1]['vf'] must contain a list of integers."
+ value_only_layers = layer["vf"]
+ break # From here on the network splits up in policy and value network
+
+ last_layer_dim_pi = last_layer_dim_shared
+ last_layer_dim_vf = last_layer_dim_shared
+
+ # Build the non-shared part of the network
+ for pi_layer_size, vf_layer_size in zip_longest(policy_only_layers, value_only_layers):
+ if pi_layer_size is not None:
+ assert isinstance(pi_layer_size, int), "Error: net_arch[-1]['pi'] must only contain integers."
+ policy_net.append(nn.Linear(last_layer_dim_pi, pi_layer_size))
+ policy_net.append(activation_fn())
+ last_layer_dim_pi = pi_layer_size
+
+ if vf_layer_size is not None:
+ assert isinstance(vf_layer_size, int), "Error: net_arch[-1]['vf'] must only contain integers."
+ value_net.append(nn.Linear(last_layer_dim_vf, vf_layer_size))
+ value_net.append(activation_fn())
+ last_layer_dim_vf = vf_layer_size
+
+ # Save dim, used to create the distributions
+ self.latent_dim_pi = last_layer_dim_pi
+ self.latent_dim_vf = last_layer_dim_vf
+
+ # Create networks
+ # If the list of layers is empty, the network will just act as an Identity module
+ self.shared_net = nn.Sequential(*shared_net).to(device)
+ self.policy_net = nn.Sequential(*policy_net).to(device)
+ self.value_net = nn.Sequential(*value_net).to(device)
+
+ def forward(self, features: th.Tensor) -> Tuple[th.Tensor, th.Tensor]:
+ """
+ :return: latent_policy, latent_value of the specified network.
+ If all layers are shared, then ``latent_policy == latent_value``
+ """
+ shared_latent = self.shared_net(features)
+ return self.policy_net(shared_latent), self.value_net(shared_latent)
+
+ def forward_actor(self, features: th.Tensor) -> th.Tensor:
+ return self.policy_net(self.shared_net(features))
+
+ def forward_critic(self, features: th.Tensor) -> th.Tensor:
+ return self.value_net(self.shared_net(features))
+
+
+class CombinedExtractor(BaseFeaturesExtractor):
+ """
+ Combined feature extractor for Dict observation spaces.
+ Builds a feature extractor for each key of the space. Input from each space
+ is fed through a separate submodule (CNN or MLP, depending on input shape),
+ the output features are concatenated and fed through additional MLP network ("combined").
+
+ :param observation_space:
+ :param cnn_output_dim: Number of features to output from each CNN submodule(s). Defaults to
+ 256 to avoid exploding network sizes.
+ """
+
+ def __init__(self, observation_space: gym.spaces.Dict, cnn_output_dim: int = 256):
+ # TODO we do not know features-dim here before going over all the items, so put something there. This is dirty!
+ super().__init__(observation_space, features_dim=1)
+
+ extractors = {}
+
+ total_concat_size = 0
+ for key, subspace in observation_space.spaces.items():
+ if is_image_space(subspace):
+ extractors[key] = NatureCNN(subspace, features_dim=cnn_output_dim)
+ total_concat_size += cnn_output_dim
+ else:
+ # The observation key is a vector, flatten it if needed
+ extractors[key] = nn.Flatten()
+ total_concat_size += get_flattened_obs_dim(subspace)
+
+ self.extractors = nn.ModuleDict(extractors)
+
+ # Update the features dim manually
+ self._features_dim = total_concat_size
+
+ def forward(self, observations: TensorDict) -> th.Tensor:
+ encoded_tensor_list = []
+
+ for key, extractor in self.extractors.items():
+ encoded_tensor_list.append(extractor(observations[key]))
+ return th.cat(encoded_tensor_list, dim=1)
+
+
+def get_actor_critic_arch(net_arch: Union[List[int], Dict[str, List[int]]]) -> Tuple[List[int], List[int]]:
+ """
+ Get the actor and critic network architectures for off-policy actor-critic algorithms (SAC, TD3, DDPG).
+
+ The ``net_arch`` parameter allows to specify the amount and size of the hidden layers,
+ which can be different for the actor and the critic.
+ It is assumed to be a list of ints or a dict.
+
+ 1. If it is a list, actor and critic networks will have the same architecture.
+ The architecture is represented by a list of integers (of arbitrary length (zero allowed))
+ each specifying the number of units per layer.
+ If the number of ints is zero, the network will be linear.
+ 2. If it is a dict, it should have the following structure:
+ ``dict(qf=[], pi=[])``.
+ where the network architecture is a list as described in 1.
+
+ For example, to have actor and critic that share the same network architecture,
+ you only need to specify ``net_arch=[256, 256]`` (here, two hidden layers of 256 units each).
+
+ If you want a different architecture for the actor and the critic,
+ then you can specify ``net_arch=dict(qf=[400, 300], pi=[64, 64])``.
+
+ .. note::
+ Compared to their on-policy counterparts, no shared layers (other than the features extractor)
+ between the actor and the critic are allowed (to prevent issues with target networks).
+
+ :param net_arch: The specification of the actor and critic networks.
+ See above for details on its formatting.
+ :return: The network architectures for the actor and the critic
+ """
+ if isinstance(net_arch, list):
+ actor_arch, critic_arch = net_arch, net_arch
+ else:
+ assert isinstance(net_arch, dict), "Error: the net_arch can only contain be a list of ints or a dict"
+ assert "pi" in net_arch, "Error: no key 'pi' was provided in net_arch for the actor network"
+ assert "qf" in net_arch, "Error: no key 'qf' was provided in net_arch for the critic network"
+ actor_arch, critic_arch = net_arch["pi"], net_arch["qf"]
+ return actor_arch, critic_arch
+
+
+class PointNetImaginationExtractorGP(BaseFeaturesExtractor):
+ def __init__(self, observation_space: gym.spaces.Dict, pc_key: str, feat_key: Optional[str] = None,
+ out_channel=256, extractor_name="smallpn",
+ gt_key: Optional[str] = None, imagination_keys=("imagination_robot",), state_key="state",
+ state_mlp_size=(64, 64), state_mlp_activation_fn=nn.ReLU, *kwargs):
+ self.imagination_key = imagination_keys
+ # Init state representation
+ self.use_state = state_key is not None
+ self.state_key = state_key
+
+ print(f"extractor use state = {self.use_state}")
+ if self.use_state:
+ if state_key not in observation_space.spaces.keys():
+ raise RuntimeError(f"State key {state_key} not in observation space: {observation_space}")
+ self.state_space = observation_space[self.state_key]
+ if feat_key is not None:
+ if feat_key not in list(observation_space.keys()):
+ raise RuntimeError(f"Feature key {feat_key} not in observation space.")
+ if pc_key not in list(observation_space.keys()):
+ raise RuntimeError(f"Point cloud key {pc_key} not in observation space.")
+
+ super().__init__(observation_space, out_channel)
+ # Point cloud input should have size (n, 3), spec size (n, 3), feat size (n, m)
+ self.pc_key = pc_key
+ self.has_feat = feat_key is not None
+ self.feat_key = feat_key
+ self.gt_key = gt_key
+
+ if extractor_name == "smallpn":
+ from stable_baselines3.networks.pretrain_nets import PointNet
+ self.extractor = PointNet()
+ elif extractor_name == "mediumpn":
+ from stable_baselines3.networks.pretrain_nets import PointNetMedium
+ self.extractor = PointNetMedium()
+ elif extractor_name == "largepn":
+ from stable_baselines3.networks.pretrain_nets import PointNetLarge
+ self.extractor = PointNetLarge()
+ else:
+ raise NotImplementedError(f"Extractor {extractor_name} not implemented. Available:\
+ smallpn, mediumpn, largepn")
+
+ # self.n_input_channels = n_input_channels
+ self.n_output_channels = out_channel
+ assert self.n_output_channels == 256
+
+ if self.use_state:
+ self.state_dim = self.state_space.shape[0]
+ if len(state_mlp_size) == 0:
+ raise RuntimeError(f"State mlp size is empty")
+ elif len(state_mlp_size) == 1:
+ net_arch = []
+ else:
+ net_arch = state_mlp_size[:-1]
+ output_dim = state_mlp_size[-1]
+
+ self.n_output_channels = out_channel + output_dim
+ self._features_dim = self.n_output_channels
+ self.state_mlp = nn.Sequential(*create_mlp(self.state_dim, output_dim, net_arch, state_mlp_activation_fn))
+
+ def forward(self, observations: TensorDict) -> th.Tensor:
+ # get raw point cloud segmentation mask
+ points = observations[self.pc_key][..., :3] # B * N * 3
+
+ b, _, _ = points.shape
+ if len(self.imagination_key) > 0:
+ for key in self.imagination_key:
+ obs = observations[key]
+ if len(obs.shape) == 2:
+ obs = obs.unsqueeze(0)
+ img_points = obs[:, :, :3]
+ points = torch.concat([points, img_points], dim=1)
+
+ # points = torch.transpose(points, 1, 2) # B * 3 * N
+ # points: B * 3 * (N + sum(Ni))
+ pn_feat = self.extractor(points) # B * 256
+ if self.use_state:
+ state_feat = self.state_mlp(observations[self.state_key])
+ return torch.cat([pn_feat, state_feat], dim=-1)
+ else:
+ return pn_feat
+
diff --git a/dexart-release/stable_baselines3/common/type_aliases.py b/dexart-release/stable_baselines3/common/type_aliases.py
new file mode 100644
index 0000000000000000000000000000000000000000..8f110d6d24f56845b4dc92a420a33bc54ccb7e3a
--- /dev/null
+++ b/dexart-release/stable_baselines3/common/type_aliases.py
@@ -0,0 +1,82 @@
+"""Common aliases for type hints"""
+
+from enum import Enum
+from typing import Any, Callable, Dict, List, NamedTuple, Tuple, Union
+
+import gym
+import numpy as np
+import torch as th
+
+from stable_baselines3.common import callbacks, vec_env
+
+GymEnv = Union[gym.Env, vec_env.VecEnv]
+GymObs = Union[Tuple, Dict[str, Any], np.ndarray, int]
+GymStepReturn = Tuple[GymObs, float, bool, Dict]
+TensorDict = Dict[Union[str, int], th.Tensor]
+OptimizerStateDict = Dict[str, Any]
+MaybeCallback = Union[None, Callable, List[callbacks.BaseCallback], callbacks.BaseCallback]
+
+# A schedule takes the remaining progress as input
+# and ouputs a scalar (e.g. learning rate, clip range, ...)
+Schedule = Callable[[float], float]
+
+
+class RolloutBufferSamples(NamedTuple):
+ observations: th.Tensor
+ actions: th.Tensor
+ old_values: th.Tensor
+ old_log_prob: th.Tensor
+ advantages: th.Tensor
+ returns: th.Tensor
+
+
+class DictRolloutBufferSamples(RolloutBufferSamples):
+ observations: TensorDict
+ actions: th.Tensor
+ old_values: th.Tensor
+ old_log_prob: th.Tensor
+ advantages: th.Tensor
+ returns: th.Tensor
+
+
+class DictSSLRolloutBufferSamples(NamedTuple):
+ observations: TensorDict
+ next_observations: List[TensorDict]
+ actions: th.Tensor
+ next_actions: List[th.Tensor]
+ old_values: th.Tensor
+ old_log_prob: th.Tensor
+ advantages: th.Tensor
+ returns: th.Tensor
+
+
+class ReplayBufferSamples(NamedTuple):
+ observations: th.Tensor
+ actions: th.Tensor
+ next_observations: th.Tensor
+ dones: th.Tensor
+ rewards: th.Tensor
+
+
+class DictReplayBufferSamples(ReplayBufferSamples):
+ observations: TensorDict
+ actions: th.Tensor
+ next_observations: th.Tensor
+ dones: th.Tensor
+ rewards: th.Tensor
+
+
+class RolloutReturn(NamedTuple):
+ episode_timesteps: int
+ n_episodes: int
+ continue_training: bool
+
+
+class TrainFrequencyUnit(Enum):
+ STEP = "step"
+ EPISODE = "episode"
+
+
+class TrainFreq(NamedTuple):
+ frequency: int
+ unit: TrainFrequencyUnit # either "step" or "episode"
diff --git a/dexart-release/stable_baselines3/common/utils.py b/dexart-release/stable_baselines3/common/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..e73ffe4efdfa88868d0b67c6c65a9f0553fcb50b
--- /dev/null
+++ b/dexart-release/stable_baselines3/common/utils.py
@@ -0,0 +1,508 @@
+import glob
+import os
+import platform
+import random
+from collections import deque
+from itertools import zip_longest
+from typing import Dict, Iterable, Optional, Tuple, Union
+
+import gym
+import numpy as np
+import torch as th
+
+import stable_baselines3 as sb3
+
+# Check if tensorboard is available for pytorch
+try:
+ from torch.utils.tensorboard import SummaryWriter
+except ImportError:
+ SummaryWriter = None
+
+from stable_baselines3.common.logger import Logger, configure
+from stable_baselines3.common.type_aliases import GymEnv, Schedule, TensorDict, TrainFreq, TrainFrequencyUnit
+
+
+def set_random_seed(seed: int, using_cuda: bool = False) -> None:
+ """
+ Seed the different random generators.
+
+ :param seed:
+ :param using_cuda:
+ """
+ # Seed python RNG
+ random.seed(seed)
+ # Seed numpy RNG
+ np.random.seed(seed)
+ # seed the RNG for all devices (both CPU and CUDA)
+ th.manual_seed(seed)
+
+ if using_cuda:
+ # Deterministic operations for CuDNN, it may impact performances
+ th.backends.cudnn.deterministic = True
+ th.backends.cudnn.benchmark = False
+
+
+# From stable baselines
+def explained_variance(y_pred: np.ndarray, y_true: np.ndarray) -> np.ndarray:
+ """
+ Computes fraction of variance that ypred explains about y.
+ Returns 1 - Var[y-ypred] / Var[y]
+
+ interpretation:
+ ev=0 => might as well have predicted zero
+ ev=1 => perfect prediction
+ ev<0 => worse than just predicting zero
+
+ :param y_pred: the prediction
+ :param y_true: the expected value
+ :return: explained variance of ypred and y
+ """
+ assert y_true.ndim == 1 and y_pred.ndim == 1
+ var_y = np.var(y_true)
+ return np.nan if var_y == 0 else 1 - np.var(y_true - y_pred) / var_y
+
+
+def update_learning_rate(optimizer: th.optim.Optimizer, learning_rate: float) -> None:
+ """
+ Update the learning rate for a given optimizer.
+ Useful when doing linear schedule.
+
+ :param optimizer:
+ :param learning_rate:
+ """
+ for param_group in optimizer.param_groups:
+ param_group["lr"] = learning_rate
+
+
+def get_schedule_fn(value_schedule: Union[Schedule, float, int]) -> Schedule:
+ """
+ Transform (if needed) learning rate and clip range (for PPO)
+ to callable.
+
+ :param value_schedule:
+ :return:
+ """
+ # If the passed schedule is a float
+ # create a constant function
+ if isinstance(value_schedule, (float, int)):
+ # Cast to float to avoid errors
+ value_schedule = constant_fn(float(value_schedule))
+ else:
+ assert callable(value_schedule)
+ return value_schedule
+
+
+def get_linear_fn(start: float, end: float, end_fraction: float) -> Schedule:
+ """
+ Create a function that interpolates linearly between start and end
+ between ``progress_remaining`` = 1 and ``progress_remaining`` = ``end_fraction``.
+ This is used in DQN for linearly annealing the exploration fraction
+ (epsilon for the epsilon-greedy strategy).
+
+ :params start: value to start with if ``progress_remaining`` = 1
+ :params end: value to end with if ``progress_remaining`` = 0
+ :params end_fraction: fraction of ``progress_remaining``
+ where end is reached e.g 0.1 then end is reached after 10%
+ of the complete training process.
+ :return:
+ """
+
+ def func(progress_remaining: float) -> float:
+ if (1 - progress_remaining) > end_fraction:
+ return end
+ else:
+ return start + (1 - progress_remaining) * (end - start) / end_fraction
+
+ return func
+
+
+def constant_fn(val: float) -> Schedule:
+ """
+ Create a function that returns a constant
+ It is useful for learning rate schedule (to avoid code duplication)
+
+ :param val:
+ :return:
+ """
+
+ def func(_):
+ return val
+
+ return func
+
+
+def get_device(device: Union[th.device, str] = "auto") -> th.device:
+ """
+ Retrieve PyTorch device.
+ It checks that the requested device is available first.
+ For now, it supports only cpu and cuda.
+ By default, it tries to use the gpu.
+
+ :param device: One for 'auto', 'cuda', 'cpu'
+ :return:
+ """
+ # Cuda by default
+ if device == "auto":
+ device = "cuda"
+ # Force conversion to th.device
+ device = th.device(device)
+
+ # Cuda not available
+ if device.type == th.device("cuda").type and not th.cuda.is_available():
+ return th.device("cpu")
+
+ return device
+
+
+def get_latest_run_id(log_path: str = "", log_name: str = "") -> int:
+ """
+ Returns the latest run number for the given log name and log path,
+ by finding the greatest number in the directories.
+
+ :param log_path: Path to the log folder containing several runs.
+ :param log_name: Name of the experiment. Each run is stored
+ in a folder named ``log_name_1``, ``log_name_2``, ...
+ :return: latest run number
+ """
+ max_run_id = 0
+ for path in glob.glob(os.path.join(log_path, f"{glob.escape(log_name)}_[0-9]*")):
+ file_name = path.split(os.sep)[-1]
+ ext = file_name.split("_")[-1]
+ if log_name == "_".join(file_name.split("_")[:-1]) and ext.isdigit() and int(ext) > max_run_id:
+ max_run_id = int(ext)
+ return max_run_id
+
+
+def configure_logger(
+ verbose: int = 0,
+ tensorboard_log: Optional[str] = None,
+ tb_log_name: str = "",
+ reset_num_timesteps: bool = True,
+) -> Logger:
+ """
+ Configure the logger's outputs.
+
+ :param verbose: the verbosity level: 0 no output, 1 info, 2 debug
+ :param tensorboard_log: the log location for tensorboard (if None, no logging)
+ :param tb_log_name: tensorboard log
+ :param reset_num_timesteps: Whether the ``num_timesteps`` attribute is reset or not.
+ It allows to continue a previous learning curve (``reset_num_timesteps=False``)
+ or start from t=0 (``reset_num_timesteps=True``, the default).
+ :return: The logger object
+ """
+ save_path, format_strings = None, ["stdout"]
+
+ if tensorboard_log is not None and SummaryWriter is None:
+ raise ImportError("Trying to log data to tensorboard but tensorboard is not installed.")
+
+ if tensorboard_log is not None and SummaryWriter is not None:
+ latest_run_id = get_latest_run_id(tensorboard_log, tb_log_name)
+ if not reset_num_timesteps:
+ # Continue training in the same directory
+ latest_run_id -= 1
+ save_path = os.path.join(tensorboard_log, f"{tb_log_name}_{latest_run_id + 1}")
+ if verbose >= 1:
+ format_strings = ["stdout", "tensorboard"]
+ else:
+ format_strings = ["tensorboard"]
+ elif verbose == 0:
+ format_strings = [""]
+ return configure(save_path, format_strings=format_strings)
+
+
+def check_for_correct_spaces(env: GymEnv, observation_space: gym.spaces.Space, action_space: gym.spaces.Space) -> None:
+ """
+ Checks that the environment has same spaces as provided ones. Used by BaseAlgorithm to check if
+ spaces match after loading the model with given env.
+ Checked parameters:
+ - observation_space
+ - action_space
+
+ :param env: Environment to check for valid spaces
+ :param observation_space: Observation space to check against
+ :param action_space: Action space to check against
+ """
+ if observation_space != env.observation_space:
+ raise ValueError(f"Observation spaces do not match: {observation_space} != {env.observation_space}")
+ if action_space != env.action_space:
+ raise ValueError(f"Action spaces do not match: {action_space} != {env.action_space}")
+
+
+def is_vectorized_box_observation(observation: np.ndarray, observation_space: gym.spaces.Box) -> bool:
+ """
+ For box observation type, detects and validates the shape,
+ then returns whether or not the observation is vectorized.
+
+ :param observation: the input observation to validate
+ :param observation_space: the observation space
+ :return: whether the given observation is vectorized or not
+ """
+ if observation.shape == observation_space.shape:
+ return False
+ elif observation.shape[1:] == observation_space.shape:
+ return True
+ else:
+ raise ValueError(
+ f"Error: Unexpected observation shape {observation.shape} for "
+ + f"Box environment, please use {observation_space.shape} "
+ + "or (n_env, {}) for the observation shape.".format(", ".join(map(str, observation_space.shape)))
+ )
+
+
+def is_vectorized_discrete_observation(observation: Union[int, np.ndarray], observation_space: gym.spaces.Discrete) -> bool:
+ """
+ For discrete observation type, detects and validates the shape,
+ then returns whether or not the observation is vectorized.
+
+ :param observation: the input observation to validate
+ :param observation_space: the observation space
+ :return: whether the given observation is vectorized or not
+ """
+ if isinstance(observation, int) or observation.shape == (): # A numpy array of a number, has shape empty tuple '()'
+ return False
+ elif len(observation.shape) == 1:
+ return True
+ else:
+ raise ValueError(
+ f"Error: Unexpected observation shape {observation.shape} for "
+ + "Discrete environment, please use () or (n_env,) for the observation shape."
+ )
+
+
+def is_vectorized_multidiscrete_observation(observation: np.ndarray, observation_space: gym.spaces.MultiDiscrete) -> bool:
+ """
+ For multidiscrete observation type, detects and validates the shape,
+ then returns whether or not the observation is vectorized.
+
+ :param observation: the input observation to validate
+ :param observation_space: the observation space
+ :return: whether the given observation is vectorized or not
+ """
+ if observation.shape == (len(observation_space.nvec),):
+ return False
+ elif len(observation.shape) == 2 and observation.shape[1] == len(observation_space.nvec):
+ return True
+ else:
+ raise ValueError(
+ f"Error: Unexpected observation shape {observation.shape} for MultiDiscrete "
+ + f"environment, please use ({len(observation_space.nvec)},) or "
+ + f"(n_env, {len(observation_space.nvec)}) for the observation shape."
+ )
+
+
+def is_vectorized_multibinary_observation(observation: np.ndarray, observation_space: gym.spaces.MultiBinary) -> bool:
+ """
+ For multibinary observation type, detects and validates the shape,
+ then returns whether or not the observation is vectorized.
+
+ :param observation: the input observation to validate
+ :param observation_space: the observation space
+ :return: whether the given observation is vectorized or not
+ """
+ if observation.shape == (observation_space.n,):
+ return False
+ elif len(observation.shape) == 2 and observation.shape[1] == observation_space.n:
+ return True
+ else:
+ raise ValueError(
+ f"Error: Unexpected observation shape {observation.shape} for MultiBinary "
+ + f"environment, please use ({observation_space.n},) or "
+ + f"(n_env, {observation_space.n}) for the observation shape."
+ )
+
+
+def is_vectorized_dict_observation(observation: np.ndarray, observation_space: gym.spaces.Dict) -> bool:
+ """
+ For dict observation type, detects and validates the shape,
+ then returns whether or not the observation is vectorized.
+
+ :param observation: the input observation to validate
+ :param observation_space: the observation space
+ :return: whether the given observation is vectorized or not
+ """
+ # We first assume that all observations are not vectorized
+ all_non_vectorized = True
+ for key, subspace in observation_space.spaces.items():
+ # This fails when the observation is not vectorized
+ # or when it has the wrong shape
+ if observation[key].shape != subspace.shape:
+ all_non_vectorized = False
+ break
+
+ if all_non_vectorized:
+ return False
+
+ all_vectorized = True
+ # Now we check that all observation are vectorized and have the correct shape
+ for key, subspace in observation_space.spaces.items():
+ if observation[key].shape[1:] != subspace.shape:
+ all_vectorized = False
+ break
+
+ if all_vectorized:
+ return True
+ else:
+ # Retrieve error message
+ error_msg = ""
+ try:
+ is_vectorized_observation(observation[key], observation_space.spaces[key])
+ except ValueError as e:
+ error_msg = f"{e}"
+ raise ValueError(
+ f"There seems to be a mix of vectorized and non-vectorized observations. "
+ f"Unexpected observation shape {observation[key].shape} for key {key} "
+ f"of type {observation_space.spaces[key]}. {error_msg}"
+ )
+
+
+def is_vectorized_observation(observation: Union[int, np.ndarray], observation_space: gym.spaces.Space) -> bool:
+ """
+ For every observation type, detects and validates the shape,
+ then returns whether or not the observation is vectorized.
+
+ :param observation: the input observation to validate
+ :param observation_space: the observation space
+ :return: whether the given observation is vectorized or not
+ """
+
+ is_vec_obs_func_dict = {
+ gym.spaces.Box: is_vectorized_box_observation,
+ gym.spaces.Discrete: is_vectorized_discrete_observation,
+ gym.spaces.MultiDiscrete: is_vectorized_multidiscrete_observation,
+ gym.spaces.MultiBinary: is_vectorized_multibinary_observation,
+ gym.spaces.Dict: is_vectorized_dict_observation,
+ }
+
+ for space_type, is_vec_obs_func in is_vec_obs_func_dict.items():
+ if isinstance(observation_space, space_type):
+ return is_vec_obs_func(observation, observation_space)
+ else:
+ # for-else happens if no break is called
+ raise ValueError(f"Error: Cannot determine if the observation is vectorized with the space type {observation_space}.")
+
+
+def safe_mean(arr: Union[np.ndarray, list, deque]) -> np.ndarray:
+ """
+ Compute the mean of an array if there is at least one element.
+ For empty array, return NaN. It is used for logging only.
+
+ :param arr:
+ :return:
+ """
+ return np.nan if len(arr) == 0 else np.mean(arr)
+
+
+def zip_strict(*iterables: Iterable) -> Iterable:
+ r"""
+ ``zip()`` function but enforces that iterables are of equal length.
+ Raises ``ValueError`` if iterables not of equal length.
+ Code inspired by Stackoverflow answer for question #32954486.
+
+ :param \*iterables: iterables to ``zip()``
+ """
+ # As in Stackoverflow #32954486, use
+ # new object for "empty" in case we have
+ # Nones in iterable.
+ sentinel = object()
+ for combo in zip_longest(*iterables, fillvalue=sentinel):
+ if sentinel in combo:
+ raise ValueError("Iterables have different lengths")
+ yield combo
+
+
+def polyak_update(
+ params: Iterable[th.nn.Parameter],
+ target_params: Iterable[th.nn.Parameter],
+ tau: float,
+) -> None:
+ """
+ Perform a Polyak average update on ``target_params`` using ``params``:
+ target parameters are slowly updated towards the src parameters.
+ ``tau``, the soft update coefficient controls the interpolation:
+ ``tau=1`` corresponds to copying the parameters to the target ones whereas nothing happens when ``tau=0``.
+ The Polyak update is done in place, with ``no_grad``, and therefore does not create intermediate tensors,
+ or a computation graph, reducing memory cost and improving performance. We scale the target params
+ by ``1-tau`` (in-place), add the new weights, scaled by ``tau`` and store the result of the sum in the target
+ params (in place).
+ See https://github.com/DLR-RM/stable-baselines3/issues/93
+
+ :param params: parameters to use to update the target params
+ :param target_params: parameters to update
+ :param tau: the soft update coefficient ("Polyak update", between 0 and 1)
+ """
+ with th.no_grad():
+ # zip does not raise an exception if length of parameters does not match.
+ for param, target_param in zip_strict(params, target_params):
+ target_param.data.mul_(1 - tau)
+ th.add(target_param.data, param.data, alpha=tau, out=target_param.data)
+
+
+def obs_as_tensor(
+ obs: Union[np.ndarray, Dict[Union[str, int], np.ndarray]], device: th.device
+) -> Union[th.Tensor, TensorDict]:
+ """
+ Moves the observation to the given device.
+
+ :param obs:
+ :param device: PyTorch device
+ :return: PyTorch tensor of the observation on a desired device.
+ """
+ if isinstance(obs, np.ndarray):
+ return th.as_tensor(obs).to(device)
+ elif isinstance(obs, dict):
+ return {key: th.as_tensor(_obs).to(device) for (key, _obs) in obs.items()}
+ else:
+ raise Exception(f"Unrecognized type of observation {type(obs)}")
+
+
+def should_collect_more_steps(
+ train_freq: TrainFreq,
+ num_collected_steps: int,
+ num_collected_episodes: int,
+) -> bool:
+ """
+ Helper used in ``collect_rollouts()`` of off-policy algorithms
+ to determine the termination condition.
+
+ :param train_freq: How much experience should be collected before updating the policy.
+ :param num_collected_steps: The number of already collected steps.
+ :param num_collected_episodes: The number of already collected episodes.
+ :return: Whether to continue or not collecting experience
+ by doing rollouts of the current policy.
+ """
+ if train_freq.unit == TrainFrequencyUnit.STEP:
+ return num_collected_steps < train_freq.frequency
+
+ elif train_freq.unit == TrainFrequencyUnit.EPISODE:
+ return num_collected_episodes < train_freq.frequency
+
+ else:
+ raise ValueError(
+ "The unit of the `train_freq` must be either TrainFrequencyUnit.STEP "
+ f"or TrainFrequencyUnit.EPISODE not '{train_freq.unit}'!"
+ )
+
+
+def get_system_info(print_info: bool = True) -> Tuple[Dict[str, str], str]:
+ """
+ Retrieve system and python env info for the current system.
+
+ :param print_info: Whether to print or not those infos
+ :return: Dictionary summing up the version for each relevant package
+ and a formatted string.
+ """
+ env_info = {
+ "OS": f"{platform.platform()} {platform.version()}",
+ "Python": platform.python_version(),
+ "Stable-Baselines3": sb3.__version__,
+ "PyTorch": th.__version__,
+ "GPU Enabled": str(th.cuda.is_available()),
+ "Numpy": np.__version__,
+ "Gym": gym.__version__,
+ }
+ env_info_str = ""
+ for key, value in env_info.items():
+ env_info_str += f"{key}: {value}\n"
+ if print_info:
+ print(env_info_str)
+ return env_info, env_info_str
diff --git a/dexart-release/stable_baselines3/common/vec_env/__init__.py b/dexart-release/stable_baselines3/common/vec_env/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..3880fbd53d9f7ed4683428149de2538042385877
--- /dev/null
+++ b/dexart-release/stable_baselines3/common/vec_env/__init__.py
@@ -0,0 +1,74 @@
+# flake8: noqa F401
+import typing
+from copy import deepcopy
+from typing import Optional, Type, Union
+
+from stable_baselines3.common.vec_env.base_vec_env import CloudpickleWrapper, VecEnv, VecEnvWrapper
+from stable_baselines3.common.vec_env.dummy_vec_env import DummyVecEnv
+from stable_baselines3.common.vec_env.stacked_observations import StackedDictObservations, StackedObservations
+from stable_baselines3.common.vec_env.subproc_vec_env import SubprocVecEnv
+from stable_baselines3.common.vec_env.vec_check_nan import VecCheckNan
+from stable_baselines3.common.vec_env.vec_extract_dict_obs import VecExtractDictObs
+from stable_baselines3.common.vec_env.vec_frame_stack import VecFrameStack
+from stable_baselines3.common.vec_env.vec_monitor import VecMonitor
+from stable_baselines3.common.vec_env.vec_normalize import VecNormalize
+from stable_baselines3.common.vec_env.vec_transpose import VecTransposeImage
+from stable_baselines3.common.vec_env.vec_video_recorder import VecVideoRecorder
+
+# Avoid circular import
+if typing.TYPE_CHECKING:
+ from stable_baselines3.common.type_aliases import GymEnv
+
+
+def unwrap_vec_wrapper(env: Union["GymEnv", VecEnv], vec_wrapper_class: Type[VecEnvWrapper]) -> Optional[VecEnvWrapper]:
+ """
+ Retrieve a ``VecEnvWrapper`` object by recursively searching.
+
+ :param env:
+ :param vec_wrapper_class:
+ :return:
+ """
+ env_tmp = env
+ while isinstance(env_tmp, VecEnvWrapper):
+ if isinstance(env_tmp, vec_wrapper_class):
+ return env_tmp
+ env_tmp = env_tmp.venv
+ return None
+
+
+def unwrap_vec_normalize(env: Union["GymEnv", VecEnv]) -> Optional[VecNormalize]:
+ """
+ :param env:
+ :return:
+ """
+ return unwrap_vec_wrapper(env, VecNormalize) # pytype:disable=bad-return-type
+
+
+def is_vecenv_wrapped(env: Union["GymEnv", VecEnv], vec_wrapper_class: Type[VecEnvWrapper]) -> bool:
+ """
+ Check if an environment is already wrapped by a given ``VecEnvWrapper``.
+
+ :param env:
+ :param vec_wrapper_class:
+ :return:
+ """
+ return unwrap_vec_wrapper(env, vec_wrapper_class) is not None
+
+
+# Define here to avoid circular import
+def sync_envs_normalization(env: "GymEnv", eval_env: "GymEnv") -> None:
+ """
+ Sync eval env and train env when using VecNormalize
+
+ :param env:
+ :param eval_env:
+ """
+ env_tmp, eval_env_tmp = env, eval_env
+ while isinstance(env_tmp, VecEnvWrapper):
+ if isinstance(env_tmp, VecNormalize):
+ # Only synchronize if observation normalization exists
+ if hasattr(env_tmp, "obs_rms"):
+ eval_env_tmp.obs_rms = deepcopy(env_tmp.obs_rms)
+ eval_env_tmp.ret_rms = deepcopy(env_tmp.ret_rms)
+ env_tmp = env_tmp.venv
+ eval_env_tmp = eval_env_tmp.venv
diff --git a/dexart-release/stable_baselines3/common/vec_env/base_vec_env.py b/dexart-release/stable_baselines3/common/vec_env/base_vec_env.py
new file mode 100644
index 0000000000000000000000000000000000000000..98706050c7c48684b623de6637a95b5bde40bf51
--- /dev/null
+++ b/dexart-release/stable_baselines3/common/vec_env/base_vec_env.py
@@ -0,0 +1,374 @@
+import inspect
+import warnings
+from abc import ABC, abstractmethod
+from typing import Any, Dict, Iterable, List, Optional, Sequence, Tuple, Type, Union
+
+import cloudpickle
+import gym
+import numpy as np
+
+# Define type aliases here to avoid circular import
+# Used when we want to access one or more VecEnv
+VecEnvIndices = Union[None, int, Iterable[int]]
+# VecEnvObs is what is returned by the reset() method
+# it contains the observation for each env
+VecEnvObs = Union[np.ndarray, Dict[str, np.ndarray], Tuple[np.ndarray, ...]]
+# VecEnvStepReturn is what is returned by the step() method
+# it contains the observation, reward, done, info for each env
+VecEnvStepReturn = Tuple[VecEnvObs, np.ndarray, np.ndarray, List[Dict]]
+
+
+def tile_images(img_nhwc: Sequence[np.ndarray]) -> np.ndarray: # pragma: no cover
+ """
+ Tile N images into one big PxQ image
+ (P,Q) are chosen to be as close as possible, and if N
+ is square, then P=Q.
+
+ :param img_nhwc: list or array of images, ndim=4 once turned into array. img nhwc
+ n = batch index, h = height, w = width, c = channel
+ :return: img_HWc, ndim=3
+ """
+ img_nhwc = np.asarray(img_nhwc)
+ n_images, height, width, n_channels = img_nhwc.shape
+ # new_height was named H before
+ new_height = int(np.ceil(np.sqrt(n_images)))
+ # new_width was named W before
+ new_width = int(np.ceil(float(n_images) / new_height))
+ img_nhwc = np.array(list(img_nhwc) + [img_nhwc[0] * 0 for _ in range(n_images, new_height * new_width)])
+ # img_HWhwc
+ out_image = img_nhwc.reshape((new_height, new_width, height, width, n_channels))
+ # img_HhWwc
+ out_image = out_image.transpose(0, 2, 1, 3, 4)
+ # img_Hh_Ww_c
+ out_image = out_image.reshape((new_height * height, new_width * width, n_channels))
+ return out_image
+
+
+class VecEnv(ABC):
+ """
+ An abstract asynchronous, vectorized environment.
+
+ :param num_envs: the number of environments
+ :param observation_space: the observation space
+ :param action_space: the action space
+ """
+
+ metadata = {"render.modes": ["human", "rgb_array"]}
+
+ def __init__(self, num_envs: int, observation_space: gym.spaces.Space, action_space: gym.spaces.Space):
+ self.num_envs = num_envs
+ self.observation_space = observation_space
+ self.action_space = action_space
+
+ @abstractmethod
+ def reset(self) -> VecEnvObs:
+ """
+ Reset all the environments and return an array of
+ observations, or a tuple of observation arrays.
+
+ If step_async is still doing work, that work will
+ be cancelled and step_wait() should not be called
+ until step_async() is invoked again.
+
+ :return: observation
+ """
+ raise NotImplementedError()
+
+ @abstractmethod
+ def step_async(self, actions: np.ndarray) -> None:
+ """
+ Tell all the environments to start taking a step
+ with the given actions.
+ Call step_wait() to get the results of the step.
+
+ You should not call this if a step_async run is
+ already pending.
+ """
+ raise NotImplementedError()
+
+ @abstractmethod
+ def step_wait(self) -> VecEnvStepReturn:
+ """
+ Wait for the step taken with step_async().
+
+ :return: observation, reward, done, information
+ """
+ raise NotImplementedError()
+
+ @abstractmethod
+ def close(self) -> None:
+ """
+ Clean up the environment's resources.
+ """
+ raise NotImplementedError()
+
+ @abstractmethod
+ def get_attr(self, attr_name: str, indices: VecEnvIndices = None) -> List[Any]:
+ """
+ Return attribute from vectorized environment.
+
+ :param attr_name: The name of the attribute whose value to return
+ :param indices: Indices of envs to get attribute from
+ :return: List of values of 'attr_name' in all environments
+ """
+ raise NotImplementedError()
+
+ @abstractmethod
+ def set_attr(self, attr_name: str, value: Any, indices: VecEnvIndices = None) -> None:
+ """
+ Set attribute inside vectorized environments.
+
+ :param attr_name: The name of attribute to assign new value
+ :param value: Value to assign to `attr_name`
+ :param indices: Indices of envs to assign value
+ :return:
+ """
+ raise NotImplementedError()
+
+ @abstractmethod
+ def env_method(self, method_name: str, *method_args, indices: VecEnvIndices = None, **method_kwargs) -> List[Any]:
+ """
+ Call instance methods of vectorized environments.
+
+ :param method_name: The name of the environment method to invoke.
+ :param indices: Indices of envs whose method to call
+ :param method_args: Any positional arguments to provide in the call
+ :param method_kwargs: Any keyword arguments to provide in the call
+ :return: List of items returned by the environment's method call
+ """
+ raise NotImplementedError()
+
+ @abstractmethod
+ def env_is_wrapped(self, wrapper_class: Type[gym.Wrapper], indices: VecEnvIndices = None) -> List[bool]:
+ """
+ Check if environments are wrapped with a given wrapper.
+
+ :param method_name: The name of the environment method to invoke.
+ :param indices: Indices of envs whose method to call
+ :param method_args: Any positional arguments to provide in the call
+ :param method_kwargs: Any keyword arguments to provide in the call
+ :return: True if the env is wrapped, False otherwise, for each env queried.
+ """
+ raise NotImplementedError()
+
+ def step(self, actions: np.ndarray) -> VecEnvStepReturn:
+ """
+ Step the environments with the given action
+
+ :param actions: the action
+ :return: observation, reward, done, information
+ """
+ self.step_async(actions)
+ return self.step_wait()
+
+ def get_images(self) -> Sequence[np.ndarray]:
+ """
+ Return RGB images from each environment
+ """
+ raise NotImplementedError
+
+ def render(self, mode: str = "human") -> Optional[np.ndarray]:
+ """
+ Gym environment rendering
+
+ :param mode: the rendering type
+ """
+ try:
+ imgs = self.get_images()
+ except NotImplementedError:
+ warnings.warn(f"Render not defined for {self}")
+ return
+
+ # Create a big image by tiling images from subprocesses
+ bigimg = tile_images(imgs)
+ if mode == "human":
+ import cv2 # pytype:disable=import-error
+
+ cv2.imshow("vecenv", bigimg[:, :, ::-1])
+ cv2.waitKey(1)
+ elif mode == "rgb_array":
+ return bigimg
+ else:
+ raise NotImplementedError(f"Render mode {mode} is not supported by VecEnvs")
+
+ @abstractmethod
+ def seed(self, seed: Optional[int] = None) -> List[Union[None, int]]:
+ """
+ Sets the random seeds for all environments, based on a given seed.
+ Each individual environment will still get its own seed, by incrementing the given seed.
+
+ :param seed: The random seed. May be None for completely random seeding.
+ :return: Returns a list containing the seeds for each individual env.
+ Note that all list elements may be None, if the env does not return anything when being seeded.
+ """
+ pass
+
+ @property
+ def unwrapped(self) -> "VecEnv":
+ if isinstance(self, VecEnvWrapper):
+ return self.venv.unwrapped
+ else:
+ return self
+
+ def getattr_depth_check(self, name: str, already_found: bool) -> Optional[str]:
+ """Check if an attribute reference is being hidden in a recursive call to __getattr__
+
+ :param name: name of attribute to check for
+ :param already_found: whether this attribute has already been found in a wrapper
+ :return: name of module whose attribute is being shadowed, if any.
+ """
+ if hasattr(self, name) and already_found:
+ return f"{type(self).__module__}.{type(self).__name__}"
+ else:
+ return None
+
+ def _get_indices(self, indices: VecEnvIndices) -> Iterable[int]:
+ """
+ Convert a flexibly-typed reference to environment indices to an implied list of indices.
+
+ :param indices: refers to indices of envs.
+ :return: the implied list of indices.
+ """
+ if indices is None:
+ indices = range(self.num_envs)
+ elif isinstance(indices, int):
+ indices = [indices]
+ return indices
+
+
+class VecEnvWrapper(VecEnv):
+ """
+ Vectorized environment base class
+
+ :param venv: the vectorized environment to wrap
+ :param observation_space: the observation space (can be None to load from venv)
+ :param action_space: the action space (can be None to load from venv)
+ """
+
+ def __init__(
+ self,
+ venv: VecEnv,
+ observation_space: Optional[gym.spaces.Space] = None,
+ action_space: Optional[gym.spaces.Space] = None,
+ ):
+ self.venv = venv
+ VecEnv.__init__(
+ self,
+ num_envs=venv.num_envs,
+ observation_space=observation_space or venv.observation_space,
+ action_space=action_space or venv.action_space,
+ )
+ self.class_attributes = dict(inspect.getmembers(self.__class__))
+
+ def step_async(self, actions: np.ndarray) -> None:
+ self.venv.step_async(actions)
+
+ @abstractmethod
+ def reset(self) -> VecEnvObs:
+ pass
+
+ @abstractmethod
+ def step_wait(self) -> VecEnvStepReturn:
+ pass
+
+ def seed(self, seed: Optional[int] = None) -> List[Union[None, int]]:
+ return self.venv.seed(seed)
+
+ def close(self) -> None:
+ return self.venv.close()
+
+ def render(self, mode: str = "human") -> Optional[np.ndarray]:
+ return self.venv.render(mode=mode)
+
+ def get_images(self) -> Sequence[np.ndarray]:
+ return self.venv.get_images()
+
+ def get_attr(self, attr_name: str, indices: VecEnvIndices = None) -> List[Any]:
+ return self.venv.get_attr(attr_name, indices)
+
+ def set_attr(self, attr_name: str, value: Any, indices: VecEnvIndices = None) -> None:
+ return self.venv.set_attr(attr_name, value, indices)
+
+ def env_method(self, method_name: str, *method_args, indices: VecEnvIndices = None, **method_kwargs) -> List[Any]:
+ return self.venv.env_method(method_name, *method_args, indices=indices, **method_kwargs)
+
+ def env_is_wrapped(self, wrapper_class: Type[gym.Wrapper], indices: VecEnvIndices = None) -> List[bool]:
+ return self.venv.env_is_wrapped(wrapper_class, indices=indices)
+
+ def __getattr__(self, name: str) -> Any:
+ """Find attribute from wrapped venv(s) if this wrapper does not have it.
+ Useful for accessing attributes from venvs which are wrapped with multiple wrappers
+ which have unique attributes of interest.
+ """
+ blocked_class = self.getattr_depth_check(name, already_found=False)
+ if blocked_class is not None:
+ own_class = f"{type(self).__module__}.{type(self).__name__}"
+ error_str = (
+ f"Error: Recursive attribute lookup for {name} from {own_class} is "
+ f"ambiguous and hides attribute from {blocked_class}"
+ )
+ raise AttributeError(error_str)
+
+ return self.getattr_recursive(name)
+
+ def _get_all_attributes(self) -> Dict[str, Any]:
+ """Get all (inherited) instance and class attributes
+
+ :return: all_attributes
+ """
+ all_attributes = self.__dict__.copy()
+ all_attributes.update(self.class_attributes)
+ return all_attributes
+
+ def getattr_recursive(self, name: str) -> Any:
+ """Recursively check wrappers to find attribute.
+
+ :param name: name of attribute to look for
+ :return: attribute
+ """
+ all_attributes = self._get_all_attributes()
+ if name in all_attributes: # attribute is present in this wrapper
+ attr = getattr(self, name)
+ elif hasattr(self.venv, "getattr_recursive"):
+ # Attribute not present, child is wrapper. Call getattr_recursive rather than getattr
+ # to avoid a duplicate call to getattr_depth_check.
+ attr = self.venv.getattr_recursive(name)
+ else: # attribute not present, child is an unwrapped VecEnv
+ attr = getattr(self.venv, name)
+
+ return attr
+
+ def getattr_depth_check(self, name: str, already_found: bool) -> str:
+ """See base class.
+
+ :return: name of module whose attribute is being shadowed, if any.
+ """
+ all_attributes = self._get_all_attributes()
+ if name in all_attributes and already_found:
+ # this venv's attribute is being hidden because of a higher venv.
+ shadowed_wrapper_class = f"{type(self).__module__}.{type(self).__name__}"
+ elif name in all_attributes and not already_found:
+ # we have found the first reference to the attribute. Now check for duplicates.
+ shadowed_wrapper_class = self.venv.getattr_depth_check(name, True)
+ else:
+ # this wrapper does not have the attribute. Keep searching.
+ shadowed_wrapper_class = self.venv.getattr_depth_check(name, already_found)
+
+ return shadowed_wrapper_class
+
+
+class CloudpickleWrapper:
+ """
+ Uses cloudpickle to serialize contents (otherwise multiprocessing tries to use pickle)
+
+ :param var: the variable you wish to wrap for pickling with cloudpickle
+ """
+
+ def __init__(self, var: Any):
+ self.var = var
+
+ def __getstate__(self) -> Any:
+ return cloudpickle.dumps(self.var)
+
+ def __setstate__(self, var: Any) -> None:
+ self.var = cloudpickle.loads(var)
diff --git a/dexart-release/stable_baselines3/common/vec_env/dummy_vec_env.py b/dexart-release/stable_baselines3/common/vec_env/dummy_vec_env.py
new file mode 100644
index 0000000000000000000000000000000000000000..c0efc8cafcdaca9596e21d3014b3db32ac92cdd3
--- /dev/null
+++ b/dexart-release/stable_baselines3/common/vec_env/dummy_vec_env.py
@@ -0,0 +1,127 @@
+from collections import OrderedDict
+from copy import deepcopy
+from typing import Any, Callable, List, Optional, Sequence, Type, Union
+
+import gym
+import numpy as np
+
+from stable_baselines3.common.vec_env.base_vec_env import VecEnv, VecEnvIndices, VecEnvObs, VecEnvStepReturn
+from stable_baselines3.common.vec_env.util import copy_obs_dict, dict_to_obs, obs_space_info
+
+
+class DummyVecEnv(VecEnv):
+ """
+ Creates a simple vectorized wrapper for multiple environments, calling each environment in sequence on the current
+ Python process. This is useful for computationally simple environment such as ``cartpole-v1``,
+ as the overhead of multiprocess or multithread outweighs the environment computation time.
+ This can also be used for RL methods that
+ require a vectorized environment, but that you want a single environments to train with.
+
+ :param env_fns: a list of functions
+ that return environments to vectorize
+ """
+
+ def __init__(self, env_fns: List[Callable[[], gym.Env]]):
+ self.envs = [fn() for fn in env_fns]
+ env = self.envs[0]
+ VecEnv.__init__(self, len(env_fns), env.observation_space, env.action_space)
+ obs_space = env.observation_space
+ self.keys, shapes, dtypes = obs_space_info(obs_space)
+
+ self.buf_obs = OrderedDict([(k, np.zeros((self.num_envs,) + tuple(shapes[k]), dtype=dtypes[k])) for k in self.keys])
+ self.buf_dones = np.zeros((self.num_envs,), dtype=bool)
+ self.buf_rews = np.zeros((self.num_envs,), dtype=np.float32)
+ self.buf_infos = [{} for _ in range(self.num_envs)]
+ self.actions = None
+ self.metadata = env.metadata
+
+ def step_async(self, actions: np.ndarray) -> None:
+ self.actions = actions
+
+ def step_wait(self) -> VecEnvStepReturn:
+ for env_idx in range(self.num_envs):
+ obs, self.buf_rews[env_idx], self.buf_dones[env_idx], self.buf_infos[env_idx] = self.envs[env_idx].step(
+ self.actions[env_idx]
+ )
+ if self.buf_dones[env_idx]:
+ # save final observation where user can get it, then reset
+ self.buf_infos[env_idx]["terminal_observation"] = obs
+ obs = self.envs[env_idx].reset()
+ self._save_obs(env_idx, obs)
+ return (self._obs_from_buf(), np.copy(self.buf_rews), np.copy(self.buf_dones), deepcopy(self.buf_infos))
+
+ def seed(self, seed: Optional[int] = None) -> List[Union[None, int]]:
+ if seed is None:
+ seed = np.random.randint(0, 2**32 - 1)
+ seeds = []
+ for idx, env in enumerate(self.envs):
+ seeds.append(env.seed(seed + idx))
+ return seeds
+
+ def reset(self) -> VecEnvObs:
+ for env_idx in range(self.num_envs):
+ obs = self.envs[env_idx].reset()
+ self._save_obs(env_idx, obs)
+ return self._obs_from_buf()
+
+ def close(self) -> None:
+ for env in self.envs:
+ env.close()
+
+ def get_images(self) -> Sequence[np.ndarray]:
+ return [env.render(mode="rgb_array") for env in self.envs]
+
+ def render(self, mode: str = "human") -> Optional[np.ndarray]:
+ """
+ Gym environment rendering. If there are multiple environments then
+ they are tiled together in one image via ``BaseVecEnv.render()``.
+ Otherwise (if ``self.num_envs == 1``), we pass the render call directly to the
+ underlying environment.
+
+ Therefore, some arguments such as ``mode`` will have values that are valid
+ only when ``num_envs == 1``.
+
+ :param mode: The rendering type.
+ """
+ if self.num_envs == 1:
+ return self.envs[0].render(mode=mode)
+ else:
+ return super().render(mode=mode)
+
+ def _save_obs(self, env_idx: int, obs: VecEnvObs) -> None:
+ for key in self.keys:
+ if key is None:
+ self.buf_obs[key][env_idx] = obs
+ else:
+ self.buf_obs[key][env_idx] = obs[key]
+
+ def _obs_from_buf(self) -> VecEnvObs:
+ return dict_to_obs(self.observation_space, copy_obs_dict(self.buf_obs))
+
+ def get_attr(self, attr_name: str, indices: VecEnvIndices = None) -> List[Any]:
+ """Return attribute from vectorized environment (see base class)."""
+ target_envs = self._get_target_envs(indices)
+ return [getattr(env_i, attr_name) for env_i in target_envs]
+
+ def set_attr(self, attr_name: str, value: Any, indices: VecEnvIndices = None) -> None:
+ """Set attribute inside vectorized environments (see base class)."""
+ target_envs = self._get_target_envs(indices)
+ for env_i in target_envs:
+ setattr(env_i, attr_name, value)
+
+ def env_method(self, method_name: str, *method_args, indices: VecEnvIndices = None, **method_kwargs) -> List[Any]:
+ """Call instance methods of vectorized environments."""
+ target_envs = self._get_target_envs(indices)
+ return [getattr(env_i, method_name)(*method_args, **method_kwargs) for env_i in target_envs]
+
+ def env_is_wrapped(self, wrapper_class: Type[gym.Wrapper], indices: VecEnvIndices = None) -> List[bool]:
+ """Check if worker environments are wrapped with a given wrapper"""
+ target_envs = self._get_target_envs(indices)
+ # Import here to avoid a circular import
+ from stable_baselines3.common import env_util
+
+ return [env_util.is_wrapped(env_i, wrapper_class) for env_i in target_envs]
+
+ def _get_target_envs(self, indices: VecEnvIndices) -> List[gym.Env]:
+ indices = self._get_indices(indices)
+ return [self.envs[i] for i in indices]
diff --git a/dexart-release/stable_baselines3/common/vec_env/maniskill2_utils_common.py b/dexart-release/stable_baselines3/common/vec_env/maniskill2_utils_common.py
new file mode 100644
index 0000000000000000000000000000000000000000..b6b90c0de331a975f73a6ab4cf5e5bfcbdc835b8
--- /dev/null
+++ b/dexart-release/stable_baselines3/common/vec_env/maniskill2_utils_common.py
@@ -0,0 +1,233 @@
+from collections import defaultdict, OrderedDict
+from typing import Dict, Sequence
+
+import gym
+import numpy as np
+from gym import spaces
+
+
+
+# -------------------------------------------------------------------------- #
+# Basic
+# -------------------------------------------------------------------------- #
+def merge_dicts(ds: Sequence[Dict], asarray=False):
+ """Merge multiple dicts with the same keys to a single one."""
+ # NOTE(jigu): To be compatible with generator, we only iterate once.
+ ret = defaultdict(list)
+ for d in ds:
+ for k in d:
+ ret[k].append(d[k])
+ ret = dict(ret)
+ # Sanity check (length)
+ assert len(set(len(v) for v in ret.values())) == 1, "Keys are not same."
+ if asarray:
+ ret = {k: np.concatenate(v) for k, v in ret.items()}
+ return ret
+
+
+# -------------------------------------------------------------------------- #
+# Numpy
+# -------------------------------------------------------------------------- #
+def normalize_vector(x, eps=1e-6):
+ x = np.asarray(x)
+ assert x.ndim == 1, x.ndim
+ norm = np.linalg.norm(x)
+ return np.zeros_like(x) if norm < eps else (x / norm)
+
+
+def compute_angle_between(x1, x2):
+ """Compute angle (radian) between two vectors."""
+ x1, x2 = normalize_vector(x1), normalize_vector(x2)
+ dot_prod = np.clip(np.dot(x1, x2), -1, 1)
+ return np.arccos(dot_prod).item()
+
+
+class np_random:
+ """Context manager for numpy random state"""
+
+ def __init__(self, seed):
+ self.seed = seed
+ self.state = None
+
+ def __enter__(self):
+ self.state = np.random.get_state()
+ np.random.seed(self.seed)
+ return self.state
+
+ def __exit__(self, exc_type, exc_val, exc_tb):
+ np.random.set_state(self.state)
+
+
+def random_choice(x: Sequence, rng: np.random.RandomState = np.random):
+ assert len(x) > 0
+ if len(x) == 1:
+ return x[0]
+ else:
+ return x[rng.randint(len(x))]
+
+
+def get_dtype_bounds(dtype: np.dtype):
+ if np.issubdtype(dtype, np.floating):
+ info = np.finfo(dtype)
+ return info.min, info.max
+ elif np.issubdtype(dtype, np.integer):
+ info = np.iinfo(dtype)
+ return info.min, info.max
+ elif np.issubdtype(dtype, np.bool_):
+ return 0, 1
+ else:
+ raise TypeError(dtype)
+
+
+# ---------------------------------------------------------------------------- #
+# OpenAI gym
+# ---------------------------------------------------------------------------- #
+def convert_observation_to_space(observation, prefix=""):
+ """Convert observation to OpenAI gym observation space (recursively).
+ Modified from `gym.envs.mujoco_env`
+ """
+ if isinstance(observation, (dict)):
+ space = spaces.Dict(
+ {
+ k: convert_observation_to_space(v, prefix + "/" + k)
+ for k, v in observation.items()
+ }
+ )
+ elif isinstance(observation, np.ndarray):
+ shape = observation.shape
+ dtype = observation.dtype
+ low, high = get_dtype_bounds(dtype)
+ if np.issubdtype(dtype, np.floating):
+ low, high = -np.inf, np.inf
+ space = spaces.Box(low, high, shape=shape, dtype=dtype)
+ elif isinstance(observation, (float, np.float32, np.float64)):
+ print(f"The observation ({prefix}) is a (float) scalar")
+ space = spaces.Box(-np.inf, np.inf, shape=[1], dtype=np.float32)
+ elif isinstance(observation, (int, np.int32, np.int64)):
+ print(f"The observation ({prefix}) is a (integer) scalar")
+ space = spaces.Box(-np.inf, np.inf, shape=[1], dtype=int)
+ elif isinstance(observation, (bool, np.bool_)):
+ print(f"The observation ({prefix}) is a (bool) scalar")
+ space = spaces.Box(0, 1, shape=[1], dtype=np.bool_)
+ else:
+ raise NotImplementedError(type(observation), observation)
+
+ return space
+
+
+def normalize_action_space(action_space: spaces.Box):
+ assert isinstance(action_space, spaces.Box), type(action_space)
+ return spaces.Box(-1, 1, shape=action_space.shape, dtype=action_space.dtype)
+
+
+def clip_and_scale_action(action, low, high):
+ """Clip action to [-1, 1] and scale according to a range [low, high]."""
+ low, high = np.asarray(low), np.asarray(high)
+ action = np.clip(action, -1, 1)
+ return 0.5 * (high + low) + 0.5 * (high - low) * action
+
+
+def inv_clip_and_scale_action(action, low, high):
+ """Inverse of `clip_and_scale_action`."""
+ low, high = np.asarray(low), np.asarray(high)
+ action = (action - 0.5 * (high + low)) / (0.5 * (high - low))
+ return np.clip(action, -1.0, 1.0)
+
+
+def inv_scale_action(action, low, high):
+ """Inverse of `clip_and_scale_action` without clipping."""
+ low, high = np.asarray(low), np.asarray(high)
+ return (action - 0.5 * (high + low)) / (0.5 * (high - low))
+
+
+def flatten_state_dict(state_dict: dict) -> np.ndarray:
+ """Flatten a dictionary containing states recursively.
+ Args:
+ state_dict: a dictionary containing scalars or 1-dim vectors.
+ Raises:
+ AssertionError: If a value of @state_dict is an ndarray with ndim > 2.
+ Returns:
+ np.ndarray: flattened states.
+ Notes:
+ The input is recommended to be ordered (e.g. OrderedDict).
+ However, since python 3.7, dictionary order is guaranteed to be insertion order.
+ """
+ states = []
+ for key, value in state_dict.items():
+ if isinstance(value, dict):
+ state = flatten_state_dict(value)
+ if state.size == 0:
+ state = None
+ elif isinstance(value, (tuple, list)):
+ state = None if len(value) == 0 else value
+ elif isinstance(value, (bool, np.bool_, int, np.int32, np.int64)):
+ # x = np.array(1) > 0 is np.bool_ instead of ndarray
+ state = int(value)
+ elif isinstance(value, (float, np.float32, np.float64)):
+ state = np.float32(value)
+ elif isinstance(value, np.ndarray):
+ if value.ndim > 2:
+ raise AssertionError(
+ "The dimension of {} should not be more than 2.".format(key)
+ )
+ state = value if value.size > 0 else None
+ else:
+ raise TypeError("Unsupported type: {}".format(type(value)))
+ if state is not None:
+ states.append(state)
+ if len(states) == 0:
+ return np.empty(0)
+ else:
+ return np.hstack(states)
+
+
+def flatten_dict_keys(d: dict, prefix=""):
+ """Flatten a dict by expanding its keys recursively."""
+ out = dict()
+ for k, v in d.items():
+ if isinstance(v, dict):
+ out.update(flatten_dict_keys(v, prefix + k + "/"))
+ else:
+ out[prefix + k] = v
+ return out
+
+
+def extract_scalars_from_info(info: dict, blacklist=()) -> Dict[str, float]:
+ """Recursively extract scalar metrics from info dict.
+ Args:
+ info (dict): info dict
+ blacklist (tuple, optional): keys to exclude.
+ Returns:
+ Dict[str, float]: scalar metrics
+ """
+ ret = {}
+ for k, v in info.items():
+ if k in blacklist:
+ continue
+
+ # Ignore placeholder
+ if v is None:
+ continue
+
+ # Recursively extract scalars
+ elif isinstance(v, dict):
+ ret2 = extract_scalars_from_info(v, blacklist=blacklist)
+ ret2 = {f"{k}.{k2}": v2 for k2, v2 in ret2.items()}
+ ret2 = {k2: v2 for k2, v2 in ret2.items() if k2 not in blacklist}
+
+ # Things that are scalar-like will have an np.size of 1.
+ # Strings also have an np.size of 1, so explicitly ban those
+ elif np.size(v) == 1 and not isinstance(v, str):
+ ret[k] = float(v)
+ return ret
+
+
+def flatten_dict_space_keys(space: spaces.Dict, prefix="") -> spaces.Dict:
+ """Flatten a dict of spaces by expanding its keys recursively."""
+ out = OrderedDict()
+ for k, v in space.spaces.items():
+ if isinstance(v, spaces.Dict):
+ out.update(flatten_dict_space_keys(v, prefix + k + "/").spaces)
+ else:
+ out[prefix + k] = v
+ return spaces.Dict(out)
\ No newline at end of file
diff --git a/dexart-release/stable_baselines3/common/vec_env/maniskill2_utils_wrappers_obs.py b/dexart-release/stable_baselines3/common/vec_env/maniskill2_utils_wrappers_obs.py
new file mode 100644
index 0000000000000000000000000000000000000000..c989378ac6684aa9cbfaf96e195b42fe4e8587a0
--- /dev/null
+++ b/dexart-release/stable_baselines3/common/vec_env/maniskill2_utils_wrappers_obs.py
@@ -0,0 +1,233 @@
+from collections import OrderedDict
+from copy import deepcopy
+from typing import Sequence
+
+import gym
+import numpy as np
+from gym import spaces
+
+from maniskill2_utils_common import (flatten_dict_keys, flatten_dict_space_keys, merge_dicts)
+
+
+class RGBDObservationWrapper(gym.ObservationWrapper):
+ """Map raw textures (Color and Position) to rgb and depth."""
+
+ def __init__(self, env):
+ super().__init__(env)
+ self.observation_space = deepcopy(env.observation_space)
+ self.update_observation_space(self.observation_space)
+
+ @staticmethod
+ def update_observation_space(space: spaces.Dict):
+ # Update image observation space
+ image_space: spaces.Dict = space.spaces["image"]
+ for cam_uid in image_space:
+ ori_cam_space = image_space[cam_uid]
+ new_cam_space = OrderedDict()
+ for key in ori_cam_space:
+ if key == "Color":
+ height, width = ori_cam_space[key].shape[:2]
+ new_cam_space["rgb"] = spaces.Box(
+ low=0, high=255, shape=(height, width, 3), dtype=np.uint8
+ )
+ elif key == "Position":
+ height, width = ori_cam_space[key].shape[:2]
+ new_cam_space["depth"] = spaces.Box(
+ low=0, high=np.inf, shape=(height, width, 1), dtype=np.float32
+ )
+ else:
+ new_cam_space[key] = ori_cam_space[key]
+ image_space.spaces[cam_uid] = spaces.Dict(new_cam_space)
+
+ def observation(self, observation: dict):
+ image_obs = observation["image"]
+ for cam_uid, ori_images in image_obs.items():
+ new_images = OrderedDict()
+ for key in ori_images:
+ if key == "Color":
+ rgb = ori_images[key][..., :3] # [H, W, 4]
+ rgb = np.clip(rgb * 255, 0, 255).astype(np.uint8)
+ new_images["rgb"] = rgb # [H, W, 4]
+ elif key == "Position":
+ depth = -ori_images[key][..., [2]] # [H, W, 1]
+ new_images["depth"] = depth
+ else:
+ new_images[key] = ori_images[key]
+ image_obs[cam_uid] = new_images
+ return observation
+
+
+def merge_dict_spaces(dict_spaces: Sequence[spaces.Dict]):
+ reverse_spaces = merge_dicts([x.spaces for x in dict_spaces])
+ for key in reverse_spaces:
+ low, high = [], []
+ for x in reverse_spaces[key]:
+ assert isinstance(x, spaces.Box), type(x)
+ low.append(x.low)
+ high.append(x.high)
+ low = np.concatenate(low)
+ high = np.concatenate(high)
+ new_space = spaces.Box(low=low, high=high, dtype=low.dtype)
+ reverse_spaces[key] = new_space
+ return spaces.Dict(OrderedDict(reverse_spaces))
+
+
+class PointCloudObservationWrapper(gym.ObservationWrapper):
+ """Convert Position textures to world-space point cloud."""
+
+ def __init__(self, env):
+ super().__init__(env)
+ self.observation_space = deepcopy(env.observation_space)
+ self.update_observation_space(self.observation_space)
+ self._buffer = {}
+
+ @staticmethod
+ def update_observation_space(space: spaces.Dict):
+ # Replace image observation spaces with point cloud ones
+ image_space: spaces.Dict = space.spaces.pop("image")
+ space.spaces.pop("camera_param")
+ pcd_space = OrderedDict()
+
+ for cam_uid in image_space:
+ cam_image_space = image_space[cam_uid]
+ cam_pcd_space = OrderedDict()
+
+ h, w = cam_image_space["Position"].shape[:2]
+ cam_pcd_space["xyzw"] = spaces.Box(
+ low=-np.inf, high=np.inf, shape=(h * w, 4), dtype=np.float32
+ )
+
+ # Extra keys
+ if "Color" in cam_image_space.spaces:
+ cam_pcd_space["rgb"] = spaces.Box(
+ low=0, high=255, shape=(h * w, 3), dtype=np.uint8
+ )
+ if "Segmentation" in cam_image_space.spaces:
+ cam_pcd_space["Segmentation"] = spaces.Box(
+ low=0, high=(2 ** 32 - 1), shape=(h * w, 4), dtype=np.uint32
+ )
+
+ pcd_space[cam_uid] = spaces.Dict(cam_pcd_space)
+
+ pcd_space = merge_dict_spaces(pcd_space.values())
+ space.spaces["pointcloud"] = pcd_space
+
+ def observation(self, observation: dict):
+ image_obs = observation.pop("image")
+ camera_params = observation.pop("camera_param")
+ pointcloud_obs = OrderedDict()
+
+ for cam_uid, images in image_obs.items():
+ cam_pcd = {}
+
+ # Each pixel is (x, y, z, z_buffer_depth) in OpenGL camera space
+ position = images["Position"]
+ # position[..., 3] = position[..., 3] < 1
+ position[..., 3] = position[..., 2] < 0
+
+ # Convert to world space
+ cam2world = camera_params[cam_uid]["cam2world_gl"]
+ xyzw = position.reshape(-1, 4) @ cam2world.T
+ cam_pcd["xyzw"] = xyzw
+
+ # Extra keys
+ if "Color" in images:
+ rgb = images["Color"][..., :3]
+ rgb = np.clip(rgb * 255, 0, 255).astype(np.uint8)
+ cam_pcd["rgb"] = rgb.reshape(-1, 3)
+ if "Segmentation" in images:
+ cam_pcd["Segmentation"] = images["Segmentation"].reshape(-1, 4)
+
+ pointcloud_obs[cam_uid] = cam_pcd
+
+ pointcloud_obs = merge_dicts(pointcloud_obs.values())
+ for key, value in pointcloud_obs.items():
+ buffer = self._buffer.get(key, None)
+ pointcloud_obs[key] = np.concatenate(value, out=buffer)
+ self._buffer[key] = pointcloud_obs[key]
+
+ observation["pointcloud"] = pointcloud_obs
+ return observation
+
+
+class RobotSegmentationObservationWrapper(gym.ObservationWrapper):
+ """Add a binary mask for robot links."""
+
+ def __init__(self, env, replace=True):
+ super().__init__(env)
+ self.observation_space = deepcopy(env.observation_space)
+ self.init_observation_space(self.observation_space, replace=replace)
+ self.replace = replace
+ # Cache robot link ids
+ self.robot_link_ids = self.env.robot_link_ids
+
+ @staticmethod
+ def init_observation_space(space: spaces.Dict, replace: bool):
+ # Update image observation spaces
+ if "image" in space.spaces:
+ image_space = space["image"]
+ for cam_uid in image_space:
+ cam_space = image_space[cam_uid]
+ if "Segmentation" not in cam_space.spaces:
+ continue
+ height, width = cam_space["Segmentation"].shape[:2]
+ new_space = spaces.Box(
+ low=0, high=1, shape=(height, width, 1), dtype="bool"
+ )
+ if replace:
+ cam_space.spaces.pop("Segmentation")
+ cam_space.spaces["robot_seg"] = new_space
+
+ # Update pointcloud observation spaces
+ if "pointcloud" in space.spaces:
+ pcd_space = space["pointcloud"]
+ if "Segmentation" in pcd_space.spaces:
+ n = pcd_space["Segmentation"].shape[0]
+ new_space = spaces.Box(low=0, high=1, shape=(n, 1), dtype="bool")
+ if replace:
+ pcd_space.spaces.pop("Segmentation")
+ pcd_space.spaces["robot_seg"] = new_space
+
+ def reset(self, **kwargs):
+ observation = self.env.reset(**kwargs)
+ self.robot_link_ids = self.env.robot_link_ids
+ return self.observation(observation)
+
+ def observation_image(self, observation: dict):
+ image_obs = observation["image"]
+ for cam_images in image_obs.values():
+ if "Segmentation" not in cam_images:
+ continue
+ seg = cam_images["Segmentation"]
+ robot_seg = np.isin(seg[..., 1:2], self.robot_link_ids)
+ if self.replace:
+ cam_images.pop("Segmentation")
+ cam_images["robot_seg"] = robot_seg
+ return observation
+
+ def observation_pointcloud(self, observation: dict):
+ pointcloud_obs = observation["pointcloud"]
+ if "Segmentation" not in pointcloud_obs:
+ return observation
+ seg = pointcloud_obs["Segmentation"]
+ robot_seg = np.isin(seg[..., 1:2], self.robot_link_ids)
+ if self.replace:
+ pointcloud_obs.pop("Segmentation")
+ pointcloud_obs["robot_seg"] = robot_seg
+ return observation
+
+ def observation(self, observation: dict):
+ if "image" in observation:
+ observation = self.observation_image(observation)
+ if "pointcloud" in observation:
+ observation = self.observation_pointcloud(observation)
+ return observation
+
+
+class FlattenObservationWrapper(gym.ObservationWrapper):
+ def __init__(self, env) -> None:
+ super().__init__(env)
+ self.observation_space = flatten_dict_space_keys(self.observation_space)
+
+ def observation(self, observation):
+ return flatten_dict_keys(observation)
diff --git a/dexart-release/stable_baselines3/common/vec_env/maniskill2_vec_env.py b/dexart-release/stable_baselines3/common/vec_env/maniskill2_vec_env.py
new file mode 100644
index 0000000000000000000000000000000000000000..dee1c7165fca39dd996c68faed9d65e3b348a5aa
--- /dev/null
+++ b/dexart-release/stable_baselines3/common/vec_env/maniskill2_vec_env.py
@@ -0,0 +1,607 @@
+"""ManiSkill2 vectorized environment.
+
+See also:
+ https://github.com/DLR-RM/stable-baselines3/blob/master/stable_baselines3/common/vec_env/subproc_vec_env.py
+"""
+
+import multiprocessing as mp
+import os
+from collections import defaultdict
+from copy import deepcopy
+from functools import partial
+from multiprocessing.connection import Connection
+from typing import Callable, Dict, List, Optional, Sequence, Type, Union
+
+import gym
+import numpy as np
+import sapien.core as sapien
+from gym import spaces
+from gym.vector.utils.shared_memory import *
+
+try:
+ import torch
+except ImportError:
+ raise ImportError("To use ManiSkill2 VecEnv, please install PyTorch first.")
+
+# from mani_skill2 import logger
+# from mani_skill2.envs.sapien_env import BaseEnv
+from gym import Env as BaseEnv
+
+def find_available_port():
+ # https://stackoverflow.com/questions/1365265/on-localhost-how-do-i-pick-a-free-port-number
+ import socket
+
+ with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
+ s.bind(("", 0))
+ port = s.getsockname()[1]
+ server_address = f"localhost:{port}"
+ return server_address
+
+
+def _worker(
+ rank: int,
+ remote: Connection,
+ parent_remote: Connection,
+ env_fn: Callable[..., BaseEnv],
+):
+ # NOTE(jigu): Set environment variables for ManiSkill2
+ os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
+ os.environ["MKL_NUM_THREADS"] = "1"
+ os.environ["NUMEXPR_NUM_THREADS"] = "1"
+ os.environ["OMP_NUM_THREADS"] = "1"
+
+ parent_remote.close()
+
+ try:
+ env = env_fn()
+ while True:
+ cmd, data = remote.recv()
+ if cmd == "step":
+ obs, reward, done, info = env.step(data)
+ remote.send((obs, reward, done, info))
+ elif cmd == "reset":
+ obs = env.reset()
+ remote.send(obs)
+ elif cmd == "close":
+ remote.close()
+ break
+ elif cmd == "env_method":
+ method = getattr(env, data[0])
+ remote.send(method(*data[1], **data[2]))
+ elif cmd == "get_attr":
+ remote.send(getattr(env, data))
+ elif cmd == "set_attr":
+ remote.send(setattr(env, data[0], data[1]))
+ elif cmd == "handshake":
+ remote.send(None)
+ else:
+ raise NotImplementedError(f"`{cmd}` is not implemented in the worker")
+ except KeyboardInterrupt:
+ print("Worker KeyboardInterrupt")
+ except EOFError:
+ print("Worker EOF")
+ except Exception as err:
+ print(err)
+ finally:
+ env.close()
+
+
+class VecEnv:
+ """Vectorized environment modified from Stable Baselines3 for ManiSkill2.
+ Image observations can stay on GPU to avoid unnecessary data transfer.
+
+ Creates a multiprocess vectorized wrapper for multiple environments, distributing each environment to its own
+ process, allowing significant speed up when the environment is computationally complex.
+
+ For performance reasons, if your environment is not IO bound, the number of environments should not exceed the
+ number of logical cores on your CPU.
+
+ .. warning::
+
+ Only 'forkserver' and 'spawn' start methods are thread-safe,
+ which is important when TensorFlow sessions or other non thread-safe
+ libraries are used in the parent (see issue #217). However, compared to
+ 'fork' they incur a small start-up cost and have restrictions on
+ global variables. With those methods, users must wrap the code in an
+ ``if __name__ == "__main__":`` block.
+ For more information, see the multiprocessing documentation.
+
+ .. warning::
+ The tensor observations are buffered. Make a copy to avoid from overwriting them.
+
+ :param env_fns: Environments to run in subprocesses
+ :param start_method: method used to start the subprocesses.
+ Must be one of the methods returned by multiprocessing.get_all_start_methods().
+ Defaults to 'forkserver' on available platforms, and 'spawn' otherwise.
+ :param server_address: The network address of the SAPIEN RenderServer.
+ If "auto", the server will be created automatically at an avaiable port.
+ Otherwise, it should be a networkd address, e.g. "localhost:12345".
+ :param server_kwargs: keyword arguments for sapien.RenderServer
+ """
+
+ device: torch.device
+
+ def __init__(
+ self,
+ env_fns: List[Callable[[], BaseEnv]],
+ start_method: Optional[str] = None,
+ server_address: str = "auto",
+ server_kwargs: dict = None,
+ ):
+ self.waiting = False
+ self.closed = False
+
+ if start_method is None:
+ # Fork is not a thread safe method (see issue #217)
+ # but is more user friendly (does not require to wrap the code in
+ # a `if __name__ == "__main__":`)
+ forkserver_available = "forkserver" in mp.get_all_start_methods()
+ start_method = "forkserver" if forkserver_available else "spawn"
+ ctx = mp.get_context(start_method)
+
+ # ---------------------------------------------------------------------------- #
+ # Acquire observation space to construct buffer
+ # NOTE(jigu): Use a separate process to avoid creating sapien resources in the src process
+ remote, work_remote = ctx.Pipe()
+ args = (0, work_remote, remote, env_fns[0])
+ process = ctx.Process(target=_worker, args=args, daemon=True)
+ process.start()
+ work_remote.close()
+ remote.send(("get_attr", "observation_space"))
+ self.observation_space: spaces.Dict = remote.recv()
+ remote.send(("get_attr", "action_space"))
+ self.action_space: spaces.Space = remote.recv()
+ remote.send(("close", None))
+ remote.close()
+ process.join()
+ # ---------------------------------------------------------------------------- #
+
+ n_envs = len(env_fns)
+ self.num_envs = n_envs
+
+ # Allocate numpy buffers
+ self.non_image_obs_space = deepcopy(self.observation_space)
+ self.image_obs_space = self.non_image_obs_space.spaces.pop("image")
+ self._last_obs_np = [None for _ in range(n_envs)]
+ self._obs_np_buffer = create_np_buffer(self.non_image_obs_space, n=n_envs)
+
+ # Start RenderServer
+ if server_address == "auto":
+ server_address = find_available_port()
+ self.server_address = server_address
+ server_kwargs = {} if server_kwargs is None else server_kwargs
+ self.server = sapien.RenderServer(**server_kwargs)
+ self.server.start(self.server_address)
+ print(f"RenderServer is running at: {server_address}")
+
+ # Wrap env_fn
+ for i, env_fn in enumerate(env_fns):
+ client_kwargs = {"address": self.server_address, "process_index": i}
+ env_fns[i] = partial(
+ env_fn, renderer="client", renderer_kwargs=client_kwargs
+ )
+
+ # Initialize workers
+ self.remotes, self.work_remotes = zip(*[ctx.Pipe() for _ in range(n_envs)])
+ self.processes = []
+ for rank in range(n_envs):
+ work_remote = self.work_remotes[rank]
+ remote = self.remotes[rank]
+ env_fn = env_fns[rank]
+ args = (rank, work_remote, remote, env_fn)
+ # daemon=True: if the src process crashes, we should not cause things to hang
+ process = ctx.Process(
+ target=_worker, args=args, daemon=True
+ ) # pytype:disable=attribute-error
+ process.start()
+ self.processes.append(process)
+ work_remote.close()
+
+ # To make sure environments are initialized in all workers
+ for remote in self.remotes:
+ remote.send(("handshake", None))
+ for remote in self.remotes:
+ remote.recv()
+
+ # Infer texture names
+ texture_names = set()
+ for cam_space in self.image_obs_space.spaces.values():
+ texture_names.update(cam_space.spaces.keys())
+ self.texture_names = tuple(texture_names)
+
+ # Allocate torch buffers
+ # A list of [n_envs, n_cams, H, W, C] tensors
+ self._obs_torch_buffer: List[
+ torch.Tensor
+ ] = self.server.auto_allocate_torch_tensors(self.texture_names)
+ self.device = self._obs_torch_buffer[0].device
+
+ # ---------------------------------------------------------------------------- #
+ # Observations
+ # ---------------------------------------------------------------------------- #
+ def _update_np_buffer(self, obs_list, indices=None):
+ indices = self._get_indices(indices)
+ for i, obs in zip(indices, obs_list):
+ self._last_obs_np[i] = obs
+ return stack_obs(
+ self._last_obs_np, self.non_image_obs_space, self._obs_np_buffer
+ )
+
+ @torch.no_grad()
+ def _get_torch_observations(self):
+ self.server.wait_all()
+
+ tensor_dict = {}
+ for i, name in enumerate(self.texture_names):
+ tensor_dict[name] = self._obs_torch_buffer[i]
+
+ # NOTE(jigu): Efficiency might not be optimized when using more cameras
+ image_obs = {}
+ for cam_idx, cam_uid in enumerate(self.image_obs_space.spaces.keys()):
+ image_obs[cam_uid] = {}
+ cam_space = self.image_obs_space[cam_uid]
+ for tex_name in cam_space:
+ tensor = tensor_dict[tex_name][:, cam_idx] # [B, H, W, C]
+ if tensor.shape[1:3] != cam_space[tex_name].shape[0:2]:
+ h, w = cam_space[tex_name].shape[0:2]
+ tensor = tensor[:, :h, :w]
+ image_obs[cam_uid][tex_name] = tensor
+
+ return dict(image=image_obs)
+
+ # ---------------------------------------------------------------------------- #
+ # Interfaces
+ # ---------------------------------------------------------------------------- #
+ def seed(self, seed: Optional[int] = None) -> List[Union[None, int]]:
+ if seed is None:
+ seed = np.random.randint(0, 2**32)
+ for idx, remote in enumerate(self.remotes):
+ remote.send(("env_method", ("seed", [seed + idx], {})))
+ return [remote.recv() for remote in self.remotes]
+
+ def reset_async(self, indices=None):
+ remotes = self._get_target_remotes(indices)
+ for remote in remotes:
+ remote.send(("reset", None))
+ self.waiting = True
+
+ def reset_wait(self, indices=None):
+ remotes = self._get_target_remotes(indices)
+ results = [remote.recv() for remote in remotes]
+ self.waiting = False
+ vec_obs = self._get_torch_observations()
+ self._update_np_buffer(results, indices)
+ vec_obs.update(deepcopy(self._obs_np_buffer))
+ return vec_obs
+
+ def reset(self, indices=None):
+ self.reset_async(indices=indices)
+ return self.reset_wait(indices=indices)
+
+ def step_async(self, actions: np.ndarray) -> None:
+ for remote, action in zip(self.remotes, actions):
+ remote.send(("step", action))
+ self.waiting = True
+
+ def step_wait(self):
+ results = [remote.recv() for remote in self.remotes]
+ self.waiting = False
+ obs_list, rews, dones, infos = zip(*results)
+ vec_obs = self._get_torch_observations()
+ self._update_np_buffer(obs_list)
+ vec_obs.update(deepcopy(self._obs_np_buffer))
+ return vec_obs, np.array(rews), np.array(dones), infos
+
+ def step(self, actions):
+ self.step_async(actions)
+ return self.step_wait()
+
+ def close(self) -> None:
+ if self.closed:
+ return
+ if self.waiting:
+ for remote in self.remotes:
+ remote.recv()
+ for remote in self.remotes:
+ remote.send(("close", None))
+ for process in self.processes:
+ process.join()
+ self.closed = True
+
+ def render(self, mode=""):
+ raise NotImplementedError
+
+ def get_attr(self, attr_name: str, indices=None) -> List:
+ """Return attribute from vectorized environment (see base class)."""
+ target_remotes = self._get_target_remotes(indices)
+ for remote in target_remotes:
+ remote.send(("get_attr", attr_name))
+ return [remote.recv() for remote in target_remotes]
+
+ def set_attr(self, attr_name: str, value, indices=None) -> None:
+ """Set attribute inside vectorized environments (see base class)."""
+ target_remotes = self._get_target_remotes(indices)
+ for remote in target_remotes:
+ remote.send(("set_attr", (attr_name, value)))
+ for remote in target_remotes:
+ remote.recv()
+
+ def env_method(
+ self,
+ method_name: str,
+ *method_args,
+ indices=None,
+ **method_kwargs,
+ ) -> List:
+ """Call instance methods of vectorized environments."""
+ target_remotes = self._get_target_remotes(indices)
+ for remote in target_remotes:
+ remote.send(("env_method", (method_name, method_args, method_kwargs)))
+ return [remote.recv() for remote in target_remotes]
+
+ def env_is_wrapped(
+ self, wrapper_class: Type[gym.Wrapper], indices=None
+ ) -> List[bool]:
+ """Check if environments are wrapped with a given wrapper."""
+ target_remotes = self._get_target_remotes(indices)
+ for remote in target_remotes:
+ remote.send(("is_wrapped", wrapper_class))
+ return [remote.recv() for remote in target_remotes]
+
+ @property
+ def unwrapped(self) -> "VecEnv":
+ if isinstance(self, VecEnvWrapper):
+ return self.venv.unwrapped
+ else:
+ return self
+
+ def _get_indices(self, indices) -> List[int]:
+ """
+ Convert a flexibly-typed reference to environment indices to an implied list of indices.
+
+ :param indices: refers to indices of envs.
+ :return: the implied list of indices.
+ """
+ if indices is None:
+ indices = list(range(self.num_envs))
+ elif isinstance(indices, int):
+ indices = [indices]
+ return indices
+
+ def _get_target_remotes(self, indices) -> List[Connection]:
+ """
+ Get the connection object needed to communicate with the wanted
+ envs that are in subprocesses.
+
+ :param indices: refers to indices of envs.
+ :return: Connection object to communicate between processes.
+ """
+ indices = self._get_indices(indices)
+ return [self.remotes[i] for i in indices]
+
+ def __repr__(self):
+ return "{}({})".format(
+ self.__class__.__name__, self.env_method("__repr__", indices=0)[0]
+ )
+
+
+def stack_observation_space(space: spaces.Space, n: int):
+ if isinstance(space, spaces.Dict):
+ sub_spaces = [
+ (key, stack_observation_space(subspace, n))
+ for key, subspace in space.spaces.items()
+ ]
+ return spaces.Dict(sub_spaces)
+ elif isinstance(space, spaces.Box):
+ shape = (n,) + space.shape
+ low = np.broadcast_to(space.low, shape)
+ high = np.broadcast_to(space.high, shape)
+ return spaces.Box(low=low, high=high, shape=shape, dtype=space.dtype)
+ else:
+ raise NotImplementedError(
+ "Unsupported observation space: {}".format(type(space))
+ )
+
+
+def create_np_buffer(space: spaces.Space, n: int):
+ if isinstance(space, spaces.Dict):
+ return {
+ key: create_np_buffer(subspace, n) for key, subspace in space.spaces.items()
+ }
+ elif isinstance(space, spaces.Box):
+ return np.zeros((n,) + space.shape, dtype=space.dtype)
+ else:
+ raise NotImplementedError(
+ "Unsupported observation space: {}".format(type(space))
+ )
+
+
+def stack_obs(obs: Sequence, space: spaces.Space, buffer: Optional[np.ndarray] = None):
+ if isinstance(space, spaces.Dict):
+ ret = {}
+ for key in space:
+ _obs = [o[key] for o in obs]
+ _buffer = None if buffer is None else buffer[key]
+ ret[key] = stack_obs(_obs, space[key], buffer=_buffer)
+ return ret
+ elif isinstance(space, spaces.Box):
+ return np.stack(obs, out=buffer)
+ else:
+ raise NotImplementedError(type(space))
+
+
+class RGBDVecEnv(VecEnv):
+ def __init__(self, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+
+ from maniskill2_utils_wrappers_obs import RGBDObservationWrapper
+
+ RGBDObservationWrapper.update_observation_space(self.observation_space)
+
+ def _get_torch_observations(self):
+ observation = super()._get_torch_observations()
+
+ image_obs = observation["image"]
+ for cam_uid, ori_images in image_obs.items():
+ new_images = {}
+ for key in ori_images:
+ if key == "Color":
+ rgb = ori_images[key][..., :3]
+ rgb = torch.clamp(rgb * 255, 0, 255).to(dtype=torch.uint8)
+ new_images["rgb"] = rgb
+ elif key == "Position":
+ depth = -ori_images[key][..., [2]]
+ new_images["depth"] = depth
+ else:
+ new_images[key] = ori_images[key]
+ image_obs[cam_uid] = new_images
+ return observation
+
+
+class PointCloudVecEnv(VecEnv):
+ def __init__(self, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+
+ from maniskill2_utils_wrappers_obs import PointCloudObservationWrapper
+
+ PointCloudObservationWrapper.update_observation_space(self.observation_space)
+ self._buffer = {}
+
+ def _get_torch_observations(self):
+ observation = super()._get_torch_observations()
+
+ image_obs = observation.pop("image")
+ pointcloud_obs = {}
+
+ for cam_uid, cam_images in image_obs.items():
+ cam_pcd = {}
+
+ # Each pixel is (x, y, z, z_buffer_depth) in OpenGL camera space
+ position = cam_images["Position"]
+ bs = position.size(0)
+
+ # Homogeneous coordinates
+ xyzw = torch.cat([position[..., :3], position[..., [3]] < 1], dim=-1)
+ cam_pcd["xyzw"] = xyzw.reshape(bs, -1, 4)
+
+ if "Color" in cam_images:
+ rgb = cam_images["Color"][..., :3]
+ rgb = torch.clamp(rgb * 255, 0, 255).to(torch.uint8)
+ cam_pcd["rgb"] = rgb.reshape(bs, -1, 3)
+
+ if "Segmentation" in cam_images:
+ seg = cam_images["Segmentation"]
+ cam_pcd["Segmentation"] = seg.reshape(bs, -1, 4)
+
+ pointcloud_obs[cam_uid] = cam_pcd
+
+ observation["pointcloud"] = pointcloud_obs
+ return observation
+
+ @torch.no_grad()
+ def observation(self, observation: dict):
+ # Move camera parameters to device
+ camera_params = observation.pop("camera_param")
+ camera_params2 = {}
+ for cam_uid in camera_params:
+ cam2world = camera_params[cam_uid]["cam2world_gl"]
+ cam2world = torch.from_numpy(cam2world).to(
+ device=self.device, non_blocking=True
+ )
+ camera_params2[cam_uid] = cam2world
+
+ pointcloud_obs = observation["pointcloud"]
+ pointcloud_obs2 = defaultdict(list)
+ for cam_uid, cam_pcd in pointcloud_obs.items():
+ # Transform coordinates to world space
+ xyzw = cam_pcd["xyzw"] # [B, H*W, 4]
+ cam2world = camera_params2[cam_uid] # [B, 4, 4]
+ cam_pcd["xyzw"] = torch.bmm(xyzw, cam2world.transpose(1, 2))
+ for k, v in cam_pcd.items():
+ pointcloud_obs2[k].append(v)
+
+ for key, value in pointcloud_obs2.items():
+ buffer = self._buffer.get(key, None)
+ self._buffer[key] = torch.cat(value, dim=1, out=buffer)
+ pointcloud_obs2[key] = self._buffer[key]
+
+ observation["pointcloud"] = pointcloud_obs2
+ return observation
+
+ def reset_wait(self, *args, **kargs):
+ obs = super().reset_wait(*args, **kargs)
+ return self.observation(obs)
+
+ def step_wait(self):
+ obs, rews, dones, infos = super().step_wait()
+ return self.observation(obs), rews, dones, infos
+
+
+class VecEnvWrapper(VecEnv):
+ def __init__(self, venv: VecEnv):
+ self.venv = venv
+ self.num_envs = venv.num_envs
+ self.observation_space = venv.observation_space
+ self.action_space = venv.action_space
+
+ def seed(self, seed: Optional[int] = None):
+ return self.venv.seed(seed)
+
+ def reset_async(self, *args, **kwargs):
+ self.venv.reset_async(*args, **kwargs)
+
+ def reset_wait(self, *args, **kwargs):
+ return self.venv.reset_wait(*args, **kwargs)
+
+ def step_async(self, actions: np.ndarray):
+ self.venv.step_async(actions)
+
+ def step_wait(self):
+ return self.venv.step_wait()
+
+ def close(self):
+ return self.venv.close()
+
+ def render(self, mode=""):
+ return self.venv.render(mode)
+
+ def get_attr(self, attr_name: str, indices=None) -> List:
+ return self.venv.get_attr(attr_name, indices)
+
+ def set_attr(self, attr_name: str, value, indices=None) -> None:
+ return self.venv.set_attr(attr_name, value, indices)
+
+ def env_method(
+ self,
+ method_name: str,
+ *method_args,
+ indices=None,
+ **method_kwargs,
+ ) -> List:
+ return self.venv.env_method(
+ method_name, *method_args, indices=indices, **method_kwargs
+ )
+
+ def env_is_wrapped(
+ self, wrapper_class: Type[gym.Wrapper], indices=None
+ ) -> List[bool]:
+ return self.venv.env_is_wrapped(wrapper_class, indices)
+
+ def __getattr__(self, name):
+ if name in self.__dict__:
+ return self.__dict__[name]
+ else:
+ return getattr(self.venv, name)
+
+
+class VecEnvObservationWrapper(VecEnvWrapper):
+ def reset_wait(self, **kwargs):
+ observation = self.venv.reset_wait(**kwargs)
+ return self.observation(observation)
+
+ def step_wait(self):
+ observation, reward, done, info = self.venv.step_wait()
+ return self.observation(observation), reward, done, info
+
+ def observation(self, observation):
+ raise NotImplementedError
\ No newline at end of file
diff --git a/dexart-release/stable_baselines3/common/vec_env/maniskill2_wrapper_obs.py b/dexart-release/stable_baselines3/common/vec_env/maniskill2_wrapper_obs.py
new file mode 100644
index 0000000000000000000000000000000000000000..5607dc49acaa50f67aaa024a284cffee2369fe4d
--- /dev/null
+++ b/dexart-release/stable_baselines3/common/vec_env/maniskill2_wrapper_obs.py
@@ -0,0 +1,102 @@
+from copy import deepcopy
+
+import torch
+
+from maniskill2_vec_env import VecEnv, VecEnvObservationWrapper
+
+
+def batch_isin(x: torch.Tensor, inds: torch.Tensor):
+ """A batch version of `torch.isin`.
+ Args:
+ x (torch.Tensor): [B, ...], integer
+ inds (torch.Tensor): [B, N], integer
+ Returns:
+ torch.Tensor: [B, ...], boolean
+ """
+
+ # # For-loop version
+ # out = []
+ # for x_i, inds_i in zip(x.unbind(0), inds.unbind(0)):
+ # out.append(torch.isin(x_i, inds_i))
+ # out = torch.stack(out, dim=0)
+
+ bs = x.size(0)
+ max_ind, _ = torch.max(x.reshape(bs, -1), dim=-1) # [B]
+ # Acquire the maximum index of each sample in batch
+ offset = torch.cumsum(max_ind + 1, dim=0)
+ offset = torch.nn.functional.pad(offset[:-1], (1, 0), value=0)
+ # Add offset to avoid indexing collision
+ _shape = (bs,) + (1,) * (x.dim() - 1)
+ # Remap indices
+ _x = x + offset.view(_shape)
+ _inds = inds + offset.view(bs, 1)
+ return torch.isin(_x, _inds)
+
+
+class VecRobotSegmentationObservationWrapper(VecEnvObservationWrapper):
+ """Add a binary mask for robot links."""
+
+ def __init__(self, venv: VecEnv, replace=True):
+ super().__init__(venv)
+
+ from maniskill2_utils_wrappers_obs import (
+ RobotSegmentationObservationWrapper,
+ )
+
+ self.observation_space = deepcopy(venv.observation_space)
+ RobotSegmentationObservationWrapper.init_observation_space(
+ self.observation_space, replace=replace
+ )
+ self.replace = replace
+
+ # Cache robot link ids
+ # NOTE(jigu): Assume robots are the same and thus can be batched
+ robot_link_ids = self.get_attr("robot_link_ids")
+ self.robot_link_ids = torch.tensor(
+ robot_link_ids, dtype=torch.int32, device=self.device
+ )
+
+ @torch.no_grad()
+ def update_robot_link_ids(self, indices=None):
+ robot_link_ids = self.get_attr("robot_link_ids", indices=indices)
+ robot_link_ids = torch.tensor(robot_link_ids, dtype=torch.int32)
+ robot_link_ids = robot_link_ids.to(device=self.device, non_blocking=True)
+ indices = self._get_indices(indices)
+ self.robot_link_ids[indices] = robot_link_ids
+
+ def observation_image(self, observation: dict):
+ image_obs = observation["image"]
+ for cam_images in image_obs.values():
+ if "Segmentation" not in cam_images:
+ continue
+ seg = cam_images["Segmentation"] # [B, H, W, 4]
+ # [B, H, W, 1]
+ robot_seg = batch_isin(seg[..., 1:2], self.robot_link_ids)
+ if self.replace:
+ cam_images.pop("Segmentation")
+ cam_images["robot_seg"] = robot_seg
+ return observation
+
+ def observation_pointcloud(self, observation: dict):
+ pointcloud_obs = observation["pointcloud"]
+ if "Segmentation" not in pointcloud_obs:
+ return observation
+ seg = pointcloud_obs["Segmentation"] # [N, 4]
+ robot_seg = batch_isin(seg[..., 1:2], self.robot_link_ids) # [N, 1]
+ if self.replace:
+ pointcloud_obs.pop("Segmentation")
+ pointcloud_obs["robot_seg"] = robot_seg
+ return observation
+
+ @torch.no_grad()
+ def observation(self, observation: dict):
+ if "image" in observation:
+ observation = self.observation_image(observation)
+ if "pointcloud" in observation:
+ observation = self.observation_pointcloud(observation)
+ return observation
+
+ def reset_wait(self, indices=None, **kwargs):
+ obs = super().reset_wait(indices=indices, **kwargs)
+ self.update_robot_link_ids(indices=indices)
+ return self.observation(obs)
\ No newline at end of file
diff --git a/dexart-release/stable_baselines3/common/vec_env/stacked_observations.py b/dexart-release/stable_baselines3/common/vec_env/stacked_observations.py
new file mode 100644
index 0000000000000000000000000000000000000000..733b72833ea275eff50f09cd2a5a009e1884eab3
--- /dev/null
+++ b/dexart-release/stable_baselines3/common/vec_env/stacked_observations.py
@@ -0,0 +1,266 @@
+import warnings
+from typing import Any, Dict, List, Optional, Tuple, Union
+
+import numpy as np
+from gym import spaces
+
+from stable_baselines3.common.preprocessing import is_image_space, is_image_space_channels_first
+
+
+class StackedObservations:
+ """
+ Frame stacking wrapper for data.
+
+ Dimension to stack over is either first (channels-first) or
+ last (channels-last), which is detected automatically using
+ ``common.preprocessing.is_image_space_channels_first`` if
+ observation is an image space.
+
+ :param num_envs: number of environments
+ :param n_stack: Number of frames to stack
+ :param observation_space: Environment observation space.
+ :param channels_order: If "first", stack on first image dimension. If "last", stack on last dimension.
+ If None, automatically detect channel to stack over in case of image observation or default to "last" (default).
+ """
+
+ def __init__(
+ self,
+ num_envs: int,
+ n_stack: int,
+ observation_space: spaces.Space,
+ channels_order: Optional[str] = None,
+ ):
+
+ self.n_stack = n_stack
+ (
+ self.channels_first,
+ self.stack_dimension,
+ self.stackedobs,
+ self.repeat_axis,
+ ) = self.compute_stacking(num_envs, n_stack, observation_space, channels_order)
+ super().__init__()
+
+ @staticmethod
+ def compute_stacking(
+ num_envs: int,
+ n_stack: int,
+ observation_space: spaces.Box,
+ channels_order: Optional[str] = None,
+ ) -> Tuple[bool, int, np.ndarray, int]:
+ """
+ Calculates the parameters in order to stack observations
+
+ :param num_envs: Number of environments in the stack
+ :param n_stack: The number of observations to stack
+ :param observation_space: The observation space
+ :param channels_order: The order of the channels
+ :return: tuple of channels_first, stack_dimension, stackedobs, repeat_axis
+ """
+ channels_first = False
+ if channels_order is None:
+ # Detect channel location automatically for images
+ if is_image_space(observation_space):
+ channels_first = is_image_space_channels_first(observation_space)
+ else:
+ # Default behavior for non-image space, stack on the last axis
+ channels_first = False
+ else:
+ assert channels_order in {
+ "last",
+ "first",
+ }, "`channels_order` must be one of following: 'last', 'first'"
+
+ channels_first = channels_order == "first"
+
+ # This includes the vec-env dimension (first)
+ stack_dimension = 1 if channels_first else -1
+ repeat_axis = 0 if channels_first else -1
+ low = np.repeat(observation_space.low, n_stack, axis=repeat_axis)
+ stackedobs = np.zeros((num_envs,) + low.shape, low.dtype)
+ return channels_first, stack_dimension, stackedobs, repeat_axis
+
+ def stack_observation_space(self, observation_space: spaces.Box) -> spaces.Box:
+ """
+ Given an observation space, returns a new observation space with stacked observations
+
+ :return: New observation space with stacked dimensions
+ """
+ low = np.repeat(observation_space.low, self.n_stack, axis=self.repeat_axis)
+ high = np.repeat(observation_space.high, self.n_stack, axis=self.repeat_axis)
+ return spaces.Box(low=low, high=high, dtype=observation_space.dtype)
+
+ def reset(self, observation: np.ndarray) -> np.ndarray:
+ """
+ Resets the stackedobs, adds the reset observation to the stack, and returns the stack
+
+ :param observation: Reset observation
+ :return: The stacked reset observation
+ """
+ self.stackedobs[...] = 0
+ if self.channels_first:
+ self.stackedobs[:, -observation.shape[self.stack_dimension] :, ...] = observation
+ else:
+ self.stackedobs[..., -observation.shape[self.stack_dimension] :] = observation
+ return self.stackedobs
+
+ def update(
+ self,
+ observations: np.ndarray,
+ dones: np.ndarray,
+ infos: List[Dict[str, Any]],
+ ) -> Tuple[np.ndarray, List[Dict[str, Any]]]:
+ """
+ Adds the observations to the stack and uses the dones to update the infos.
+
+ :param observations: numpy array of observations
+ :param dones: numpy array of done info
+ :param infos: numpy array of info dicts
+ :return: tuple of the stacked observations and the updated infos
+ """
+ stack_ax_size = observations.shape[self.stack_dimension]
+ self.stackedobs = np.roll(self.stackedobs, shift=-stack_ax_size, axis=self.stack_dimension)
+ for i, done in enumerate(dones):
+ if done:
+ if "terminal_observation" in infos[i]:
+ old_terminal = infos[i]["terminal_observation"]
+ if self.channels_first:
+ new_terminal = np.concatenate(
+ (self.stackedobs[i, :-stack_ax_size, ...], old_terminal),
+ axis=0, # self.stack_dimension - 1, as there is not batch dim
+ )
+ else:
+ new_terminal = np.concatenate(
+ (self.stackedobs[i, ..., :-stack_ax_size], old_terminal),
+ axis=self.stack_dimension,
+ )
+ infos[i]["terminal_observation"] = new_terminal
+ else:
+ warnings.warn("VecFrameStack wrapping a VecEnv without terminal_observation info")
+ self.stackedobs[i] = 0
+ if self.channels_first:
+ self.stackedobs[:, -observations.shape[self.stack_dimension] :, ...] = observations
+ else:
+ self.stackedobs[..., -observations.shape[self.stack_dimension] :] = observations
+ return self.stackedobs, infos
+
+
+class StackedDictObservations(StackedObservations):
+ """
+ Frame stacking wrapper for dictionary data.
+
+ Dimension to stack over is either first (channels-first) or
+ last (channels-last), which is detected automatically using
+ ``common.preprocessing.is_image_space_channels_first`` if
+ observation is an image space.
+
+ :param num_envs: number of environments
+ :param n_stack: Number of frames to stack
+ :param channels_order: If "first", stack on first image dimension. If "last", stack on last dimension.
+ If None, automatically detect channel to stack over in case of image observation or default to "last" (default).
+ """
+
+ def __init__(
+ self,
+ num_envs: int,
+ n_stack: int,
+ observation_space: spaces.Dict,
+ channels_order: Optional[Union[str, Dict[str, str]]] = None,
+ ):
+ self.n_stack = n_stack
+ self.channels_first = {}
+ self.stack_dimension = {}
+ self.stackedobs = {}
+ self.repeat_axis = {}
+
+ for key, subspace in observation_space.spaces.items():
+ assert isinstance(subspace, spaces.Box), "StackedDictObservations only works with nested gym.spaces.Box"
+ if isinstance(channels_order, str) or channels_order is None:
+ subspace_channel_order = channels_order
+ else:
+ subspace_channel_order = channels_order[key]
+ (
+ self.channels_first[key],
+ self.stack_dimension[key],
+ self.stackedobs[key],
+ self.repeat_axis[key],
+ ) = self.compute_stacking(num_envs, n_stack, subspace, subspace_channel_order)
+
+ def stack_observation_space(self, observation_space: spaces.Dict) -> spaces.Dict:
+ """
+ Returns the stacked verson of a Dict observation space
+
+ :param observation_space: Dict observation space to stack
+ :return: stacked observation space
+ """
+ spaces_dict = {}
+ for key, subspace in observation_space.spaces.items():
+ low = np.repeat(subspace.low, self.n_stack, axis=self.repeat_axis[key])
+ high = np.repeat(subspace.high, self.n_stack, axis=self.repeat_axis[key])
+ spaces_dict[key] = spaces.Box(low=low, high=high, dtype=subspace.dtype)
+ return spaces.Dict(spaces=spaces_dict)
+
+ def reset(self, observation: Dict[str, np.ndarray]) -> Dict[str, np.ndarray]:
+ """
+ Resets the stacked observations, adds the reset observation to the stack, and returns the stack
+
+ :param observation: Reset observation
+ :return: Stacked reset observations
+ """
+ for key, obs in observation.items():
+ self.stackedobs[key][...] = 0
+ if self.channels_first[key]:
+ self.stackedobs[key][:, -obs.shape[self.stack_dimension[key]] :, ...] = obs
+ else:
+ self.stackedobs[key][..., -obs.shape[self.stack_dimension[key]] :] = obs
+ return self.stackedobs
+
+ def update(
+ self,
+ observations: Dict[str, np.ndarray],
+ dones: np.ndarray,
+ infos: List[Dict[str, Any]],
+ ) -> Tuple[Dict[str, np.ndarray], List[Dict[str, Any]]]:
+ """
+ Adds the observations to the stack and uses the dones to update the infos.
+
+ :param observations: Dict of numpy arrays of observations
+ :param dones: numpy array of dones
+ :param infos: dict of infos
+ :return: tuple of the stacked observations and the updated infos
+ """
+ for key in self.stackedobs.keys():
+ stack_ax_size = observations[key].shape[self.stack_dimension[key]]
+ self.stackedobs[key] = np.roll(
+ self.stackedobs[key],
+ shift=-stack_ax_size,
+ axis=self.stack_dimension[key],
+ )
+
+ for i, done in enumerate(dones):
+ if done:
+ if "terminal_observation" in infos[i]:
+ old_terminal = infos[i]["terminal_observation"][key]
+ if self.channels_first[key]:
+ new_terminal = np.vstack(
+ (
+ self.stackedobs[key][i, :-stack_ax_size, ...],
+ old_terminal,
+ )
+ )
+ else:
+ new_terminal = np.concatenate(
+ (
+ self.stackedobs[key][i, ..., :-stack_ax_size],
+ old_terminal,
+ ),
+ axis=self.stack_dimension[key],
+ )
+ infos[i]["terminal_observation"][key] = new_terminal
+ else:
+ warnings.warn("VecFrameStack wrapping a VecEnv without terminal_observation info")
+ self.stackedobs[key][i] = 0
+ if self.channels_first[key]:
+ self.stackedobs[key][:, -stack_ax_size:, ...] = observations[key]
+ else:
+ self.stackedobs[key][..., -stack_ax_size:] = observations[key]
+ return self.stackedobs, infos
diff --git a/dexart-release/stable_baselines3/common/vec_env/subproc_vec_env.py b/dexart-release/stable_baselines3/common/vec_env/subproc_vec_env.py
new file mode 100644
index 0000000000000000000000000000000000000000..a8a5f27964ce848ebd6ff96e4ecdb9d301c5d17d
--- /dev/null
+++ b/dexart-release/stable_baselines3/common/vec_env/subproc_vec_env.py
@@ -0,0 +1,221 @@
+import multiprocessing as mp
+from collections import OrderedDict
+from typing import Any, Callable, List, Optional, Sequence, Tuple, Type, Union
+
+import gym
+import numpy as np
+
+from stable_baselines3.common.vec_env.base_vec_env import (
+ CloudpickleWrapper,
+ VecEnv,
+ VecEnvIndices,
+ VecEnvObs,
+ VecEnvStepReturn,
+)
+
+
+def _worker(
+ remote: mp.connection.Connection, parent_remote: mp.connection.Connection, env_fn_wrapper: CloudpickleWrapper
+) -> None:
+ # Import here to avoid a circular import
+ from stable_baselines3.common.env_util import is_wrapped
+
+ parent_remote.close()
+ env = env_fn_wrapper.var()
+ while True:
+ try:
+ cmd, data = remote.recv()
+ if cmd == "step":
+ observation, reward, done, info = env.step(data)
+ if done:
+ # save final observation where user can get it, then reset
+ info["terminal_observation"] = observation
+ observation = env.reset()
+ remote.send((observation, reward, done, info))
+ elif cmd == "seed":
+ remote.send(env.seed(data))
+ elif cmd == "reset":
+ observation = env.reset()
+ remote.send(observation)
+ elif cmd == "render":
+ remote.send(env.render(data))
+ elif cmd == "close":
+ env.close()
+ remote.close()
+ break
+ elif cmd == "get_spaces":
+ remote.send((env.observation_space, env.action_space))
+ elif cmd == "env_method":
+ method = getattr(env, data[0])
+ remote.send(method(*data[1], **data[2]))
+ elif cmd == "get_attr":
+ remote.send(getattr(env, data))
+ elif cmd == "set_attr":
+ remote.send(setattr(env, data[0], data[1]))
+ elif cmd == "is_wrapped":
+ remote.send(is_wrapped(env, data))
+ else:
+ raise NotImplementedError(f"`{cmd}` is not implemented in the worker")
+ except EOFError:
+ break
+
+class SubprocVecEnv(VecEnv):
+ """
+ Creates a multiprocess vectorized wrapper for multiple environments, distributing each environment to its own
+ process, allowing significant speed up when the environment is computationally complex.
+
+ For performance reasons, if your environment is not IO bound, the number of environments should not exceed the
+ number of logical cores on your CPU.
+
+ .. warning::
+
+ Only 'forkserver' and 'spawn' start methods are thread-safe,
+ which is important when TensorFlow sessions or other non thread-safe
+ libraries are used in the parent (see issue #217). However, compared to
+ 'fork' they incur a small start-up cost and have restrictions on
+ global variables. With those methods, users must wrap the code in an
+ ``if __name__ == "__main__":`` block.
+ For more information, see the multiprocessing documentation.
+
+ :param env_fns: Environments to run in subprocesses
+ :param start_method: method used to start the subprocesses.
+ Must be one of the methods returned by multiprocessing.get_all_start_methods().
+ Defaults to 'forkserver' on available platforms, and 'spawn' otherwise.
+ """
+
+ def __init__(self, env_fns: List[Callable[[], gym.Env]], start_method: Optional[str] = None):
+ self.waiting = False
+ self.closed = False
+ n_envs = len(env_fns)
+
+ if start_method is None:
+ # Fork is not a thread safe method (see issue #217)
+ # but is more user friendly (does not require to wrap the code in
+ # a `if __name__ == "__main__":`)
+ forkserver_available = "forkserver" in mp.get_all_start_methods()
+ start_method = "forkserver" if forkserver_available else "spawn"
+ ctx = mp.get_context(start_method)
+
+ self.remotes, self.work_remotes = zip(*[ctx.Pipe() for _ in range(n_envs)])
+ self.processes = []
+ for work_remote, remote, env_fn in zip(self.work_remotes, self.remotes, env_fns):
+ args = (work_remote, remote, CloudpickleWrapper(env_fn))
+ # daemon=True: if the src process crashes, we should not cause things to hang
+ process = ctx.Process(target=_worker, args=args, daemon=True) # pytype:disable=attribute-error
+ process.start()
+ self.processes.append(process)
+ work_remote.close()
+
+ self.remotes[0].send(("get_spaces", None))
+ observation_space, action_space = self.remotes[0].recv()
+ VecEnv.__init__(self, len(env_fns), observation_space, action_space)
+
+ def step_async(self, actions: np.ndarray) -> None:
+ for remote, action in zip(self.remotes, actions):
+ remote.send(("step", action))
+ self.waiting = True
+
+ def step_wait(self) -> VecEnvStepReturn:
+ results = [remote.recv() for remote in self.remotes]
+ self.waiting = False
+ obs, rews, dones, infos = zip(*results)
+ return _flatten_obs(obs, self.observation_space), np.stack(rews), np.stack(dones), infos
+
+ def seed(self, seed: Optional[int] = None) -> List[Union[None, int]]:
+ if seed is None:
+ seed = np.random.randint(0, 2**32 - 1)
+ for idx, remote in enumerate(self.remotes):
+ remote.send(("seed", seed + idx))
+ return [remote.recv() for remote in self.remotes]
+
+ def reset(self) -> VecEnvObs:
+ for remote in self.remotes:
+ remote.send(("reset", None))
+ obs = [remote.recv() for remote in self.remotes]
+ return _flatten_obs(obs, self.observation_space)
+
+ def close(self) -> None:
+ if self.closed:
+ return
+ if self.waiting:
+ for remote in self.remotes:
+ remote.recv()
+ for remote in self.remotes:
+ remote.send(("close", None))
+ for process in self.processes:
+ process.join()
+ self.closed = True
+
+ def get_images(self) -> Sequence[np.ndarray]:
+ for pipe in self.remotes:
+ # gather images from subprocesses
+ # `mode` will be taken into account later
+ pipe.send(("render", "rgb_array"))
+ imgs = [pipe.recv() for pipe in self.remotes]
+ return imgs
+
+ def get_attr(self, attr_name: str, indices: VecEnvIndices = None) -> List[Any]:
+ """Return attribute from vectorized environment (see base class)."""
+ target_remotes = self._get_target_remotes(indices)
+ for remote in target_remotes:
+ remote.send(("get_attr", attr_name))
+ return [remote.recv() for remote in target_remotes]
+
+ def set_attr(self, attr_name: str, value: Any, indices: VecEnvIndices = None) -> None:
+ """Set attribute inside vectorized environments (see base class)."""
+ target_remotes = self._get_target_remotes(indices)
+ for remote in target_remotes:
+ remote.send(("set_attr", (attr_name, value)))
+ for remote in target_remotes:
+ remote.recv()
+
+ def env_method(self, method_name: str, *method_args, indices: VecEnvIndices = None, **method_kwargs) -> List[Any]:
+ """Call instance methods of vectorized environments."""
+ target_remotes = self._get_target_remotes(indices)
+ for remote in target_remotes:
+ remote.send(("env_method", (method_name, method_args, method_kwargs)))
+ return [remote.recv() for remote in target_remotes]
+
+ def env_is_wrapped(self, wrapper_class: Type[gym.Wrapper], indices: VecEnvIndices = None) -> List[bool]:
+ """Check if worker environments are wrapped with a given wrapper"""
+ target_remotes = self._get_target_remotes(indices)
+ for remote in target_remotes:
+ remote.send(("is_wrapped", wrapper_class))
+ return [remote.recv() for remote in target_remotes]
+
+ def _get_target_remotes(self, indices: VecEnvIndices) -> List[Any]:
+ """
+ Get the connection object needed to communicate with the wanted
+ envs that are in subprocesses.
+
+ :param indices: refers to indices of envs.
+ :return: Connection object to communicate between processes.
+ """
+ indices = self._get_indices(indices)
+ return [self.remotes[i] for i in indices]
+
+
+def _flatten_obs(obs: Union[List[VecEnvObs], Tuple[VecEnvObs]], space: gym.spaces.Space) -> VecEnvObs:
+ """
+ Flatten observations, depending on the observation space.
+
+ :param obs: observations.
+ A list or tuple of observations, one per environment.
+ Each environment observation may be a NumPy array, or a dict or tuple of NumPy arrays.
+ :return: flattened observations.
+ A flattened NumPy array or an OrderedDict or tuple of flattened numpy arrays.
+ Each NumPy array has the environment index as its first axis.
+ """
+ assert isinstance(obs, (list, tuple)), "expected list or tuple of observations per environment"
+ assert len(obs) > 0, "need observations from at least one environment"
+
+ if isinstance(space, gym.spaces.Dict):
+ assert isinstance(space.spaces, OrderedDict), "Dict space must have ordered subspaces"
+ assert isinstance(obs[0], dict), "non-dict observation for environment with Dict observation space"
+ return OrderedDict([(k, np.stack([o[k] for o in obs])) for k in space.spaces.keys()])
+ elif isinstance(space, gym.spaces.Tuple):
+ assert isinstance(obs[0], tuple), "non-tuple observation for environment with Tuple observation space"
+ obs_len = len(space.spaces)
+ return tuple(np.stack([o[i] for o in obs]) for i in range(obs_len))
+ else:
+ return np.stack(obs)
diff --git a/dexart-release/stable_baselines3/common/vec_env/util.py b/dexart-release/stable_baselines3/common/vec_env/util.py
new file mode 100644
index 0000000000000000000000000000000000000000..ca590cb1c81221cc2860490b50e74ef59c9736ba
--- /dev/null
+++ b/dexart-release/stable_baselines3/common/vec_env/util.py
@@ -0,0 +1,76 @@
+"""
+Helpers for dealing with vectorized environments.
+"""
+from collections import OrderedDict
+from typing import Any, Dict, List, Tuple
+
+import gym
+import numpy as np
+
+from stable_baselines3.common.preprocessing import check_for_nested_spaces
+from stable_baselines3.common.vec_env.base_vec_env import VecEnvObs
+
+
+def copy_obs_dict(obs: Dict[str, np.ndarray]) -> Dict[str, np.ndarray]:
+ """
+ Deep-copy a dict of numpy arrays.
+
+ :param obs: a dict of numpy arrays.
+ :return: a dict of copied numpy arrays.
+ """
+ assert isinstance(obs, OrderedDict), f"unexpected type for observations '{type(obs)}'"
+ return OrderedDict([(k, np.copy(v)) for k, v in obs.items()])
+
+
+def dict_to_obs(obs_space: gym.spaces.Space, obs_dict: Dict[Any, np.ndarray]) -> VecEnvObs:
+ """
+ Convert an internal representation raw_obs into the appropriate type
+ specified by space.
+
+ :param obs_space: an observation space.
+ :param obs_dict: a dict of numpy arrays.
+ :return: returns an observation of the same type as space.
+ If space is Dict, function is identity; if space is Tuple, converts dict to Tuple;
+ otherwise, space is unstructured and returns the value raw_obs[None].
+ """
+ if isinstance(obs_space, gym.spaces.Dict):
+ return obs_dict
+ elif isinstance(obs_space, gym.spaces.Tuple):
+ assert len(obs_dict) == len(obs_space.spaces), "size of observation does not match size of observation space"
+ return tuple(obs_dict[i] for i in range(len(obs_space.spaces)))
+ else:
+ assert set(obs_dict.keys()) == {None}, "multiple observation keys for unstructured observation space"
+ return obs_dict[None]
+
+
+def obs_space_info(obs_space: gym.spaces.Space) -> Tuple[List[str], Dict[Any, Tuple[int, ...]], Dict[Any, np.dtype]]:
+ """
+ Get dict-structured information about a gym.Space.
+
+ Dict spaces are represented directly by their dict of subspaces.
+ Tuple spaces are converted into a dict with keys indexing into the tuple.
+ Unstructured spaces are represented by {None: obs_space}.
+
+ :param obs_space: an observation space
+ :return: A tuple (keys, shapes, dtypes):
+ keys: a list of dict keys.
+ shapes: a dict mapping keys to shapes.
+ dtypes: a dict mapping keys to dtypes.
+ """
+ check_for_nested_spaces(obs_space)
+ if isinstance(obs_space, gym.spaces.Dict):
+ assert isinstance(obs_space.spaces, OrderedDict), "Dict space must have ordered subspaces"
+ subspaces = obs_space.spaces
+ elif isinstance(obs_space, gym.spaces.Tuple):
+ subspaces = {i: space for i, space in enumerate(obs_space.spaces)}
+ else:
+ assert not hasattr(obs_space, "spaces"), f"Unsupported structured space '{type(obs_space)}'"
+ subspaces = {None: obs_space}
+ keys = []
+ shapes = {}
+ dtypes = {}
+ for key, box in subspaces.items():
+ keys.append(key)
+ shapes[key] = box.shape
+ dtypes[key] = box.dtype
+ return keys, shapes, dtypes
diff --git a/dexart-release/stable_baselines3/common/vec_env/vec_check_nan.py b/dexart-release/stable_baselines3/common/vec_env/vec_check_nan.py
new file mode 100644
index 0000000000000000000000000000000000000000..258f9c26bef5e355d4db3ec7e08410a676e3e709
--- /dev/null
+++ b/dexart-release/stable_baselines3/common/vec_env/vec_check_nan.py
@@ -0,0 +1,86 @@
+import warnings
+
+import numpy as np
+
+from stable_baselines3.common.vec_env.base_vec_env import VecEnv, VecEnvObs, VecEnvStepReturn, VecEnvWrapper
+
+
+class VecCheckNan(VecEnvWrapper):
+ """
+ NaN and inf checking wrapper for vectorized environment, will raise a warning by default,
+ allowing you to know from what the NaN of inf originated from.
+
+ :param venv: the vectorized environment to wrap
+ :param raise_exception: Whether or not to raise a ValueError, instead of a UserWarning
+ :param warn_once: Whether or not to only warn once.
+ :param check_inf: Whether or not to check for +inf or -inf as well
+ """
+
+ def __init__(self, venv: VecEnv, raise_exception: bool = False, warn_once: bool = True, check_inf: bool = True):
+ VecEnvWrapper.__init__(self, venv)
+ self.raise_exception = raise_exception
+ self.warn_once = warn_once
+ self.check_inf = check_inf
+ self._actions = None
+ self._observations = None
+ self._user_warned = False
+
+ def step_async(self, actions: np.ndarray) -> None:
+ self._check_val(async_step=True, actions=actions)
+
+ self._actions = actions
+ self.venv.step_async(actions)
+
+ def step_wait(self) -> VecEnvStepReturn:
+ observations, rewards, news, infos = self.venv.step_wait()
+
+ self._check_val(async_step=False, observations=observations, rewards=rewards, news=news)
+
+ self._observations = observations
+ return observations, rewards, news, infos
+
+ def reset(self) -> VecEnvObs:
+ observations = self.venv.reset()
+ self._actions = None
+
+ self._check_val(async_step=False, observations=observations)
+
+ self._observations = observations
+ return observations
+
+ def _check_val(self, *, async_step: bool, **kwargs) -> None:
+ # if warn and warn once and have warned once: then stop checking
+ if not self.raise_exception and self.warn_once and self._user_warned:
+ return
+
+ found = []
+ for name, val in kwargs.items():
+ has_nan = np.any(np.isnan(val))
+ has_inf = self.check_inf and np.any(np.isinf(val))
+ if has_inf:
+ found.append((name, "inf"))
+ if has_nan:
+ found.append((name, "nan"))
+
+ if found:
+ self._user_warned = True
+ msg = ""
+ for i, (name, type_val) in enumerate(found):
+ msg += f"found {type_val} in {name}"
+ if i != len(found) - 1:
+ msg += ", "
+
+ msg += ".\r\nOriginated from the "
+
+ if not async_step:
+ if self._actions is None:
+ msg += "environment observation (at reset)"
+ else:
+ msg += f"environment, Last given value was: \r\n\taction={self._actions}"
+ else:
+ msg += f"RL model, Last given value was: \r\n\tobservations={self._observations}"
+
+ if self.raise_exception:
+ raise ValueError(msg)
+ else:
+ warnings.warn(msg, UserWarning)
diff --git a/dexart-release/stable_baselines3/common/vec_env/vec_extract_dict_obs.py b/dexart-release/stable_baselines3/common/vec_env/vec_extract_dict_obs.py
new file mode 100644
index 0000000000000000000000000000000000000000..8582b7a308c7d8695b8b93061e4119d5cdcf1f5b
--- /dev/null
+++ b/dexart-release/stable_baselines3/common/vec_env/vec_extract_dict_obs.py
@@ -0,0 +1,24 @@
+import numpy as np
+
+from stable_baselines3.common.vec_env.base_vec_env import VecEnv, VecEnvStepReturn, VecEnvWrapper
+
+
+class VecExtractDictObs(VecEnvWrapper):
+ """
+ A vectorized wrapper for extracting dictionary observations.
+
+ :param venv: The vectorized environment
+ :param key: The key of the dictionary observation
+ """
+
+ def __init__(self, venv: VecEnv, key: str):
+ self.key = key
+ super().__init__(venv=venv, observation_space=venv.observation_space.spaces[self.key])
+
+ def reset(self) -> np.ndarray:
+ obs = self.venv.reset()
+ return obs[self.key]
+
+ def step_wait(self) -> VecEnvStepReturn:
+ obs, reward, done, info = self.venv.step_wait()
+ return obs[self.key], reward, done, info
diff --git a/dexart-release/stable_baselines3/common/vec_env/vec_frame_stack.py b/dexart-release/stable_baselines3/common/vec_env/vec_frame_stack.py
new file mode 100644
index 0000000000000000000000000000000000000000..e06d5125e0f4b666af20a1b86ef53c63b635c1b9
--- /dev/null
+++ b/dexart-release/stable_baselines3/common/vec_env/vec_frame_stack.py
@@ -0,0 +1,64 @@
+from typing import Any, Dict, List, Optional, Tuple, Union
+
+import numpy as np
+from gym import spaces
+
+from stable_baselines3.common.vec_env.base_vec_env import VecEnv, VecEnvWrapper
+from stable_baselines3.common.vec_env.stacked_observations import StackedDictObservations, StackedObservations
+
+
+class VecFrameStack(VecEnvWrapper):
+ """
+ Frame stacking wrapper for vectorized environment. Designed for image observations.
+
+ Uses the StackedObservations class, or StackedDictObservations depending on the observations space
+
+ :param venv: the vectorized environment to wrap
+ :param n_stack: Number of frames to stack
+ :param channels_order: If "first", stack on first image dimension. If "last", stack on last dimension.
+ If None, automatically detect channel to stack over in case of image observation or default to "last" (default).
+ Alternatively channels_order can be a dictionary which can be used with environments with Dict observation spaces
+ """
+
+ def __init__(self, venv: VecEnv, n_stack: int, channels_order: Optional[Union[str, Dict[str, str]]] = None):
+ self.venv = venv
+ self.n_stack = n_stack
+
+ wrapped_obs_space = venv.observation_space
+
+ if isinstance(wrapped_obs_space, spaces.Box):
+ assert not isinstance(
+ channels_order, dict
+ ), f"Expected None or string for channels_order but received {channels_order}"
+ self.stackedobs = StackedObservations(venv.num_envs, n_stack, wrapped_obs_space, channels_order)
+
+ elif isinstance(wrapped_obs_space, spaces.Dict):
+ self.stackedobs = StackedDictObservations(venv.num_envs, n_stack, wrapped_obs_space, channels_order)
+
+ else:
+ raise Exception("VecFrameStack only works with gym.spaces.Box and gym.spaces.Dict observation spaces")
+
+ observation_space = self.stackedobs.stack_observation_space(wrapped_obs_space)
+ VecEnvWrapper.__init__(self, venv, observation_space=observation_space)
+
+ def step_wait(
+ self,
+ ) -> Tuple[Union[np.ndarray, Dict[str, np.ndarray]], np.ndarray, np.ndarray, List[Dict[str, Any]],]:
+
+ observations, rewards, dones, infos = self.venv.step_wait()
+
+ observations, infos = self.stackedobs.update(observations, dones, infos)
+
+ return observations, rewards, dones, infos
+
+ def reset(self) -> Union[np.ndarray, Dict[str, np.ndarray]]:
+ """
+ Reset all environments
+ """
+ observation = self.venv.reset() # pytype:disable=annotation-type-mismatch
+
+ observation = self.stackedobs.reset(observation)
+ return observation
+
+ def close(self) -> None:
+ self.venv.close()
diff --git a/dexart-release/stable_baselines3/common/vec_env/vec_monitor.py b/dexart-release/stable_baselines3/common/vec_env/vec_monitor.py
new file mode 100644
index 0000000000000000000000000000000000000000..ddc099a12d048eb9c4658f76569939f565d38c81
--- /dev/null
+++ b/dexart-release/stable_baselines3/common/vec_env/vec_monitor.py
@@ -0,0 +1,100 @@
+import time
+import warnings
+from typing import Optional, Tuple
+
+import numpy as np
+
+from stable_baselines3.common.vec_env.base_vec_env import VecEnv, VecEnvObs, VecEnvStepReturn, VecEnvWrapper
+
+
+class VecMonitor(VecEnvWrapper):
+ """
+ A vectorized monitor wrapper for *vectorized* Gym environments,
+ it is used to record the episode reward, length, time and other data.
+
+ Some environments like `openai/procgen `_
+ or `gym3 `_ directly initialize the
+ vectorized environments, without giving us a chance to use the ``Monitor``
+ wrapper. So this class simply does the job of the ``Monitor`` wrapper on
+ a vectorized level.
+
+ :param venv: The vectorized environment
+ :param filename: the location to save a log file, can be None for no log
+ :param info_keywords: extra information to log, from the information return of env.step()
+ """
+
+ def __init__(
+ self,
+ venv: VecEnv,
+ filename: Optional[str] = None,
+ info_keywords: Tuple[str, ...] = (),
+ ):
+ # Avoid circular import
+ from stable_baselines3.common.monitor import Monitor, ResultsWriter
+
+ # This check is not valid for special `VecEnv`
+ # like the ones created by Procgen, that does follow completely
+ # the `VecEnv` interface
+ try:
+ is_wrapped_with_monitor = venv.env_is_wrapped(Monitor)[0]
+ except AttributeError:
+ is_wrapped_with_monitor = False
+
+ if is_wrapped_with_monitor:
+ warnings.warn(
+ "The environment is already wrapped with a `Monitor` wrapper"
+ "but you are wrapping it with a `VecMonitor` wrapper, the `Monitor` statistics will be"
+ "overwritten by the `VecMonitor` ones.",
+ UserWarning,
+ )
+
+ VecEnvWrapper.__init__(self, venv)
+ self.episode_returns = None
+ self.episode_lengths = None
+ self.episode_count = 0
+ self.t_start = time.time()
+
+ env_id = None
+ if hasattr(venv, "spec") and venv.spec is not None:
+ env_id = venv.spec.id
+
+ if filename:
+ self.results_writer = ResultsWriter(
+ filename, header={"t_start": self.t_start, "env_id": env_id}, extra_keys=info_keywords
+ )
+ else:
+ self.results_writer = None
+ self.info_keywords = info_keywords
+
+ def reset(self) -> VecEnvObs:
+ obs = self.venv.reset()
+ self.episode_returns = np.zeros(self.num_envs, dtype=np.float32)
+ self.episode_lengths = np.zeros(self.num_envs, dtype=np.int32)
+ return obs
+
+ def step_wait(self) -> VecEnvStepReturn:
+ obs, rewards, dones, infos = self.venv.step_wait()
+ self.episode_returns += rewards
+ self.episode_lengths += 1
+ new_infos = list(infos[:])
+ for i in range(len(dones)):
+ if dones[i]:
+ info = infos[i].copy()
+ episode_return = self.episode_returns[i]
+ episode_length = self.episode_lengths[i]
+ episode_info = {"r": episode_return, "l": episode_length, "t": round(time.time() - self.t_start, 6)}
+ for key in self.info_keywords:
+ episode_info[key] = info[key]
+ info["episode"] = episode_info
+ self.episode_count += 1
+ self.episode_returns[i] = 0
+ self.episode_lengths[i] = 0
+ if self.results_writer:
+ self.results_writer.write_row(episode_info)
+ new_infos[i] = info
+ return obs, rewards, dones, new_infos
+
+ def close(self) -> None:
+ if self.results_writer:
+ self.results_writer.close()
+ return self.venv.close()
diff --git a/dexart-release/stable_baselines3/common/vec_env/vec_normalize.py b/dexart-release/stable_baselines3/common/vec_env/vec_normalize.py
new file mode 100644
index 0000000000000000000000000000000000000000..f3ee588aba47e49691d5fb9ae51e097e6bc1bb3e
--- /dev/null
+++ b/dexart-release/stable_baselines3/common/vec_env/vec_normalize.py
@@ -0,0 +1,296 @@
+import pickle
+import warnings
+from copy import deepcopy
+from typing import Any, Dict, List, Optional, Union
+
+import gym
+import numpy as np
+
+from stable_baselines3.common import utils
+from stable_baselines3.common.running_mean_std import RunningMeanStd
+from stable_baselines3.common.vec_env.base_vec_env import VecEnv, VecEnvStepReturn, VecEnvWrapper
+
+
+class VecNormalize(VecEnvWrapper):
+ """
+ A moving average, normalizing wrapper for vectorized environment.
+ has support for saving/loading moving average,
+
+ :param venv: the vectorized environment to wrap
+ :param training: Whether to update or not the moving average
+ :param norm_obs: Whether to normalize observation or not (default: True)
+ :param norm_reward: Whether to normalize rewards or not (default: True)
+ :param clip_obs: Max absolute value for observation
+ :param clip_reward: Max value absolute for discounted reward
+ :param gamma: discount factor
+ :param epsilon: To avoid division by zero
+ :param norm_obs_keys: Which keys from observation dict to normalize.
+ If not specified, all keys will be normalized.
+ """
+
+ def __init__(
+ self,
+ venv: VecEnv,
+ training: bool = True,
+ norm_obs: bool = True,
+ norm_reward: bool = True,
+ clip_obs: float = 10.0,
+ clip_reward: float = 10.0,
+ gamma: float = 0.99,
+ epsilon: float = 1e-8,
+ norm_obs_keys: Optional[List[str]] = None,
+ ):
+ VecEnvWrapper.__init__(self, venv)
+
+ self.norm_obs = norm_obs
+ self.norm_obs_keys = norm_obs_keys
+ # Check observation spaces
+ if self.norm_obs:
+ self._sanity_checks()
+
+ if isinstance(self.observation_space, gym.spaces.Dict):
+ self.obs_spaces = self.observation_space.spaces
+ self.obs_rms = {key: RunningMeanStd(shape=self.obs_spaces[key].shape) for key in self.norm_obs_keys}
+ else:
+ self.obs_spaces = None
+ self.obs_rms = RunningMeanStd(shape=self.observation_space.shape)
+
+ self.ret_rms = RunningMeanStd(shape=())
+ self.clip_obs = clip_obs
+ self.clip_reward = clip_reward
+ # Returns: discounted rewards
+ self.returns = np.zeros(self.num_envs)
+ self.gamma = gamma
+ self.epsilon = epsilon
+ self.training = training
+ self.norm_obs = norm_obs
+ self.norm_reward = norm_reward
+ self.old_obs = np.array([])
+ self.old_reward = np.array([])
+
+ def _sanity_checks(self) -> None:
+ """
+ Check the observations that are going to be normalized are of the correct type (spaces.Box).
+ """
+ if isinstance(self.observation_space, gym.spaces.Dict):
+ # By default, we normalize all keys
+ if self.norm_obs_keys is None:
+ self.norm_obs_keys = list(self.observation_space.spaces.keys())
+ # Check that all keys are of type Box
+ for obs_key in self.norm_obs_keys:
+ if not isinstance(self.observation_space.spaces[obs_key], gym.spaces.Box):
+ raise ValueError(
+ f"VecNormalize only supports `gym.spaces.Box` observation spaces but {obs_key} "
+ f"is of type {self.observation_space.spaces[obs_key]}. "
+ "You should probably explicitely pass the observation keys "
+ " that should be normalized via the `norm_obs_keys` parameter."
+ )
+
+ elif isinstance(self.observation_space, gym.spaces.Box):
+ if self.norm_obs_keys is not None:
+ raise ValueError("`norm_obs_keys` param is applicable only with `gym.spaces.Dict` observation spaces")
+
+ else:
+ raise ValueError(
+ "VecNormalize only supports `gym.spaces.Box` and `gym.spaces.Dict` observation spaces, "
+ f"not {self.observation_space}"
+ )
+
+ def __getstate__(self) -> Dict[str, Any]:
+ """
+ Gets state for pickling.
+
+ Excludes self.venv, as in general VecEnv's may not be pickleable."""
+ state = self.__dict__.copy()
+ # these attributes are not pickleable
+ del state["venv"]
+ del state["class_attributes"]
+ # these attributes depend on the above and so we would prefer not to pickle
+ del state["returns"]
+ return state
+
+ def __setstate__(self, state: Dict[str, Any]) -> None:
+ """
+ Restores pickled state.
+
+ User must call set_venv() after unpickling before using.
+
+ :param state:"""
+ # Backward compatibility
+ if "norm_obs_keys" not in state and isinstance(state["observation_space"], gym.spaces.Dict):
+ state["norm_obs_keys"] = list(state["observation_space"].spaces.keys())
+ self.__dict__.update(state)
+ assert "venv" not in state
+ self.venv = None
+
+ def set_venv(self, venv: VecEnv) -> None:
+ """
+ Sets the vector environment to wrap to venv.
+
+ Also sets attributes derived from this such as `num_env`.
+
+ :param venv:
+ """
+ if self.venv is not None:
+ raise ValueError("Trying to set venv of already initialized VecNormalize wrapper.")
+ VecEnvWrapper.__init__(self, venv)
+
+ # Check only that the observation_space match
+ utils.check_for_correct_spaces(venv, self.observation_space, venv.action_space)
+ self.returns = np.zeros(self.num_envs)
+
+ def step_wait(self) -> VecEnvStepReturn:
+ """
+ Apply sequence of actions to sequence of environments
+ actions -> (observations, rewards, dones)
+
+ where ``dones`` is a boolean vector indicating whether each element is new.
+ """
+ obs, rewards, dones, infos = self.venv.step_wait()
+ self.old_obs = obs
+ self.old_reward = rewards
+
+ if self.training and self.norm_obs:
+ if isinstance(obs, dict) and isinstance(self.obs_rms, dict):
+ for key in self.obs_rms.keys():
+ self.obs_rms[key].update(obs[key])
+ else:
+ self.obs_rms.update(obs)
+
+ obs = self.normalize_obs(obs)
+
+ if self.training:
+ self._update_reward(rewards)
+ rewards = self.normalize_reward(rewards)
+
+ # Normalize the terminal observations
+ for idx, done in enumerate(dones):
+ if not done:
+ continue
+ if "terminal_observation" in infos[idx]:
+ infos[idx]["terminal_observation"] = self.normalize_obs(infos[idx]["terminal_observation"])
+
+ self.returns[dones] = 0
+ return obs, rewards, dones, infos
+
+ def _update_reward(self, reward: np.ndarray) -> None:
+ """Update reward normalization statistics."""
+ self.returns = self.returns * self.gamma + reward
+ self.ret_rms.update(self.returns)
+
+ def _normalize_obs(self, obs: np.ndarray, obs_rms: RunningMeanStd) -> np.ndarray:
+ """
+ Helper to normalize observation.
+ :param obs:
+ :param obs_rms: associated statistics
+ :return: normalized observation
+ """
+ return np.clip((obs - obs_rms.mean) / np.sqrt(obs_rms.var + self.epsilon), -self.clip_obs, self.clip_obs)
+
+ def _unnormalize_obs(self, obs: np.ndarray, obs_rms: RunningMeanStd) -> np.ndarray:
+ """
+ Helper to unnormalize observation.
+ :param obs:
+ :param obs_rms: associated statistics
+ :return: unnormalized observation
+ """
+ return (obs * np.sqrt(obs_rms.var + self.epsilon)) + obs_rms.mean
+
+ def normalize_obs(self, obs: Union[np.ndarray, Dict[str, np.ndarray]]) -> Union[np.ndarray, Dict[str, np.ndarray]]:
+ """
+ Normalize observations using this VecNormalize's observations statistics.
+ Calling this method does not update statistics.
+ """
+ # Avoid modifying by reference the original object
+ obs_ = deepcopy(obs)
+ if self.norm_obs:
+ if isinstance(obs, dict) and isinstance(self.obs_rms, dict):
+ # Only normalize the specified keys
+ for key in self.norm_obs_keys:
+ obs_[key] = self._normalize_obs(obs[key], self.obs_rms[key]).astype(np.float32)
+ else:
+ obs_ = self._normalize_obs(obs, self.obs_rms).astype(np.float32)
+ return obs_
+
+ def normalize_reward(self, reward: np.ndarray) -> np.ndarray:
+ """
+ Normalize rewards using this VecNormalize's rewards statistics.
+ Calling this method does not update statistics.
+ """
+ if self.norm_reward:
+ reward = np.clip(reward / np.sqrt(self.ret_rms.var + self.epsilon), -self.clip_reward, self.clip_reward)
+ return reward
+
+ def unnormalize_obs(self, obs: Union[np.ndarray, Dict[str, np.ndarray]]) -> Union[np.ndarray, Dict[str, np.ndarray]]:
+ # Avoid modifying by reference the original object
+ obs_ = deepcopy(obs)
+ if self.norm_obs:
+ if isinstance(obs, dict) and isinstance(self.obs_rms, dict):
+ for key in self.norm_obs_keys:
+ obs_[key] = self._unnormalize_obs(obs[key], self.obs_rms[key])
+ else:
+ obs_ = self._unnormalize_obs(obs, self.obs_rms)
+ return obs_
+
+ def unnormalize_reward(self, reward: np.ndarray) -> np.ndarray:
+ if self.norm_reward:
+ return reward * np.sqrt(self.ret_rms.var + self.epsilon)
+ return reward
+
+ def get_original_obs(self) -> Union[np.ndarray, Dict[str, np.ndarray]]:
+ """
+ Returns an unnormalized version of the observations from the most recent
+ step or reset.
+ """
+ return deepcopy(self.old_obs)
+
+ def get_original_reward(self) -> np.ndarray:
+ """
+ Returns an unnormalized version of the rewards from the most recent step.
+ """
+ return self.old_reward.copy()
+
+ def reset(self) -> Union[np.ndarray, Dict[str, np.ndarray]]:
+ """
+ Reset all environments
+ :return: first observation of the episode
+ """
+ obs = self.venv.reset()
+ self.old_obs = obs
+ self.returns = np.zeros(self.num_envs)
+ if self.training and self.norm_obs:
+ if isinstance(obs, dict) and isinstance(self.obs_rms, dict):
+ for key in self.obs_rms.keys():
+ self.obs_rms[key].update(obs[key])
+ else:
+ self.obs_rms.update(obs)
+ return self.normalize_obs(obs)
+
+ @staticmethod
+ def load(load_path: str, venv: VecEnv) -> "VecNormalize":
+ """
+ Loads a saved VecNormalize object.
+
+ :param load_path: the path to load from.
+ :param venv: the VecEnv to wrap.
+ :return:
+ """
+ with open(load_path, "rb") as file_handler:
+ vec_normalize = pickle.load(file_handler)
+ vec_normalize.set_venv(venv)
+ return vec_normalize
+
+ def save(self, save_path: str) -> None:
+ """
+ Save current VecNormalize object with
+ all running statistics and settings (e.g. clip_obs)
+
+ :param save_path: The path to save to
+ """
+ with open(save_path, "wb") as file_handler:
+ pickle.dump(self, file_handler)
+
+ @property
+ def ret(self) -> np.ndarray:
+ warnings.warn("`VecNormalize` `ret` attribute is deprecated. Please use `returns` instead.", DeprecationWarning)
+ return self.returns
diff --git a/dexart-release/stable_baselines3/common/vec_env/vec_transpose.py b/dexart-release/stable_baselines3/common/vec_env/vec_transpose.py
new file mode 100644
index 0000000000000000000000000000000000000000..b6b0ad832f71a876628aa022645e1078940dc30e
--- /dev/null
+++ b/dexart-release/stable_baselines3/common/vec_env/vec_transpose.py
@@ -0,0 +1,113 @@
+from copy import deepcopy
+from typing import Dict, Union
+
+import numpy as np
+from gym import spaces
+
+from stable_baselines3.common.preprocessing import is_image_space, is_image_space_channels_first
+from stable_baselines3.common.vec_env.base_vec_env import VecEnv, VecEnvStepReturn, VecEnvWrapper
+
+
+class VecTransposeImage(VecEnvWrapper):
+ """
+ Re-order channels, from HxWxC to CxHxW.
+ It is required for PyTorch convolution layers.
+
+ :param venv:
+ :param skip: Skip this wrapper if needed as we rely on heuristic to apply it or not,
+ which may result in unwanted behavior, see GH issue #671.
+ """
+
+ def __init__(self, venv: VecEnv, skip: bool = False):
+ assert is_image_space(venv.observation_space) or isinstance(
+ venv.observation_space, spaces.dict.Dict
+ ), "The observation space must be an image or dictionary observation space"
+
+ self.skip = skip
+ # Do nothing
+ if skip:
+ super().__init__(venv)
+ return
+
+ if isinstance(venv.observation_space, spaces.dict.Dict):
+ self.image_space_keys = []
+ observation_space = deepcopy(venv.observation_space)
+ for key, space in observation_space.spaces.items():
+ if is_image_space(space):
+ # Keep track of which keys should be transposed later
+ self.image_space_keys.append(key)
+ observation_space.spaces[key] = self.transpose_space(space, key)
+ else:
+ observation_space = self.transpose_space(venv.observation_space)
+ super().__init__(venv, observation_space=observation_space)
+
+ @staticmethod
+ def transpose_space(observation_space: spaces.Box, key: str = "") -> spaces.Box:
+ """
+ Transpose an observation space (re-order channels).
+
+ :param observation_space:
+ :param key: In case of dictionary space, the key of the observation space.
+ :return:
+ """
+ # Sanity checks
+ assert is_image_space(observation_space), "The observation space must be an image"
+ assert not is_image_space_channels_first(
+ observation_space
+ ), f"The observation space {key} must follow the channel last convention"
+ height, width, channels = observation_space.shape
+ new_shape = (channels, height, width)
+ return spaces.Box(low=0, high=255, shape=new_shape, dtype=observation_space.dtype)
+
+ @staticmethod
+ def transpose_image(image: np.ndarray) -> np.ndarray:
+ """
+ Transpose an image or batch of images (re-order channels).
+
+ :param image:
+ :return:
+ """
+ if len(image.shape) == 3:
+ return np.transpose(image, (2, 0, 1))
+ return np.transpose(image, (0, 3, 1, 2))
+
+ def transpose_observations(self, observations: Union[np.ndarray, Dict]) -> Union[np.ndarray, Dict]:
+ """
+ Transpose (if needed) and return new observations.
+
+ :param observations:
+ :return: Transposed observations
+ """
+ # Do nothing
+ if self.skip:
+ return observations
+
+ if isinstance(observations, dict):
+ # Avoid modifying the original object in place
+ observations = deepcopy(observations)
+ for k in self.image_space_keys:
+ observations[k] = self.transpose_image(observations[k])
+ else:
+ observations = self.transpose_image(observations)
+ return observations
+
+ def step_wait(self) -> VecEnvStepReturn:
+ observations, rewards, dones, infos = self.venv.step_wait()
+
+ # Transpose the terminal observations
+ for idx, done in enumerate(dones):
+ if not done:
+ continue
+ if "terminal_observation" in infos[idx]:
+ infos[idx]["terminal_observation"] = self.transpose_observations(infos[idx]["terminal_observation"])
+
+ return self.transpose_observations(observations), rewards, dones, infos
+
+ def reset(self) -> Union[np.ndarray, Dict]:
+ """
+ Reset all environments
+ """
+ return self.transpose_observations(self.venv.reset())
+
+ def close(self) -> None:
+ self.venv.close()
diff --git a/dexart-release/stable_baselines3/common/vec_env/vec_video_recorder.py b/dexart-release/stable_baselines3/common/vec_env/vec_video_recorder.py
new file mode 100644
index 0000000000000000000000000000000000000000..70d74ebe4c981fd0c1182bee97ff175802bfadb3
--- /dev/null
+++ b/dexart-release/stable_baselines3/common/vec_env/vec_video_recorder.py
@@ -0,0 +1,113 @@
+import os
+from typing import Callable
+
+from gym.wrappers.monitoring import video_recorder
+
+from stable_baselines3.common.vec_env.base_vec_env import VecEnv, VecEnvObs, VecEnvStepReturn, VecEnvWrapper
+from stable_baselines3.common.vec_env.dummy_vec_env import DummyVecEnv
+from stable_baselines3.common.vec_env.subproc_vec_env import SubprocVecEnv
+
+
+class VecVideoRecorder(VecEnvWrapper):
+ """
+ Wraps a VecEnv or VecEnvWrapper object to record rendered image as mp4 video.
+ It requires ffmpeg or avconv to be installed on the machine.
+
+ :param venv:
+ :param video_folder: Where to save videos
+ :param record_video_trigger: Function that defines when to start recording.
+ The function takes the current number of step,
+ and returns whether we should start recording or not.
+ :param video_length: Length of recorded videos
+ :param name_prefix: Prefix to the video name
+ """
+
+ def __init__(
+ self,
+ venv: VecEnv,
+ video_folder: str,
+ record_video_trigger: Callable[[int], bool],
+ video_length: int = 200,
+ name_prefix: str = "rl-video",
+ ):
+
+ VecEnvWrapper.__init__(self, venv)
+
+ self.env = venv
+ # Temp variable to retrieve metadata
+ temp_env = venv
+
+ # Unwrap to retrieve metadata dict
+ # that will be used by gym recorder
+ while isinstance(temp_env, VecEnvWrapper):
+ temp_env = temp_env.venv
+
+ if isinstance(temp_env, DummyVecEnv) or isinstance(temp_env, SubprocVecEnv):
+ metadata = temp_env.get_attr("metadata")[0]
+ else:
+ metadata = temp_env.metadata
+
+ self.env.metadata = metadata
+
+ self.record_video_trigger = record_video_trigger
+ self.video_recorder = None
+
+ self.video_folder = os.path.abspath(video_folder)
+ # Create output folder if needed
+ os.makedirs(self.video_folder, exist_ok=True)
+
+ self.name_prefix = name_prefix
+ self.step_id = 0
+ self.video_length = video_length
+
+ self.recording = False
+ self.recorded_frames = 0
+
+ def reset(self) -> VecEnvObs:
+ obs = self.venv.reset()
+ self.start_video_recorder()
+ return obs
+
+ def start_video_recorder(self) -> None:
+ self.close_video_recorder()
+
+ video_name = f"{self.name_prefix}-step-{self.step_id}-to-step-{self.step_id + self.video_length}"
+ base_path = os.path.join(self.video_folder, video_name)
+ self.video_recorder = video_recorder.VideoRecorder(
+ env=self.env, base_path=base_path, metadata={"step_id": self.step_id}
+ )
+
+ self.video_recorder.capture_frame()
+ self.recorded_frames = 1
+ self.recording = True
+
+ def _video_enabled(self) -> bool:
+ return self.record_video_trigger(self.step_id)
+
+ def step_wait(self) -> VecEnvStepReturn:
+ obs, rews, dones, infos = self.venv.step_wait()
+
+ self.step_id += 1
+ if self.recording:
+ self.video_recorder.capture_frame()
+ self.recorded_frames += 1
+ if self.recorded_frames > self.video_length:
+ print(f"Saving video to {self.video_recorder.path}")
+ self.close_video_recorder()
+ elif self._video_enabled():
+ self.start_video_recorder()
+
+ return obs, rews, dones, infos
+
+ def close_video_recorder(self) -> None:
+ if self.recording:
+ self.video_recorder.close()
+ self.recording = False
+ self.recorded_frames = 1
+
+ def close(self) -> None:
+ VecEnvWrapper.close(self)
+ self.close_video_recorder()
+
+ def __del__(self):
+ self.close()
diff --git a/dexart-release/stable_baselines3/networks/pretrain_nets.py b/dexart-release/stable_baselines3/networks/pretrain_nets.py
new file mode 100644
index 0000000000000000000000000000000000000000..ff286a3ea19b0e2f96ba9404bc16062bb404c8fb
--- /dev/null
+++ b/dexart-release/stable_baselines3/networks/pretrain_nets.py
@@ -0,0 +1,118 @@
+import torch
+import torch.nn as nn
+
+
+class PointNet(nn.Module): # actually pointnet
+ def __init__(self, point_channel=3, output_dim=256):
+ # NOTE: we require the output dim to be 256, in order to match the pretrained weights
+ super(PointNet, self).__init__()
+
+ print(f'PointNetSmall')
+
+ in_channel = point_channel
+ mlp_out_dim = 256
+ self.local_mlp = nn.Sequential(
+ nn.Linear(in_channel, 64),
+ nn.GELU(),
+ nn.Linear(64, mlp_out_dim),
+ )
+ self.reset_parameters_()
+
+ def reset_parameters_(self):
+ for m in self.modules():
+ if isinstance(m, nn.Linear):
+ nn.init.trunc_normal_(m.weight, std=.02)
+ if m.bias is not None:
+ nn.init.zeros_(m.bias)
+
+ def forward(self, x):
+ '''
+ x: [B, N, 3]
+ '''
+ # pc = x[0].cpu().detach().numpy()
+ # Local
+ x = self.local_mlp(x)
+ # gloabal max pooling
+ x = torch.max(x, dim=1)[0]
+ return x
+
+
+class PointNetMedium(nn.Module): # actually pointnet
+ def __init__(self, point_channel=3, output_dim=256):
+ # NOTE: we require the output dim to be 256, in order to match the pretrained weights
+ super(PointNetMedium, self).__init__()
+
+ print(f'PointNetMedium')
+
+ in_channel = point_channel
+ mlp_out_dim = 256
+ self.local_mlp = nn.Sequential(
+ nn.Linear(in_channel, 64),
+ nn.GELU(),
+ nn.Linear(64, 64),
+ nn.GELU(),
+ nn.Linear(64, 128),
+ nn.GELU(),
+ nn.Linear(128, mlp_out_dim),
+ )
+ self.reset_parameters_()
+
+ def reset_parameters_(self):
+ for m in self.modules():
+ if isinstance(m, nn.Linear):
+ nn.init.trunc_normal_(m.weight, std=.02)
+ if m.bias is not None:
+ nn.init.zeros_(m.bias)
+
+ def forward(self, x):
+ '''
+ x: [B, N, 3]
+ '''
+ # Local
+ x = self.local_mlp(x)
+ # gloabal max pooling
+ x = torch.max(x, dim=1)[0]
+ return x
+
+
+class PointNetLarge(nn.Module): # actually pointnet
+ def __init__(self, point_channel=3, output_dim=256):
+ # NOTE: we require the output dim to be 256, in order to match the pretrained weights
+ super(PointNetLarge, self).__init__()
+
+ print(f'PointNetLarge')
+
+ in_channel = point_channel
+ mlp_out_dim = 256
+ self.local_mlp = nn.Sequential(
+ nn.Linear(in_channel, 64),
+ nn.GELU(),
+ nn.Linear(64, 64),
+ nn.GELU(),
+ nn.Linear(64, 128),
+ nn.GELU(),
+ nn.Linear(128, 128),
+ nn.GELU(),
+ nn.Linear(128, 256),
+ nn.GELU(),
+ nn.Linear(256, mlp_out_dim),
+ )
+
+ self.reset_parameters_()
+
+ def reset_parameters_(self):
+ for m in self.modules():
+ if isinstance(m, nn.Linear):
+ nn.init.trunc_normal_(m.weight, std=.02)
+ if m.bias is not None:
+ nn.init.zeros_(m.bias)
+
+ def forward(self, x):
+ '''
+ x: [B, N, 3]
+ '''
+ # Local
+ x = self.local_mlp(x)
+ # gloabal max pooling
+ x = torch.max(x, dim=1)[0]
+ return x
diff --git a/dexart-release/stable_baselines3/ppo/__init__.py b/dexart-release/stable_baselines3/ppo/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e5c23fc9c6d2f327befeb1b1b2df66847f3bff62
--- /dev/null
+++ b/dexart-release/stable_baselines3/ppo/__init__.py
@@ -0,0 +1,2 @@
+from stable_baselines3.ppo.policies import CnnPolicy, MlpPolicy, MultiInputPolicy
+from stable_baselines3.ppo.ppo import PPO
diff --git a/dexart-release/stable_baselines3/ppo/policies.py b/dexart-release/stable_baselines3/ppo/policies.py
new file mode 100644
index 0000000000000000000000000000000000000000..fb7afaef13be27ca853b7ef4087864f469a1e3dc
--- /dev/null
+++ b/dexart-release/stable_baselines3/ppo/policies.py
@@ -0,0 +1,7 @@
+# This file is here just to define MlpPolicy/CnnPolicy
+# that work for PPO
+from stable_baselines3.common.policies import ActorCriticCnnPolicy, ActorCriticPolicy, MultiInputActorCriticPolicy
+
+MlpPolicy = ActorCriticPolicy
+CnnPolicy = ActorCriticCnnPolicy
+MultiInputPolicy = MultiInputActorCriticPolicy
diff --git a/dexart-release/stable_baselines3/ppo/ppo.py b/dexart-release/stable_baselines3/ppo/ppo.py
new file mode 100644
index 0000000000000000000000000000000000000000..4fac910ccf3eba684015ec379a844fb4ffc3df09
--- /dev/null
+++ b/dexart-release/stable_baselines3/ppo/ppo.py
@@ -0,0 +1,358 @@
+import warnings
+from typing import Any, Dict, Optional, Type, Union
+
+import numpy as np
+import torch as th
+from gym import spaces
+from torch.nn import functional as F
+
+from stable_baselines3.common.on_policy_algorithm import OnPolicyAlgorithm
+from stable_baselines3.common.policies import ActorCriticCnnPolicy, ActorCriticPolicy, BasePolicy, \
+ MultiInputActorCriticPolicy
+from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, Schedule
+from stable_baselines3.common.utils import explained_variance, get_schedule_fn, update_learning_rate
+
+
+class AdaptiveScheduler:
+ def __init__(self, kl_threshold, min_lr, max_lr, init_lr):
+ super().__init__()
+ self.min_lr = min_lr
+ self.max_lr = max_lr
+ self.kl_threshold = kl_threshold
+ self.current_lr = init_lr
+
+ def update(self, kl_dist):
+ lr = self.current_lr
+ if kl_dist > (2.0 * self.kl_threshold):
+ lr = max(self.current_lr / 1.5, self.min_lr)
+ if kl_dist < (0.5 * self.kl_threshold):
+ lr = min(self.current_lr * 1.5, self.max_lr)
+ self.current_lr = lr
+ return lr
+
+
+class PPO(OnPolicyAlgorithm):
+ """
+ Proximal Policy Optimization algorithm (PPO) (clip version)
+
+ Paper: https://arxiv.org/abs/1707.06347
+ Code: This implementation borrows code from OpenAI Spinning Up (https://github.com/openai/spinningup/)
+ https://github.com/ikostrikov/pytorch-a2c-ppo-acktr-gail and
+ Stable Baselines (PPO2 from https://github.com/hill-a/stable-baselines)
+
+ Introduction to PPO: https://spinningup.openai.com/en/latest/algorithms/ppo.html
+
+ :param policy: The policy model to use (MlpPolicy, CnnPolicy, ...)
+ :param env: The environment to learn from (if registered in Gym, can be str)
+ :param learning_rate: The learning rate, it can be a function
+ of the current progress remaining (from 1 to 0)
+ :param n_steps: The number of steps to run for each environment per update
+ (i.e. rollout buffer size is n_steps * n_envs where n_envs is number of environment copies running in parallel)
+ NOTE: n_steps * n_envs must be greater than 1 (because of the advantage normalization)
+ See https://github.com/pytorch/pytorch/issues/29372
+ :param batch_size: Minibatch size
+ :param n_epochs: Number of epoch when optimizing the surrogate loss
+ :param gamma: Discount factor
+ :param gae_lambda: Factor for trade-off of bias vs variance for Generalized Advantage Estimator
+ :param clip_range: Clipping parameter, it can be a function of the current progress
+ remaining (from 1 to 0).
+ :param clip_range_vf: Clipping parameter for the value function,
+ it can be a function of the current progress remaining (from 1 to 0).
+ This is a parameter specific to the OpenAI implementation. If None is passed (default),
+ no clipping will be done on the value function.
+ IMPORTANT: this clipping depends on the reward scaling.
+ :param normalize_advantage: Whether to normalize or not the advantage
+ :param ent_coef: Entropy coefficient for the loss calculation
+ :param vf_coef: Value function coefficient for the loss calculation
+ :param max_grad_norm: The maximum value for the gradient clipping
+ :param use_sde: Whether to use generalized State Dependent Exploration (gSDE)
+ instead of action noise exploration (default: False)
+ :param sde_sample_freq: Sample a new noise matrix every n steps when using gSDE
+ Default: -1 (only sample at the beginning of the rollout)
+ :param target_kl: Limit the KL divergence between updates,
+ because the clipping is not enough to prevent large update
+ see issue #213 (cf https://github.com/hill-a/stable-baselines/issues/213)
+ By default, there is no limit on the kl div.
+ :param tensorboard_log: the log location for tensorboard (if None, no logging)
+ :param create_eval_env: Whether to create a second environment that will be
+ used for evaluating the agent periodically. (Only available when passing string for the environment)
+ :param policy_kwargs: additional arguments to be passed to the policy on creation
+ :param verbose: the verbosity level: 0 no output, 1 info, 2 debug
+ :param seed: Seed for the pseudo random generators
+ :param device: Device (cpu, cuda, ...) on which the code should be run.
+ Setting it to auto, the code will be run on the GPU if possible.
+ :param _init_setup_model: Whether or not to build the network at the creation of the instance
+ """
+
+ policy_aliases: Dict[str, Type[BasePolicy]] = {
+ "MlpPolicy": ActorCriticPolicy,
+ "PointCloudPolicy": ActorCriticPolicy,
+ "CnnPolicy": ActorCriticCnnPolicy,
+ "MultiInputPolicy": MultiInputActorCriticPolicy,
+ }
+
+ def __init__(
+ self,
+ policy: Union[str, Type[ActorCriticPolicy]],
+ env: Union[GymEnv, str],
+ learning_rate: Union[float, Schedule] = 3e-4,
+ n_steps: int = 2048,
+ batch_size: int = 64,
+ n_epochs: int = 10,
+ gamma: float = 0.99,
+ gae_lambda: float = 0.95,
+ clip_range: Union[float, Schedule] = 0.2,
+ clip_range_vf: Union[None, float, Schedule] = None,
+ normalize_advantage: bool = True,
+ ent_coef: float = 0.0,
+ vf_coef: float = 0.5,
+ max_grad_norm: float = 0.5,
+ use_sde: bool = False,
+ sde_sample_freq: int = -1,
+ target_kl: Optional[float] = None,
+ tensorboard_log: Optional[str] = None,
+ create_eval_env: bool = False,
+ policy_kwargs: Optional[Dict[str, Any]] = None,
+ verbose: int = 0,
+ seed: Optional[int] = None,
+ device: Union[th.device, str] = "auto",
+ _init_setup_model: bool = True,
+ adaptive_kl: float = 0.02,
+ min_lr=1e-4,
+ max_lr=1e-3,
+ ):
+
+ super().__init__(
+ policy,
+ env,
+ learning_rate=learning_rate,
+ n_steps=n_steps,
+ gamma=gamma,
+ gae_lambda=gae_lambda,
+ ent_coef=ent_coef,
+ vf_coef=vf_coef,
+ max_grad_norm=max_grad_norm,
+ use_sde=use_sde,
+ sde_sample_freq=sde_sample_freq,
+ tensorboard_log=tensorboard_log,
+ policy_kwargs=policy_kwargs,
+ verbose=verbose,
+ device=device,
+ create_eval_env=create_eval_env,
+ seed=seed,
+ _init_setup_model=False,
+ supported_action_spaces=(
+ spaces.Box,
+ spaces.Discrete,
+ spaces.MultiDiscrete,
+ spaces.MultiBinary,
+ ),
+ )
+
+ # Sanity check, otherwise it will lead to noisy gradient and NaN
+ # because of the advantage normalization
+ if normalize_advantage:
+ assert (
+ batch_size > 1
+ ), "`batch_size` must be greater than 1. See https://github.com/DLR-RM/stable-baselines3/issues/440"
+
+ if self.env is not None:
+ # Check that `n_steps * n_envs > 1` to avoid NaN
+ # when doing advantage normalization
+ buffer_size = self.env.num_envs * self.n_steps
+ assert (
+ buffer_size > 1
+ ), f"`n_steps * n_envs` must be greater than 1. Currently n_steps={self.n_steps} and n_envs={self.env.num_envs}"
+ # Check that the rollout buffer size is a multiple of the mini-batch size
+ untruncated_batches = buffer_size // batch_size
+ if buffer_size % batch_size > 0:
+ warnings.warn(
+ f"You have specified a mini-batch size of {batch_size},"
+ f" but because the `RolloutBuffer` is of size `n_steps * n_envs = {buffer_size}`,"
+ f" after every {untruncated_batches} untruncated mini-batches,"
+ f" there will be a truncated mini-batch of size {buffer_size % batch_size}\n"
+ f"We recommend using a `batch_size` that is a factor of `n_steps * n_envs`.\n"
+ f"Info: (n_steps={self.n_steps} and n_envs={self.env.num_envs})"
+ )
+ self.batch_size = batch_size
+ self.n_epochs = n_epochs
+ self.clip_range = clip_range
+ self.clip_range_vf = clip_range_vf
+ self.normalize_advantage = normalize_advantage
+ self.target_kl = target_kl
+ self.kl_scheduler = AdaptiveScheduler(kl_threshold=adaptive_kl, min_lr=min_lr, max_lr=max_lr,
+ init_lr=learning_rate)
+
+ if _init_setup_model:
+ self._setup_model()
+
+ def _setup_model(self) -> None:
+ super()._setup_model()
+
+ # Initialize schedules for policy/value clipping
+ self.clip_range = get_schedule_fn(self.clip_range)
+ if self.clip_range_vf is not None:
+ if isinstance(self.clip_range_vf, (float, int)):
+ assert self.clip_range_vf > 0, "`clip_range_vf` must be positive, " "pass `None` to deactivate vf clipping"
+
+ self.clip_range_vf = get_schedule_fn(self.clip_range_vf)
+
+ def train(self) -> None:
+ """
+ Update policy using the currently gathered rollout buffer.
+ """
+ # Switch to train mode (this affects batch norm / dropout)
+ self.policy.set_training_mode(True)
+ # Update optimizer learning rate
+ self.logger.record("train/learning_rate", self.kl_scheduler.current_lr)
+ # Compute current clip range
+ clip_range = self.clip_range(self._current_progress_remaining)
+ # Optional: clip range for the value function
+ if self.clip_range_vf is not None:
+ clip_range_vf = self.clip_range_vf(self._current_progress_remaining)
+
+ entropy_losses = []
+ pg_losses, value_losses = [], []
+ clip_fractions = []
+
+ continue_training = True
+
+ # train for n_epochs epochs
+ num_early_stopping = 0
+ num_batch_update = 0
+ for epoch in range(self.n_epochs):
+ approx_kl_divs = []
+ # Do a complete pass on the rollout buffer
+ for rollout_data in self.rollout_buffer.get(self.batch_size):
+ num_batch_update += 1
+ actions = rollout_data.actions
+ if isinstance(self.action_space, spaces.Discrete):
+ # Convert discrete action from float to long
+ actions = rollout_data.actions.long().flatten()
+
+ # Re-sample the noise matrix because the log_std has changed
+ if self.use_sde:
+ self.policy.reset_noise(self.batch_size)
+
+ values, log_prob, entropy = self.policy.evaluate_actions(rollout_data.observations, actions)
+ values = values.flatten()
+
+ # Calculate approximate form of reverse KL Divergence for early stopping
+ # see issue #417: https://github.com/DLR-RM/stable-baselines3/issues/417
+ # and discussion in PR #419: https://github.com/DLR-RM/stable-baselines3/pull/419
+ # and Schulman blog: http://joschu.net/blog/kl-approx.html
+ with th.no_grad():
+ log_ratio = log_prob - rollout_data.old_log_prob
+ approx_kl_div = th.mean((th.exp(log_ratio) - 1) - log_ratio).cpu().numpy()
+ approx_kl_divs.append(approx_kl_div)
+
+ n_iter = self._n_updates // self.n_epochs
+ scheduling = max(0.99 ** n_iter, 0.05) * 10
+ if self.target_kl is not None and approx_kl_div > 1.5 * self.target_kl * scheduling:
+ # continue_training = False
+ num_early_stopping += 1
+ # if self.verbose >= 1:
+ # print(f"Early stopping at step {epoch} due to reaching max kl: {approx_kl_div:.2f}")
+ continue
+
+ # Normalize advantage
+ advantages = rollout_data.advantages
+ if self.normalize_advantage:
+ advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8)
+
+ # ratio between old and new policy, should be one at the first iteration
+ ratio = th.exp(log_prob - rollout_data.old_log_prob)
+
+ # clipped surrogate loss
+ policy_loss_1 = advantages * ratio
+ policy_loss_2 = advantages * th.clamp(ratio, 1 - clip_range, 1 + clip_range)
+ policy_loss = -th.min(policy_loss_1, policy_loss_2).mean()
+
+ # Logging
+ pg_losses.append(policy_loss.item())
+ clip_fraction = th.mean((th.abs(ratio - 1) > clip_range).float()).item()
+ clip_fractions.append(clip_fraction)
+
+ if self.clip_range_vf is None:
+ # No clipping
+ values_pred = values
+ else:
+ # Clip the different between old and new value
+ # NOTE: this depends on the reward scaling
+ values_pred = rollout_data.old_values + th.clamp(
+ values - rollout_data.old_values, -clip_range_vf, clip_range_vf
+ )
+ # Value loss using the TD(gae_lambda) target
+ value_loss = F.mse_loss(rollout_data.returns, values_pred)
+ value_losses.append(value_loss.item())
+
+ # Entropy loss favor exploration
+ if entropy is None:
+ # Approximate entropy when no analytical form
+ entropy_loss = -th.mean(-log_prob)
+ else:
+ entropy_loss = -th.mean(entropy)
+
+ entropy_losses.append(entropy_loss.item())
+
+ loss = policy_loss + self.ent_coef * entropy_loss + self.vf_coef * value_loss
+
+ lr = self.kl_scheduler.update(approx_kl_div)
+ update_learning_rate(self.policy.optimizer, lr)
+
+ # Optimization step
+ self.policy.optimizer.zero_grad()
+ loss.backward()
+ # Clip grad norm
+ th.nn.utils.clip_grad_norm_(filter(lambda p: p.requires_grad, self.policy.parameters()), self.max_grad_norm)
+ self.policy.optimizer.step()
+
+ if not continue_training:
+ break
+ if self.verbose >= 1:
+ print(f"Early stopping / mini batch update: {num_early_stopping} / {num_batch_update}")
+ self._n_updates += self.n_epochs
+ explained_var = explained_variance(self.rollout_buffer.values.flatten(), self.rollout_buffer.returns.flatten())
+ # Logs
+ self.logger.record("train/entropy_loss", np.mean(entropy_losses))
+ self.logger.record("train/policy_gradient_loss", np.mean(pg_losses))
+ self.logger.record("train/value_loss", np.mean(value_losses))
+ self.logger.record("train/approx_kl", np.mean(approx_kl_divs))
+ self.logger.record("train/clip_fraction", np.mean(clip_fractions))
+ self.logger.record("train/explained_variance", explained_var)
+ self.logger.record("rollout/skipped_minibatch", num_early_stopping / num_batch_update)
+ if hasattr(self.policy, "log_std"):
+ self.logger.record("train/std", th.exp(self.policy.log_std).mean().item())
+
+ self.logger.record("train/n_updates", self._n_updates, exclude="tensorboard")
+ self.logger.record("train/clip_range", clip_range)
+ if self.clip_range_vf is not None:
+ self.logger.record("train/clip_range_vf", clip_range_vf)
+
+ def learn(
+ self,
+ total_timesteps: int,
+ callback: MaybeCallback = None,
+ log_interval: int = 1,
+ eval_env: Optional[GymEnv] = None,
+ eval_freq: int = -1,
+ n_eval_episodes: int = 5,
+ tb_log_name: str = "PPO",
+ eval_log_path: Optional[str] = None,
+ reset_num_timesteps: bool = True,
+ iter_start=0,
+ **kwargs
+ ) -> "PPO":
+
+ return super().learn(
+ total_timesteps=total_timesteps,
+ callback=callback,
+ log_interval=log_interval,
+ eval_env=eval_env,
+ eval_freq=eval_freq,
+ n_eval_episodes=n_eval_episodes,
+ tb_log_name=tb_log_name,
+ eval_log_path=eval_log_path,
+ reset_num_timesteps=reset_num_timesteps,
+ iter_start=iter_start,
+ )