diff --git a/.gitattributes b/.gitattributes index 9ca4c9045e8b6ec32c3ba6519c9b01976ec0d99f..f67edbf743e0c002d40727db1eae7818ae9278f4 100644 --- a/.gitattributes +++ b/.gitattributes @@ -43,3 +43,5 @@ Metaworld/zarr_path:[[:space:]]data/metaworld_disassemble_expert.zarr/data/point Metaworld/zarr_path:[[:space:]]/data/haojun/datasets/3d-dp/metaworld_hammer_expert.zarr/data/depth/0.0.0 filter=lfs diff=lfs merge=lfs -text Metaworld/zarr_path:[[:space:]]data/metaworld_door-lock_expert.zarr/data/point_cloud/16.0.0 filter=lfs diff=lfs merge=lfs -text Metaworld/zarr_path:[[:space:]]/data/haojun/datasets/3d-dp/metaworld_drawer-open_expert.zarr/data/point_cloud/5.0.0 filter=lfs diff=lfs merge=lfs -text +Metaworld/zarr_path:[[:space:]]data/metaworld_door-close_expert.zarr/data/point_cloud/12.0.0 filter=lfs diff=lfs merge=lfs -text +Metaworld/zarr_path:[[:space:]]data/metaworld_door-lock_expert.zarr/data/point_cloud/7.0.0 filter=lfs diff=lfs merge=lfs -text diff --git a/Metaworld/metaworld/envs/assets_v2/sawyer_xyz/sawyer_dial.xml b/Metaworld/metaworld/envs/assets_v2/sawyer_xyz/sawyer_dial.xml new file mode 100644 index 0000000000000000000000000000000000000000..3b2d3dc3aa3e52aca2882f7b213dd3753732e15e --- /dev/null +++ b/Metaworld/metaworld/envs/assets_v2/sawyer_xyz/sawyer_dial.xml @@ -0,0 +1,25 @@ + + + + + + + + + + + + + + + --> + + + + + + + + + diff --git a/Metaworld/metaworld/envs/assets_v2/sawyer_xyz/sawyer_door_lock.xml b/Metaworld/metaworld/envs/assets_v2/sawyer_xyz/sawyer_door_lock.xml new file mode 100644 index 0000000000000000000000000000000000000000..baa86c3eb73294751e4efdf803ab280399cbedf2 --- /dev/null +++ b/Metaworld/metaworld/envs/assets_v2/sawyer_xyz/sawyer_door_lock.xml @@ -0,0 +1,26 @@ + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/Metaworld/metaworld/envs/assets_v2/sawyer_xyz/sawyer_door_pull.xml b/Metaworld/metaworld/envs/assets_v2/sawyer_xyz/sawyer_door_pull.xml new file mode 100644 index 0000000000000000000000000000000000000000..8c026691e0346e844de4aca455741d075e8e287e --- /dev/null +++ b/Metaworld/metaworld/envs/assets_v2/sawyer_xyz/sawyer_door_pull.xml @@ -0,0 +1,23 @@ + + + + + + + + + + + + + + + + + + + + + + diff --git a/Metaworld/metaworld/envs/assets_v2/sawyer_xyz/sawyer_faucet.xml b/Metaworld/metaworld/envs/assets_v2/sawyer_xyz/sawyer_faucet.xml new file mode 100644 index 0000000000000000000000000000000000000000..a4c70d22783957612024bbe4e694618a04d28b47 --- /dev/null +++ b/Metaworld/metaworld/envs/assets_v2/sawyer_xyz/sawyer_faucet.xml @@ -0,0 +1,35 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/Metaworld/metaworld/envs/assets_v2/sawyer_xyz/sawyer_laptop.xml b/Metaworld/metaworld/envs/assets_v2/sawyer_xyz/sawyer_laptop.xml new file mode 100644 index 0000000000000000000000000000000000000000..4c12990f7bffeaf40b50e45aa90efca67a0ba0db --- /dev/null +++ b/Metaworld/metaworld/envs/assets_v2/sawyer_xyz/sawyer_laptop.xml @@ -0,0 +1,22 @@ + + + + + + + + + + + + + + + + + + + + + + diff --git a/Metaworld/zarr_path: data/metaworld_door-close_expert.zarr/data/point_cloud/12.0.0 b/Metaworld/zarr_path: data/metaworld_door-close_expert.zarr/data/point_cloud/12.0.0 new file mode 100644 index 0000000000000000000000000000000000000000..a36a25cd9059075eb53d5862b6557fbfca81a929 --- /dev/null +++ b/Metaworld/zarr_path: data/metaworld_door-close_expert.zarr/data/point_cloud/12.0.0 @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:902da15cc5607a827c4a6cf9b7c396bacd2bd244963aa4e36f400f543b23497b +size 1231019 diff --git a/Metaworld/zarr_path: data/metaworld_door-lock_expert.zarr/data/point_cloud/7.0.0 b/Metaworld/zarr_path: data/metaworld_door-lock_expert.zarr/data/point_cloud/7.0.0 new file mode 100644 index 0000000000000000000000000000000000000000..145c8b9a6b9a5dbae2337d32a8762931535418d6 --- /dev/null +++ b/Metaworld/zarr_path: data/metaworld_door-lock_expert.zarr/data/point_cloud/7.0.0 @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:32e76433f75e66cccd7b4db7962d3fc1ae4fe8915409efc427235456347323c4 +size 1213234 diff --git a/dexart-release/assets/sapien/102697/cues.txt b/dexart-release/assets/sapien/102697/cues.txt new file mode 100644 index 0000000000000000000000000000000000000000..fc91228deb45f4f0076d8ee2a82dbf3490ff8e4c --- /dev/null +++ b/dexart-release/assets/sapien/102697/cues.txt @@ -0,0 +1,5 @@ +link_0 hinge lever button +link_1 slider pump_lid lid +link_2 hinge lid lid +link_3 hinge seat seat +link_4 static base_body base_body diff --git a/dexart-release/assets/sapien/102697/meta.json b/dexart-release/assets/sapien/102697/meta.json new file mode 100644 index 0000000000000000000000000000000000000000..73d99eeda0b2619fb673d5d2a2fac81a7ab6066b --- /dev/null +++ b/dexart-release/assets/sapien/102697/meta.json @@ -0,0 +1 @@ +{"user_id": "haozhu", "model_cat": "Toilet", "model_id": "db252ecd6286a334733badcb2e574996-0", "version": "1", "anno_id": "2697", "time_in_sec": "31"} \ No newline at end of file diff --git a/dexart-release/assets/sapien/102697/mobility.urdf b/dexart-release/assets/sapien/102697/mobility.urdf new file mode 100644 index 0000000000000000000000000000000000000000..66322627aea3fc2cef1f786c1aff0a34a0b14508 --- /dev/null +++ b/dexart-release/assets/sapien/102697/mobility.urdf @@ -0,0 +1,502 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + \ No newline at end of file diff --git a/dexart-release/assets/sapien/102697/mobility_v2.json b/dexart-release/assets/sapien/102697/mobility_v2.json new file mode 100644 index 0000000000000000000000000000000000000000..425a87c4126d0dc8c85dca23d592be3dae58ba10 --- /dev/null +++ b/dexart-release/assets/sapien/102697/mobility_v2.json @@ -0,0 +1 @@ +[{"id":0,"parent":4,"joint":"hinge","name":"lever","parts":[{"id":1,"name":"button"}],"jointData":{"axis":{"origin":[-0.387179130380233,0.5229989412652956,-0.2586584187924479],"direction":[-0.91844856752952,-0.00006467369864908537,0.39554042096893877]},"limit":{"a":0,"b":30,"noLimit":false}}},{"id":1,"parent":4,"joint":"slider","name":"pump_lid","parts":[{"id":2,"name":"lid"}],"jointData":{"axis":{"origin":[0,0,0],"direction":[0,1,0]},"limit":{"a":0,"b":0.02,"noLimit":false,"rotates":false,"noRotationLimit":false,"rotationLimit":0}}},{"id":2,"parent":4,"joint":"hinge","name":"lid","parts":[{"id":3,"name":"lid"}],"jointData":{"axis":{"origin":[0,0.05794601786221419,-0.06205458525801076],"direction":[1,0,0]},"limit":{"a":0,"b":-94,"noLimit":false}}},{"id":3,"parent":4,"joint":"hinge","name":"seat","parts":[{"id":4,"name":"seat"}],"jointData":{"axis":{"origin":[0,0.05794601786221419,-0.06205458525801076],"direction":[1,0,0]},"limit":{"a":0,"b":-94,"noLimit":false}}},{"id":4,"parent":-1,"joint":"static","name":"base_body","parts":[{"id":5,"name":"base_body"}],"jointData":{}}] \ No newline at end of file diff --git a/dexart-release/assets/sapien/102697/new_objs/102697_link_1_12.mtl b/dexart-release/assets/sapien/102697/new_objs/102697_link_1_12.mtl new file mode 100644 index 0000000000000000000000000000000000000000..0a8d7de97ba5b99f083fcf82402f12b2d42e820b --- /dev/null +++ b/dexart-release/assets/sapien/102697/new_objs/102697_link_1_12.mtl @@ -0,0 +1,12 @@ +# Blender MTL File: 'None' +# Material Count: 1 + +newmtl Shape.14067 +Ns 400.000000 +Ka 0.400000 0.400000 0.400000 +Kd 0.336000 0.232000 0.584000 +Ks 0.250000 0.250000 0.250000 +Ke 0.000000 0.000000 0.000000 +Ni 1.000000 +d 0.500000 +illum 2 diff --git a/dexart-release/assets/sapien/102697/new_objs/102697_link_1_14.mtl b/dexart-release/assets/sapien/102697/new_objs/102697_link_1_14.mtl new file mode 100644 index 0000000000000000000000000000000000000000..9b787bc663cb512a8d4bca7d408e8e384c3b8619 --- /dev/null +++ b/dexart-release/assets/sapien/102697/new_objs/102697_link_1_14.mtl @@ -0,0 +1,12 @@ +# Blender MTL File: 'None' +# Material Count: 1 + +newmtl Shape.14069 +Ns 400.000000 +Ka 0.400000 0.400000 0.400000 +Kd 0.296000 0.784000 0.192000 +Ks 0.250000 0.250000 0.250000 +Ke 0.000000 0.000000 0.000000 +Ni 1.000000 +d 0.500000 +illum 2 diff --git a/dexart-release/assets/sapien/102697/new_objs/102697_link_1_14.obj b/dexart-release/assets/sapien/102697/new_objs/102697_link_1_14.obj new file mode 100644 index 0000000000000000000000000000000000000000..5815a3adc15fdcd79d4d8e4f3f64b8d93d911c3a --- /dev/null +++ b/dexart-release/assets/sapien/102697/new_objs/102697_link_1_14.obj @@ -0,0 +1,109 @@ +# Blender v2.79 (sub 0) OBJ File: '' +# www.blender.org +mtllib 102697_link_1_14.mtl +o Shape_IndexedFaceSet.014_Shape_IndexedFaceSet.14061 +v 0.381286 0.716318 -0.278277 +v 0.467366 0.669121 -0.458755 +v 0.467366 0.666342 -0.458755 +v 0.314641 0.669121 -0.486500 +v 0.314641 0.666342 -0.275488 +v 0.314641 0.716318 -0.486500 +v 0.442356 0.716318 -0.486520 +v 0.406252 0.666342 -0.275488 +v 0.314641 0.716318 -0.275488 +v 0.453478 0.702433 -0.433757 +v 0.464367 0.666416 -0.486207 +v 0.397925 0.702433 -0.275488 +v 0.386833 0.666342 -0.486500 +v 0.447917 0.713539 -0.453197 +v 0.433456 0.670746 -0.350149 +v 0.411813 0.705207 -0.311600 +v 0.453478 0.710760 -0.486520 +v 0.459025 0.702433 -0.455986 +v 0.397925 0.713539 -0.300506 +v 0.409033 0.674673 -0.283855 +v 0.430634 0.693883 -0.355457 +v 0.461805 0.669121 -0.436525 +v 0.447917 0.710760 -0.439314 +vn 0.6201 0.7564 0.2082 +vn -1.0000 0.0000 0.0000 +vn 0.0000 1.0000 0.0000 +vn -0.0002 0.0000 -1.0000 +vn 0.0000 -1.0000 0.0000 +vn 0.0000 0.0000 1.0000 +vn 0.9941 0.0000 -0.1086 +vn 0.0406 0.2433 0.9691 +vn -0.0385 -0.9992 -0.0132 +vn 0.0010 -1.0000 -0.0028 +vn 0.1781 0.9826 0.0522 +vn -0.0001 -0.0002 -1.0000 +vn 0.9633 0.2356 -0.1285 +vn -0.0000 -0.0004 -1.0000 +vn 0.0038 -0.0061 -1.0000 +vn 0.4470 0.8945 -0.0000 +vn 0.7133 0.6982 0.0608 +vn 0.9470 0.2175 0.2363 +vn 0.9624 0.2498 -0.1067 +vn 0.5705 0.7506 0.3332 +vn 0.2834 0.9545 0.0928 +vn 0.6574 0.6887 0.3057 +vn 0.8546 0.1972 0.4804 +vn 0.9385 0.0321 0.3438 +vn 0.8973 0.2493 0.3642 +vn 0.9155 0.2217 0.3357 +vn 0.8943 0.3344 0.2974 +vn 0.9251 0.1885 0.3296 +vn 0.9350 0.1837 0.3034 +vn 0.9701 0.0000 0.2427 +vn 0.8018 -0.5344 0.2674 +vn 0.9470 0.2170 0.2369 +vn 0.8672 -0.4031 0.2922 +vn 0.9325 0.2086 0.2948 +vn 0.7291 0.6431 0.2341 +vn 0.7174 0.6831 0.1368 +vn 0.7541 0.6292 0.1882 +vn 0.5142 0.8410 0.1683 +usemtl Shape.14069 +s off +f 19//1 16//1 23//1 +f 4//2 5//2 6//2 +f 6//3 1//3 7//3 +f 4//4 6//4 7//4 +f 5//5 3//5 8//5 +f 5//6 8//6 9//6 +f 1//3 6//3 9//3 +f 6//2 5//2 9//2 +f 2//7 3//7 11//7 +f 9//6 8//6 12//6 +f 1//8 9//8 12//8 +f 5//9 4//9 13//9 +f 3//5 5//5 13//5 +f 11//10 3//10 13//10 +f 7//11 1//11 14//11 +f 4//12 7//12 17//12 +f 2//13 11//13 17//13 +f 13//14 4//14 17//14 +f 11//15 13//15 17//15 +f 7//16 14//16 17//16 +f 17//17 14//17 18//17 +f 10//18 2//18 18//18 +f 2//19 17//19 18//19 +f 1//20 12//20 19//20 +f 14//21 1//21 19//21 +f 12//22 16//22 19//22 +f 12//23 8//23 20//23 +f 8//24 15//24 20//24 +f 16//25 12//25 20//25 +f 16//26 20//26 21//26 +f 10//27 16//27 21//27 +f 20//28 15//28 21//28 +f 21//29 15//29 22//29 +f 3//30 2//30 22//30 +f 8//31 3//31 22//31 +f 2//32 10//32 22//32 +f 15//33 8//33 22//33 +f 10//34 21//34 22//34 +f 16//35 10//35 23//35 +f 18//36 14//36 23//36 +f 10//37 18//37 23//37 +f 14//38 19//38 23//38 diff --git a/dexart-release/assets/sapien/102697/new_objs/102697_link_1_5.mtl b/dexart-release/assets/sapien/102697/new_objs/102697_link_1_5.mtl new file mode 100644 index 0000000000000000000000000000000000000000..1d543567c45ae9dae1065722296e1ef27ad60a92 --- /dev/null +++ b/dexart-release/assets/sapien/102697/new_objs/102697_link_1_5.mtl @@ -0,0 +1,12 @@ +# Blender MTL File: 'None' +# Material Count: 1 + +newmtl Shape.14060 +Ns 400.000000 +Ka 0.400000 0.400000 0.400000 +Kd 0.576000 0.288000 0.088000 +Ks 0.250000 0.250000 0.250000 +Ke 0.000000 0.000000 0.000000 +Ni 1.000000 +d 0.500000 +illum 2 diff --git a/dexart-release/assets/sapien/102697/new_objs/102697_link_3_0.obj b/dexart-release/assets/sapien/102697/new_objs/102697_link_3_0.obj new file mode 100644 index 0000000000000000000000000000000000000000..02ce7be22a24a10eeb20da679572050bce75f8d6 --- /dev/null +++ b/dexart-release/assets/sapien/102697/new_objs/102697_link_3_0.obj @@ -0,0 +1,217 @@ +# Blender v2.79 (sub 0) OBJ File: '' +# www.blender.org +mtllib 102697_link_3_0.mtl +o Shape_IndexedFaceSet_Shape_IndexedFaceSet.16309 +v -0.318776 0.060655 0.107693 +v -0.063870 0.047615 -0.076335 +v -0.006959 0.042434 0.076198 +v -0.006959 0.088979 -0.019441 +v -0.309458 0.088979 0.114981 +v -0.182789 0.086387 -0.042715 +v -0.257765 0.042434 0.006414 +v -0.312052 0.042434 0.114981 +v -0.167256 0.073461 0.114981 +v -0.006959 0.042434 -0.068570 +v -0.006959 0.073461 0.060669 +v -0.006959 0.088979 -0.065989 +v -0.262922 0.086387 0.008996 +v -0.151785 0.042434 0.114981 +v -0.154809 0.063235 -0.062488 +v -0.255171 0.088979 0.114981 +v -0.265516 0.065707 -0.001350 +v -0.058683 0.083799 -0.073734 +v -0.155045 0.046362 -0.055734 +v -0.006959 0.055369 0.078780 +v -0.317239 0.052781 0.091707 +v -0.309458 0.083799 0.086544 +v -0.069027 0.088979 -0.065989 +v -0.006959 0.070873 -0.078916 +v -0.159536 0.083799 -0.058243 +v -0.107812 0.042434 -0.063407 +v -0.257765 0.045027 -0.001350 +v -0.159811 0.051020 -0.059069 +v -0.306895 0.088979 0.096890 +v -0.149191 0.060536 0.114981 +v -0.056120 0.078632 -0.076335 +v -0.314646 0.045027 0.096890 +v -0.006959 0.047615 -0.076335 +v -0.319802 0.076049 0.114981 +v -0.260328 0.078632 -0.001350 +v -0.312052 0.068285 0.081361 +v -0.250915 0.087266 0.009661 +v -0.006959 0.070873 0.065852 +v -0.304301 0.050193 0.068434 +v -0.268110 0.070873 0.003833 +v -0.006959 0.060536 0.076198 +v -0.006959 0.086387 -0.071152 +v -0.006959 0.078632 0.034833 +v -0.167107 0.053534 -0.056233 +v -0.162264 0.065707 -0.059140 +v -0.317239 0.045027 0.114981 +v -0.198290 0.081220 -0.037551 +v -0.247420 0.042434 -0.001350 +vn -0.2189 -0.8734 -0.4350 +vn 0.0000 -1.0000 0.0000 +vn 0.0000 0.0000 1.0000 +vn 1.0000 0.0000 0.0000 +vn -0.0000 1.0000 0.0000 +vn 0.2540 -0.1893 0.9485 +vn -0.0785 0.9576 -0.2772 +vn -0.0984 0.7614 -0.6407 +vn -0.0223 -0.8995 -0.4364 +vn -0.1804 -0.1959 -0.9639 +vn -0.1633 -0.5893 -0.7912 +vn -0.1721 -0.6838 -0.7091 +vn -0.3065 -0.7412 -0.5972 +vn -0.5117 0.8123 -0.2799 +vn 0.2435 0.3404 0.9082 +vn 0.2453 -0.0352 0.9688 +vn -0.1444 0.0361 -0.9889 +vn -0.0504 0.0126 -0.9986 +vn -0.1621 0.1636 -0.9731 +vn -0.1399 0.3890 -0.9106 +vn -0.2162 -0.9704 -0.1081 +vn -0.4800 -0.8321 -0.2779 +vn 0.0000 -0.8318 -0.5550 +vn 0.0000 -0.1103 -0.9939 +vn -0.9962 -0.0274 -0.0823 +vn -0.7761 0.6209 -0.1100 +vn -0.7791 0.6162 -0.1155 +vn -0.3952 0.6848 -0.6123 +vn -0.9554 0.1496 -0.2548 +vn -0.9305 0.2462 -0.2713 +vn -0.0669 0.9924 -0.1037 +vn -0.0356 0.9974 -0.0631 +vn -0.0693 0.9955 -0.0640 +vn -0.0240 0.9991 -0.0350 +vn 0.1545 0.8754 0.4580 +vn 0.1519 0.8843 0.4415 +vn -0.8035 -0.3012 -0.5135 +vn -0.7750 -0.5093 -0.3742 +vn -0.6037 -0.7165 -0.3495 +vn -0.8716 -0.0235 -0.4897 +vn -0.8750 -0.0297 -0.4832 +vn -0.7861 0.4152 -0.4579 +vn -0.7546 0.4201 -0.5041 +vn -0.7121 0.2858 -0.6413 +vn -0.8694 0.0560 -0.4909 +vn -0.8355 0.2946 -0.4637 +vn 0.2448 0.4334 0.8673 +vn 0.2223 0.6897 0.6891 +vn 0.0000 0.8937 -0.4487 +vn -0.0128 0.8231 -0.5677 +vn 0.0237 0.4474 -0.8940 +vn 0.0214 0.4580 -0.8887 +vn 0.1009 0.9773 0.1863 +vn 0.1037 0.9753 0.1952 +vn -0.4960 -0.1859 -0.8482 +vn -0.3894 -0.0969 -0.9159 +vn -0.4480 -0.3961 -0.8015 +vn -0.3790 0.1027 -0.9197 +vn -0.4224 -0.0481 -0.9051 +vn -0.4884 -0.0141 -0.8725 +vn -0.9926 -0.1156 -0.0385 +vn -0.4462 -0.8926 -0.0640 +vn -0.9107 -0.3918 -0.1305 +vn -0.9960 -0.0823 0.0336 +vn -0.4229 0.5449 -0.7240 +vn -0.4253 0.5896 -0.6867 +vn -0.5000 0.2007 -0.8425 +vn -0.4868 0.0799 -0.8698 +vn -0.4738 0.1147 -0.8731 +vn -0.1247 -0.9517 -0.2806 +vn -0.2313 -0.9228 -0.3082 +usemtl Shape.16319 +s off +f 27//1 19//1 48//1 +f 7//2 3//2 8//2 +f 5//3 8//3 9//3 +f 4//4 3//4 10//4 +f 3//2 7//2 10//2 +f 3//4 4//4 11//4 +f 5//5 4//5 12//5 +f 4//4 10//4 12//4 +f 8//2 3//2 14//2 +f 9//3 8//3 14//3 +f 4//5 5//5 16//5 +f 5//3 9//3 16//3 +f 3//4 11//4 20//4 +f 14//6 3//6 20//6 +f 5//5 12//5 23//5 +f 12//4 10//4 24//4 +f 6//7 23//7 25//7 +f 23//8 18//8 25//8 +f 2//9 10//9 26//9 +f 10//2 7//2 26//2 +f 15//10 2//10 28//10 +f 2//11 26//11 28//11 +f 26//12 19//12 28//12 +f 19//13 27//13 28//13 +f 13//14 22//14 29//14 +f 5//5 23//5 29//5 +f 9//3 14//3 30//3 +f 20//15 9//15 30//15 +f 14//16 20//16 30//16 +f 2//17 15//17 31//17 +f 24//18 2//18 31//18 +f 15//19 25//19 31//19 +f 25//20 18//20 31//20 +f 7//21 8//21 32//21 +f 27//22 7//22 32//22 +f 10//23 2//23 33//23 +f 2//24 24//24 33//24 +f 24//4 10//4 33//4 +f 8//3 5//3 34//3 +f 21//25 1//25 34//25 +f 5//26 29//26 34//26 +f 29//27 22//27 34//27 +f 13//28 6//28 35//28 +f 21//29 34//29 36//29 +f 34//30 22//30 36//30 +f 6//31 13//31 37//31 +f 23//32 6//32 37//32 +f 13//33 29//33 37//33 +f 29//34 23//34 37//34 +f 16//35 9//35 38//35 +f 11//36 16//36 38//36 +f 20//4 11//4 38//4 +f 17//37 27//37 39//37 +f 32//38 21//38 39//38 +f 27//39 32//39 39//39 +f 36//40 17//40 39//40 +f 21//41 36//41 39//41 +f 22//42 13//42 40//42 +f 13//43 35//43 40//43 +f 35//44 17//44 40//44 +f 17//45 36//45 40//45 +f 36//46 22//46 40//46 +f 9//47 20//47 41//47 +f 38//48 9//48 41//48 +f 20//4 38//4 41//4 +f 23//49 12//49 42//49 +f 18//50 23//50 42//50 +f 12//4 24//4 42//4 +f 24//51 31//51 42//51 +f 31//52 18//52 42//52 +f 11//4 4//4 43//4 +f 4//53 16//53 43//53 +f 16//54 11//54 43//54 +f 27//55 17//55 44//55 +f 15//56 28//56 44//56 +f 28//57 27//57 44//57 +f 25//58 15//58 45//58 +f 15//59 44//59 45//59 +f 44//60 17//60 45//60 +f 1//61 21//61 46//61 +f 32//62 8//62 46//62 +f 21//63 32//63 46//63 +f 34//64 1//64 46//64 +f 8//3 34//3 46//3 +f 6//65 25//65 47//65 +f 35//66 6//66 47//66 +f 17//67 35//67 47//67 +f 45//68 17//68 47//68 +f 25//69 45//69 47//69 +f 26//2 7//2 48//2 +f 19//70 26//70 48//70 +f 7//71 27//71 48//71 diff --git a/dexart-release/assets/sapien/102697/new_objs/102697_link_3_5.obj b/dexart-release/assets/sapien/102697/new_objs/102697_link_3_5.obj new file mode 100644 index 0000000000000000000000000000000000000000..074a59b319d2140a08c5e2802a1d71a0ca212920 --- /dev/null +++ b/dexart-release/assets/sapien/102697/new_objs/102697_link_3_5.obj @@ -0,0 +1,250 @@ +# Blender v2.79 (sub 0) OBJ File: '' +# www.blender.org +mtllib 102697_link_3_5.mtl +o Shape_IndexedFaceSet.005_Shape_IndexedFaceSet.16314 +v -0.006959 0.055369 0.663225 +v -0.035413 0.042434 0.789889 +v -0.006985 0.088979 0.766613 +v -0.224145 0.088979 0.575274 +v -0.112985 0.042434 0.575296 +v -0.244834 0.042434 0.624424 +v -0.216380 0.083799 0.676129 +v -0.128515 0.070873 0.575274 +v -0.154338 0.047615 0.743337 +v -0.061288 0.088979 0.779539 +v -0.006959 0.042434 0.792466 +v -0.006985 0.073461 0.683903 +v -0.006959 0.042434 0.668378 +v -0.260364 0.088979 0.590820 +v -0.030254 0.068285 0.800239 +v -0.221565 0.065707 0.678706 +v -0.268103 0.042434 0.575274 +v -0.146599 0.081220 0.745913 +v -0.081951 0.055369 0.784714 +v -0.006985 0.088979 0.789889 +v -0.270683 0.060536 0.598550 +v -0.102641 0.042434 0.764037 +v -0.206061 0.055369 0.696785 +v -0.150775 0.086767 0.728688 +v -0.087137 0.076049 0.779539 +v -0.154338 0.065707 0.745913 +v -0.262944 0.081220 0.603725 +v -0.273262 0.081220 0.575274 +v -0.262944 0.047615 0.606323 +v -0.032834 0.047615 0.797640 +v -0.006985 0.068294 0.800239 +v -0.244834 0.088979 0.619249 +v -0.030254 0.083799 0.795064 +v -0.006985 0.070878 0.676129 +v -0.087137 0.050193 0.782138 +v -0.006985 0.081215 0.722659 +v -0.213800 0.078632 0.683881 +v -0.242254 0.052781 0.645102 +v -0.180186 0.081220 0.575274 +v -0.143993 0.083799 0.745913 +v -0.079372 0.083799 0.779539 +v -0.209848 0.046828 0.679494 +v -0.262944 0.088979 0.575274 +v -0.009565 0.047615 0.660627 +v -0.092296 0.068285 0.779539 +v -0.159523 0.070873 0.740739 +v -0.275868 0.068285 0.583047 +v -0.273262 0.045027 0.575274 +v -0.032834 0.055369 0.800239 +v -0.125910 0.068285 0.575274 +v -0.006985 0.063123 0.668378 +v -0.074186 0.045027 0.782138 +v -0.204443 0.044694 0.679584 +vn -0.3528 -0.8802 0.3176 +vn 0.0000 -1.0000 -0.0000 +vn -0.0000 1.0000 -0.0000 +vn 1.0000 0.0000 0.0000 +vn 0.0000 0.0000 -1.0000 +vn 1.0000 0.0008 0.0000 +vn 1.0000 0.0006 0.0001 +vn -0.6739 -0.1041 0.7314 +vn -0.8426 0.1891 0.5042 +vn -0.7059 0.6605 0.2560 +vn -0.0513 -0.8223 0.5667 +vn 1.0000 0.0006 0.0013 +vn -0.1903 0.9645 0.1830 +vn -0.0704 0.9943 0.0806 +vn -0.5935 0.7366 0.3242 +vn -0.6160 0.6947 0.3714 +vn -0.0901 0.8767 0.4726 +vn 0.0988 0.4453 0.8899 +vn -0.0001 0.3164 0.9486 +vn 1.0000 0.0023 -0.0008 +vn 0.2531 0.9181 -0.3050 +vn -0.3097 -0.7486 0.5862 +vn -0.4927 -0.1227 0.8615 +vn 1.0000 0.0017 -0.0003 +vn 0.1515 0.9734 -0.1719 +vn 1.0000 0.0019 -0.0004 +vn 0.1659 0.9670 -0.1935 +vn -0.7073 0.1481 0.6912 +vn -0.7933 0.3506 0.4977 +vn -0.8130 0.2852 0.5077 +vn -0.8528 0.0077 0.5221 +vn -0.7975 -0.2017 0.5686 +vn -0.6081 -0.6770 0.4146 +vn -0.8486 -0.2184 0.4819 +vn 0.1702 0.9644 -0.2025 +vn 0.1908 0.9528 -0.2362 +vn -0.2498 0.9330 0.2591 +vn -0.1520 0.9623 0.2257 +vn -0.5596 0.5915 0.5805 +vn -0.5675 0.5734 0.5909 +vn -0.2916 0.2921 0.9108 +vn -0.4189 0.4197 0.8052 +vn -0.2076 0.7248 0.6569 +vn -0.2872 0.3031 0.9086 +vn -0.2434 0.8497 0.4677 +vn -0.4183 0.4227 0.8039 +vn -0.5264 -0.7109 0.4664 +vn -0.5764 -0.6997 0.4220 +vn -0.6036 -0.6544 0.4554 +vn -0.5981 0.7953 0.0993 +vn 0.6344 0.0453 -0.7717 +vn 0.8968 -0.1637 -0.4110 +vn 0.5171 -0.6211 -0.5890 +vn -0.3140 0.1256 0.9411 +vn -0.3097 0.2058 0.9283 +vn -0.4499 0.2990 0.8416 +vn -0.4707 0.2348 0.8505 +vn -0.4460 0.0014 0.8950 +vn -0.4762 -0.0095 0.8793 +vn -0.5346 0.2667 0.8020 +vn -0.6914 0.0290 0.7219 +vn -0.6165 0.4454 0.6493 +vn -0.7048 0.1501 0.6933 +vn -0.8834 0.2281 0.4094 +vn -0.8746 0.3668 0.3172 +vn -0.9481 0.0000 -0.3179 +vn -0.4393 -0.8740 0.2080 +vn -0.4713 -0.8521 0.2276 +vn -0.8850 -0.3363 0.3221 +vn -0.9562 -0.1834 0.2281 +vn -0.3008 0.0601 0.9518 +vn 0.1250 -0.3153 0.9407 +vn 0.1424 -0.2848 0.9480 +vn 0.0000 0.0000 1.0000 +vn -0.2970 -0.1701 0.9396 +vn -0.2748 -0.3056 0.9117 +vn 0.5887 0.2937 -0.7531 +vn 0.0001 -0.0008 -1.0000 +vn 0.9999 0.0100 -0.0100 +vn 0.5062 0.6097 -0.6100 +vn 0.5596 0.4600 -0.6894 +vn 0.5341 0.5376 -0.6524 +vn -0.1298 -0.9323 0.3377 +vn -0.1701 -0.7922 0.5861 +vn -0.3013 -0.7554 0.5819 +vn -0.2439 -0.6115 0.7527 +vn -0.1402 -0.9798 0.1428 +vn -0.1678 -0.9699 0.1763 +vn -0.3550 -0.8867 0.2963 +usemtl Shape.16324 +s off +f 9//1 42//1 53//1 +f 5//2 2//2 6//2 +f 3//3 4//3 10//3 +f 2//2 5//2 11//2 +f 1//4 11//4 13//4 +f 11//2 5//2 13//2 +f 10//3 4//3 14//3 +f 5//2 6//2 17//2 +f 4//5 8//5 17//5 +f 1//6 3//6 20//6 +f 3//3 10//3 20//3 +f 11//7 1//7 20//7 +f 6//2 2//2 22//2 +f 23//8 9//8 26//8 +f 21//9 16//9 27//9 +f 4//5 17//5 28//5 +f 27//10 14//10 28//10 +f 2//11 11//11 30//11 +f 11//12 20//12 31//12 +f 10//3 14//3 32//3 +f 7//13 24//13 32//13 +f 24//14 10//14 32//14 +f 14//15 27//15 32//15 +f 27//16 7//16 32//16 +f 20//17 10//17 33//17 +f 31//18 20//18 33//18 +f 15//19 31//19 33//19 +f 12//20 1//20 34//20 +f 8//21 12//21 34//21 +f 9//22 22//22 35//22 +f 26//23 9//23 35//23 +f 3//24 1//24 36//24 +f 4//25 3//25 36//25 +f 1//26 12//26 36//26 +f 12//27 4//27 36//27 +f 16//28 23//28 37//28 +f 7//29 27//29 37//29 +f 27//30 16//30 37//30 +f 16//31 21//31 38//31 +f 23//32 16//32 38//32 +f 29//33 6//33 38//33 +f 21//34 29//34 38//34 +f 8//5 4//5 39//5 +f 4//35 12//35 39//35 +f 12//36 8//36 39//36 +f 24//37 7//37 40//37 +f 10//38 24//38 40//38 +f 7//39 37//39 40//39 +f 37//40 18//40 40//40 +f 25//41 15//41 41//41 +f 18//42 25//42 41//42 +f 33//43 10//43 41//43 +f 15//44 33//44 41//44 +f 10//45 40//45 41//45 +f 40//46 18//46 41//46 +f 9//47 23//47 42//47 +f 38//48 6//48 42//48 +f 23//49 38//49 42//49 +f 14//3 4//3 43//3 +f 4//5 28//5 43//5 +f 28//50 14//50 43//50 +f 5//51 1//51 44//51 +f 1//52 13//52 44//52 +f 13//53 5//53 44//53 +f 19//54 15//54 45//54 +f 15//55 25//55 45//55 +f 25//56 18//56 45//56 +f 18//57 26//57 45//57 +f 35//58 19//58 45//58 +f 26//59 35//59 45//59 +f 26//60 18//60 46//60 +f 23//61 26//61 46//61 +f 18//62 37//62 46//62 +f 37//63 23//63 46//63 +f 21//64 27//64 47//64 +f 27//65 28//65 47//65 +f 47//66 28//66 48//66 +f 17//67 6//67 48//67 +f 28//5 17//5 48//5 +f 6//68 29//68 48//68 +f 29//69 21//69 48//69 +f 21//70 47//70 48//70 +f 15//71 19//71 49//71 +f 30//72 11//72 49//72 +f 11//73 31//73 49//73 +f 31//74 15//74 49//74 +f 19//75 35//75 49//75 +f 35//76 30//76 49//76 +f 1//77 5//77 50//77 +f 5//78 17//78 50//78 +f 17//5 8//5 50//5 +f 34//79 1//79 51//79 +f 8//80 34//80 51//80 +f 1//81 50//81 51//81 +f 50//82 8//82 51//82 +f 22//83 2//83 52//83 +f 2//84 30//84 52//84 +f 35//85 22//85 52//85 +f 30//86 35//86 52//86 +f 6//87 22//87 53//87 +f 22//88 9//88 53//88 +f 42//89 6//89 53//89 diff --git a/dexart-release/assets/sapien/102697/new_objs/102697_link_3_6.obj b/dexart-release/assets/sapien/102697/new_objs/102697_link_3_6.obj new file mode 100644 index 0000000000000000000000000000000000000000..cf34a50e2465905761098e362a44b1b53d414349 --- /dev/null +++ b/dexart-release/assets/sapien/102697/new_objs/102697_link_3_6.obj @@ -0,0 +1,169 @@ +# Blender v2.79 (sub 0) OBJ File: '' +# www.blender.org +mtllib 102697_link_3_6.mtl +o Shape_IndexedFaceSet.006_Shape_IndexedFaceSet.16315 +v 0.267139 0.042434 0.557173 +v 0.174056 0.055369 0.316720 +v 0.220597 0.088979 0.557173 +v 0.117181 0.047615 0.557173 +v 0.321419 0.088979 0.316720 +v 0.176650 0.042434 0.316720 +v 0.192149 0.073461 0.316720 +v 0.272305 0.086387 0.557173 +v 0.324013 0.042434 0.316720 +v 0.132681 0.070873 0.549382 +v 0.311086 0.052781 0.466641 +v 0.272305 0.088979 0.316720 +v 0.124942 0.042434 0.557173 +v 0.308492 0.081220 0.461463 +v 0.137869 0.073461 0.557173 +v 0.282638 0.068285 0.549382 +v 0.298159 0.088979 0.469231 +v 0.334346 0.068290 0.319286 +v 0.308492 0.045027 0.458873 +v 0.124942 0.065707 0.554560 +v 0.334346 0.052781 0.316697 +v 0.184389 0.070873 0.321875 +v 0.285232 0.050193 0.539048 +v 0.324013 0.083799 0.358079 +v 0.267139 0.088979 0.557173 +v 0.137869 0.073461 0.551971 +v 0.119775 0.057957 0.554560 +v 0.174056 0.045027 0.316720 +v 0.316253 0.042434 0.383948 +v 0.331774 0.047610 0.321875 +v 0.329180 0.083803 0.316697 +v 0.272305 0.045027 0.557173 +v 0.220597 0.088979 0.551971 +v 0.119775 0.045027 0.551971 +v 0.313680 0.088979 0.386514 +v 0.282638 0.042434 0.520946 +v 0.179222 0.063128 0.316720 +vn -0.0002 0.0002 -1.0000 +vn 0.0000 0.0000 1.0000 +vn 0.0000 -1.0000 -0.0000 +vn 0.0000 1.0000 0.0000 +vn 0.8767 0.3663 0.3117 +vn 0.9441 0.1404 0.2983 +vn 0.6868 0.6920 0.2223 +vn 0.9787 0.1197 0.1671 +vn -0.3480 0.2786 0.8951 +vn -0.5230 0.8497 0.0660 +vn 0.0000 -0.0022 -1.0000 +vn 0.9879 -0.0256 0.1532 +vn -0.6125 0.7781 -0.1392 +vn 0.9409 -0.0559 0.3340 +vn 0.8012 -0.5355 0.2670 +vn 0.9537 0.2610 0.1497 +vn 0.4429 0.8828 0.1562 +vn -0.1899 0.9808 -0.0438 +vn -0.1844 0.9829 0.0000 +vn -0.4464 0.8948 0.0000 +vn -0.3652 0.9271 -0.0843 +vn -0.4071 0.9087 -0.0925 +vn -0.9578 0.1845 -0.2206 +vn -0.4906 0.3271 0.8076 +vn -0.9731 0.0000 -0.2302 +vn -0.0001 -0.0001 -1.0000 +vn 0.8869 -0.4394 0.1424 +vn 0.6665 -0.6663 -0.3344 +vn 0.9361 -0.3202 0.1452 +vn 0.5256 -0.8486 0.0607 +vn 0.6247 -0.7755 0.0915 +vn 0.0000 0.0044 -1.0000 +vn -0.0003 0.0014 -1.0000 +vn 0.7031 0.1171 -0.7014 +vn 0.9363 0.3313 0.1169 +vn -0.0001 0.0000 -1.0000 +vn 0.6020 0.0000 0.7985 +vn 0.8242 -0.1871 0.5345 +vn -0.1842 0.9821 -0.0405 +vn -0.5501 -0.8240 0.1356 +vn -0.3801 -0.9213 -0.0817 +vn -0.8628 -0.4647 -0.1991 +vn -0.6977 -0.6980 -0.1610 +vn 0.6532 0.7472 0.1226 +vn 0.6915 0.7120 0.1216 +vn 0.5539 0.8303 0.0614 +vn 0.6054 0.7923 0.0757 +vn 0.6261 -0.7451 0.2297 +vn 0.2367 -0.9698 0.0581 +vn 0.4406 -0.8777 0.1885 +vn 0.6228 -0.7475 0.2311 +vn -0.5476 0.6851 -0.4804 +vn -0.7588 0.6260 -0.1800 +vn -0.8168 0.5439 -0.1923 +vn -0.8165 0.5444 -0.1922 +vn -0.0002 0.0001 -1.0000 +usemtl Shape.16325 +s off +f 7//1 31//1 37//1 +f 1//2 3//2 4//2 +f 3//2 1//2 8//2 +f 1//3 6//3 9//3 +f 3//4 5//4 12//4 +f 1//2 4//2 13//2 +f 6//3 1//3 13//3 +f 4//2 3//2 15//2 +f 14//5 8//5 16//5 +f 11//6 14//6 16//6 +f 5//4 3//4 17//4 +f 8//7 14//7 17//7 +f 14//8 11//8 18//8 +f 4//9 15//9 20//9 +f 15//10 10//10 20//10 +f 9//11 6//11 21//11 +f 18//12 11//12 21//12 +f 20//13 10//13 22//13 +f 11//14 16//14 23//14 +f 19//15 11//15 23//15 +f 14//16 18//16 24//16 +f 3//2 8//2 25//2 +f 17//4 3//4 25//4 +f 8//17 17//17 25//17 +f 12//18 7//18 26//18 +f 15//19 3//19 26//19 +f 10//20 15//20 26//20 +f 7//21 22//21 26//21 +f 22//22 10//22 26//22 +f 2//23 4//23 27//23 +f 4//24 20//24 27//24 +f 4//25 2//25 28//25 +f 21//26 6//26 28//26 +f 1//3 9//3 29//3 +f 11//27 19//27 30//27 +f 9//28 21//28 30//28 +f 21//29 11//29 30//29 +f 29//30 9//30 30//30 +f 19//31 29//31 30//31 +f 12//32 5//32 31//32 +f 7//33 12//33 31//33 +f 18//34 21//34 31//34 +f 24//35 18//35 31//35 +f 28//36 2//36 31//36 +f 21//36 28//36 31//36 +f 8//2 1//2 32//2 +f 16//37 8//37 32//37 +f 23//38 16//38 32//38 +f 3//4 12//4 33//4 +f 26//19 3//19 33//19 +f 12//39 26//39 33//39 +f 13//40 4//40 34//40 +f 6//41 13//41 34//41 +f 4//42 28//42 34//42 +f 28//43 6//43 34//43 +f 5//4 17//4 35//4 +f 17//44 14//44 35//44 +f 14//45 24//45 35//45 +f 31//46 5//46 35//46 +f 24//47 31//47 35//47 +f 19//48 23//48 36//48 +f 1//3 29//3 36//3 +f 29//49 19//49 36//49 +f 32//50 1//50 36//50 +f 23//51 32//51 36//51 +f 22//52 7//52 37//52 +f 20//53 22//53 37//53 +f 2//54 27//54 37//54 +f 27//55 20//55 37//55 +f 31//56 2//56 37//56 diff --git a/dexart-release/assets/sapien/102697/new_objs/102697_link_4_11.obj b/dexart-release/assets/sapien/102697/new_objs/102697_link_4_11.obj new file mode 100644 index 0000000000000000000000000000000000000000..240ae654bc988106c6fe7646f940f1bd7dfc5537 --- /dev/null +++ b/dexart-release/assets/sapien/102697/new_objs/102697_link_4_11.obj @@ -0,0 +1,113 @@ +# Blender v2.79 (sub 0) OBJ File: '' +# www.blender.org +mtllib 102697_link_4_11.mtl +o Shape_IndexedFaceSet.011_Shape_IndexedFaceSet.7497 +v -0.323117 -0.045453 0.089563 +v -0.106232 0.036714 0.031262 +v -0.106232 -0.014976 0.031262 +v -0.207356 0.064630 -0.188472 +v -0.342601 0.058882 0.201139 +v -0.106232 -0.022362 -0.168161 +v -0.269912 -0.053693 0.229603 +v -0.099529 0.061840 -0.178611 +v -0.268728 0.058882 0.223327 +v -0.217054 -0.014976 -0.168161 +v -0.114215 -0.048395 -0.086531 +v -0.354324 -0.050570 0.218569 +v -0.106232 -0.051916 0.009150 +v -0.349976 0.051485 0.178989 +v -0.239154 0.021942 0.215931 +v -0.224454 0.036714 -0.168161 +v -0.170368 -0.047738 -0.078248 +v -0.331607 -0.019798 0.104076 +v -0.246554 -0.022362 0.223327 +v -0.202304 -0.022362 -0.168161 +v -0.126343 0.050010 0.012102 +v -0.246554 0.044099 0.223327 +v -0.224454 -0.000205 -0.168161 +vn -0.9305 0.0000 -0.3663 +vn 0.9995 0.0000 0.0319 +vn 0.0870 -0.1295 -0.9878 +vn -0.0040 0.9999 0.0134 +vn 0.0243 0.9996 0.0176 +vn -0.2829 0.1803 0.9420 +vn -0.1273 0.0565 0.9902 +vn 0.9995 -0.0157 0.0262 +vn 0.9966 -0.0810 -0.0135 +vn 0.8380 -0.5383 -0.0897 +vn -0.4515 0.8806 -0.1437 +vn -0.9523 0.1448 0.2687 +vn 0.8116 0.0000 0.5842 +vn -0.8922 0.3024 -0.3355 +vn -0.0793 -0.9951 -0.0587 +vn -0.0330 -0.9990 -0.0300 +vn -0.0269 -0.9992 -0.0280 +vn -0.0169 -0.9992 -0.0354 +vn -0.9542 -0.1811 -0.2380 +vn -0.9782 -0.0375 -0.2042 +vn -0.9321 0.1193 -0.3421 +vn 0.7252 -0.4335 0.5349 +vn 0.7686 -0.3286 0.5489 +vn 0.8076 -0.0366 0.5886 +vn -0.0000 -0.2274 -0.9738 +vn -0.4300 -0.8588 -0.2785 +vn -0.1163 -0.2322 -0.9657 +vn 0.0000 -0.9527 -0.3038 +vn -0.2235 -0.9559 -0.1904 +vn -0.0489 -0.9657 -0.2552 +vn 0.4655 0.8769 0.1198 +vn 0.3830 0.8970 0.2205 +vn 0.1907 0.9778 0.0875 +vn 0.5197 0.7795 0.3497 +vn 0.0368 0.0552 0.9978 +vn 0.8066 0.0736 0.5865 +vn 0.2595 0.0000 0.9657 +vn 0.7069 0.0000 0.7073 +vn -0.8241 -0.4128 -0.3879 +vn -0.3724 -0.1866 -0.9091 +vn -0.7650 0.0000 -0.6440 +vn -0.9238 -0.0961 -0.3705 +usemtl Shape.7503 +s off +f 18//1 16//1 23//1 +f 2//2 3//2 8//2 +f 6//3 4//3 8//3 +f 4//4 5//4 9//4 +f 8//5 4//5 9//5 +f 9//6 5//6 12//6 +f 7//7 9//7 12//7 +f 8//8 3//8 13//8 +f 6//9 8//9 13//9 +f 11//10 6//10 13//10 +f 5//11 4//11 14//11 +f 12//12 5//12 14//12 +f 3//13 2//13 15//13 +f 14//14 4//14 16//14 +f 12//15 1//15 17//15 +f 7//16 12//16 17//16 +f 13//17 7//17 17//17 +f 11//18 13//18 17//18 +f 1//19 12//19 18//19 +f 12//20 14//20 18//20 +f 14//21 16//21 18//21 +f 7//22 13//22 19//22 +f 13//23 3//23 19//23 +f 3//24 15//24 19//24 +f 4//25 6//25 20//25 +f 1//26 10//26 20//26 +f 10//27 4//27 20//27 +f 6//28 11//28 20//28 +f 17//29 1//29 20//29 +f 11//30 17//30 20//30 +f 2//31 8//31 21//31 +f 9//32 2//32 21//32 +f 8//33 9//33 21//33 +f 2//34 9//34 22//34 +f 9//35 7//35 22//35 +f 15//36 2//36 22//36 +f 7//37 19//37 22//37 +f 19//38 15//38 22//38 +f 10//39 1//39 23//39 +f 4//40 10//40 23//40 +f 16//41 4//41 23//41 +f 1//42 18//42 23//42 diff --git a/dexart-release/assets/sapien/102697/new_objs/102697_link_4_19.obj b/dexart-release/assets/sapien/102697/new_objs/102697_link_4_19.obj new file mode 100644 index 0000000000000000000000000000000000000000..59dbb889a91f621e683a90a6ac304eb46adb8789 --- /dev/null +++ b/dexart-release/assets/sapien/102697/new_objs/102697_link_4_19.obj @@ -0,0 +1,105 @@ +# Blender v2.79 (sub 0) OBJ File: '' +# www.blender.org +mtllib 102697_link_4_19.mtl +o Shape_IndexedFaceSet.019_Shape_IndexedFaceSet.7505 +v 0.418781 0.627924 -0.508767 +v 0.428328 0.622978 -0.508489 +v -0.056475 0.145332 -0.517066 +v -0.076686 0.620168 -0.567029 +v 0.366468 0.620168 -0.567029 +v -0.076686 0.125411 -0.567029 +v -0.054501 0.599577 -0.517235 +v 0.320187 0.115689 -0.504922 +v 0.322183 0.125411 -0.567029 +v 0.418075 0.269311 -0.508853 +v 0.411683 0.362817 -0.537486 +v -0.067659 0.626136 -0.530364 +v 0.388636 0.635034 -0.507943 +v 0.389069 0.133686 -0.508902 +v 0.351689 0.214056 -0.567029 +v 0.314606 0.110688 -0.506236 +v 0.413343 0.608120 -0.537486 +v 0.390810 0.209582 -0.537486 +v 0.432230 0.376647 -0.508872 +v 0.373857 0.605402 -0.567029 +v 0.359079 0.620168 -0.507943 +v -0.067415 0.625895 -0.551722 +vn 0.0041 0.9694 -0.2455 +vn 0.4154 0.7751 -0.4762 +vn 0.0000 0.0000 -1.0000 +vn -0.0322 0.0005 0.9995 +vn -0.9254 -0.0111 0.3788 +vn -0.9710 -0.0000 0.2391 +vn -0.7030 0.0033 0.7112 +vn -0.0527 0.4214 0.9053 +vn 0.0729 0.1956 0.9780 +vn 0.2100 0.9266 -0.3119 +vn 0.0149 0.0038 0.9999 +vn 0.0612 -0.0134 0.9980 +vn -0.0967 -0.9105 0.4021 +vn -0.0476 -0.2038 0.9778 +vn 0.0000 -0.9719 -0.2354 +vn 0.2833 -0.9396 -0.1922 +vn 0.1599 -0.4139 0.8961 +vn 0.5278 0.6137 -0.5872 +vn 0.8262 -0.1125 -0.5520 +vn 0.6491 -0.2811 -0.7069 +vn 0.8763 -0.1873 -0.4438 +vn 0.5947 -0.0810 -0.7998 +vn 0.5773 -0.1922 -0.7936 +vn 0.0375 -0.0010 0.9993 +vn 0.0502 -0.0064 0.9987 +vn 0.8317 -0.1098 -0.5442 +vn 0.8852 0.0147 -0.4650 +vn 0.8135 -0.0055 -0.5815 +vn 0.4969 -0.0281 -0.8673 +vn 0.5624 0.2814 -0.7775 +vn 0.5992 -0.0041 -0.8006 +vn -0.0228 0.0077 0.9997 +vn -0.0249 0.0495 0.9985 +vn -0.0031 0.0062 1.0000 +vn 0.0000 0.9366 -0.3504 +vn -0.5068 0.8619 -0.0155 +vn -0.0189 0.9998 -0.0115 +usemtl Shape.7511 +s off +f 13//1 5//1 22//1 +f 1//2 2//2 5//2 +f 4//3 5//3 6//3 +f 7//4 3//4 8//4 +f 6//3 5//3 9//3 +f 6//5 3//5 12//5 +f 4//6 6//6 12//6 +f 3//7 7//7 12//7 +f 12//8 7//8 13//8 +f 2//9 1//9 13//9 +f 1//10 5//10 13//10 +f 8//11 2//11 13//11 +f 10//12 8//12 14//12 +f 9//3 5//3 15//3 +f 3//13 6//13 16//13 +f 8//14 3//14 16//14 +f 6//15 9//15 16//15 +f 9//16 14//16 16//16 +f 14//17 8//17 16//17 +f 5//18 2//18 17//18 +f 11//19 10//19 18//19 +f 14//20 9//20 18//20 +f 10//21 14//21 18//21 +f 15//22 11//22 18//22 +f 9//23 15//23 18//23 +f 2//24 8//24 19//24 +f 8//25 10//25 19//25 +f 10//26 11//26 19//26 +f 17//27 2//27 19//27 +f 11//28 17//28 19//28 +f 11//29 15//29 20//29 +f 15//3 5//3 20//3 +f 5//30 17//30 20//30 +f 17//31 11//31 20//31 +f 7//32 8//32 21//32 +f 13//33 7//33 21//33 +f 8//34 13//34 21//34 +f 5//35 4//35 22//35 +f 4//36 12//36 22//36 +f 12//37 13//37 22//37 diff --git a/dexart-release/assets/sapien/102697/new_objs/102697_link_4_3.mtl b/dexart-release/assets/sapien/102697/new_objs/102697_link_4_3.mtl new file mode 100644 index 0000000000000000000000000000000000000000..4813824b1d555853c436a16aa53832e48462157e --- /dev/null +++ b/dexart-release/assets/sapien/102697/new_objs/102697_link_4_3.mtl @@ -0,0 +1,12 @@ +# Blender MTL File: 'None' +# Material Count: 1 + +newmtl Shape.7495 +Ns 400.000000 +Ka 0.400000 0.400000 0.400000 +Kd 0.168000 0.496000 0.216000 +Ks 0.250000 0.250000 0.250000 +Ke 0.000000 0.000000 0.000000 +Ni 1.000000 +d 0.500000 +illum 2 diff --git a/dexart-release/assets/sapien/102697/new_objs/102697_link_4_33.mtl b/dexart-release/assets/sapien/102697/new_objs/102697_link_4_33.mtl new file mode 100644 index 0000000000000000000000000000000000000000..eae730d995bafc1fc4d40dae64eacd7c06c8a02f --- /dev/null +++ b/dexart-release/assets/sapien/102697/new_objs/102697_link_4_33.mtl @@ -0,0 +1,12 @@ +# Blender MTL File: 'None' +# Material Count: 1 + +newmtl Shape.7525 +Ns 400.000000 +Ka 0.400000 0.400000 0.400000 +Kd 0.272000 0.624000 0.536000 +Ks 0.250000 0.250000 0.250000 +Ke 0.000000 0.000000 0.000000 +Ni 1.000000 +d 0.500000 +illum 2 diff --git a/dexart-release/assets/sapien/102697/new_objs/102697_link_4_4.mtl b/dexart-release/assets/sapien/102697/new_objs/102697_link_4_4.mtl new file mode 100644 index 0000000000000000000000000000000000000000..396e5bcf8d40c1142017984f4aad04c4a616b409 --- /dev/null +++ b/dexart-release/assets/sapien/102697/new_objs/102697_link_4_4.mtl @@ -0,0 +1,12 @@ +# Blender MTL File: 'None' +# Material Count: 1 + +newmtl Shape.7496 +Ns 400.000000 +Ka 0.400000 0.400000 0.400000 +Kd 0.720000 0.472000 0.504000 +Ks 0.250000 0.250000 0.250000 +Ke 0.000000 0.000000 0.000000 +Ni 1.000000 +d 0.500000 +illum 2 diff --git a/dexart-release/assets/sapien/102697/new_objs/102697_link_4_8.mtl b/dexart-release/assets/sapien/102697/new_objs/102697_link_4_8.mtl new file mode 100644 index 0000000000000000000000000000000000000000..26d5f949ad7bc15b162251a226af5e1cd2305870 --- /dev/null +++ b/dexart-release/assets/sapien/102697/new_objs/102697_link_4_8.mtl @@ -0,0 +1,12 @@ +# Blender MTL File: 'None' +# Material Count: 1 + +newmtl Shape.7500 +Ns 400.000000 +Ka 0.400000 0.400000 0.400000 +Kd 0.184000 0.536000 0.280000 +Ks 0.250000 0.250000 0.250000 +Ke 0.000000 0.000000 0.000000 +Ni 1.000000 +d 0.500000 +illum 2 diff --git a/dexart-release/assets/sapien/102697/result.json b/dexart-release/assets/sapien/102697/result.json new file mode 100644 index 0000000000000000000000000000000000000000..176894da5199288fa5953af8e86e86f548d373fc --- /dev/null +++ b/dexart-release/assets/sapien/102697/result.json @@ -0,0 +1 @@ +[{"text": "Toilet", "name": "Toilet", "id": 0, "children": [{"text": "button", "name": "button", "id": 1, "objs": ["original-8"]}, {"text": "lid", "name": "lid", "id": 2, "objs": ["original-4", "original-3"]}, {"text": "lid", "name": "lid", "id": 3, "objs": ["original-21", "original-20"]}, {"text": "seat", "name": "seat", "id": 4, "objs": ["original-18", "original-17"]}, {"text": "base_body", "name": "base_body", "id": 5, "objs": ["original-13", "original-6", "original-15", "original-11", "original-7", "original-10", "original-14"]}]}] \ No newline at end of file diff --git a/dexart-release/assets/sapien/102697/result_original.json b/dexart-release/assets/sapien/102697/result_original.json new file mode 100644 index 0000000000000000000000000000000000000000..ab433e563f0bc74457d8b63bac00218605d84ed8 --- /dev/null +++ b/dexart-release/assets/sapien/102697/result_original.json @@ -0,0 +1 @@ +[{"id": 0, "name": "Toilet", "text": "Toilet", "children": [{"id": 1, "name": "button", "text": "button", "objs": ["original-8"]}, {"id": 2, "name": "lid", "text": "lid", "objs": ["original-4", "original-3"]}, {"id": 3, "name": "lid", "text": "lid", "objs": ["original-21", "original-20"]}, {"id": 4, "name": "seat", "text": "seat", "objs": ["original-18", "original-17"]}]}] \ No newline at end of file diff --git a/dexart-release/assets/sapien/102697/semantics.txt b/dexart-release/assets/sapien/102697/semantics.txt new file mode 100644 index 0000000000000000000000000000000000000000..78e04a0f6d97e6d0c1c1436663e811d6b6d19b6a --- /dev/null +++ b/dexart-release/assets/sapien/102697/semantics.txt @@ -0,0 +1,5 @@ +link_0 hinge lever +link_1 slider pump_lid +link_2 hinge lid +link_3 hinge seat +link_4 static toilet_body diff --git a/dexart-release/dexart.egg-info/PKG-INFO b/dexart-release/dexart.egg-info/PKG-INFO new file mode 100644 index 0000000000000000000000000000000000000000..71b72913fe316b9a164c464ffc857ae04b7fdb21 --- /dev/null +++ b/dexart-release/dexart.egg-info/PKG-INFO @@ -0,0 +1,12 @@ +Metadata-Version: 2.1 +Name: dexart +Version: 0.1.0 +Summary: UNKNOWN +Home-page: https://github.com/Kami-code/dexart-release +Author: Xiaolong Wang's Lab +License: UNKNOWN +Platform: UNKNOWN +License-File: LICENSE + +UNKNOWN + diff --git a/dexart-release/dexart.egg-info/SOURCES.txt b/dexart-release/dexart.egg-info/SOURCES.txt new file mode 100644 index 0000000000000000000000000000000000000000..688f57b67ad7ec4a33139851809bf05ee3492d42 --- /dev/null +++ b/dexart-release/dexart.egg-info/SOURCES.txt @@ -0,0 +1,51 @@ +LICENSE +README.md +setup.py +dexart.egg-info/PKG-INFO +dexart.egg-info/SOURCES.txt +dexart.egg-info/dependency_links.txt +dexart.egg-info/requires.txt +dexart.egg-info/top_level.txt +stable_baselines3/__init__.py +stable_baselines3/pickle_utils.py +stable_baselines3/a2c/__init__.py +stable_baselines3/a2c/a2c.py +stable_baselines3/a2c/policies.py +stable_baselines3/common/__init__.py +stable_baselines3/common/base_class.py +stable_baselines3/common/buffers.py +stable_baselines3/common/callbacks.py +stable_baselines3/common/distributions.py +stable_baselines3/common/env_util.py +stable_baselines3/common/evaluation.py +stable_baselines3/common/logger.py +stable_baselines3/common/monitor.py +stable_baselines3/common/noise.py +stable_baselines3/common/on_policy_algorithm.py +stable_baselines3/common/policies.py +stable_baselines3/common/preprocessing.py +stable_baselines3/common/running_mean_std.py +stable_baselines3/common/save_util.py +stable_baselines3/common/torch_layers.py +stable_baselines3/common/type_aliases.py +stable_baselines3/common/utils.py +stable_baselines3/common/vec_env/__init__.py +stable_baselines3/common/vec_env/base_vec_env.py +stable_baselines3/common/vec_env/dummy_vec_env.py +stable_baselines3/common/vec_env/maniskill2_utils_common.py +stable_baselines3/common/vec_env/maniskill2_utils_wrappers_obs.py +stable_baselines3/common/vec_env/maniskill2_vec_env.py +stable_baselines3/common/vec_env/maniskill2_wrapper_obs.py +stable_baselines3/common/vec_env/stacked_observations.py +stable_baselines3/common/vec_env/subproc_vec_env.py +stable_baselines3/common/vec_env/util.py +stable_baselines3/common/vec_env/vec_check_nan.py +stable_baselines3/common/vec_env/vec_extract_dict_obs.py +stable_baselines3/common/vec_env/vec_frame_stack.py +stable_baselines3/common/vec_env/vec_monitor.py +stable_baselines3/common/vec_env/vec_normalize.py +stable_baselines3/common/vec_env/vec_transpose.py +stable_baselines3/common/vec_env/vec_video_recorder.py +stable_baselines3/ppo/__init__.py +stable_baselines3/ppo/policies.py +stable_baselines3/ppo/ppo.py \ No newline at end of file diff --git a/dexart-release/dexart.egg-info/dependency_links.txt b/dexart-release/dexart.egg-info/dependency_links.txt new file mode 100644 index 0000000000000000000000000000000000000000..8b137891791fe96927ad78e64b0aad7bded08bdc --- /dev/null +++ b/dexart-release/dexart.egg-info/dependency_links.txt @@ -0,0 +1 @@ + diff --git a/dexart-release/dexart.egg-info/requires.txt b/dexart-release/dexart.egg-info/requires.txt new file mode 100644 index 0000000000000000000000000000000000000000..3df94ffddf65c396377019387eec478fce9480df --- /dev/null +++ b/dexart-release/dexart.egg-info/requires.txt @@ -0,0 +1,4 @@ +transforms3d +sapien==2.2.1 +numpy +open3d>=0.15.1 diff --git a/dexart-release/dexart.egg-info/top_level.txt b/dexart-release/dexart.egg-info/top_level.txt new file mode 100644 index 0000000000000000000000000000000000000000..108131853cb412de9a5f163e52796658bcef9fb8 --- /dev/null +++ b/dexart-release/dexart.egg-info/top_level.txt @@ -0,0 +1 @@ +stable_baselines3 diff --git a/dexart-release/examples/gen_demonstration_expert.py b/dexart-release/examples/gen_demonstration_expert.py new file mode 100644 index 0000000000000000000000000000000000000000..6ae7c8e49ddbd765f3572d19c0ce210096e50488 --- /dev/null +++ b/dexart-release/examples/gen_demonstration_expert.py @@ -0,0 +1,238 @@ +import os +import sys + +sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))) + +import argparse +import zarr +import torch +import numpy as np +import torch.nn.functional as F +import pytorch3d.ops as torch3d_ops +from dexart.env.task_setting import TRAIN_CONFIG, RANDOM_CONFIG +from dexart.env.create_env import create_env +from stable_baselines3 import PPO +# from examples.train import get_3d_policy_kwargs +from train import get_3d_policy_kwargs +from tqdm import tqdm +from termcolor import cprint + + +def parse_args(): + parser = argparse.ArgumentParser() + parser.add_argument('--task_name', type=str, required=True) + parser.add_argument('--checkpoint_path', type=str, required=True) + parser.add_argument('--num_episodes', type=int, default=10, help='number of total episodes') + parser.add_argument('--use_test_set', dest='use_test_set', action='store_true', default=False) + parser.add_argument('--root_dir', type=str, default='data', help='directory to save data') + parser.add_argument('--img_size', type=int, default=84, help='image size') + parser.add_argument('--num_points', type=int, default=1024, help='number of points in point cloud') + args = parser.parse_args() + return args + +def downsample_with_fps(points: np.ndarray, num_points: int = 512): + # fast point cloud sampling using torch3d + points = torch.from_numpy(points).unsqueeze(0).cuda() + num_points = torch.tensor([num_points]).cuda() + # remember to only use coord to sample + _, sampled_indices = torch3d_ops.sample_farthest_points(points=points[...,:3], K=num_points) + points = points.squeeze(0).cpu().numpy() + points = points[sampled_indices.squeeze(0).cpu().numpy()] + return points + +def main(): + args = parse_args() + task_name = args.task_name + use_test_set = args.use_test_set + checkpoint_path = args.checkpoint_path + + + save_dir = os.path.join(args.root_dir, 'dexart_'+args.task_name+'_expert.zarr') + if os.path.exists(save_dir): + cprint('Data already exists at {}'.format(save_dir), 'red') + cprint("If you want to overwrite, delete the existing directory first.", "red") + cprint("Do you want to overwrite? (y/n)", "red") + # user_input = input() + user_input = 'y' + if user_input == 'y': + cprint('Overwriting {}'.format(save_dir), 'red') + os.system('rm -rf {}'.format(save_dir)) + else: + cprint('Exiting', 'red') + return + os.makedirs(save_dir, exist_ok=True) + + + if use_test_set: + indeces = TRAIN_CONFIG[task_name]['unseen'] + cprint(f"using unseen instances {indeces}", 'yellow') + else: + indeces = TRAIN_CONFIG[task_name]['seen'] + cprint(f"using seen instances {indeces}", 'yellow') + + rand_pos = RANDOM_CONFIG[task_name]['rand_pos'] + rand_degree = RANDOM_CONFIG[task_name]['rand_degree'] + env = create_env(task_name=task_name, + use_visual_obs=True, + use_gui=False, + is_eval=True, + pc_noise=True, + pc_seg=True, + index=indeces, + img_type='robot', + rand_pos=rand_pos, + rand_degree=rand_degree) + + policy = PPO.load(checkpoint_path, env, 'cuda:0', + policy_kwargs=get_3d_policy_kwargs(extractor_name='smallpn'), + check_obs_space=False, force_load=True) + + eval_instances = len(env.instance_list) + num_episodes = args.num_episodes + cprint(f"generate {num_episodes} episodes in total", 'yellow') + + success_list = [] + reward_list = [] + + total_count = 0 + img_arrays = [] + point_cloud_arrays = [] + depth_arrays = [] + state_arrays = [] + imagin_robot_arrays = [] + action_arrays = [] + episode_ends_arrays = [] + + + with tqdm(total=num_episodes) as pbar: + num_success = 0 + while num_success < num_episodes: + + # obs dict keys: 'instance_1-seg_gt', 'instance_1-point_cloud', + # 'instance_1-rgb', 'imagination_robot', 'state', 'oracle_state' + obs = env.reset() + eval_success = False + reward_sum = 0 + + img_arrays_sub = [] + point_cloud_arrays_sub = [] + depth_arrays_sub = [] + state_arrays_sub = [] + imagin_robot_arrays_sub = [] + action_arrays_sub = [] + total_count_sub = 0 + for j in range(env.horizon): + + if isinstance(obs, dict): + for key, value in obs.items(): + obs[key] = value[np.newaxis, :] + else: + obs = obs[np.newaxis, :] + action = policy.predict(observation=obs, deterministic=True)[0] + + # fetch data + total_count_sub += 1 + obs_state = obs['state'][0] # (32) + obs_imagin_robot = obs['imagination_robot'][0] # (96,7) + obs_point_cloud = obs['instance_1-point_cloud'][0] # (1024,3) + obs_depth = obs['instance_1-depth'][0] # (84,84) + + if obs_point_cloud.shape[0] > args.num_points: + obs_point_cloud = downsample_with_fps(obs_point_cloud, num_points=args.num_points) + obs_image = obs['instance_1-rgb'][0] # (84,84,3), [0,1] + + + # to 0-255 + obs_image = (obs_image*255).astype(np.uint8) + + # interpolate to target image size + if obs_image.shape[0] != args.img_size: + obs_image = F.interpolate(torch.from_numpy(obs_image).permute(2,0,1).unsqueeze(0), + size=args.img_size).squeeze().permute(1,2,0).numpy() + # save data + img_arrays_sub.append(obs_image) + imagin_robot_arrays_sub.append(obs_imagin_robot) + point_cloud_arrays_sub.append(obs_point_cloud) + depth_arrays_sub.append(obs_depth) + state_arrays_sub.append(obs_state) + action_arrays_sub.append(action) + + # step + obs, reward, done, _ = env.step(action) + reward_sum += reward + if env.is_eval_done: + eval_success = True + if done: + break + + if eval_success: + total_count += total_count_sub + episode_ends_arrays.append(total_count) # the index of the last step of the episode + reward_list.append(reward_sum) + success_list.append(int(eval_success)) + + img_arrays.extend(img_arrays_sub) + imagin_robot_arrays.extend(imagin_robot_arrays_sub) + point_cloud_arrays.extend(point_cloud_arrays_sub) + depth_arrays.extend(depth_arrays_sub) + state_arrays.extend(state_arrays_sub) + action_arrays.extend(action_arrays_sub) + + num_success += 1 + + pbar.update(1) + pbar.set_description(f"reward = {reward_sum}, success = {eval_success}") + else: + print("episode failed. continue.") + continue + + + cprint(f"reward_mean = {np.mean(reward_list)}, success rate = {np.mean(success_list)}", 'yellow') + + ############################### + # save data + ############################### + # create zarr file + zarr_root = zarr.group(save_dir) + zarr_data = zarr_root.create_group('data') + zarr_meta = zarr_root.create_group('meta') + # save img, state, action arrays into data, and episode ends arrays into meta + img_arrays = np.stack(img_arrays, axis=0) + if img_arrays.shape[1] == 3: # make channel last + img_arrays = np.transpose(img_arrays, (0,2,3,1)) + state_arrays = np.stack(state_arrays, axis=0) + imagin_robot_arrays = np.stack(imagin_robot_arrays, axis=0) + point_cloud_arrays = np.stack(point_cloud_arrays, axis=0) + depth_arrays = np.stack(depth_arrays, axis=0) + action_arrays = np.stack(action_arrays, axis=0) + episode_ends_arrays = np.array(episode_ends_arrays) + + compressor = zarr.Blosc(cname='zstd', clevel=3, shuffle=1) + img_chunk_size = (env.horizon, img_arrays.shape[1], img_arrays.shape[2], img_arrays.shape[3]) + imagin_robot_chunk_size = (env.horizon, imagin_robot_arrays.shape[1], imagin_robot_arrays.shape[2]) + point_cloud_chunk_size = (env.horizon, point_cloud_arrays.shape[1], point_cloud_arrays.shape[2]) + depth_chunk_size = (env.horizon, depth_arrays.shape[1], depth_arrays.shape[2]) + state_chunk_size = (env.horizon, state_arrays.shape[1]) + action_chunk_size = (env.horizon, action_arrays.shape[1]) + zarr_data.create_dataset('img', data=img_arrays, chunks=img_chunk_size, dtype='uint8', overwrite=True, compressor=compressor) + zarr_data.create_dataset('state', data=state_arrays, chunks=state_chunk_size, dtype='float32', overwrite=True, compressor=compressor) + zarr_data.create_dataset('imagin_robot', data=imagin_robot_arrays, chunks=imagin_robot_chunk_size, dtype='float32', overwrite=True, compressor=compressor) + zarr_data.create_dataset('point_cloud', data=point_cloud_arrays, chunks=point_cloud_chunk_size, dtype='float32', overwrite=True, compressor=compressor) + zarr_data.create_dataset('depth', data=depth_arrays, chunks=depth_chunk_size, dtype='float32', overwrite=True, compressor=compressor) + zarr_data.create_dataset('action', data=action_arrays, chunks=action_chunk_size, dtype='float32', overwrite=True, compressor=compressor) + zarr_meta.create_dataset('episode_ends', data=episode_ends_arrays, dtype='int64', overwrite=True, compressor=compressor) + + # print shape + cprint(f'img shape: {img_arrays.shape}, range: [{np.min(img_arrays)}, {np.max(img_arrays)}]', 'green') + cprint(f'imagin_robot shape: {imagin_robot_arrays.shape}, range: [{np.min(imagin_robot_arrays)}, {np.max(imagin_robot_arrays)}]', 'green') + cprint(f'point_cloud shape: {point_cloud_arrays.shape}, range: [{np.min(point_cloud_arrays)}, {np.max(point_cloud_arrays)}]', 'green') + cprint(f'depth shape: {depth_arrays.shape}, range: [{np.min(depth_arrays)}, {np.max(depth_arrays)}]', 'green') + cprint(f'state shape: {state_arrays.shape}, range: [{np.min(state_arrays)}, {np.max(state_arrays)}]', 'green') + cprint(f'action shape: {action_arrays.shape}, range: [{np.min(action_arrays)}, {np.max(action_arrays)}]', 'green') + cprint(f'Saved zarr file to {save_dir}', 'green') + + +if __name__ == "__main__": + main() + + diff --git a/dexart-release/examples/train.py b/dexart-release/examples/train.py new file mode 100644 index 0000000000000000000000000000000000000000..2f78b55bcf88c0353b79e4024e4ce2a86c57fdda --- /dev/null +++ b/dexart-release/examples/train.py @@ -0,0 +1,124 @@ +import os +import sys +sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))) +import random +from collections import OrderedDict +import torch.nn as nn +import argparse +from dexart.env.create_env import create_env +from dexart.env.task_setting import TRAIN_CONFIG, IMG_CONFIG, RANDOM_CONFIG +from stable_baselines3.common.torch_layers import PointNetImaginationExtractorGP +from stable_baselines3.common.vec_env.subproc_vec_env import SubprocVecEnv +from stable_baselines3.ppo import PPO +import torch + +BASE_DIR = os.path.dirname(os.path.abspath(__file__)) + + +def get_3d_policy_kwargs(extractor_name): + feature_extractor_class = PointNetImaginationExtractorGP + feature_extractor_kwargs = {"pc_key": "instance_1-point_cloud", "gt_key": "instance_1-seg_gt", + "extractor_name": extractor_name, + "imagination_keys": [f'imagination_{key}' for key in IMG_CONFIG['robot'].keys()], + "state_key": "state"} + + policy_kwargs = { + "features_extractor_class": feature_extractor_class, + "features_extractor_kwargs": feature_extractor_kwargs, + "net_arch": [dict(pi=[64, 64], vf=[64, 64])], + "activation_fn": nn.ReLU, + } + return policy_kwargs + + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + parser.add_argument('--n', type=int, default=10) + parser.add_argument('--workers', type=int, default=1) + parser.add_argument('--lr', type=float, default=3e-4) + parser.add_argument('--ep', type=int, default=10) + parser.add_argument('--bs', type=int, default=10) + parser.add_argument('--seed', type=int, default=100) + parser.add_argument('--iter', type=int, default=1000) + parser.add_argument('--freeze', dest='freeze', action='store_true', default=False) + parser.add_argument('--task_name', type=str, default="laptop") + parser.add_argument('--extractor_name', type=str, default="smallpn") + parser.add_argument('--pretrain_path', type=str, default=None) + args = parser.parse_args() + + task_name = args.task_name + extractor_name = args.extractor_name + seed = args.seed if args.seed >= 0 else random.randint(0, 100000) + pretrain_path = args.pretrain_path + horizon = 200 + env_iter = args.iter * horizon * args.n + print(f"freeze: {args.freeze}") + + rand_pos = RANDOM_CONFIG[task_name]['rand_pos'] + rand_degree = RANDOM_CONFIG[task_name]['rand_degree'] + + + def create_env_fn(): + seen_indeces = TRAIN_CONFIG[task_name]['seen'] + environment = create_env(task_name=task_name, + use_visual_obs=True, + use_gui=False, + is_eval=False, + pc_noise=True, + index=seen_indeces, + img_type='robot', + rand_pos=rand_pos, + rand_degree=rand_degree + ) + return environment + + + def create_eval_env_fn(): + unseen_indeces = TRAIN_CONFIG[task_name]['unseen'] + environment = create_env(task_name=task_name, + use_visual_obs=True, + use_gui=False, + is_eval=True, + pc_noise=True, + index=unseen_indeces, + img_type='robot', + rand_pos=rand_pos, + rand_degree=rand_degree) + return environment + + + env = SubprocVecEnv([create_env_fn] * args.workers, "spawn") # train on a list of envs. + + model = PPO("PointCloudPolicy", env, verbose=1, + n_epochs=args.ep, + n_steps=(args.n // args.workers) * horizon, + learning_rate=args.lr, + batch_size=args.bs, + seed=seed, + policy_kwargs=get_3d_policy_kwargs(extractor_name=extractor_name), + min_lr=args.lr, + max_lr=args.lr, + adaptive_kl=0.02, + target_kl=0.2, + ) + + if pretrain_path is not None: + state_dict: OrderedDict = torch.load(pretrain_path) + model.policy.features_extractor.extractor.load_state_dict(state_dict, strict=False) + print("load pretrained model: ", pretrain_path) + + rollout = int(model.num_timesteps / (horizon * args.n)) + + # after loading or init the model, then freeze it if needed + if args.freeze: + model.policy.features_extractor.extractor.eval() + for param in model.policy.features_extractor.extractor.parameters(): + param.requires_grad = False + print("freeze model!") + + model.learn( + total_timesteps=int(env_iter), + reset_num_timesteps=False, + iter_start=rollout, + callback=None + ) diff --git a/dexart-release/examples/utils.py b/dexart-release/examples/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..cbe23a9a962a904ca60660165af003f226fb8b09 --- /dev/null +++ b/dexart-release/examples/utils.py @@ -0,0 +1,66 @@ +#!/usr/bin/env python +# -*- coding: UTF-8 -*- +import numpy as np +from dexart.env.task_setting import ROBUSTNESS_INIT_CAMERA_CONFIG +import open3d as o3d + +def visualize_observation(obs, use_seg=False, img_type=None): + def visualize_pc_with_seg_label(cloud): + pc = o3d.geometry.PointCloud(o3d.utility.Vector3dVector(cloud[:, :3])) + + def map(feature): + color = np.zeros((feature.shape[0], 3)) + COLOR20 = np.array( + [[230, 25, 75], [60, 180, 75], [255, 225, 25], [0, 130, 200], [245, 130, 48], + [145, 30, 180], [70, 240, 240], [240, 50, 230], [210, 245, 60], [250, 190, 190], + [0, 128, 128], [230, 190, 255], [170, 110, 40], [255, 250, 200], [128, 0, 0], + [170, 255, 195], [128, 128, 0], [255, 215, 180], [0, 0, 128], [128, 128, 128]]) / 255 + for i in range(feature.shape[0]): + for j in range(feature.shape[1]): + if feature[i, j] == 1: + color[i, :] = COLOR20[j, :] + return color + + color = map(cloud[:, 3:]) + pc.colors = o3d.utility.Vector3dVector(color) + return pc + + pc = obs["instance_1-point_cloud"] + if use_seg: + gt_seg = obs["instance_1-seg_gt"] + pc = np.concatenate([pc, gt_seg], axis=1) + pc = visualize_pc_with_seg_label(pc) + if img_type == "robot": + robot_pc = obs["imagination_robot"] + pc += visualize_pc_with_seg_label(robot_pc) + else: + raise NotImplementedError + return pc + + +def get_viewpoint_camera_parameter(): + robustness_init_camera_config = ROBUSTNESS_INIT_CAMERA_CONFIG['laptop'] + r = robustness_init_camera_config['r'] + phi = robustness_init_camera_config['phi'] + theta = robustness_init_camera_config['theta'] + center = robustness_init_camera_config['center'] + + x0, y0, z0 = center + # phi in [0, pi/2] + # theta in [0, 2 * pi] + x = x0 + r * np.sin(phi) * np.cos(theta) + y = y0 + r * np.sin(phi) * np.sin(theta) + z = z0 + r * np.cos(phi) + + cam_pos = np.array([x, y, z]) + forward = np.array([x0 - x, y0 - y, z0 - z]) + forward /= np.linalg.norm(forward) + + left = np.cross([0, 0, 1], forward) + left = left / np.linalg.norm(left) + + up = np.cross(forward, left) + mat44 = np.eye(4) + mat44[:3, :3] = np.stack([forward, left, up], axis=1) + mat44[:3, 3] = cam_pos + return cam_pos, center, up, mat44 \ No newline at end of file diff --git a/dexart-release/stable_baselines3/a2c/__init__.py b/dexart-release/stable_baselines3/a2c/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..7e9996494b658e9bef6efccfc9cdad269da4609d --- /dev/null +++ b/dexart-release/stable_baselines3/a2c/__init__.py @@ -0,0 +1,2 @@ +from stable_baselines3.a2c.a2c import A2C +from stable_baselines3.a2c.policies import CnnPolicy, MlpPolicy, MultiInputPolicy diff --git a/dexart-release/stable_baselines3/a2c/a2c.py b/dexart-release/stable_baselines3/a2c/a2c.py new file mode 100644 index 0000000000000000000000000000000000000000..13adf680001deaaae21659d09b22cdeaffd7a775 --- /dev/null +++ b/dexart-release/stable_baselines3/a2c/a2c.py @@ -0,0 +1,207 @@ +from typing import Any, Dict, Optional, Type, Union + +import torch as th +from gym import spaces +from torch.nn import functional as F + +from stable_baselines3.common.on_policy_algorithm import OnPolicyAlgorithm +from stable_baselines3.common.policies import ActorCriticCnnPolicy, ActorCriticPolicy, BasePolicy, MultiInputActorCriticPolicy +from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, Schedule +from stable_baselines3.common.utils import explained_variance + + +class A2C(OnPolicyAlgorithm): + """ + Advantage Actor Critic (A2C) + + Paper: https://arxiv.org/abs/1602.01783 + Code: This implementation borrows code from https://github.com/ikostrikov/pytorch-a2c-ppo-acktr-gail and + and Stable Baselines (https://github.com/hill-a/stable-baselines) + + Introduction to A2C: https://hackernoon.com/intuitive-rl-intro-to-advantage-actor-critic-a2c-4ff545978752 + + :param policy: The policy model to use (MlpPolicy, CnnPolicy, ...) + :param env: The environment to learn from (if registered in Gym, can be str) + :param learning_rate: The learning rate, it can be a function + of the current progress remaining (from 1 to 0) + :param n_steps: The number of steps to run for each environment per update + (i.e. batch size is n_steps * n_env where n_env is number of environment copies running in parallel) + :param gamma: Discount factor + :param gae_lambda: Factor for trade-off of bias vs variance for Generalized Advantage Estimator + Equivalent to classic advantage when set to 1. + :param ent_coef: Entropy coefficient for the loss calculation + :param vf_coef: Value function coefficient for the loss calculation + :param max_grad_norm: The maximum value for the gradient clipping + :param rms_prop_eps: RMSProp epsilon. It stabilizes square root computation in denominator + of RMSProp update + :param use_rms_prop: Whether to use RMSprop (default) or Adam as optimizer + :param use_sde: Whether to use generalized State Dependent Exploration (gSDE) + instead of action noise exploration (default: False) + :param sde_sample_freq: Sample a new noise matrix every n steps when using gSDE + Default: -1 (only sample at the beginning of the rollout) + :param normalize_advantage: Whether to normalize or not the advantage + :param tensorboard_log: the log location for tensorboard (if None, no logging) + :param create_eval_env: Whether to create a second environment that will be + used for evaluating the agent periodically. (Only available when passing string for the environment) + :param policy_kwargs: additional arguments to be passed to the policy on creation + :param verbose: the verbosity level: 0 no output, 1 info, 2 debug + :param seed: Seed for the pseudo random generators + :param device: Device (cpu, cuda, ...) on which the code should be run. + Setting it to auto, the code will be run on the GPU if possible. + :param _init_setup_model: Whether or not to build the network at the creation of the instance + """ + + policy_aliases: Dict[str, Type[BasePolicy]] = { + "MlpPolicy": ActorCriticPolicy, + "CnnPolicy": ActorCriticCnnPolicy, + "MultiInputPolicy": MultiInputActorCriticPolicy, + } + + def __init__( + self, + policy: Union[str, Type[ActorCriticPolicy]], + env: Union[GymEnv, str], + learning_rate: Union[float, Schedule] = 7e-4, + n_steps: int = 5, + gamma: float = 0.99, + gae_lambda: float = 1.0, + ent_coef: float = 0.0, + vf_coef: float = 0.5, + max_grad_norm: float = 0.5, + rms_prop_eps: float = 1e-5, + use_rms_prop: bool = True, + use_sde: bool = False, + sde_sample_freq: int = -1, + normalize_advantage: bool = False, + tensorboard_log: Optional[str] = None, + create_eval_env: bool = False, + policy_kwargs: Optional[Dict[str, Any]] = None, + verbose: int = 0, + seed: Optional[int] = None, + device: Union[th.device, str] = "auto", + _init_setup_model: bool = True, + ): + + super().__init__( + policy, + env, + learning_rate=learning_rate, + n_steps=n_steps, + gamma=gamma, + gae_lambda=gae_lambda, + ent_coef=ent_coef, + vf_coef=vf_coef, + max_grad_norm=max_grad_norm, + use_sde=use_sde, + sde_sample_freq=sde_sample_freq, + tensorboard_log=tensorboard_log, + policy_kwargs=policy_kwargs, + verbose=verbose, + device=device, + create_eval_env=create_eval_env, + seed=seed, + _init_setup_model=False, + supported_action_spaces=( + spaces.Box, + spaces.Discrete, + spaces.MultiDiscrete, + spaces.MultiBinary, + ), + ) + + self.normalize_advantage = normalize_advantage + + # Update optimizer inside the policy if we want to use RMSProp + # (original implementation) rather than Adam + if use_rms_prop and "optimizer_class" not in self.policy_kwargs: + self.policy_kwargs["optimizer_class"] = th.optim.RMSprop + self.policy_kwargs["optimizer_kwargs"] = dict(alpha=0.99, eps=rms_prop_eps, weight_decay=0) + + if _init_setup_model: + self._setup_model() + + def train(self) -> None: + """ + Update policy using the currently gathered + rollout buffer (one gradient step over whole data). + """ + # Switch to train mode (this affects batch norm / dropout) + self.policy.set_training_mode(True) + + # Update optimizer learning rate + self._update_learning_rate(self.policy.optimizer) + + # This will only loop once (get all data in one go) + for rollout_data in self.rollout_buffer.get(batch_size=None): + + actions = rollout_data.actions + if isinstance(self.action_space, spaces.Discrete): + # Convert discrete action from float to long + actions = actions.long().flatten() + + values, log_prob, entropy = self.policy.evaluate_actions(rollout_data.observations, actions) + values = values.flatten() + + # Normalize advantage (not present in the original implementation) + advantages = rollout_data.advantages + if self.normalize_advantage: + advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8) + + # Policy gradient loss + policy_loss = -(advantages * log_prob).mean() + + # Value loss using the TD(gae_lambda) target + value_loss = F.mse_loss(rollout_data.returns, values) + + # Entropy loss favor exploration + if entropy is None: + # Approximate entropy when no analytical form + entropy_loss = -th.mean(-log_prob) + else: + entropy_loss = -th.mean(entropy) + + loss = policy_loss + self.ent_coef * entropy_loss + self.vf_coef * value_loss + + # Optimization step + self.policy.optimizer.zero_grad() + loss.backward() + + # Clip grad norm + th.nn.utils.clip_grad_norm_(self.policy.parameters(), self.max_grad_norm) + self.policy.optimizer.step() + + explained_var = explained_variance(self.rollout_buffer.values.flatten(), self.rollout_buffer.returns.flatten()) + + self._n_updates += 1 + self.logger.record("train/n_updates", self._n_updates, exclude="tensorboard") + self.logger.record("train/explained_variance", explained_var) + self.logger.record("train/entropy_loss", entropy_loss.item()) + self.logger.record("train/policy_loss", policy_loss.item()) + self.logger.record("train/value_loss", value_loss.item()) + if hasattr(self.policy, "log_std"): + self.logger.record("train/std", th.exp(self.policy.log_std).mean().item()) + + def learn( + self, + total_timesteps: int, + callback: MaybeCallback = None, + log_interval: int = 100, + eval_env: Optional[GymEnv] = None, + eval_freq: int = -1, + n_eval_episodes: int = 5, + tb_log_name: str = "A2C", + eval_log_path: Optional[str] = None, + reset_num_timesteps: bool = True, + ) -> "A2C": + + return super().learn( + total_timesteps=total_timesteps, + callback=callback, + log_interval=log_interval, + eval_env=eval_env, + eval_freq=eval_freq, + n_eval_episodes=n_eval_episodes, + tb_log_name=tb_log_name, + eval_log_path=eval_log_path, + reset_num_timesteps=reset_num_timesteps, + ) diff --git a/dexart-release/stable_baselines3/a2c/policies.py b/dexart-release/stable_baselines3/a2c/policies.py new file mode 100644 index 0000000000000000000000000000000000000000..7299b34df478f6392cf5182842dacd83373aed0f --- /dev/null +++ b/dexart-release/stable_baselines3/a2c/policies.py @@ -0,0 +1,7 @@ +# This file is here just to define MlpPolicy/CnnPolicy +# that work for A2C +from stable_baselines3.common.policies import ActorCriticCnnPolicy, ActorCriticPolicy, MultiInputActorCriticPolicy + +MlpPolicy = ActorCriticPolicy +CnnPolicy = ActorCriticCnnPolicy +MultiInputPolicy = MultiInputActorCriticPolicy diff --git a/dexart-release/stable_baselines3/common/__init__.py b/dexart-release/stable_baselines3/common/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/dexart-release/stable_baselines3/common/base_class.py b/dexart-release/stable_baselines3/common/base_class.py new file mode 100644 index 0000000000000000000000000000000000000000..31266e76b9f00b3c5c2e52e3add465aec4f3e9e4 --- /dev/null +++ b/dexart-release/stable_baselines3/common/base_class.py @@ -0,0 +1,835 @@ +"""Abstract base classes for RL algorithms.""" + +import io +import pathlib +import time +from abc import ABC, abstractmethod +from collections import deque +from typing import Any, Dict, Iterable, List, Optional, Tuple, Type, Union + +import gym +import numpy as np +import torch as th + +from stable_baselines3.common import utils +from stable_baselines3.common.callbacks import BaseCallback, CallbackList, ConvertCallback, EvalCallback +from stable_baselines3.common.env_util import is_wrapped +from stable_baselines3.common.logger import Logger +from stable_baselines3.common.monitor import Monitor +from stable_baselines3.common.noise import ActionNoise +from stable_baselines3.common.policies import BasePolicy +from stable_baselines3.common.preprocessing import check_for_nested_spaces, is_image_space, is_image_space_channels_first +from stable_baselines3.common.save_util import load_from_zip_file, recursive_getattr, recursive_setattr, save_to_zip_file +from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, Schedule +from stable_baselines3.common.utils import ( + check_for_correct_spaces, + get_device, + get_schedule_fn, + get_system_info, + set_random_seed, + update_learning_rate, +) +from stable_baselines3.common.vec_env import ( + DummyVecEnv, + VecEnv, + VecNormalize, + VecTransposeImage, + is_vecenv_wrapped, + unwrap_vec_normalize, +) + + +def maybe_make_env(env: Union[GymEnv, str, None], verbose: int) -> Optional[GymEnv]: + """If env is a string, make the environment; otherwise, return env. + + :param env: The environment to learn from. + :param verbose: logging verbosity + :return A Gym (vector) environment. + """ + if isinstance(env, str): + if verbose >= 1: + print(f"Creating environment from the given name '{env}'") + env = gym.make(env) + return env + + +class BaseAlgorithm(ABC): + """ + The base of RL algorithms + + :param policy: Policy object + :param env: The environment to learn from + (if registered in Gym, can be str. Can be None for loading trained models) + :param learning_rate: learning rate for the optimizer, + it can be a function of the current progress remaining (from 1 to 0) + :param policy_kwargs: Additional arguments to be passed to the policy on creation + :param tensorboard_log: the log location for tensorboard (if None, no logging) + :param verbose: The verbosity level: 0 none, 1 training information, 2 debug + :param device: Device on which the code should run. + By default, it will try to use a Cuda compatible device and fallback to cpu + if it is not possible. + :param support_multi_env: Whether the algorithm supports training + with multiple environments (as in A2C) + :param create_eval_env: Whether to create a second environment that will be + used for evaluating the agent periodically. (Only available when passing string for the environment) + :param monitor_wrapper: When creating an environment, whether to wrap it + or not in a Monitor wrapper. + :param seed: Seed for the pseudo random generators + :param use_sde: Whether to use generalized State Dependent Exploration (gSDE) + instead of action noise exploration (default: False) + :param sde_sample_freq: Sample a new noise matrix every n steps when using gSDE + Default: -1 (only sample at the beginning of the rollout) + :param supported_action_spaces: The action spaces supported by the algorithm. + """ + + # Policy aliases (see _get_policy_from_name()) + policy_aliases: Dict[str, Type[BasePolicy]] = {} + + def __init__( + self, + policy: Type[BasePolicy], + env: Union[GymEnv, str, None], + learning_rate: Union[float, Schedule], + policy_kwargs: Optional[Dict[str, Any]] = None, + tensorboard_log: Optional[str] = None, + verbose: int = 0, + device: Union[th.device, str] = "auto", + support_multi_env: bool = False, + create_eval_env: bool = False, + monitor_wrapper: bool = True, + seed: Optional[int] = None, + use_sde: bool = False, + sde_sample_freq: int = -1, + supported_action_spaces: Optional[Tuple[gym.spaces.Space, ...]] = None, + ): + if isinstance(policy, str): + self.policy_class = self._get_policy_from_name(policy) + else: + self.policy_class = policy + + self.device = get_device(device) + if verbose > 0: + print(f"Using {self.device} device") + + self.env = None # type: Optional[GymEnv] + # get VecNormalize object if needed + self._vec_normalize_env = unwrap_vec_normalize(env) + self.verbose = verbose + self.policy_kwargs = {} if policy_kwargs is None else policy_kwargs + self.observation_space = None # type: Optional[gym.spaces.Space] + self.action_space = None # type: Optional[gym.spaces.Space] + self.n_envs = None + self.num_timesteps = 0 + # Used for updating schedules + self._total_timesteps = 0 + # Used for computing fps, it is updated at each call of learn() + self._num_timesteps_at_start = 0 + self.eval_env = None + self.seed = seed + self.action_noise = None # type: Optional[ActionNoise] + self.start_time = None + self.policy = None + self.learning_rate = learning_rate + self.tensorboard_log = tensorboard_log + self.lr_schedule = None # type: Optional[Schedule] + self._last_obs = None # type: Optional[Union[np.ndarray, Dict[str, np.ndarray]]] + self._last_episode_starts = None # type: Optional[np.ndarray] + # When using VecNormalize: + self._last_original_obs = None # type: Optional[Union[np.ndarray, Dict[str, np.ndarray]]] + self._episode_num = 0 + # Used for gSDE only + self.use_sde = use_sde + self.sde_sample_freq = sde_sample_freq + # Track the training progress remaining (from 1 to 0) + # this is used to update the learning rate + self._current_progress_remaining = 1 + # Buffers for logging + self.ep_info_buffer = None # type: Optional[deque] + self.ep_success_buffer = None # type: Optional[deque] + # For logging (and TD3 delayed updates) + self._n_updates = 0 # type: int + # The logger object + self._logger = None # type: Logger + # Whether the user passed a custom logger or not + self._custom_logger = False + + # Create and wrap the env if needed + if env is not None: + if isinstance(env, str): + if create_eval_env: + self.eval_env = maybe_make_env(env, self.verbose) + + env = maybe_make_env(env, self.verbose) + env = self._wrap_env(env, self.verbose, monitor_wrapper) + + self.observation_space = env.observation_space + self.action_space = env.action_space + self.n_envs = env.num_envs + self.env = env + + if supported_action_spaces is not None: + assert isinstance(self.action_space, supported_action_spaces), ( + f"The algorithm only supports {supported_action_spaces} as action spaces " + f"but {self.action_space} was provided" + ) + + if not support_multi_env and self.n_envs > 1: + raise ValueError( + "Error: the model does not support multiple envs; it requires " "a single vectorized environment." + ) + + # Catch common mistake: using MlpPolicy/CnnPolicy instead of MultiInputPolicy + if policy in ["MlpPolicy", "CnnPolicy"] and isinstance(self.observation_space, gym.spaces.Dict): + raise ValueError(f"You must use `MultiInputPolicy` when working with dict observation space, not {policy}") + + if self.use_sde and not isinstance(self.action_space, gym.spaces.Box): + raise ValueError("generalized State-Dependent Exploration (gSDE) can only be used with continuous actions.") + + @staticmethod + def _wrap_env(env: GymEnv, verbose: int = 0, monitor_wrapper: bool = True) -> VecEnv: + """ " + Wrap environment with the appropriate wrappers if needed. + For instance, to have a vectorized environment + or to re-order the image channels. + + :param env: + :param verbose: + :param monitor_wrapper: Whether to wrap the env in a ``Monitor`` when possible. + :return: The wrapped environment. + """ + if not isinstance(env, VecEnv): + if not is_wrapped(env, Monitor) and monitor_wrapper: + if verbose >= 1: + print("Wrapping the env with a `Monitor` wrapper") + env = Monitor(env) + if verbose >= 1: + print("Wrapping the env in a DummyVecEnv.") + env = DummyVecEnv([lambda: env]) + + # Make sure that dict-spaces are not nested (not supported) + check_for_nested_spaces(env.observation_space) + + if not is_vecenv_wrapped(env, VecTransposeImage): + wrap_with_vectranspose = False + if isinstance(env.observation_space, gym.spaces.Dict): + # If even one of the keys is a image-space in need of transpose, apply transpose + # If the image spaces are not consistent (for instance one is channel first, + # the other channel last), VecTransposeImage will throw an error + for space in env.observation_space.spaces.values(): + wrap_with_vectranspose = wrap_with_vectranspose or ( + is_image_space(space) and not is_image_space_channels_first(space) + ) + else: + wrap_with_vectranspose = is_image_space(env.observation_space) and not is_image_space_channels_first( + env.observation_space + ) + + if wrap_with_vectranspose: + if verbose >= 1: + print("Wrapping the env in a VecTransposeImage.") + env = VecTransposeImage(env) + + return env + + @abstractmethod + def _setup_model(self) -> None: + """Create networks, buffer and optimizers.""" + + def set_logger(self, logger: Logger) -> None: + """ + Setter for for logger object. + + .. warning:: + + When passing a custom logger object, + this will overwrite ``tensorboard_log`` and ``verbose`` settings + passed to the constructor. + """ + self._logger = logger + # User defined logger + self._custom_logger = True + + @property + def logger(self) -> Logger: + """Getter for the logger object.""" + return self._logger + + def _get_eval_env(self, eval_env: Optional[GymEnv]) -> Optional[GymEnv]: + """ + Return the environment that will be used for evaluation. + + :param eval_env:) + :return: + """ + if eval_env is None: + eval_env = self.eval_env + + if eval_env is not None: + eval_env = self._wrap_env(eval_env, self.verbose) + assert eval_env.num_envs == 1 + return eval_env + + def _setup_lr_schedule(self) -> None: + """Transform to callable if needed.""" + self.lr_schedule = get_schedule_fn(self.learning_rate) + + def _update_current_progress_remaining(self, num_timesteps: int, total_timesteps: int) -> None: + """ + Compute current progress remaining (starts from 1 and ends to 0) + + :param num_timesteps: current number of timesteps + :param total_timesteps: + """ + self._current_progress_remaining = 1.0 - float(num_timesteps) / float(total_timesteps) + + def _update_learning_rate(self, optimizers: Union[List[th.optim.Optimizer], th.optim.Optimizer]) -> None: + """ + Update the optimizers learning rate using the current learning rate schedule + and the current progress remaining (from 1 to 0). + + :param optimizers: + An optimizer or a list of optimizers. + """ + # Log the current learning rate + self.logger.record("train/learning_rate", self.lr_schedule(self._current_progress_remaining)) + + if not isinstance(optimizers, list): + optimizers = [optimizers] + for optimizer in optimizers: + update_learning_rate(optimizer, self.lr_schedule(self._current_progress_remaining)) + + def _excluded_save_params(self) -> List[str]: + """ + Returns the names of the parameters that should be excluded from being + saved by pickling. E.g. replay buffers are skipped by default + as they take up a lot of space. PyTorch variables should be excluded + with this so they can be stored with ``th.save``. + + :return: List of parameters that should be excluded from being saved with pickle. + """ + return [ + "policy", + "device", + "env", + "eval_env", + "replay_buffer", + "rollout_buffer", + "_vec_normalize_env", + "_episode_storage", + "_logger", + "_custom_logger", + ] + + def _get_policy_from_name(self, policy_name: str) -> Type[BasePolicy]: + """ + Get a policy class from its name representation. + + The goal here is to standardize policy naming, e.g. + all algorithms can call upon "MlpPolicy" or "CnnPolicy", + and they receive respective policies that work for them. + + :param policy_name: Alias of the policy + :return: A policy class (type) + """ + + if policy_name in self.policy_aliases: + return self.policy_aliases[policy_name] + else: + raise ValueError(f"Policy {policy_name} unknown") + + def _get_torch_save_params(self) -> Tuple[List[str], List[str]]: + """ + Get the name of the torch variables that will be saved with + PyTorch ``th.save``, ``th.load`` and ``state_dicts`` instead of the default + pickling strategy. This is to handle device placement correctly. + + Names can point to specific variables under classes, e.g. + "policy.optimizer" would point to ``optimizer`` object of ``self.policy`` + if this object. + + :return: + List of Torch variables whose state dicts to save (e.g. th.nn.Modules), + and list of other Torch variables to store with ``th.save``. + """ + state_dicts = ["policy"] + + return state_dicts, [] + + def _init_callback( + self, + callback: MaybeCallback, + eval_env: Optional[VecEnv] = None, + eval_freq: int = 10000, + n_eval_episodes: int = 5, + log_path: Optional[str] = None, + ) -> BaseCallback: + """ + :param callback: Callback(s) called at every step with state of the algorithm. + :param eval_freq: How many steps between evaluations; if None, do not evaluate. + :param n_eval_episodes: How many episodes to play per evaluation + :param n_eval_episodes: Number of episodes to rollout during evaluation. + :param log_path: Path to a folder where the evaluations will be saved + :return: A hybrid callback calling `callback` and performing evaluation. + """ + # Convert a list of callbacks into a callback + if isinstance(callback, list): + callback = CallbackList(callback) + + # Convert functional callback to object + if not isinstance(callback, BaseCallback): + callback = ConvertCallback(callback) + + # Create eval callback in charge of the evaluation + if eval_env is not None: + eval_callback = EvalCallback( + eval_env, + best_model_save_path=log_path, + log_path=log_path, + eval_freq=eval_freq, + n_eval_episodes=n_eval_episodes, + ) + callback = CallbackList([callback, eval_callback]) + + callback.init_callback(self) + return callback + + def _setup_learn( + self, + total_timesteps: int, + eval_env: Optional[GymEnv], + callback: MaybeCallback = None, + eval_freq: int = 10000, + n_eval_episodes: int = 5, + log_path: Optional[str] = None, + reset_num_timesteps: bool = True, + tb_log_name: str = "run", + ) -> Tuple[int, BaseCallback]: + """ + Initialize different variables needed for training. + + :param total_timesteps: The total number of samples (env steps) to train on + :param eval_env: Environment to use for evaluation. + :param callback: Callback(s) called at every step with state of the algorithm. + :param eval_freq: How many steps between evaluations + :param n_eval_episodes: How many episodes to play per evaluation + :param log_path: Path to a folder where the evaluations will be saved + :param reset_num_timesteps: Whether to reset or not the ``num_timesteps`` attribute + :param tb_log_name: the name of the run for tensorboard log + :return: + """ + self.start_time = time.time() + + if self.ep_info_buffer is None or reset_num_timesteps: + # Initialize buffers if they don't exist, or reinitialize if resetting counters + self.ep_info_buffer = deque(maxlen=100) + self.ep_success_buffer = deque(maxlen=100) + + if self.action_noise is not None: + self.action_noise.reset() + + if reset_num_timesteps: + self.num_timesteps = 0 + self._episode_num = 0 + else: + # Make sure training timesteps are ahead of the internal counter + total_timesteps += self.num_timesteps + self._total_timesteps = total_timesteps + self._num_timesteps_at_start = self.num_timesteps + + # Avoid resetting the environment when calling ``.learn()`` consecutive times + if reset_num_timesteps or self._last_obs is None: + self._last_obs = self.env.reset() # pytype: disable=annotation-type-mismatch + self._last_episode_starts = np.ones((self.env.num_envs,), dtype=bool) + # Retrieve unnormalized observation for saving into the buffer + if self._vec_normalize_env is not None: + self._last_original_obs = self._vec_normalize_env.get_original_obs() + + if eval_env is not None and self.seed is not None: + eval_env.seed(self.seed) + + eval_env = self._get_eval_env(eval_env) + + # Configure logger's outputs if no logger was passed + if not self._custom_logger: + self._logger = utils.configure_logger(self.verbose, self.tensorboard_log, tb_log_name, reset_num_timesteps) + + # Create eval callback if needed + callback = self._init_callback(callback, eval_env, eval_freq, n_eval_episodes, log_path) + + return total_timesteps, callback + + def _update_info_buffer(self, infos: List[Dict[str, Any]], dones: Optional[np.ndarray] = None) -> None: + """ + Retrieve reward, episode length, episode success and update the buffer + if using Monitor wrapper or a GoalEnv. + + :param infos: List of additional information about the transition. + :param dones: Termination signals + """ + if dones is None: + dones = np.array([False] * len(infos)) + for idx, info in enumerate(infos): + maybe_ep_info = info.get("episode") + maybe_is_success = info.get("is_success") + if maybe_ep_info is not None: + self.ep_info_buffer.extend([maybe_ep_info]) + if maybe_is_success is not None and dones[idx]: + self.ep_success_buffer.append(maybe_is_success) + + def get_env(self) -> Optional[VecEnv]: + """ + Returns the current environment (can be None if not defined). + + :return: The current environment + """ + return self.env + + def get_vec_normalize_env(self) -> Optional[VecNormalize]: + """ + Return the ``VecNormalize`` wrapper of the training env + if it exists. + + :return: The ``VecNormalize`` env. + """ + return self._vec_normalize_env + + def set_env(self, env: GymEnv, force_reset: bool = True) -> None: + """ + Checks the validity of the environment, and if it is coherent, set it as the current environment. + Furthermore wrap any non vectorized env into a vectorized + checked parameters: + - observation_space + - action_space + + :param env: The environment for learning a policy + :param force_reset: Force call to ``reset()`` before training + to avoid unexpected behavior. + See issue https://github.com/DLR-RM/stable-baselines3/issues/597 + """ + # if it is not a VecEnv, make it a VecEnv + # and do other transformations (dict obs, image transpose) if needed + env = self._wrap_env(env, self.verbose) + # Check that the observation spaces match + check_for_correct_spaces(env, self.observation_space, self.action_space) + # Update VecNormalize object + # otherwise the wrong env may be used, see https://github.com/DLR-RM/stable-baselines3/issues/637 + self._vec_normalize_env = unwrap_vec_normalize(env) + + # Discard `_last_obs`, this will force the env to reset before training + # See issue https://github.com/DLR-RM/stable-baselines3/issues/597 + if force_reset: + self._last_obs = None + + self.n_envs = env.num_envs + self.env = env + + @abstractmethod + def learn( + self, + total_timesteps: int, + callback: MaybeCallback = None, + log_interval: int = 100, + tb_log_name: str = "run", + eval_env: Optional[GymEnv] = None, + eval_freq: int = -1, + n_eval_episodes: int = 5, + eval_log_path: Optional[str] = None, + reset_num_timesteps: bool = True, + ) -> "BaseAlgorithm": + """ + Return a trained model. + + :param total_timesteps: The total number of samples (env steps) to train on + :param callback: callback(s) called at every step with state of the algorithm. + :param log_interval: The number of timesteps before logging. + :param tb_log_name: the name of the run for TensorBoard logging + :param eval_env: Environment that will be used to evaluate the agent + :param eval_freq: Evaluate the agent every ``eval_freq`` timesteps (this may vary a little) + :param n_eval_episodes: Number of episode to evaluate the agent + :param eval_log_path: Path to a folder where the evaluations will be saved + :param reset_num_timesteps: whether or not to reset the current timestep number (used in logging) + :return: the trained model + """ + + def predict( + self, + observation: np.ndarray, + state: Optional[Tuple[np.ndarray, ...]] = None, + episode_start: Optional[np.ndarray] = None, + deterministic: bool = False, + ) -> Tuple[np.ndarray, Optional[Tuple[np.ndarray, ...]]]: + """ + Get the policy action from an observation (and optional hidden state). + Includes sugar-coating to handle different observations (e.g. normalizing images). + + :param observation: the input observation + :param state: The last hidden states (can be None, used in recurrent policies) + :param episode_start: The last masks (can be None, used in recurrent policies) + this correspond to beginning of episodes, + where the hidden states of the RNN must be reset. + :param deterministic: Whether or not to return deterministic actions. + :return: the model's action and the next hidden state + (used in recurrent policies) + """ + return self.policy.predict(observation, state, episode_start, deterministic) + + def set_random_seed(self, seed: Optional[int] = None) -> None: + """ + Set the seed of the pseudo-random generators + (python, numpy, pytorch, gym, action_space) + + :param seed: + """ + if seed is None: + return + set_random_seed(seed, using_cuda=self.device.type == th.device("cuda").type) + self.action_space.seed(seed) + if self.env is not None: + self.env.seed(seed) + if self.eval_env is not None: + self.eval_env.seed(seed) + + def set_parameters( + self, + load_path_or_dict: Union[str, Dict[str, Dict]], + exact_match: bool = True, + device: Union[th.device, str] = "auto", + ) -> None: + """ + Load parameters from a given zip-file or a nested dictionary containing parameters for + different modules (see ``get_parameters``). + + :param load_path_or_iter: Location of the saved data (path or file-like, see ``save``), or a nested + dictionary containing nn.Module parameters used by the policy. The dictionary maps + object names to a state-dictionary returned by ``torch.nn.Module.state_dict()``. + :param exact_match: If True, the given parameters should include parameters for each + module and each of their parameters, otherwise raises an Exception. If set to False, this + can be used to update only specific parameters. + :param device: Device on which the code should run. + """ + params = None + if isinstance(load_path_or_dict, dict): + params = load_path_or_dict + else: + _, params, _ = load_from_zip_file(load_path_or_dict, device=device) + + # Keep track which objects were updated. + # `_get_torch_save_params` returns [params, other_pytorch_variables]. + # We are only interested in former here. + objects_needing_update = set(self._get_torch_save_params()[0]) + updated_objects = set() + + for name in params: + attr = None + try: + attr = recursive_getattr(self, name) + except Exception: + # What errors recursive_getattr could throw? KeyError, but + # possible something else too (e.g. if key is an int?). + # Catch anything for now. + raise ValueError(f"Key {name} is an invalid object name.") + + if isinstance(attr, th.optim.Optimizer): + # Optimizers do not support "strict" keyword... + # Seems like they will just replace the whole + # optimizer state with the given one. + # On top of this, optimizer state-dict + # seems to change (e.g. first ``optim.step()``), + # which makes comparing state dictionary keys + # invalid (there is also a nesting of dictionaries + # with lists with dictionaries with ...), adding to the + # mess. + # + # TL;DR: We might not be able to reliably say + # if given state-dict is missing keys. + # + # Solution: Just load the state-dict as is, and trust + # the user has provided a sensible state dictionary. + attr.load_state_dict(params[name]) + else: + # Assume attr is th.nn.Module + attr.load_state_dict(params[name], strict=exact_match) + updated_objects.add(name) + + if exact_match and updated_objects != objects_needing_update: + raise ValueError( + "Names of parameters do not match agents' parameters: " + f"expected {objects_needing_update}, got {updated_objects}" + ) + + @classmethod + def load( + cls, + path: Union[str, pathlib.Path, io.BufferedIOBase], + env: Optional[GymEnv] = None, + device: Union[th.device, str] = "auto", + custom_objects: Optional[Dict[str, Any]] = None, + print_system_info: bool = False, + force_reset: bool = True, + check_obs_space: bool = True, + force_load: bool = False, + **kwargs, + ) -> "BaseAlgorithm": + """ + Load the model from a zip-file. + Warning: ``load`` re-creates the model from scratch, it does not update it in-place! + For an in-place load use ``set_parameters`` instead. + + :param path: path to the file (or a file-like) where to + load the agent from + :param env: the new environment to run the loaded model on + (can be None if you only need prediction from a trained model) has priority over any saved environment + :param device: Device on which the code should run. + :param custom_objects: Dictionary of objects to replace + upon loading. If a variable is present in this dictionary as a + key, it will not be deserialized and the corresponding item + will be used instead. Similar to custom_objects in + ``keras.models.load_model``. Useful when you have an object in + file that can not be deserialized. + :param print_system_info: Whether to print system info from the saved model + and the current system info (useful to debug loading issues) + :param force_reset: Force call to ``reset()`` before training + to avoid unexpected behavior. + See https://github.com/DLR-RM/stable-baselines3/issues/597 + :param kwargs: extra arguments to change the model when loading + :return: new model instance with loaded parameters + """ + if print_system_info: + print("== CURRENT SYSTEM INFO ==") + get_system_info() + + data, params, pytorch_variables = load_from_zip_file( + path, device=device, custom_objects=custom_objects, print_system_info=print_system_info + ) + + # Remove stored device information and replace with ours + if "policy_kwargs" in data: + if "device" in data["policy_kwargs"]: + del data["policy_kwargs"]["device"] + + if not force_load: + if "policy_kwargs" in kwargs and kwargs["policy_kwargs"] != data["policy_kwargs"]: + raise ValueError( + f"The specified policy kwargs do not equal the stored policy kwargs." + f"Stored kwargs: {data['policy_kwargs']}, specified kwargs: {kwargs['policy_kwargs']}" + ) + + if "observation_space" not in data or "action_space" not in data: + raise KeyError("The observation_space and action_space were not given, can't verify new environments") + + if env is not None: + # Wrap first if needed + env = cls._wrap_env(env, data["verbose"]) + # Check if given env is valid + if check_obs_space: + check_for_correct_spaces(env, data["observation_space"], data["action_space"]) + # Discard `_last_obs`, this will force the env to reset before training + # See issue https://github.com/DLR-RM/stable-baselines3/issues/597 + if force_reset and data is not None: + data["_last_obs"] = None + else: + # Use stored env, if one exists. If not, continue as is (can be used for predict) + if "env" in data: + env = data["env"] + + # noinspection PyArgumentList + model = cls( # pytype: disable=not-instantiable,wrong-keyword-args + policy=data["policy_class"], + env=env, + device=device, + _init_setup_model=False, # pytype: disable=not-instantiable,wrong-keyword-args + ) + + # load parameters + if not force_load: + model.__dict__.update(data) # TODO: Helin cancelled this + model.__dict__.update(kwargs) + model._setup_model() + + # put state_dicts back in place + model.set_parameters(params, exact_match=True, device=device) + + # put other pytorch variables back in place + if pytorch_variables is not None: + for name in pytorch_variables: + # Skip if PyTorch variable was not defined (to ensure backward compatibility). + # This happens when using SAC/TQC. + # SAC has an entropy coefficient which can be fixed or optimized. + # If it is optimized, an additional PyTorch variable `log_ent_coef` is defined, + # otherwise it is initialized to `None`. + if pytorch_variables[name] is None: + continue + # Set the data attribute directly to avoid issue when using optimizers + # See https://github.com/DLR-RM/stable-baselines3/issues/391 + recursive_setattr(model, name + ".data", pytorch_variables[name].data) + + # Sample gSDE exploration matrix, so it uses the right device + # see issue #44 + if model.use_sde: + model.policy.reset_noise() # pytype: disable=attribute-error + return model + + def get_parameters(self) -> Dict[str, Dict]: + """ + Return the parameters of the agent. This includes parameters from different networks, e.g. + critics (value functions) and policies (pi functions). + + :return: Mapping of from names of the objects to PyTorch state-dicts. + """ + state_dicts_names, _ = self._get_torch_save_params() + params = {} + for name in state_dicts_names: + attr = recursive_getattr(self, name) + # Retrieve state dict + params[name] = attr.state_dict() + return params + + def save( + self, + path: Union[str, pathlib.Path, io.BufferedIOBase], + exclude: Optional[Iterable[str]] = None, + include: Optional[Iterable[str]] = None, + ) -> None: + """ + Save all the attributes of the object and the model parameters in a zip-file. + + :param path: path to the file where the rl agent should be saved + :param exclude: name of parameters that should be excluded in addition to the default ones + :param include: name of parameters that might be excluded but should be included anyway + """ + # Copy parameter list so we don't mutate the original dict + data = self.__dict__.copy() + + # Exclude is union of specified parameters (if any) and standard exclusions + if exclude is None: + exclude = [] + exclude = set(exclude).union(self._excluded_save_params()) + + # Do not exclude params if they are specifically included + if include is not None: + exclude = exclude.difference(include) + + state_dicts_names, torch_variable_names = self._get_torch_save_params() + all_pytorch_variables = state_dicts_names + torch_variable_names + for torch_var in all_pytorch_variables: + # We need to get only the name of the top most module as we'll remove that + var_name = torch_var.split(".")[0] + # Any params that are in the save vars must not be saved by data + exclude.add(var_name) + + # Remove parameter entries of parameters which are to be excluded + for param_name in exclude: + data.pop(param_name, None) + + # Build dict of torch variables + pytorch_variables = None + if torch_variable_names is not None: + pytorch_variables = {} + for name in torch_variable_names: + attr = recursive_getattr(self, name) + pytorch_variables[name] = attr + + # Build dict of state_dicts + params_to_save = self.get_parameters() + + save_to_zip_file(path, data=data, params=params_to_save, pytorch_variables=pytorch_variables) diff --git a/dexart-release/stable_baselines3/common/buffers.py b/dexart-release/stable_baselines3/common/buffers.py new file mode 100644 index 0000000000000000000000000000000000000000..ff420f04a986cf7178d38d3b493125db9a3764e6 --- /dev/null +++ b/dexart-release/stable_baselines3/common/buffers.py @@ -0,0 +1,1010 @@ +import warnings +from abc import ABC, abstractmethod +from typing import Any, Dict, Generator, List, Optional, Union +import stable_baselines3.pickle_utils as pickle_utils +import numpy as np +import torch as th +from gym import spaces + +from stable_baselines3.common.preprocessing import get_action_dim, get_obs_shape +from stable_baselines3.common.type_aliases import ( + DictReplayBufferSamples, + DictRolloutBufferSamples, + DictSSLRolloutBufferSamples, + ReplayBufferSamples, + RolloutBufferSamples, +) + +from stable_baselines3.common.vec_env import VecNormalize + +try: + # Check memory used by replay buffer when possible + import psutil +except ImportError: + psutil = None + + +class BaseBuffer(ABC): + """ + Base class that represent a buffer (rollout or replay) + + :param buffer_size: Max number of element in the buffer + :param observation_space: Observation space + :param action_space: Action space + :param device: PyTorch device + to which the values will be converted + :param n_envs: Number of parallel environments + """ + + def __init__( + self, + buffer_size: int, + observation_space: spaces.Space, + action_space: spaces.Space, + device: Union[th.device, str] = "cpu", + n_envs: int = 1, + ): + super().__init__() + self.buffer_size = buffer_size + self.observation_space = observation_space + self.action_space = action_space + self.obs_shape = get_obs_shape(observation_space) + + self.action_dim = get_action_dim(action_space) + self.pos = 0 + self.full = False + self.device = device + self.n_envs = n_envs + + @staticmethod + def swap_and_flatten(arr: np.ndarray) -> np.ndarray: + """ + Swap and then flatten axes 0 (buffer_size) and 1 (n_envs) + to convert shape from [n_steps, n_envs, ...] (when ... is the shape of the features) + to [n_steps * n_envs, ...] (which maintain the order) + + :param arr: + :return: + """ + shape = arr.shape + if len(shape) < 3: + shape = shape + (1,) + return arr.swapaxes(0, 1).reshape(shape[0] * shape[1], *shape[2:]) + + def size(self) -> int: + """ + :return: The current size of the buffer + """ + if self.full: + return self.buffer_size + return self.pos + + def add(self, *args, **kwargs) -> None: + """ + Add elements to the buffer. + """ + raise NotImplementedError() + + def extend(self, *args, **kwargs) -> None: + """ + Add a new batch of transitions to the buffer + """ + # Do a for loop along the batch axis + for data in zip(*args): + self.add(*data) + + def reset(self) -> None: + """ + Reset the buffer. + """ + self.pos = 0 + self.full = False + + def sample(self, batch_size: int, env: Optional[VecNormalize] = None): + """ + :param batch_size: Number of element to sample + :param env: associated gym VecEnv + to normalize the observations/rewards when sampling + :return: + """ + upper_bound = self.buffer_size if self.full else self.pos + batch_inds = np.random.randint(0, upper_bound, size=batch_size) + return self._get_samples(batch_inds, env=env) + + @abstractmethod + def _get_samples( + self, batch_inds: np.ndarray, env: Optional[VecNormalize] = None + ) -> Union[ReplayBufferSamples, RolloutBufferSamples]: + """ + :param batch_inds: + :param env: + :return: + """ + raise NotImplementedError() + + def to_torch(self, array: np.ndarray, copy: bool = True) -> th.Tensor: + """ + Convert a numpy array to a PyTorch tensor. + Note: it copies the data by default + + :param array: + :param copy: Whether to copy or not the data + (may be useful to avoid changing things be reference) + :return: + """ + if copy: + return th.tensor(array).to(self.device) + return th.as_tensor(array).to(self.device) + + @staticmethod + def _normalize_obs( + obs: Union[np.ndarray, Dict[str, np.ndarray]], + env: Optional[VecNormalize] = None, + ) -> Union[np.ndarray, Dict[str, np.ndarray]]: + if env is not None: + return env.normalize_obs(obs) + return obs + + @staticmethod + def _normalize_reward(reward: np.ndarray, env: Optional[VecNormalize] = None) -> np.ndarray: + if env is not None: + return env.normalize_reward(reward).astype(np.float32) + return reward + + +class ExpertBuffer(BaseBuffer): + def __init__(self, + buffer_size: int, + observation_space: spaces.Space, + action_space: spaces.Space, + device: Union[th.device, str] = "cpu", + n_envs: int = 1, + dataset_path='' + ): + super(ExpertBuffer, self).__init__(buffer_size, observation_space, action_space, device, n_envs=n_envs) + data = pickle_utils.load_data(dataset_path) + + data_obs = [] + data_action = [] + + self.optimize_memory_usage = False + + # print(data[0]) + + for trajectory in data: + print(trajectory.keys()) + # for k, v in trajectory.items(): + data_obs.append(trajectory['observations']) + data_action.append(trajectory['actions']) + self.observations = np.concatenate(data_obs, axis=0) + self.actions = np.concatenate(data_action, axis=0) + + assert len(self.observations) == len(self.actions), "Demo Dataset Error: Obs num does not match Action num." + print('Expert buffer info:', self.observations.shape, self.actions.shape) + self.buffer_size = len(self.observations) + self.full = True + + def add( + self, + obs: np.ndarray, + next_obs: np.ndarray, + action: np.ndarray, + reward: np.ndarray, + done: np.ndarray, + infos: List[Dict[str, Any]], + ) -> None: + assert False, "We do not expect user to use this method." + + def _get_samples(self, batch_inds: np.ndarray, env: Optional[VecNormalize] = None) -> ReplayBufferSamples: + # Sample randomly the env idx + data = ( + self._normalize_obs(self.observations[batch_inds, :], env), + self.actions[batch_inds, :], + ) + return ReplayBufferSamples(*tuple(map(self.to_torch, data))) + + def sample(self, batch_size: int, env: Optional[VecNormalize] = None) -> ReplayBufferSamples: + """ + Sample elements from the replay buffer. + Custom sampling when using memory efficient variant, + as we should not sample the element with index `self.pos` + See https://github.com/DLR-RM/stable-baselines3/pull/28#issuecomment-637559274 + + :param batch_size: Number of element to sample + :param env: associated gym VecEnv + to normalize the observations/rewards when sampling + :return: + """ + if not self.optimize_memory_usage: + return super().sample(batch_size=batch_size, env=env) + # Do not sample the element with index `self.pos` as the transitions is invalid + # (we use only one array to store `obs` and `next_obs`) + if self.full: + batch_inds = (np.random.randint(1, self.buffer_size, size=batch_size) + self.pos) % self.buffer_size + else: + batch_inds = np.random.randint(0, self.pos, size=batch_size) + return self._get_samples(batch_inds, env=env) + + def get_all_samples(self, env=None): + data = ( + self._normalize_obs(self.observations, env), + self.actions, + ) + return ReplayBufferSamples(*tuple(map(self.to_torch, data)), None, None, None) + + +class ReplayBuffer(BaseBuffer): + """ + Replay buffer used in off-policy algorithms like SAC/TD3. + + :param buffer_size: Max number of element in the buffer + :param observation_space: Observation space + :param action_space: Action space + :param device: + :param n_envs: Number of parallel environments + :param optimize_memory_usage: Enable a memory efficient variant + of the replay buffer which reduces by almost a factor two the memory used, + at a cost of more complexity. + See https://github.com/DLR-RM/stable-baselines3/issues/37#issuecomment-637501195 + and https://github.com/DLR-RM/stable-baselines3/pull/28#issuecomment-637559274 + :param handle_timeout_termination: Handle timeout termination (due to timelimit) + separately and treat the task as infinite horizon task. + https://github.com/DLR-RM/stable-baselines3/issues/284 + """ + + def __init__( + self, + buffer_size: int, + observation_space: spaces.Space, + action_space: spaces.Space, + device: Union[th.device, str] = "cpu", + n_envs: int = 1, + optimize_memory_usage: bool = False, + handle_timeout_termination: bool = True, + ): + super().__init__(buffer_size, observation_space, action_space, device, n_envs=n_envs) + + # Adjust buffer size + self.buffer_size = max(buffer_size // n_envs, 1) + + # Check that the replay buffer can fit into the memory + if psutil is not None: + mem_available = psutil.virtual_memory().available + + self.optimize_memory_usage = optimize_memory_usage + + self.observations = np.zeros((self.buffer_size, self.n_envs) + self.obs_shape, dtype=observation_space.dtype) + + if optimize_memory_usage: + # `observations` contains also the next observation + self.next_observations = None + else: + self.next_observations = np.zeros((self.buffer_size, self.n_envs) + self.obs_shape, dtype=observation_space.dtype) + + self.actions = np.zeros((self.buffer_size, self.n_envs, self.action_dim), dtype=action_space.dtype) + + self.rewards = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32) + self.dones = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32) + # Handle timeouts termination properly if needed + # see https://github.com/DLR-RM/stable-baselines3/issues/284 + self.handle_timeout_termination = handle_timeout_termination + self.timeouts = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32) + + if psutil is not None: + total_memory_usage = self.observations.nbytes + self.actions.nbytes + self.rewards.nbytes + self.dones.nbytes + + if self.next_observations is not None: + total_memory_usage += self.next_observations.nbytes + + if total_memory_usage > mem_available: + # Convert to GB + total_memory_usage /= 1e9 + mem_available /= 1e9 + warnings.warn( + "This system does not have apparently enough memory to store the complete " + f"replay buffer {total_memory_usage:.2f}GB > {mem_available:.2f}GB" + ) + + def add( + self, + obs: np.ndarray, + next_obs: np.ndarray, + action: np.ndarray, + reward: np.ndarray, + done: np.ndarray, + infos: List[Dict[str, Any]], + ) -> None: + + # Reshape needed when using multiple envs with discrete observations + # as numpy cannot broadcast (n_discrete,) to (n_discrete, 1) + if isinstance(self.observation_space, spaces.Discrete): + obs = obs.reshape((self.n_envs,) + self.obs_shape) + next_obs = next_obs.reshape((self.n_envs,) + self.obs_shape) + + # Same, for actions + if isinstance(self.action_space, spaces.Discrete): + action = action.reshape((self.n_envs, self.action_dim)) + + # Copy to avoid modification by reference + self.observations[self.pos] = np.array(obs).copy() + + if self.optimize_memory_usage: + self.observations[(self.pos + 1) % self.buffer_size] = np.array(next_obs).copy() + else: + self.next_observations[self.pos] = np.array(next_obs).copy() + + self.actions[self.pos] = np.array(action).copy() + self.rewards[self.pos] = np.array(reward).copy() + self.dones[self.pos] = np.array(done).copy() + + if self.handle_timeout_termination: + self.timeouts[self.pos] = np.array([info.get("TimeLimit.truncated", False) for info in infos]) + + self.pos += 1 + if self.pos == self.buffer_size: + self.full = True + self.pos = 0 + + def sample(self, batch_size: int, env: Optional[VecNormalize] = None) -> ReplayBufferSamples: + """ + Sample elements from the replay buffer. + Custom sampling when using memory efficient variant, + as we should not sample the element with index `self.pos` + See https://github.com/DLR-RM/stable-baselines3/pull/28#issuecomment-637559274 + + :param batch_size: Number of element to sample + :param env: associated gym VecEnv + to normalize the observations/rewards when sampling + :return: + """ + if not self.optimize_memory_usage: + return super().sample(batch_size=batch_size, env=env) + # Do not sample the element with index `self.pos` as the transitions is invalid + # (we use only one array to store `obs` and `next_obs`) + if self.full: + batch_inds = (np.random.randint(1, self.buffer_size, size=batch_size) + self.pos) % self.buffer_size + else: + batch_inds = np.random.randint(0, self.pos, size=batch_size) + return self._get_samples(batch_inds, env=env) + + def _get_samples(self, batch_inds: np.ndarray, env: Optional[VecNormalize] = None) -> ReplayBufferSamples: + # Sample randomly the env idx + env_indices = np.random.randint(0, high=self.n_envs, size=(len(batch_inds),)) + + if self.optimize_memory_usage: + next_obs = self._normalize_obs(self.observations[(batch_inds + 1) % self.buffer_size, env_indices, :], env) + else: + next_obs = self._normalize_obs(self.next_observations[batch_inds, env_indices, :], env) + + data = ( + self._normalize_obs(self.observations[batch_inds, env_indices, :], env), + self.actions[batch_inds, env_indices, :], + next_obs, + # Only use dones that are not due to timeouts + # deactivated by default (timeouts is initialized as an array of False) + (self.dones[batch_inds, env_indices] * (1 - self.timeouts[batch_inds, env_indices])).reshape(-1, 1), + self._normalize_reward(self.rewards[batch_inds, env_indices].reshape(-1, 1), env), + ) + return ReplayBufferSamples(*tuple(map(self.to_torch, data))) + + +class RolloutBuffer(BaseBuffer): + """ + Rollout buffer used in on-policy algorithms like A2C/PPO. + It corresponds to ``buffer_size`` transitions collected + using the current policy. + This experience will be discarded after the policy update. + In order to use PPO objective, we also store the current value of each state + and the log probability of each taken action. + + The term rollout here refers to the model-free notion and should not + be used with the concept of rollout used in model-based RL or planning. + Hence, it is only involved in policy and value function training but not action selection. + + :param buffer_size: Max number of element in the buffer + :param observation_space: Observation space + :param action_space: Action space + :param device: + :param gae_lambda: Factor for trade-off of bias vs variance for Generalized Advantage Estimator + Equivalent to classic advantage when set to 1. + :param gamma: Discount factor + :param n_envs: Number of parallel environments + """ + + def __init__( + self, + buffer_size: int, + observation_space: spaces.Space, + action_space: spaces.Space, + device: Union[th.device, str] = "cpu", + gae_lambda: float = 1, + gamma: float = 0.99, + n_envs: int = 1, + ): + + super().__init__(buffer_size, observation_space, action_space, device, n_envs=n_envs) + self.gae_lambda = gae_lambda + self.gamma = gamma + self.observations, self.actions, self.rewards, self.advantages = None, None, None, None + self.returns, self.episode_starts, self.values, self.log_probs = None, None, None, None + self.generator_ready = False + self.reset() + + def reset(self) -> None: + + self.observations = np.zeros((self.buffer_size, self.n_envs) + self.obs_shape, dtype=np.float32) + self.actions = np.zeros((self.buffer_size, self.n_envs, self.action_dim), dtype=np.float32) + self.rewards = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32) + self.returns = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32) + self.episode_starts = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32) + self.values = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32) + self.log_probs = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32) + self.advantages = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32) + self.generator_ready = False + super().reset() + + def compute_returns_and_advantage(self, last_values: th.Tensor, dones: np.ndarray) -> None: + """ + Post-processing step: compute the lambda-return (TD(lambda) estimate) + and GAE(lambda) advantage. + + Uses Generalized Advantage Estimation (https://arxiv.org/abs/1506.02438) + to compute the advantage. To obtain Monte-Carlo advantage estimate (A(s) = R - V(S)) + where R is the sum of discounted reward with value bootstrap + (because we don't always have full episode), set ``gae_lambda=1.0`` during initialization. + + The TD(lambda) estimator has also two special cases: + - TD(1) is Monte-Carlo estimate (sum of discounted rewards) + - TD(0) is one-step estimate with bootstrapping (r_t + gamma * v(s_{t+1})) + + For more information, see discussion in https://github.com/DLR-RM/stable-baselines3/pull/375. + + :param last_values: state value estimation for the last step (one for each env) + :param dones: if the last step was a terminal step (one bool for each env). + """ + # Convert to numpy + last_values = last_values.clone().cpu().numpy().flatten() + + last_gae_lam = 0 + for step in reversed(range(self.buffer_size)): + if step == self.buffer_size - 1: + next_non_terminal = 1.0 - dones + next_values = last_values + else: + next_non_terminal = 1.0 - self.episode_starts[step + 1] + next_values = self.values[step + 1] + delta = self.rewards[step] + self.gamma * next_values * next_non_terminal - self.values[step] + last_gae_lam = delta + self.gamma * self.gae_lambda * next_non_terminal * last_gae_lam + self.advantages[step] = last_gae_lam + # TD(lambda) estimator, see Github PR #375 or "Telescoping in TD(lambda)" + # in David Silver Lecture 4: https://www.youtube.com/watch?v=PnHCvfgC_ZA + self.returns = self.advantages + self.values + + def add( + self, + obs: np.ndarray, + action: np.ndarray, + reward: np.ndarray, + episode_start: np.ndarray, + value: th.Tensor, + log_prob: th.Tensor, + ) -> None: + """ + :param obs: Observation + :param action: Action + :param reward: + :param episode_start: Start of episode signal. + :param value: estimated value of the current state + following the current policy. + :param log_prob: log probability of the action + following the current policy. + """ + if len(log_prob.shape) == 0: + # Reshape 0-d tensor to avoid error + log_prob = log_prob.reshape(-1, 1) + + # Reshape needed when using multiple envs with discrete observations + # as numpy cannot broadcast (n_discrete,) to (n_discrete, 1) + if isinstance(self.observation_space, spaces.Discrete): + obs = obs.reshape((self.n_envs,) + self.obs_shape) + + self.observations[self.pos] = np.array(obs).copy() + self.actions[self.pos] = np.array(action).copy() + self.rewards[self.pos] = np.array(reward).copy() + self.episode_starts[self.pos] = np.array(episode_start).copy() + self.values[self.pos] = value.clone().cpu().numpy().flatten() + self.log_probs[self.pos] = log_prob.clone().cpu().numpy() + self.pos += 1 + if self.pos == self.buffer_size: + self.full = True + + def get(self, batch_size: Optional[int] = None) -> Generator[RolloutBufferSamples, None, None]: + assert self.full, "" + indices = np.random.permutation(self.buffer_size * self.n_envs) + # Prepare the data + if not self.generator_ready: + + _tensor_names = [ + "observations", + "actions", + "values", + "log_probs", + "advantages", + "returns", + ] + + for tensor in _tensor_names: + self.__dict__[tensor] = self.swap_and_flatten(self.__dict__[tensor]) + self.generator_ready = True + + # Return everything, don't create minibatches + if batch_size is None: + batch_size = self.buffer_size * self.n_envs + + start_idx = 0 + while start_idx < self.buffer_size * self.n_envs: + yield self._get_samples(indices[start_idx : start_idx + batch_size]) + start_idx += batch_size + + def _get_samples(self, batch_inds: np.ndarray, env: Optional[VecNormalize] = None) -> RolloutBufferSamples: + data = ( + self.observations[batch_inds], + self.actions[batch_inds], + self.values[batch_inds].flatten(), + self.log_probs[batch_inds].flatten(), + self.advantages[batch_inds].flatten(), + self.returns[batch_inds].flatten(), + ) + return RolloutBufferSamples(*tuple(map(self.to_torch, data))) + + +class DictReplayBuffer(ReplayBuffer): + """ + Dict Replay buffer used in off-policy algorithms like SAC/TD3. + Extends the ReplayBuffer to use dictionary observations + + :param buffer_size: Max number of element in the buffer + :param observation_space: Observation space + :param action_space: Action space + :param device: + :param n_envs: Number of parallel environments + :param optimize_memory_usage: Enable a memory efficient variant + Disabled for now (see https://github.com/DLR-RM/stable-baselines3/pull/243#discussion_r531535702) + :param handle_timeout_termination: Handle timeout termination (due to timelimit) + separately and treat the task as infinite horizon task. + https://github.com/DLR-RM/stable-baselines3/issues/284 + """ + + def __init__( + self, + buffer_size: int, + observation_space: spaces.Space, + action_space: spaces.Space, + device: Union[th.device, str] = "cpu", + n_envs: int = 1, + optimize_memory_usage: bool = False, + handle_timeout_termination: bool = True, + ): + super(ReplayBuffer, self).__init__(buffer_size, observation_space, action_space, device, n_envs=n_envs) + + assert isinstance(self.obs_shape, dict), "DictReplayBuffer must be used with Dict obs space only" + self.buffer_size = max(buffer_size // n_envs, 1) + + # Check that the replay buffer can fit into the memory + if psutil is not None: + mem_available = psutil.virtual_memory().available + + assert optimize_memory_usage is False, "DictReplayBuffer does not support optimize_memory_usage" + # disabling as this adds quite a bit of complexity + # https://github.com/DLR-RM/stable-baselines3/pull/243#discussion_r531535702 + self.optimize_memory_usage = optimize_memory_usage + + self.observations = { + key: np.zeros((self.buffer_size, self.n_envs) + _obs_shape, dtype=observation_space[key].dtype) + for key, _obs_shape in self.obs_shape.items() + } + self.next_observations = { + key: np.zeros((self.buffer_size, self.n_envs) + _obs_shape, dtype=observation_space[key].dtype) + for key, _obs_shape in self.obs_shape.items() + } + + self.actions = np.zeros((self.buffer_size, self.n_envs, self.action_dim), dtype=action_space.dtype) + self.rewards = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32) + self.dones = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32) + + # Handle timeouts termination properly if needed + # see https://github.com/DLR-RM/stable-baselines3/issues/284 + self.handle_timeout_termination = handle_timeout_termination + self.timeouts = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32) + + if psutil is not None: + obs_nbytes = 0 + for _, obs in self.observations.items(): + obs_nbytes += obs.nbytes + + total_memory_usage = obs_nbytes + self.actions.nbytes + self.rewards.nbytes + self.dones.nbytes + if self.next_observations is not None: + next_obs_nbytes = 0 + for _, obs in self.observations.items(): + next_obs_nbytes += obs.nbytes + total_memory_usage += next_obs_nbytes + + if total_memory_usage > mem_available: + # Convert to GB + total_memory_usage /= 1e9 + mem_available /= 1e9 + warnings.warn( + "This system does not have apparently enough memory to store the complete " + f"replay buffer {total_memory_usage:.2f}GB > {mem_available:.2f}GB" + ) + + def add( + self, + obs: Dict[str, np.ndarray], + next_obs: Dict[str, np.ndarray], + action: np.ndarray, + reward: np.ndarray, + done: np.ndarray, + infos: List[Dict[str, Any]], + ) -> None: + # Copy to avoid modification by reference + for key in self.observations.keys(): + # Reshape needed when using multiple envs with discrete observations + # as numpy cannot broadcast (n_discrete,) to (n_discrete, 1) + if isinstance(self.observation_space.spaces[key], spaces.Discrete): + obs[key] = obs[key].reshape((self.n_envs,) + self.obs_shape[key]) + self.observations[key][self.pos] = np.array(obs[key]) + + for key in self.next_observations.keys(): + if isinstance(self.observation_space.spaces[key], spaces.Discrete): + next_obs[key] = next_obs[key].reshape((self.n_envs,) + self.obs_shape[key]) + self.next_observations[key][self.pos] = np.array(next_obs[key]).copy() + + # Same reshape, for actions + if isinstance(self.action_space, spaces.Discrete): + action = action.reshape((self.n_envs, self.action_dim)) + + self.actions[self.pos] = np.array(action).copy() + self.rewards[self.pos] = np.array(reward).copy() + self.dones[self.pos] = np.array(done).copy() + + if self.handle_timeout_termination: + self.timeouts[self.pos] = np.array([info.get("TimeLimit.truncated", False) for info in infos]) + + self.pos += 1 + if self.pos == self.buffer_size: + self.full = True + self.pos = 0 + + def sample(self, batch_size: int, env: Optional[VecNormalize] = None) -> DictReplayBufferSamples: + """ + Sample elements from the replay buffer. + + :param batch_size: Number of element to sample + :param env: associated gym VecEnv + to normalize the observations/rewards when sampling + :return: + """ + return super(ReplayBuffer, self).sample(batch_size=batch_size, env=env) + + def _get_samples(self, batch_inds: np.ndarray, env: Optional[VecNormalize] = None) -> DictReplayBufferSamples: + # Sample randomly the env idx + env_indices = np.random.randint(0, high=self.n_envs, size=(len(batch_inds),)) + + # Normalize if needed and remove extra dimension (we are using only one env for now) + obs_ = self._normalize_obs({key: obs[batch_inds, env_indices, :] for key, obs in self.observations.items()}, env) + next_obs_ = self._normalize_obs( + {key: obs[batch_inds, env_indices, :] for key, obs in self.next_observations.items()}, env + ) + + # Convert to torch tensor + observations = {key: self.to_torch(obs) for key, obs in obs_.items()} + next_observations = {key: self.to_torch(obs) for key, obs in next_obs_.items()} + + return DictReplayBufferSamples( + observations=observations, + actions=self.to_torch(self.actions[batch_inds, env_indices]), + next_observations=next_observations, + # Only use dones that are not due to timeouts + # deactivated by default (timeouts is initialized as an array of False) + dones=self.to_torch(self.dones[batch_inds, env_indices] * (1 - self.timeouts[batch_inds, env_indices])).reshape( + -1, 1 + ), + rewards=self.to_torch(self._normalize_reward(self.rewards[batch_inds, env_indices].reshape(-1, 1), env)), + ) + + +class DictRolloutBuffer(RolloutBuffer): + """ + Dict Rollout buffer used in on-policy algorithms like A2C/PPO. + Extends the RolloutBuffer to use dictionary observations + + It corresponds to ``buffer_size`` transitions collected + using the current policy. + This experience will be discarded after the policy update. + In order to use PPO objective, we also store the current value of each state + and the log probability of each taken action. + + The term rollout here refers to the model-free notion and should not + be used with the concept of rollout used in model-based RL or planning. + Hence, it is only involved in policy and value function training but not action selection. + + :param buffer_size: Max number of element in the buffer + :param observation_space: Observation space + :param action_space: Action space + :param device: + :param gae_lambda: Factor for trade-off of bias vs variance for Generalized Advantage Estimator + Equivalent to Monte-Carlo advantage estimate when set to 1. + :param gamma: Discount factor + :param n_envs: Number of parallel environments + """ + + def __init__( + self, + buffer_size: int, + observation_space: spaces.Space, + action_space: spaces.Space, + device: Union[th.device, str] = "cpu", + gae_lambda: float = 1, + gamma: float = 0.99, + n_envs: int = 1, + ): + + super(RolloutBuffer, self).__init__(buffer_size, observation_space, action_space, device, n_envs=n_envs) + + assert isinstance(self.obs_shape, dict), "DictRolloutBuffer must be used with Dict obs space only" + + self.gae_lambda = gae_lambda + self.gamma = gamma + self.observations, self.actions, self.rewards, self.advantages = None, None, None, None + self.returns, self.episode_starts, self.values, self.log_probs = None, None, None, None + self.generator_ready = False + self.reset() + + def reset(self) -> None: + assert isinstance(self.obs_shape, dict), "DictRolloutBuffer must be used with Dict obs space only" + self.observations = {} + for key, obs_input_shape in self.obs_shape.items(): + self.observations[key] = np.zeros((self.buffer_size, self.n_envs) + obs_input_shape, dtype=np.float32) + self.actions = np.zeros((self.buffer_size, self.n_envs, self.action_dim), dtype=np.float32) + self.rewards = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32) + self.returns = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32) + self.episode_starts = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32) + self.values = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32) + self.log_probs = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32) + self.advantages = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32) + self.generator_ready = False + super(RolloutBuffer, self).reset() + + def add( + self, + obs: Dict[str, np.ndarray], + action: np.ndarray, + reward: np.ndarray, + episode_start: np.ndarray, + value: th.Tensor, + log_prob: th.Tensor, + ) -> None: + """ + :param obs: Observation + :param action: Action + :param reward: + :param episode_start: Start of episode signal. + :param value: estimated value of the current state + following the current policy. + :param log_prob: log probability of the action + following the current policy. + """ + if len(log_prob.shape) == 0: + # Reshape 0-d tensor to avoid error + log_prob = log_prob.reshape(-1, 1) + + for key in self.observations.keys(): + obs_ = np.array(obs[key]).copy() + # Reshape needed when using multiple envs with discrete observations + # as numpy cannot broadcast (n_discrete,) to (n_discrete, 1) + if isinstance(self.observation_space.spaces[key], spaces.Discrete): + obs_ = obs_.reshape((self.n_envs,) + self.obs_shape[key]) + self.observations[key][self.pos] = obs_ + + self.actions[self.pos] = np.array(action).copy() + self.rewards[self.pos] = np.array(reward).copy() + self.episode_starts[self.pos] = np.array(episode_start).copy() + self.values[self.pos] = value.clone().cpu().numpy().flatten() + self.log_probs[self.pos] = log_prob.clone().cpu().numpy() + self.pos += 1 + if self.pos == self.buffer_size: + self.full = True + + def get(self, batch_size: Optional[int] = None) -> Generator[DictRolloutBufferSamples, None, None]: + assert self.full, "" + indices = np.random.permutation(self.buffer_size * self.n_envs) + # Prepare the data + if not self.generator_ready: + + for key, obs in self.observations.items(): + self.observations[key] = self.swap_and_flatten(obs) + + _tensor_names = ["actions", "values", "log_probs", "advantages", "returns"] + + for tensor in _tensor_names: + self.__dict__[tensor] = self.swap_and_flatten(self.__dict__[tensor]) + self.generator_ready = True + + # Return everything, don't create minibatches + if batch_size is None: + batch_size = self.buffer_size * self.n_envs + + start_idx = 0 + while start_idx < self.buffer_size * self.n_envs: + yield self._get_samples(indices[start_idx : start_idx + batch_size]) + start_idx += batch_size + + def _get_samples(self, batch_inds: np.ndarray, env: Optional[VecNormalize] = None) -> DictRolloutBufferSamples: + + return DictRolloutBufferSamples( + observations={key: self.to_torch(obs[batch_inds]) for (key, obs) in self.observations.items()}, + actions=self.to_torch(self.actions[batch_inds]), + old_values=self.to_torch(self.values[batch_inds].flatten()), + old_log_prob=self.to_torch(self.log_probs[batch_inds].flatten()), + advantages=self.to_torch(self.advantages[batch_inds].flatten()), + returns=self.to_torch(self.returns[batch_inds].flatten()), + ) + + +class DictSSLRolloutBuffer(RolloutBuffer): + """ + Dict Rollout buffer used in on-policy algorithms like A2C/PPO. + Extends the RolloutBuffer to use dictionary observations + + It corresponds to ``buffer_size`` transitions collected + using the current policy. + This experience will be discarded after the policy update. + In order to use PPO objective, we also store the current value of each state + and the log probability of each taken action. + + The term rollout here refers to the model-free notion and should not + be used with the concept of rollout used in model-based RL or planning. + Hence, it is only involved in policy and value function training but not action selection. + + :param buffer_size: Max number of element in the buffer + :param observation_space: Observation space + :param action_space: Action space + :param device: + :param gae_lambda: Factor for trade-off of bias vs variance for Generalized Advantage Estimator + Equivalent to Monte-Carlo advantage estimate when set to 1. + :param gamma: Discount factor + :param n_envs: Number of parallel environments + """ + + def __init__( + self, + buffer_size: int, + observation_space: spaces.Space, + action_space: spaces.Space, + device: Union[th.device, str] = "cpu", + gae_lambda: float = 1, + gamma: float = 0.99, + n_envs: int = 1, + ): + + super(RolloutBuffer, self).__init__(buffer_size, observation_space, action_space, device, n_envs=n_envs) + + assert isinstance(self.obs_shape, dict), "DictRolloutBuffer must be used with Dict obs space only" + + self.gae_lambda = gae_lambda + self.gamma = gamma + self.observations, self.actions, self.rewards, self.advantages = None, None, None, None + self.returns, self.episode_starts, self.values, self.log_probs = None, None, None, None + self.generator_ready = False + self.reset() + + def reset(self) -> None: + assert isinstance(self.obs_shape, dict), "DictRolloutBuffer must be used with Dict obs space only" + self.observations = {} + for key, obs_input_shape in self.obs_shape.items(): + self.observations[key] = np.zeros((self.buffer_size, self.n_envs) + obs_input_shape, dtype=np.float32) + self.actions = np.zeros((self.buffer_size, self.n_envs, self.action_dim), dtype=np.float32) + self.rewards = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32) + self.returns = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32) + self.episode_starts = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32) + self.values = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32) + self.log_probs = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32) + self.advantages = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32) + self.generator_ready = False + super(RolloutBuffer, self).reset() + + def add( + self, + obs: Dict[str, np.ndarray], + action: np.ndarray, + reward: np.ndarray, + episode_start: np.ndarray, + value: th.Tensor, + log_prob: th.Tensor, + ) -> None: + """ + :param obs: Observation + :param action: Action + :param reward: + :param episode_start: Start of episode signal. + :param value: estimated value of the current state + following the current policy. + :param log_prob: log probability of the action + following the current policy. + """ + if len(log_prob.shape) == 0: + # Reshape 0-d tensor to avoid error + log_prob = log_prob.reshape(-1, 1) + + for key in self.observations.keys(): + obs_ = np.array(obs[key]).copy() + # Reshape needed when using multiple envs with discrete observations + # as numpy cannot broadcast (n_discrete,) to (n_discrete, 1) + if isinstance(self.observation_space.spaces[key], spaces.Discrete): + obs_ = obs_.reshape((self.n_envs,) + self.obs_shape[key]) + self.observations[key][self.pos] = obs_ + + self.actions[self.pos] = np.array(action).copy() + self.rewards[self.pos] = np.array(reward).copy() + self.episode_starts[self.pos] = np.array(episode_start).copy() + self.values[self.pos] = value.clone().cpu().numpy().flatten() + self.log_probs[self.pos] = log_prob.clone().cpu().numpy() + self.pos += 1 + if self.pos == self.buffer_size: + self.full = True + + def get(self, batch_size: Optional[int] = None) -> Generator[DictRolloutBufferSamples, None, None]: + assert self.full, "" + indices = np.random.permutation(self.buffer_size * self.n_envs) + # Prepare the data + if not self.generator_ready: + + for key, obs in self.observations.items(): + self.observations[key] = self.swap_and_flatten(obs) + + _tensor_names = ["actions", "values", "log_probs", "advantages", "returns"] + + for tensor in _tensor_names: + self.__dict__[tensor] = self.swap_and_flatten(self.__dict__[tensor]) + self.generator_ready = True + + # Return everything, don't create minibatches + if batch_size is None: + batch_size = self.buffer_size * self.n_envs + + start_idx = 0 + while start_idx < self.buffer_size * self.n_envs: + yield self._get_samples(indices[start_idx : start_idx + batch_size]) + start_idx += batch_size + + def _get_samples(self, batch_inds: np.ndarray, env: Optional[VecNormalize] = None) -> DictSSLRolloutBufferSamples: + + def get_next_obs(obs, ind, i): + result = {} + for key, obs in obs.items(): + future_batch_inds = np.clip(ind + i, 0, len(obs)-1) + next_obs = self.to_torch(obs[future_batch_inds]) + result[key] = next_obs + + return result + + def get_next_action(actions, ind, i): + future_batch_inds = np.clip(ind + i, 0, len(actions) - 1) + return self.to_torch(actions[future_batch_inds]) + + + next_observations = [get_next_obs(self.observations, batch_inds, i) for i in range(4)] + next_actions = [get_next_action(self.actions, batch_inds, i) for i in range(4)] + + return DictSSLRolloutBufferSamples( + observations={key: self.to_torch(obs[batch_inds]) for (key, obs) in self.observations.items()}, + next_observations=next_observations, + actions=self.to_torch(self.actions[batch_inds]), + next_actions=next_actions, + old_values=self.to_torch(self.values[batch_inds].flatten()), + old_log_prob=self.to_torch(self.log_probs[batch_inds].flatten()), + advantages=self.to_torch(self.advantages[batch_inds].flatten()), + returns=self.to_torch(self.returns[batch_inds].flatten()), + ) + diff --git a/dexart-release/stable_baselines3/common/callbacks.py b/dexart-release/stable_baselines3/common/callbacks.py new file mode 100644 index 0000000000000000000000000000000000000000..c5f297c7c7b11e00e4be1e03ac7ae4ede0b6ee45 --- /dev/null +++ b/dexart-release/stable_baselines3/common/callbacks.py @@ -0,0 +1,602 @@ +import os +import warnings +from abc import ABC, abstractmethod +from typing import Any, Callable, Dict, List, Optional, Union + +import gym +import numpy as np + +from stable_baselines3.common import base_class # pytype: disable=pyi-error +from stable_baselines3.common.evaluation import evaluate_policy +from stable_baselines3.common.vec_env import DummyVecEnv, VecEnv, sync_envs_normalization + + +class BaseCallback(ABC): + """ + Base class for callback. + + :param verbose: + """ + + def __init__(self, verbose: int = 0): + super().__init__() + # The RL model + self.model = None # type: Optional[base_class.BaseAlgorithm] + # An alias for self.model.get_env(), the environment used for training + self.training_env = None # type: Union[gym.Env, VecEnv, None] + # Number of time the callback was called + self.n_calls = 0 # type: int + # n_envs * n times env.step() was called + self.num_timesteps = 0 # type: int + self.verbose = verbose + self.locals: Dict[str, Any] = {} + self.globals: Dict[str, Any] = {} + self.logger = None + # Sometimes, for event callback, it is useful + # to have access to the parent object + self.parent = None # type: Optional[BaseCallback] + + # Type hint as string to avoid circular import + def init_callback(self, model: "base_class.BaseAlgorithm") -> None: + """ + Initialize the callback by saving references to the + RL model and the training environment for convenience. + """ + self.model = model + self.training_env = model.get_env() + self.logger = model.logger + self._init_callback() + + def _init_callback(self) -> None: + pass + + def on_training_start(self, locals_: Dict[str, Any], globals_: Dict[str, Any]) -> None: + # Those are reference and will be updated automatically + self.locals = locals_ + self.globals = globals_ + self._on_training_start() + + def _on_training_start(self) -> None: + pass + + def on_rollout_start(self) -> None: + self._on_rollout_start() + + def _on_rollout_start(self) -> None: + pass + + @abstractmethod + def _on_step(self) -> bool: + """ + :return: If the callback returns False, training is aborted early. + """ + return True + + def on_step(self) -> bool: + """ + This method will be called by the model after each call to ``env.step()``. + + For child callback (of an ``EventCallback``), this will be called + when the event is triggered. + + :return: If the callback returns False, training is aborted early. + """ + self.n_calls += 1 + # timesteps start at zero + self.num_timesteps = self.model.num_timesteps + + return self._on_step() + + def on_training_end(self) -> None: + self._on_training_end() + + def _on_training_end(self) -> None: + pass + + def on_rollout_end(self) -> None: + self._on_rollout_end() + + def _on_rollout_end(self) -> None: + pass + + def update_locals(self, locals_: Dict[str, Any]) -> None: + """ + Update the references to the local variables. + + :param locals_: the local variables during rollout collection + """ + self.locals.update(locals_) + self.update_child_locals(locals_) + + def update_child_locals(self, locals_: Dict[str, Any]) -> None: + """ + Update the references to the local variables on sub callbacks. + + :param locals_: the local variables during rollout collection + """ + pass + + +class EventCallback(BaseCallback): + """ + Base class for triggering callback on event. + + :param callback: Callback that will be called + when an event is triggered. + :param verbose: + """ + + def __init__(self, callback: Optional[BaseCallback] = None, verbose: int = 0): + super().__init__(verbose=verbose) + self.callback = callback + # Give access to the parent + if callback is not None: + self.callback.parent = self + + def init_callback(self, model: "base_class.BaseAlgorithm") -> None: + super().init_callback(model) + if self.callback is not None: + self.callback.init_callback(self.model) + + def _on_training_start(self) -> None: + if self.callback is not None: + self.callback.on_training_start(self.locals, self.globals) + + def _on_event(self) -> bool: + if self.callback is not None: + return self.callback.on_step() + return True + + def _on_step(self) -> bool: + return True + + def update_child_locals(self, locals_: Dict[str, Any]) -> None: + """ + Update the references to the local variables. + + :param locals_: the local variables during rollout collection + """ + if self.callback is not None: + self.callback.update_locals(locals_) + + +class CallbackList(BaseCallback): + """ + Class for chaining callbacks. + + :param callbacks: A list of callbacks that will be called + sequentially. + """ + + def __init__(self, callbacks: List[BaseCallback]): + super().__init__() + assert isinstance(callbacks, list) + self.callbacks = callbacks + + def _init_callback(self) -> None: + for callback in self.callbacks: + callback.init_callback(self.model) + + def _on_training_start(self) -> None: + for callback in self.callbacks: + callback.on_training_start(self.locals, self.globals) + + def _on_rollout_start(self) -> None: + for callback in self.callbacks: + callback.on_rollout_start() + + def _on_step(self) -> bool: + continue_training = True + for callback in self.callbacks: + # Return False (stop training) if at least one callback returns False + continue_training = callback.on_step() and continue_training + return continue_training + + def _on_rollout_end(self) -> None: + for callback in self.callbacks: + callback.on_rollout_end() + + def _on_training_end(self) -> None: + for callback in self.callbacks: + callback.on_training_end() + + def update_child_locals(self, locals_: Dict[str, Any]) -> None: + """ + Update the references to the local variables. + + :param locals_: the local variables during rollout collection + """ + for callback in self.callbacks: + callback.update_locals(locals_) + + +class CheckpointCallback(BaseCallback): + """ + Callback for saving a model every ``save_freq`` calls + to ``env.step()``. + + .. warning:: + + When using multiple environments, each call to ``env.step()`` + will effectively correspond to ``n_envs`` steps. + To account for that, you can use ``save_freq = max(save_freq // n_envs, 1)`` + + :param save_freq: + :param save_path: Path to the folder where the model will be saved. + :param name_prefix: Common prefix to the saved models + :param verbose: + """ + + def __init__(self, save_freq: int, save_path: str, name_prefix: str = "rl_model", verbose: int = 0): + super().__init__(verbose) + self.save_freq = save_freq + self.save_path = save_path + self.name_prefix = name_prefix + + def _init_callback(self) -> None: + # Create folder if needed + if self.save_path is not None: + os.makedirs(self.save_path, exist_ok=True) + + def _on_step(self) -> bool: + if self.n_calls % self.save_freq == 0: + path = os.path.join(self.save_path, f"{self.name_prefix}_{self.num_timesteps}_steps") + self.model.save(path) + if self.verbose > 1: + print(f"Saving model checkpoint to {path}") + return True + + +class ConvertCallback(BaseCallback): + """ + Convert functional callback (old-style) to object. + + :param callback: + :param verbose: + """ + + def __init__(self, callback: Callable[[Dict[str, Any], Dict[str, Any]], bool], verbose: int = 0): + super().__init__(verbose) + self.callback = callback + + def _on_step(self) -> bool: + if self.callback is not None: + return self.callback(self.locals, self.globals) + return True + + +class EvalCallback(EventCallback): + """ + Callback for evaluating an agent. + + .. warning:: + + When using multiple environments, each call to ``env.step()`` + will effectively correspond to ``n_envs`` steps. + To account for that, you can use ``eval_freq = max(eval_freq // n_envs, 1)`` + + :param eval_env: The environment used for initialization + :param callback_on_new_best: Callback to trigger + when there is a new best model according to the ``mean_reward`` + :param callback_after_eval: Callback to trigger after every evaluation + :param n_eval_episodes: The number of episodes to test the agent + :param eval_freq: Evaluate the agent every ``eval_freq`` call of the callback. + :param log_path: Path to a folder where the evaluations (``evaluations.npz``) + will be saved. It will be updated at each evaluation. + :param best_model_save_path: Path to a folder where the best model + according to performance on the eval env will be saved. + :param deterministic: Whether the evaluation should + use a stochastic or deterministic actions. + :param render: Whether to render or not the environment during evaluation + :param verbose: + :param warn: Passed to ``evaluate_policy`` (warns if ``eval_env`` has not been + wrapped with a Monitor wrapper) + """ + + def __init__( + self, + eval_env: Union[gym.Env, VecEnv], + callback_on_new_best: Optional[BaseCallback] = None, + callback_after_eval: Optional[BaseCallback] = None, + n_eval_episodes: int = 5, + eval_freq: int = 10000, + log_path: Optional[str] = None, + best_model_save_path: Optional[str] = None, + deterministic: bool = True, + render: bool = False, + verbose: int = 1, + warn: bool = True, + ): + super().__init__(callback_after_eval, verbose=verbose) + + self.callback_on_new_best = callback_on_new_best + if self.callback_on_new_best is not None: + # Give access to the parent + self.callback_on_new_best.parent = self + + self.n_eval_episodes = n_eval_episodes + self.eval_freq = eval_freq + self.best_mean_reward = -np.inf + self.last_mean_reward = -np.inf + self.deterministic = deterministic + self.render = render + self.warn = warn + + # Convert to VecEnv for consistency + if not isinstance(eval_env, VecEnv): + eval_env = DummyVecEnv([lambda: eval_env]) + + self.eval_env = eval_env + self.best_model_save_path = best_model_save_path + # Logs will be written in ``evaluations.npz`` + if log_path is not None: + log_path = os.path.join(log_path, "evaluations") + self.log_path = log_path + self.evaluations_results = [] + self.evaluations_timesteps = [] + self.evaluations_length = [] + # For computing success rate + self._is_success_buffer = [] + self.evaluations_successes = [] + + def _init_callback(self) -> None: + # Does not work in some corner cases, where the wrapper is not the same + if not isinstance(self.training_env, type(self.eval_env)): + warnings.warn("Training and eval env are not of the same type" f"{self.training_env} != {self.eval_env}") + + # Create folders if needed + if self.best_model_save_path is not None: + os.makedirs(self.best_model_save_path, exist_ok=True) + if self.log_path is not None: + os.makedirs(os.path.dirname(self.log_path), exist_ok=True) + + # Init callback called on new best model + if self.callback_on_new_best is not None: + self.callback_on_new_best.init_callback(self.model) + + def _log_success_callback(self, locals_: Dict[str, Any], globals_: Dict[str, Any]) -> None: + """ + Callback passed to the ``evaluate_policy`` function + in order to log the success rate (when applicable), + for instance when using HER. + + :param locals_: + :param globals_: + """ + info = locals_["info"] + + if locals_["done"]: + maybe_is_success = info.get("is_success") + if maybe_is_success is not None: + self._is_success_buffer.append(maybe_is_success) + + def _on_step(self) -> bool: + + continue_training = True + + if self.eval_freq > 0 and self.n_calls % self.eval_freq == 0: + + # Sync training and eval env if there is VecNormalize + if self.model.get_vec_normalize_env() is not None: + try: + sync_envs_normalization(self.training_env, self.eval_env) + except AttributeError: + raise AssertionError( + "Training and eval env are not wrapped the same way, " + "see https://stable-baselines3.readthedocs.io/en/master/guide/callbacks.html#evalcallback " + "and warning above." + ) + + # Reset success rate buffer + self._is_success_buffer = [] + + episode_rewards, episode_lengths = evaluate_policy( + self.model, + self.eval_env, + n_eval_episodes=self.n_eval_episodes, + render=self.render, + deterministic=self.deterministic, + return_episode_rewards=True, + warn=self.warn, + callback=self._log_success_callback, + ) + + if self.log_path is not None: + self.evaluations_timesteps.append(self.num_timesteps) + self.evaluations_results.append(episode_rewards) + self.evaluations_length.append(episode_lengths) + + kwargs = {} + # Save success log if present + if len(self._is_success_buffer) > 0: + self.evaluations_successes.append(self._is_success_buffer) + kwargs = dict(successes=self.evaluations_successes) + + np.savez( + self.log_path, + timesteps=self.evaluations_timesteps, + results=self.evaluations_results, + ep_lengths=self.evaluations_length, + **kwargs, + ) + + mean_reward, std_reward = np.mean(episode_rewards), np.std(episode_rewards) + mean_ep_length, std_ep_length = np.mean(episode_lengths), np.std(episode_lengths) + self.last_mean_reward = mean_reward + + if self.verbose > 0: + print(f"Eval num_timesteps={self.num_timesteps}, " f"episode_reward={mean_reward:.2f} +/- {std_reward:.2f}") + print(f"Episode length: {mean_ep_length:.2f} +/- {std_ep_length:.2f}") + # Add to current Logger + self.logger.record("eval/mean_reward", float(mean_reward)) + self.logger.record("eval/mean_ep_length", mean_ep_length) + + if len(self._is_success_buffer) > 0: + success_rate = np.mean(self._is_success_buffer) + if self.verbose > 0: + print(f"Success rate: {100 * success_rate:.2f}%") + self.logger.record("eval/success_rate", success_rate) + + # Dump log so the evaluation results are printed with the correct timestep + self.logger.record("time/total_timesteps", self.num_timesteps, exclude="tensorboard") + self.logger.dump(self.num_timesteps) + + if mean_reward > self.best_mean_reward: + if self.verbose > 0: + print("New best mean reward!") + if self.best_model_save_path is not None: + self.model.save(os.path.join(self.best_model_save_path, "best_model")) + self.best_mean_reward = mean_reward + # Trigger callback on new best model, if needed + if self.callback_on_new_best is not None: + continue_training = self.callback_on_new_best.on_step() + + # Trigger callback after every evaluation, if needed + if self.callback is not None: + continue_training = continue_training and self._on_event() + + return continue_training + + def update_child_locals(self, locals_: Dict[str, Any]) -> None: + """ + Update the references to the local variables. + + :param locals_: the local variables during rollout collection + """ + if self.callback: + self.callback.update_locals(locals_) + + +class StopTrainingOnRewardThreshold(BaseCallback): + """ + Stop the training once a threshold in episodic reward + has been reached (i.e. when the model is good enough). + + It must be used with the ``EvalCallback``. + + :param reward_threshold: Minimum expected reward per episode + to stop training. + :param verbose: + """ + + def __init__(self, reward_threshold: float, verbose: int = 0): + super().__init__(verbose=verbose) + self.reward_threshold = reward_threshold + + def _on_step(self) -> bool: + assert self.parent is not None, "``StopTrainingOnMinimumReward`` callback must be used " "with an ``EvalCallback``" + # Convert np.bool_ to bool, otherwise callback() is False won't work + continue_training = bool(self.parent.best_mean_reward < self.reward_threshold) + if self.verbose > 0 and not continue_training: + print( + f"Stopping training because the mean reward {self.parent.best_mean_reward:.2f} " + f" is above the threshold {self.reward_threshold}" + ) + return continue_training + + +class EveryNTimesteps(EventCallback): + """ + Trigger a callback every ``n_steps`` timesteps + + :param n_steps: Number of timesteps between two trigger. + :param callback: Callback that will be called + when the event is triggered. + """ + + def __init__(self, n_steps: int, callback: BaseCallback): + super().__init__(callback) + self.n_steps = n_steps + self.last_time_trigger = 0 + + def _on_step(self) -> bool: + if (self.num_timesteps - self.last_time_trigger) >= self.n_steps: + self.last_time_trigger = self.num_timesteps + return self._on_event() + return True + + +class StopTrainingOnMaxEpisodes(BaseCallback): + """ + Stop the training once a maximum number of episodes are played. + + For multiple environments presumes that, the desired behavior is that the agent trains on each env for ``max_episodes`` + and in total for ``max_episodes * n_envs`` episodes. + + :param max_episodes: Maximum number of episodes to stop training. + :param verbose: Select whether to print information about when training ended by reaching ``max_episodes`` + """ + + def __init__(self, max_episodes: int, verbose: int = 0): + super().__init__(verbose=verbose) + self.max_episodes = max_episodes + self._total_max_episodes = max_episodes + self.n_episodes = 0 + + def _init_callback(self) -> None: + # At start set total max according to number of envirnments + self._total_max_episodes = self.max_episodes * self.training_env.num_envs + + def _on_step(self) -> bool: + # Check that the `dones` local variable is defined + assert "dones" in self.locals, "`dones` variable is not defined, please check your code next to `callback.on_step()`" + self.n_episodes += np.sum(self.locals["dones"]).item() + + continue_training = self.n_episodes < self._total_max_episodes + + if self.verbose > 0 and not continue_training: + mean_episodes_per_env = self.n_episodes / self.training_env.num_envs + mean_ep_str = ( + f"with an average of {mean_episodes_per_env:.2f} episodes per env" if self.training_env.num_envs > 1 else "" + ) + + print( + f"Stopping training with a total of {self.num_timesteps} steps because the " + f"{self.locals.get('tb_log_name')} model reached max_episodes={self.max_episodes}, " + f"by playing for {self.n_episodes} episodes " + f"{mean_ep_str}" + ) + return continue_training + + +class StopTrainingOnNoModelImprovement(BaseCallback): + """ + Stop the training early if there is no new best model (new best mean reward) after more than N consecutive evaluations. + + It is possible to define a minimum number of evaluations before start to count evaluations without improvement. + + It must be used with the ``EvalCallback``. + + :param max_no_improvement_evals: Maximum number of consecutive evaluations without a new best model. + :param min_evals: Number of evaluations before start to count evaluations without improvements. + :param verbose: Verbosity of the output (set to 1 for info messages) + """ + + def __init__(self, max_no_improvement_evals: int, min_evals: int = 0, verbose: int = 0): + super().__init__(verbose=verbose) + self.max_no_improvement_evals = max_no_improvement_evals + self.min_evals = min_evals + self.last_best_mean_reward = -np.inf + self.no_improvement_evals = 0 + + def _on_step(self) -> bool: + assert self.parent is not None, "``StopTrainingOnNoModelImprovement`` callback must be used with an ``EvalCallback``" + + continue_training = True + + if self.n_calls > self.min_evals: + if self.parent.best_mean_reward > self.last_best_mean_reward: + self.no_improvement_evals = 0 + else: + self.no_improvement_evals += 1 + if self.no_improvement_evals > self.max_no_improvement_evals: + continue_training = False + + self.last_best_mean_reward = self.parent.best_mean_reward + + if self.verbose > 0 and not continue_training: + print( + f"Stopping training because there was no new best model in the last {self.no_improvement_evals:d} evaluations" + ) + + return continue_training diff --git a/dexart-release/stable_baselines3/common/distributions.py b/dexart-release/stable_baselines3/common/distributions.py new file mode 100644 index 0000000000000000000000000000000000000000..3d1ff5aa0bb1730b24f460f4c479d30058365448 --- /dev/null +++ b/dexart-release/stable_baselines3/common/distributions.py @@ -0,0 +1,699 @@ +"""Probability distributions.""" + +from abc import ABC, abstractmethod +from typing import Any, Dict, List, Optional, Tuple, Union + +import gym +import torch as th +from gym import spaces +from torch import nn +from torch.distributions import Bernoulli, Categorical, Normal + +from stable_baselines3.common.preprocessing import get_action_dim + + +class Distribution(ABC): + """Abstract base class for distributions.""" + + def __init__(self): + super().__init__() + self.distribution = None + + @abstractmethod + def proba_distribution_net(self, *args, **kwargs) -> Union[nn.Module, Tuple[nn.Module, nn.Parameter]]: + """Create the layers and parameters that represent the distribution. + + Subclasses must define this, but the arguments and return type vary between + concrete classes.""" + + @abstractmethod + def proba_distribution(self, *args, **kwargs) -> "Distribution": + """Set parameters of the distribution. + + :return: self + """ + + @abstractmethod + def log_prob(self, x: th.Tensor) -> th.Tensor: + """ + Returns the log likelihood + + :param x: the taken action + :return: The log likelihood of the distribution + """ + + @abstractmethod + def entropy(self) -> Optional[th.Tensor]: + """ + Returns Shannon's entropy of the probability + + :return: the entropy, or None if no analytical form is known + """ + + @abstractmethod + def sample(self) -> th.Tensor: + """ + Returns a sample from the probability distribution + + :return: the stochastic action + """ + + @abstractmethod + def mode(self) -> th.Tensor: + """ + Returns the most likely action (deterministic output) + from the probability distribution + + :return: the stochastic action + """ + + def get_actions(self, deterministic: bool = False) -> th.Tensor: + """ + Return actions according to the probability distribution. + + :param deterministic: + :return: + """ + if deterministic: + return self.mode() + return self.sample() + + @abstractmethod + def actions_from_params(self, *args, **kwargs) -> th.Tensor: + """ + Returns samples from the probability distribution + given its parameters. + + :return: actions + """ + + @abstractmethod + def log_prob_from_params(self, *args, **kwargs) -> Tuple[th.Tensor, th.Tensor]: + """ + Returns samples and the associated log probabilities + from the probability distribution given its parameters. + + :return: actions and log prob + """ + + +def sum_independent_dims(tensor: th.Tensor) -> th.Tensor: + """ + Continuous actions are usually considered to be independent, + so we can sum components of the ``log_prob`` or the entropy. + + :param tensor: shape: (n_batch, n_actions) or (n_batch,) + :return: shape: (n_batch,) + """ + if len(tensor.shape) > 1: + tensor = tensor.sum(dim=1) + else: + tensor = tensor.sum() + return tensor + + +class DiagGaussianDistribution(Distribution): + """ + Gaussian distribution with diagonal covariance matrix, for continuous actions. + + :param action_dim: Dimension of the action space. + """ + + def __init__(self, action_dim: int): + super().__init__() + self.action_dim = action_dim + self.mean_actions = None + self.log_std = None + + def proba_distribution_net(self, latent_dim: int, log_std_init: float = 0.0) -> Tuple[nn.Module, nn.Parameter]: + """ + Create the layers and parameter that represent the distribution: + one output will be the mean of the Gaussian, the other parameter will be the + standard deviation (log std in fact to allow negative values) + + :param latent_dim: Dimension of the last layer of the policy (before the action layer) + :param log_std_init: Initial value for the log standard deviation + :return: + """ + mean_actions = nn.Linear(latent_dim, self.action_dim) + # TODO: allow action dependent std + log_std = nn.Parameter(th.ones(self.action_dim) * log_std_init, requires_grad=True) + return mean_actions, log_std + + def proba_distribution(self, mean_actions: th.Tensor, log_std: th.Tensor) -> "DiagGaussianDistribution": + """ + Create the distribution given its parameters (mean, std) + + :param mean_actions: + :param log_std: + :return: + """ + action_std = th.ones_like(mean_actions) * log_std.exp() + self.distribution = Normal(mean_actions, action_std) + return self + + def log_prob(self, actions: th.Tensor) -> th.Tensor: + """ + Get the log probabilities of actions according to the distribution. + Note that you must first call the ``proba_distribution()`` method. + + :param actions: + :return: + """ + log_prob = self.distribution.log_prob(actions) + return sum_independent_dims(log_prob) + + def entropy(self) -> th.Tensor: + return sum_independent_dims(self.distribution.entropy()) + + def sample(self) -> th.Tensor: + # Reparametrization trick to pass gradients + return self.distribution.rsample() + + def mode(self) -> th.Tensor: + return self.distribution.mean + + def actions_from_params(self, mean_actions: th.Tensor, log_std: th.Tensor, deterministic: bool = False) -> th.Tensor: + # Update the proba distribution + self.proba_distribution(mean_actions, log_std) + return self.get_actions(deterministic=deterministic) + + def log_prob_from_params(self, mean_actions: th.Tensor, log_std: th.Tensor) -> Tuple[th.Tensor, th.Tensor]: + """ + Compute the log probability of taking an action + given the distribution parameters. + + :param mean_actions: + :param log_std: + :return: + """ + actions = self.actions_from_params(mean_actions, log_std) + log_prob = self.log_prob(actions) + return actions, log_prob + + +class SquashedDiagGaussianDistribution(DiagGaussianDistribution): + """ + Gaussian distribution with diagonal covariance matrix, followed by a squashing function (tanh) to ensure bounds. + + :param action_dim: Dimension of the action space. + :param epsilon: small value to avoid NaN due to numerical imprecision. + """ + + def __init__(self, action_dim: int, epsilon: float = 1e-6): + super().__init__(action_dim) + # Avoid NaN (prevents division by zero or log of zero) + self.epsilon = epsilon + self.gaussian_actions = None + + def proba_distribution(self, mean_actions: th.Tensor, log_std: th.Tensor) -> "SquashedDiagGaussianDistribution": + super().proba_distribution(mean_actions, log_std) + return self + + def log_prob(self, actions: th.Tensor, gaussian_actions: Optional[th.Tensor] = None) -> th.Tensor: + # Inverse tanh + # Naive implementation (not stable): 0.5 * torch.log((1 + x) / (1 - x)) + # We use numpy to avoid numerical instability + if gaussian_actions is None: + # It will be clipped to avoid NaN when inversing tanh + gaussian_actions = TanhBijector.inverse(actions) + + # Log likelihood for a Gaussian distribution + log_prob = super().log_prob(gaussian_actions) + # Squash correction (from original SAC implementation) + # this comes from the fact that tanh is bijective and differentiable + log_prob -= th.sum(th.log(1 - actions**2 + self.epsilon), dim=1) + return log_prob + + def entropy(self) -> Optional[th.Tensor]: + # No analytical form, + # entropy needs to be estimated using -log_prob.mean() + return None + + def sample(self) -> th.Tensor: + # Reparametrization trick to pass gradients + self.gaussian_actions = super().sample() + return th.tanh(self.gaussian_actions) + + def mode(self) -> th.Tensor: + self.gaussian_actions = super().mode() + # Squash the output + return th.tanh(self.gaussian_actions) + + def log_prob_from_params(self, mean_actions: th.Tensor, log_std: th.Tensor) -> Tuple[th.Tensor, th.Tensor]: + action = self.actions_from_params(mean_actions, log_std) + log_prob = self.log_prob(action, self.gaussian_actions) + return action, log_prob + + +class CategoricalDistribution(Distribution): + """ + Categorical distribution for discrete actions. + + :param action_dim: Number of discrete actions + """ + + def __init__(self, action_dim: int): + super().__init__() + self.action_dim = action_dim + + def proba_distribution_net(self, latent_dim: int) -> nn.Module: + """ + Create the layer that represents the distribution: + it will be the logits of the Categorical distribution. + You can then get probabilities using a softmax. + + :param latent_dim: Dimension of the last layer + of the policy network (before the action layer) + :return: + """ + action_logits = nn.Linear(latent_dim, self.action_dim) + return action_logits + + def proba_distribution(self, action_logits: th.Tensor) -> "CategoricalDistribution": + self.distribution = Categorical(logits=action_logits) + return self + + def log_prob(self, actions: th.Tensor) -> th.Tensor: + return self.distribution.log_prob(actions) + + def entropy(self) -> th.Tensor: + return self.distribution.entropy() + + def sample(self) -> th.Tensor: + return self.distribution.sample() + + def mode(self) -> th.Tensor: + return th.argmax(self.distribution.probs, dim=1) + + def actions_from_params(self, action_logits: th.Tensor, deterministic: bool = False) -> th.Tensor: + # Update the proba distribution + self.proba_distribution(action_logits) + return self.get_actions(deterministic=deterministic) + + def log_prob_from_params(self, action_logits: th.Tensor) -> Tuple[th.Tensor, th.Tensor]: + actions = self.actions_from_params(action_logits) + log_prob = self.log_prob(actions) + return actions, log_prob + + +class MultiCategoricalDistribution(Distribution): + """ + MultiCategorical distribution for multi discrete actions. + + :param action_dims: List of sizes of discrete action spaces + """ + + def __init__(self, action_dims: List[int]): + super().__init__() + self.action_dims = action_dims + + def proba_distribution_net(self, latent_dim: int) -> nn.Module: + """ + Create the layer that represents the distribution: + it will be the logits (flattened) of the MultiCategorical distribution. + You can then get probabilities using a softmax on each sub-space. + + :param latent_dim: Dimension of the last layer + of the policy network (before the action layer) + :return: + """ + + action_logits = nn.Linear(latent_dim, sum(self.action_dims)) + return action_logits + + def proba_distribution(self, action_logits: th.Tensor) -> "MultiCategoricalDistribution": + self.distribution = [Categorical(logits=split) for split in th.split(action_logits, tuple(self.action_dims), dim=1)] + return self + + def log_prob(self, actions: th.Tensor) -> th.Tensor: + # Extract each discrete action and compute log prob for their respective distributions + return th.stack( + [dist.log_prob(action) for dist, action in zip(self.distribution, th.unbind(actions, dim=1))], dim=1 + ).sum(dim=1) + + def entropy(self) -> th.Tensor: + return th.stack([dist.entropy() for dist in self.distribution], dim=1).sum(dim=1) + + def sample(self) -> th.Tensor: + return th.stack([dist.sample() for dist in self.distribution], dim=1) + + def mode(self) -> th.Tensor: + return th.stack([th.argmax(dist.probs, dim=1) for dist in self.distribution], dim=1) + + def actions_from_params(self, action_logits: th.Tensor, deterministic: bool = False) -> th.Tensor: + # Update the proba distribution + self.proba_distribution(action_logits) + return self.get_actions(deterministic=deterministic) + + def log_prob_from_params(self, action_logits: th.Tensor) -> Tuple[th.Tensor, th.Tensor]: + actions = self.actions_from_params(action_logits) + log_prob = self.log_prob(actions) + return actions, log_prob + + +class BernoulliDistribution(Distribution): + """ + Bernoulli distribution for MultiBinary action spaces. + + :param action_dim: Number of binary actions + """ + + def __init__(self, action_dims: int): + super().__init__() + self.action_dims = action_dims + + def proba_distribution_net(self, latent_dim: int) -> nn.Module: + """ + Create the layer that represents the distribution: + it will be the logits of the Bernoulli distribution. + + :param latent_dim: Dimension of the last layer + of the policy network (before the action layer) + :return: + """ + action_logits = nn.Linear(latent_dim, self.action_dims) + return action_logits + + def proba_distribution(self, action_logits: th.Tensor) -> "BernoulliDistribution": + self.distribution = Bernoulli(logits=action_logits) + return self + + def log_prob(self, actions: th.Tensor) -> th.Tensor: + return self.distribution.log_prob(actions).sum(dim=1) + + def entropy(self) -> th.Tensor: + return self.distribution.entropy().sum(dim=1) + + def sample(self) -> th.Tensor: + return self.distribution.sample() + + def mode(self) -> th.Tensor: + return th.round(self.distribution.probs) + + def actions_from_params(self, action_logits: th.Tensor, deterministic: bool = False) -> th.Tensor: + # Update the proba distribution + self.proba_distribution(action_logits) + return self.get_actions(deterministic=deterministic) + + def log_prob_from_params(self, action_logits: th.Tensor) -> Tuple[th.Tensor, th.Tensor]: + actions = self.actions_from_params(action_logits) + log_prob = self.log_prob(actions) + return actions, log_prob + + +class StateDependentNoiseDistribution(Distribution): + """ + Distribution class for using generalized State Dependent Exploration (gSDE). + Paper: https://arxiv.org/abs/2005.05719 + + It is used to create the noise exploration matrix and + compute the log probability of an action with that noise. + + :param action_dim: Dimension of the action space. + :param full_std: Whether to use (n_features x n_actions) parameters + for the std instead of only (n_features,) + :param use_expln: Use ``expln()`` function instead of ``exp()`` to ensure + a positive standard deviation (cf paper). It allows to keep variance + above zero and prevent it from growing too fast. In practice, ``exp()`` is usually enough. + :param squash_output: Whether to squash the output using a tanh function, + this ensures bounds are satisfied. + :param learn_features: Whether to learn features for gSDE or not. + This will enable gradients to be backpropagated through the features + ``latent_sde`` in the code. + :param epsilon: small value to avoid NaN due to numerical imprecision. + """ + + def __init__( + self, + action_dim: int, + full_std: bool = True, + use_expln: bool = False, + squash_output: bool = False, + learn_features: bool = False, + epsilon: float = 1e-6, + ): + super().__init__() + self.action_dim = action_dim + self.latent_sde_dim = None + self.mean_actions = None + self.log_std = None + self.weights_dist = None + self.exploration_mat = None + self.exploration_matrices = None + self._latent_sde = None + self.use_expln = use_expln + self.full_std = full_std + self.epsilon = epsilon + self.learn_features = learn_features + if squash_output: + self.bijector = TanhBijector(epsilon) + else: + self.bijector = None + + def get_std(self, log_std: th.Tensor) -> th.Tensor: + """ + Get the standard deviation from the learned parameter + (log of it by default). This ensures that the std is positive. + + :param log_std: + :return: + """ + if self.use_expln: + # From gSDE paper, it allows to keep variance + # above zero and prevent it from growing too fast + below_threshold = th.exp(log_std) * (log_std <= 0) + # Avoid NaN: zeros values that are below zero + safe_log_std = log_std * (log_std > 0) + self.epsilon + above_threshold = (th.log1p(safe_log_std) + 1.0) * (log_std > 0) + std = below_threshold + above_threshold + else: + # Use normal exponential + std = th.exp(log_std) + + if self.full_std: + return std + # Reduce the number of parameters: + return th.ones(self.latent_sde_dim, self.action_dim).to(log_std.device) * std + + def sample_weights(self, log_std: th.Tensor, batch_size: int = 1) -> None: + """ + Sample weights for the noise exploration matrix, + using a centered Gaussian distribution. + + :param log_std: + :param batch_size: + """ + std = self.get_std(log_std) + self.weights_dist = Normal(th.zeros_like(std), std) + # Reparametrization trick to pass gradients + self.exploration_mat = self.weights_dist.rsample() + # Pre-compute matrices in case of parallel exploration + self.exploration_matrices = self.weights_dist.rsample((batch_size,)) + + def proba_distribution_net( + self, latent_dim: int, log_std_init: float = -2.0, latent_sde_dim: Optional[int] = None + ) -> Tuple[nn.Module, nn.Parameter]: + """ + Create the layers and parameter that represent the distribution: + one output will be the deterministic action, the other parameter will be the + standard deviation of the distribution that control the weights of the noise matrix. + + :param latent_dim: Dimension of the last layer of the policy (before the action layer) + :param log_std_init: Initial value for the log standard deviation + :param latent_sde_dim: Dimension of the last layer of the features extractor + for gSDE. By default, it is shared with the policy network. + :return: + """ + # Network for the deterministic action, it represents the mean of the distribution + mean_actions_net = nn.Linear(latent_dim, self.action_dim) + # When we learn features for the noise, the feature dimension + # can be different between the policy and the noise network + self.latent_sde_dim = latent_dim if latent_sde_dim is None else latent_sde_dim + # Reduce the number of parameters if needed + log_std = th.ones(self.latent_sde_dim, self.action_dim) if self.full_std else th.ones(self.latent_sde_dim, 1) + # Transform it to a parameter so it can be optimized + log_std = nn.Parameter(log_std * log_std_init, requires_grad=True) + # Sample an exploration matrix + self.sample_weights(log_std) + return mean_actions_net, log_std + + def proba_distribution( + self, mean_actions: th.Tensor, log_std: th.Tensor, latent_sde: th.Tensor + ) -> "StateDependentNoiseDistribution": + """ + Create the distribution given its parameters (mean, std) + + :param mean_actions: + :param log_std: + :param latent_sde: + :return: + """ + # Stop gradient if we don't want to influence the features + self._latent_sde = latent_sde if self.learn_features else latent_sde.detach() + variance = th.mm(self._latent_sde**2, self.get_std(log_std) ** 2) + self.distribution = Normal(mean_actions, th.sqrt(variance + self.epsilon)) + return self + + def log_prob(self, actions: th.Tensor) -> th.Tensor: + if self.bijector is not None: + gaussian_actions = self.bijector.inverse(actions) + else: + gaussian_actions = actions + # log likelihood for a gaussian + log_prob = self.distribution.log_prob(gaussian_actions) + # Sum along action dim + log_prob = sum_independent_dims(log_prob) + + if self.bijector is not None: + # Squash correction (from original SAC implementation) + log_prob -= th.sum(self.bijector.log_prob_correction(gaussian_actions), dim=1) + return log_prob + + def entropy(self) -> Optional[th.Tensor]: + if self.bijector is not None: + # No analytical form, + # entropy needs to be estimated using -log_prob.mean() + return None + return sum_independent_dims(self.distribution.entropy()) + + def sample(self) -> th.Tensor: + noise = self.get_noise(self._latent_sde) + actions = self.distribution.mean + noise + if self.bijector is not None: + return self.bijector.forward(actions) + return actions + + def mode(self) -> th.Tensor: + actions = self.distribution.mean + if self.bijector is not None: + return self.bijector.forward(actions) + return actions + + def get_noise(self, latent_sde: th.Tensor) -> th.Tensor: + latent_sde = latent_sde if self.learn_features else latent_sde.detach() + # Default case: only one exploration matrix + if len(latent_sde) == 1 or len(latent_sde) != len(self.exploration_matrices): + return th.mm(latent_sde, self.exploration_mat) + # Use batch matrix multiplication for efficient computation + # (batch_size, n_features) -> (batch_size, 1, n_features) + latent_sde = latent_sde.unsqueeze(1) + # (batch_size, 1, n_actions) + noise = th.bmm(latent_sde, self.exploration_matrices) + return noise.squeeze(1) + + def actions_from_params( + self, mean_actions: th.Tensor, log_std: th.Tensor, latent_sde: th.Tensor, deterministic: bool = False + ) -> th.Tensor: + # Update the proba distribution + self.proba_distribution(mean_actions, log_std, latent_sde) + return self.get_actions(deterministic=deterministic) + + def log_prob_from_params( + self, mean_actions: th.Tensor, log_std: th.Tensor, latent_sde: th.Tensor + ) -> Tuple[th.Tensor, th.Tensor]: + actions = self.actions_from_params(mean_actions, log_std, latent_sde) + log_prob = self.log_prob(actions) + return actions, log_prob + + +class TanhBijector: + """ + Bijective transformation of a probability distribution + using a squashing function (tanh) + TODO: use Pyro instead (https://pyro.ai/) + + :param epsilon: small value to avoid NaN due to numerical imprecision. + """ + + def __init__(self, epsilon: float = 1e-6): + super().__init__() + self.epsilon = epsilon + + @staticmethod + def forward(x: th.Tensor) -> th.Tensor: + return th.tanh(x) + + @staticmethod + def atanh(x: th.Tensor) -> th.Tensor: + """ + Inverse of Tanh + + Taken from Pyro: https://github.com/pyro-ppl/pyro + 0.5 * torch.log((1 + x ) / (1 - x)) + """ + return 0.5 * (x.log1p() - (-x).log1p()) + + @staticmethod + def inverse(y: th.Tensor) -> th.Tensor: + """ + Inverse tanh. + + :param y: + :return: + """ + eps = th.finfo(y.dtype).eps + # Clip the action to avoid NaN + return TanhBijector.atanh(y.clamp(min=-1.0 + eps, max=1.0 - eps)) + + def log_prob_correction(self, x: th.Tensor) -> th.Tensor: + # Squash correction (from original SAC implementation) + return th.log(1.0 - th.tanh(x) ** 2 + self.epsilon) + + +def make_proba_distribution( + action_space: gym.spaces.Space, use_sde: bool = False, dist_kwargs: Optional[Dict[str, Any]] = None +) -> Distribution: + """ + Return an instance of Distribution for the correct type of action space + + :param action_space: the input action space + :param use_sde: Force the use of StateDependentNoiseDistribution + instead of DiagGaussianDistribution + :param dist_kwargs: Keyword arguments to pass to the probability distribution + :return: the appropriate Distribution object + """ + if dist_kwargs is None: + dist_kwargs = {} + + if isinstance(action_space, spaces.Box): + assert len(action_space.shape) == 1, "Error: the action space must be a vector" + cls = StateDependentNoiseDistribution if use_sde else DiagGaussianDistribution + return cls(get_action_dim(action_space), **dist_kwargs) + elif isinstance(action_space, spaces.Discrete): + return CategoricalDistribution(action_space.n, **dist_kwargs) + elif isinstance(action_space, spaces.MultiDiscrete): + return MultiCategoricalDistribution(action_space.nvec, **dist_kwargs) + elif isinstance(action_space, spaces.MultiBinary): + return BernoulliDistribution(action_space.n, **dist_kwargs) + else: + raise NotImplementedError( + "Error: probability distribution, not implemented for action space" + f"of type {type(action_space)}." + " Must be of type Gym Spaces: Box, Discrete, MultiDiscrete or MultiBinary." + ) + + +def kl_divergence(dist_true: Distribution, dist_pred: Distribution) -> th.Tensor: + """ + Wrapper for the PyTorch implementation of the full form KL Divergence + + :param dist_true: the p distribution + :param dist_pred: the q distribution + :return: KL(dist_true||dist_pred) + """ + # KL Divergence for different distribution types is out of scope + assert dist_true.__class__ == dist_pred.__class__, "Error: input distributions should be the same type" + + # MultiCategoricalDistribution is not a PyTorch Distribution subclass + # so we need to implement it ourselves! + if isinstance(dist_pred, MultiCategoricalDistribution): + assert dist_pred.action_dims == dist_true.action_dims, "Error: distributions must have the same input space" + return th.stack( + [th.distributions.kl_divergence(p, q) for p, q in zip(dist_true.distribution, dist_pred.distribution)], + dim=1, + ).sum(dim=1) + + # Use the PyTorch kl_divergence implementation + else: + return th.distributions.kl_divergence(dist_true.distribution, dist_pred.distribution) diff --git a/dexart-release/stable_baselines3/common/env_util.py b/dexart-release/stable_baselines3/common/env_util.py new file mode 100644 index 0000000000000000000000000000000000000000..297cce30404fc8811ca3371a01a7ad6824c5e274 --- /dev/null +++ b/dexart-release/stable_baselines3/common/env_util.py @@ -0,0 +1,104 @@ +import os +from typing import Any, Callable, Dict, Optional, Type, Union + +import gym + +from stable_baselines3.common.monitor import Monitor +from stable_baselines3.common.vec_env import DummyVecEnv, SubprocVecEnv, VecEnv + + +def unwrap_wrapper(env: gym.Env, wrapper_class: Type[gym.Wrapper]) -> Optional[gym.Wrapper]: + """ + Retrieve a ``VecEnvWrapper`` object by recursively searching. + + :param env: Environment to unwrap + :param wrapper_class: Wrapper to look for + :return: Environment unwrapped till ``wrapper_class`` if it has been wrapped with it + """ + env_tmp = env + while isinstance(env_tmp, gym.Wrapper): + if isinstance(env_tmp, wrapper_class): + return env_tmp + env_tmp = env_tmp.env + return None + + +def is_wrapped(env: Type[gym.Env], wrapper_class: Type[gym.Wrapper]) -> bool: + """ + Check if a given environment has been wrapped with a given wrapper. + + :param env: Environment to check + :param wrapper_class: Wrapper class to look for + :return: True if environment has been wrapped with ``wrapper_class``. + """ + return unwrap_wrapper(env, wrapper_class) is not None + + +def make_vec_env( + env_id: Union[str, Type[gym.Env]], + n_envs: int = 1, + seed: Optional[int] = None, + start_index: int = 0, + monitor_dir: Optional[str] = None, + wrapper_class: Optional[Callable[[gym.Env], gym.Env]] = None, + env_kwargs: Optional[Dict[str, Any]] = None, + vec_env_cls: Optional[Type[Union[DummyVecEnv, SubprocVecEnv]]] = None, + vec_env_kwargs: Optional[Dict[str, Any]] = None, + monitor_kwargs: Optional[Dict[str, Any]] = None, + wrapper_kwargs: Optional[Dict[str, Any]] = None, +) -> VecEnv: + """ + Create a wrapped, monitored ``VecEnv``. + By default it uses a ``DummyVecEnv`` which is usually faster + than a ``SubprocVecEnv``. + + :param env_id: the environment ID or the environment class + :param n_envs: the number of environments you wish to have in parallel + :param seed: the initial seed for the random number generator + :param start_index: start rank index + :param monitor_dir: Path to a folder where the monitor files will be saved. + If None, no file will be written, however, the env will still be wrapped + in a Monitor wrapper to provide additional information about training. + :param wrapper_class: Additional wrapper to use on the environment. + This can also be a function with single argument that wraps the environment in many things. + :param env_kwargs: Optional keyword argument to pass to the env constructor + :param vec_env_cls: A custom ``VecEnv`` class constructor. Default: None. + :param vec_env_kwargs: Keyword arguments to pass to the ``VecEnv`` class constructor. + :param monitor_kwargs: Keyword arguments to pass to the ``Monitor`` class constructor. + :param wrapper_kwargs: Keyword arguments to pass to the ``Wrapper`` class constructor. + :return: The wrapped environment + """ + env_kwargs = {} if env_kwargs is None else env_kwargs + vec_env_kwargs = {} if vec_env_kwargs is None else vec_env_kwargs + monitor_kwargs = {} if monitor_kwargs is None else monitor_kwargs + wrapper_kwargs = {} if wrapper_kwargs is None else wrapper_kwargs + + def make_env(rank): + def _init(): + if isinstance(env_id, str): + env = gym.make(env_id, **env_kwargs) + else: + env = env_id(**env_kwargs) + if seed is not None: + env.seed(seed + rank) + env.action_space.seed(seed + rank) + # Wrap the env in a Monitor wrapper + # to have additional training information + monitor_path = os.path.join(monitor_dir, str(rank)) if monitor_dir is not None else None + # Create the monitor folder if needed + if monitor_path is not None: + os.makedirs(monitor_dir, exist_ok=True) + env = Monitor(env, filename=monitor_path, **monitor_kwargs) + # Optionally, wrap the environment with the provided wrapper + if wrapper_class is not None: + env = wrapper_class(env, **wrapper_kwargs) + return env + + return _init + + # No custom VecEnv is passed + if vec_env_cls is None: + # Default: use a DummyVecEnv + vec_env_cls = DummyVecEnv + + return vec_env_cls([make_env(i + start_index) for i in range(n_envs)], **vec_env_kwargs) \ No newline at end of file diff --git a/dexart-release/stable_baselines3/common/evaluation.py b/dexart-release/stable_baselines3/common/evaluation.py new file mode 100644 index 0000000000000000000000000000000000000000..e3f14d3f8083c0e4b98efb784a629c0b75141470 --- /dev/null +++ b/dexart-release/stable_baselines3/common/evaluation.py @@ -0,0 +1,131 @@ +import warnings +from typing import Any, Callable, Dict, List, Optional, Tuple, Union + +import gym +import numpy as np + +from stable_baselines3.common import base_class +from stable_baselines3.common.vec_env import DummyVecEnv, VecEnv, VecMonitor, is_vecenv_wrapped + + +def evaluate_policy( + model: "base_class.BaseAlgorithm", + env: Union[gym.Env, VecEnv], + n_eval_episodes: int = 10, + deterministic: bool = True, + render: bool = False, + callback: Optional[Callable[[Dict[str, Any], Dict[str, Any]], None]] = None, + reward_threshold: Optional[float] = None, + return_episode_rewards: bool = False, + warn: bool = True, +) -> Union[Tuple[float, float], Tuple[List[float], List[int]]]: + """ + Runs policy for ``n_eval_episodes`` episodes and returns average reward. + If a vector env is passed in, this divides the episodes to evaluate onto the + different elements of the vector env. This static division of work is done to + remove bias. See https://github.com/DLR-RM/stable-baselines3/issues/402 for more + details and discussion. + + .. note:: + If environment has not been wrapped with ``Monitor`` wrapper, reward and + episode lengths are counted as it appears with ``env.step`` calls. If + the environment contains wrappers that modify rewards or episode lengths + (e.g. reward scaling, early episode reset), these will affect the evaluation + results as well. You can avoid this by wrapping environment with ``Monitor`` + wrapper before anything else. + + :param model: The RL agent you want to evaluate. + :param env: The gym environment or ``VecEnv`` environment. + :param n_eval_episodes: Number of episode to evaluate the agent + :param deterministic: Whether to use deterministic or stochastic actions + :param render: Whether to render the environment or not + :param callback: callback function to do additional checks, + called after each step. Gets locals() and globals() passed as parameters. + :param reward_threshold: Minimum expected reward per episode, + this will raise an error if the performance is not met + :param return_episode_rewards: If True, a list of rewards and episode lengths + per episode will be returned instead of the mean. + :param warn: If True (default), warns user about lack of a Monitor wrapper in the + evaluation environment. + :return: Mean reward per episode, std of reward per episode. + Returns ([float], [int]) when ``return_episode_rewards`` is True, first + list containing per-episode rewards and second containing per-episode lengths + (in number of steps). + """ + is_monitor_wrapped = False + # Avoid circular import + from stable_baselines3.common.monitor import Monitor + + if not isinstance(env, VecEnv): + env = DummyVecEnv([lambda: env]) + + is_monitor_wrapped = is_vecenv_wrapped(env, VecMonitor) or env.env_is_wrapped(Monitor)[0] + + if not is_monitor_wrapped and warn: + warnings.warn( + "Evaluation environment is not wrapped with a ``Monitor`` wrapper. " + "This may result in reporting modified episode lengths and rewards, if other wrappers happen to modify these. " + "Consider wrapping environment first with ``Monitor`` wrapper.", + UserWarning, + ) + + n_envs = env.num_envs + episode_rewards = [] + episode_lengths = [] + + episode_counts = np.zeros(n_envs, dtype="int") + # Divides episodes among different sub environments in the vector as evenly as possible + episode_count_targets = np.array([(n_eval_episodes + i) // n_envs for i in range(n_envs)], dtype="int") + + current_rewards = np.zeros(n_envs) + current_lengths = np.zeros(n_envs, dtype="int") + observations = env.reset() + states = None + episode_starts = np.ones((env.num_envs,), dtype=bool) + while (episode_counts < episode_count_targets).any(): + actions, states = model.predict(observations, state=states, episode_start=episode_starts, deterministic=deterministic) + observations, rewards, dones, infos = env.step(actions) + current_rewards += rewards + current_lengths += 1 + for i in range(n_envs): + if episode_counts[i] < episode_count_targets[i]: + + # unpack values so that the callback can access the local variables + reward = rewards[i] + done = dones[i] + info = infos[i] + episode_starts[i] = done + + if callback is not None: + callback(locals(), globals()) + + if dones[i]: + if is_monitor_wrapped: + # Atari wrapper can send a "done" signal when + # the agent loses a life, but it does not correspond + # to the true end of episode + if "episode" in info.keys(): + # Do not trust "done" with episode endings. + # Monitor wrapper includes "episode" key in info if environment + # has been wrapped with it. Use those rewards instead. + episode_rewards.append(info["episode"]["r"]) + episode_lengths.append(info["episode"]["l"]) + # Only increment at the real end of an episode + episode_counts[i] += 1 + else: + episode_rewards.append(current_rewards[i]) + episode_lengths.append(current_lengths[i]) + episode_counts[i] += 1 + current_rewards[i] = 0 + current_lengths[i] = 0 + + if render: + env.render() + + mean_reward = np.mean(episode_rewards) + std_reward = np.std(episode_rewards) + if reward_threshold is not None: + assert mean_reward > reward_threshold, "Mean reward below threshold: " f"{mean_reward:.2f} < {reward_threshold:.2f}" + if return_episode_rewards: + return episode_rewards, episode_lengths + return mean_reward, std_reward diff --git a/dexart-release/stable_baselines3/common/logger.py b/dexart-release/stable_baselines3/common/logger.py new file mode 100644 index 0000000000000000000000000000000000000000..1be6945c5bbfc1ef5905c5cdec1851182faea7bb --- /dev/null +++ b/dexart-release/stable_baselines3/common/logger.py @@ -0,0 +1,644 @@ +import datetime +import json +import os +import sys +import tempfile +import warnings +from collections import defaultdict +from typing import Any, Dict, List, Optional, Sequence, TextIO, Tuple, Union + +import numpy as np +import pandas +import torch as th +from matplotlib import pyplot as plt + +try: + from torch.utils.tensorboard import SummaryWriter +except ImportError: + SummaryWriter = None + +DEBUG = 10 +INFO = 20 +WARN = 30 +ERROR = 40 +DISABLED = 50 + + +class Video: + """ + Video data class storing the video frames and the frame per seconds + + :param frames: frames to create the video from + :param fps: frames per second + """ + + def __init__(self, frames: th.Tensor, fps: Union[float, int]): + self.frames = frames + self.fps = fps + + +class Figure: + """ + Figure data class storing a matplotlib figure and whether to close the figure after logging it + + :param figure: figure to log + :param close: if true, close the figure after logging it + """ + + def __init__(self, figure: plt.figure, close: bool): + self.figure = figure + self.close = close + + +class Image: + """ + Image data class storing an image and data format + + :param image: image to log + :param dataformats: Image data format specification of the form NCHW, NHWC, CHW, HWC, HW, WH, etc. + More info in add_image method doc at https://pytorch.org/docs/stable/tensorboard.html + Gym envs normally use 'HWC' (channel last) + """ + + def __init__(self, image: Union[th.Tensor, np.ndarray, str], dataformats: str): + self.image = image + self.dataformats = dataformats + + +class FormatUnsupportedError(NotImplementedError): + """ + Custom error to display informative message when + a value is not supported by some formats. + + :param unsupported_formats: A sequence of unsupported formats, + for instance ``["stdout"]``. + :param value_description: Description of the value that cannot be logged by this format. + """ + + def __init__(self, unsupported_formats: Sequence[str], value_description: str): + if len(unsupported_formats) > 1: + format_str = f"formats {', '.join(unsupported_formats)} are" + else: + format_str = f"format {unsupported_formats[0]} is" + super().__init__( + f"The {format_str} not supported for the {value_description} value logged.\n" + f"You can exclude formats via the `exclude` parameter of the logger's `record` function." + ) + + +class KVWriter: + """ + Key Value writer + """ + + def write(self, key_values: Dict[str, Any], key_excluded: Dict[str, Union[str, Tuple[str, ...]]], + step: int = 0) -> None: + """ + Write a dictionary to file + + :param key_values: + :param key_excluded: + :param step: + """ + raise NotImplementedError + + def close(self) -> None: + """ + Close owned resources + """ + raise NotImplementedError + + +class SeqWriter: + """ + sequence writer + """ + + def write_sequence(self, sequence: List) -> None: + """ + write_sequence an array to file + + :param sequence: + """ + raise NotImplementedError + + +class HumanOutputFormat(KVWriter, SeqWriter): + """A human-readable output format producing ASCII tables of key-value pairs. + + Set attribute ``max_length`` to change the maximum length of keys and values + to write to output (or specify it when calling ``__init__``). + + :param filename_or_file: the file to write the log to + :param max_length: the maximum length of keys and values to write to output. + Outputs longer than this will be truncated. An error will be raised + if multiple keys are truncated to the same value. The maximum output + width will be ``2*max_length + 7``. The default of 36 produces output + no longer than 79 characters wide. + """ + + def __init__(self, filename_or_file: Union[str, TextIO], max_length: int = 36): + self.max_length = max_length + if isinstance(filename_or_file, str): + self.file = open(filename_or_file, "wt") + self.own_file = True + else: + assert hasattr(filename_or_file, "write"), f"Expected file or str, got {filename_or_file}" + self.file = filename_or_file + self.own_file = False + + def write(self, key_values: Dict, key_excluded: Dict, step: int = 0) -> None: + # Create strings for printing + key2str = {} + tag = None + for (key, value), (_, excluded) in zip(sorted(key_values.items()), sorted(key_excluded.items())): + + if excluded is not None and ("stdout" in excluded or "log" in excluded): + continue + + elif isinstance(value, Video): + raise FormatUnsupportedError(["stdout", "log"], "video") + + elif isinstance(value, Figure): + raise FormatUnsupportedError(["stdout", "log"], "figure") + + elif isinstance(value, Image): + raise FormatUnsupportedError(["stdout", "log"], "image") + + elif isinstance(value, float): + # Align left + value_str = f"{value:<8.3g}" + else: + value_str = str(value) + + if key.find("/") > 0: # Find tag and add it to the dict + tag = key[: key.find("/") + 1] + key2str[self._truncate(tag)] = "" + # Remove tag from key + if tag is not None and tag in key: + key = str(" " + key[len(tag):]) + + truncated_key = self._truncate(key) + if truncated_key in key2str: + raise ValueError( + f"Key '{key}' truncated to '{truncated_key}' that already exists. Consider increasing `max_length`." + ) + key2str[truncated_key] = self._truncate(value_str) + + # Find max widths + if len(key2str) == 0: + warnings.warn("Tried to write empty key-value dict") + return + else: + key_width = max(map(len, key2str.keys())) + val_width = max(map(len, key2str.values())) + + # Write out the data + dashes = "-" * (key_width + val_width + 7) + lines = [dashes] + for key, value in key2str.items(): + key_space = " " * (key_width - len(key)) + val_space = " " * (val_width - len(value)) + lines.append(f"| {key}{key_space} | {value}{val_space} |") + lines.append(dashes) + self.file.write("\n".join(lines) + "\n") + + # Flush the output to the file + self.file.flush() + + def _truncate(self, string: str) -> str: + if len(string) > self.max_length: + string = string[: self.max_length - 3] + "..." + return string + + def write_sequence(self, sequence: List) -> None: + sequence = list(sequence) + for i, elem in enumerate(sequence): + self.file.write(elem) + if i < len(sequence) - 1: # add space unless this is the last one + self.file.write(" ") + self.file.write("\n") + self.file.flush() + + def close(self) -> None: + """ + closes the file + """ + if self.own_file: + self.file.close() + + +def filter_excluded_keys( + key_values: Dict[str, Any], key_excluded: Dict[str, Union[str, Tuple[str, ...]]], _format: str +) -> Dict[str, Any]: + """ + Filters the keys specified by ``key_exclude`` for the specified format + + :param key_values: log dictionary to be filtered + :param key_excluded: keys to be excluded per format + :param _format: format for which this filter is run + :return: dict without the excluded keys + """ + + def is_excluded(key: str) -> bool: + return key in key_excluded and key_excluded[key] is not None and _format in key_excluded[key] + + return {key: value for key, value in key_values.items() if not is_excluded(key)} + + +class JSONOutputFormat(KVWriter): + """ + Log to a file, in the JSON format + + :param filename: the file to write the log to + """ + + def __init__(self, filename: str): + self.file = open(filename, "wt") + + def write(self, key_values: Dict[str, Any], key_excluded: Dict[str, Union[str, Tuple[str, ...]]], + step: int = 0) -> None: + def cast_to_json_serializable(value: Any): + if isinstance(value, Video): + raise FormatUnsupportedError(["json"], "video") + if isinstance(value, Figure): + raise FormatUnsupportedError(["json"], "figure") + if isinstance(value, Image): + raise FormatUnsupportedError(["json"], "image") + if hasattr(value, "dtype"): + if value.shape == () or len(value) == 1: + # if value is a dimensionless numpy array or of length 1, serialize as a float + return float(value) + else: + # otherwise, a value is a numpy array, serialize as a list or nested lists + return value.tolist() + return value + + key_values = { + key: cast_to_json_serializable(value) + for key, value in filter_excluded_keys(key_values, key_excluded, "json").items() + } + self.file.write(json.dumps(key_values) + "\n") + self.file.flush() + + def close(self) -> None: + """ + closes the file + """ + + self.file.close() + + +class CSVOutputFormat(KVWriter): + """ + Log to a file, in a CSV format + + :param filename: the file to write the log to + """ + + def __init__(self, filename: str): + self.file = open(filename, "w+t") + self.keys = [] + self.separator = "," + self.quotechar = '"' + + def write(self, key_values: Dict[str, Any], key_excluded: Dict[str, Union[str, Tuple[str, ...]]], + step: int = 0) -> None: + # Add our current row to the history + key_values = filter_excluded_keys(key_values, key_excluded, "csv") + extra_keys = key_values.keys() - self.keys + if extra_keys: + self.keys.extend(extra_keys) + self.file.seek(0) + lines = self.file.readlines() + self.file.seek(0) + for (i, key) in enumerate(self.keys): + if i > 0: + self.file.write(",") + self.file.write(key) + self.file.write("\n") + for line in lines[1:]: + self.file.write(line[:-1]) + self.file.write(self.separator * len(extra_keys)) + self.file.write("\n") + for i, key in enumerate(self.keys): + if i > 0: + self.file.write(",") + value = key_values.get(key) + + if isinstance(value, Video): + raise FormatUnsupportedError(["csv"], "video") + + elif isinstance(value, Figure): + raise FormatUnsupportedError(["csv"], "figure") + + elif isinstance(value, Image): + raise FormatUnsupportedError(["csv"], "image") + + elif isinstance(value, str): + # escape quotechars by prepending them with another quotechar + value = value.replace(self.quotechar, self.quotechar + self.quotechar) + + # additionally wrap text with quotechars so that any delimiters in the text are ignored by csv readers + self.file.write(self.quotechar + value + self.quotechar) + + elif value is not None: + self.file.write(str(value)) + self.file.write("\n") + self.file.flush() + + def close(self) -> None: + """ + closes the file + """ + self.file.close() + + +class TensorBoardOutputFormat(KVWriter): + """ + Dumps key/value pairs into TensorBoard's numeric format. + + :param folder: the folder to write the log to + """ + + def __init__(self, folder: str): + assert SummaryWriter is not None, "tensorboard is not installed, you can use " "pip install tensorboard to do so" + self.writer = SummaryWriter(log_dir=folder) + + def write(self, key_values: Dict[str, Any], key_excluded: Dict[str, Union[str, Tuple[str, ...]]], + step: int = 0) -> None: + for (key, value), (_, excluded) in zip(sorted(key_values.items()), sorted(key_excluded.items())): + + if excluded is not None and "tensorboard" in excluded: + continue + + if isinstance(value, np.ScalarType): + if isinstance(value, str): + # str is considered a np.ScalarType + self.writer.add_text(key, value, step) + else: + self.writer.add_scalar(key, value, step) + + if isinstance(value, th.Tensor): + self.writer.add_histogram(key, value, step) + + if isinstance(value, Video): + self.writer.add_video(key, value.frames, step, value.fps) + + if isinstance(value, Figure): + self.writer.add_figure(key, value.figure, step, close=value.close) + + if isinstance(value, Image): + self.writer.add_image(key, value.image, step, dataformats=value.dataformats) + + # Flush the output to the file + self.writer.flush() + + def close(self) -> None: + """ + closes the file + """ + if self.writer: + self.writer.close() + self.writer = None + + +def make_output_format(_format: str, log_dir: str, log_suffix: str = "") -> KVWriter: + """ + return a logger for the requested format + + :param _format: the requested format to log to ('stdout', 'log', 'json' or 'csv' or 'tensorboard') + :param log_dir: the logging directory + :param log_suffix: the suffix for the log file + :return: the logger + """ + os.makedirs(log_dir, exist_ok=True) + if _format == "stdout": + return HumanOutputFormat(sys.stdout) + elif _format == "log": + return HumanOutputFormat(os.path.join(log_dir, f"log{log_suffix}.txt")) + elif _format == "json": + return JSONOutputFormat(os.path.join(log_dir, f"progress{log_suffix}.json")) + elif _format == "csv": + return CSVOutputFormat(os.path.join(log_dir, f"progress{log_suffix}.csv")) + elif _format == "tensorboard": + return TensorBoardOutputFormat(log_dir) + else: + raise ValueError(f"Unknown format specified: {_format}") + + +# ================================================================ +# Backend +# ================================================================ + + +class Logger: + """ + The logger class. + + :param folder: the logging location + :param output_formats: the list of output formats + """ + + def __init__(self, folder: Optional[str], output_formats: List[KVWriter]): + self.name_to_value = defaultdict(float) # values this iteration + self.name_to_count = defaultdict(int) + self.name_to_excluded = defaultdict(str) + self.level = INFO + self.dir = folder + self.output_formats = output_formats + + def record(self, key: str, value: Any, exclude: Optional[Union[str, Tuple[str, ...]]] = None) -> None: + """ + Log a value of some diagnostic + Call this once for each diagnostic quantity, each iteration + If called many times, last value will be used. + + :param key: save to log this key + :param value: save to log this value + :param exclude: outputs to be excluded + """ + self.name_to_value[key] = value + self.name_to_excluded[key] = exclude + + def record_mean(self, key: str, value: Any, exclude: Optional[Union[str, Tuple[str, ...]]] = None) -> None: + """ + The same as record(), but if called many times, values averaged. + + :param key: save to log this key + :param value: save to log this value + :param exclude: outputs to be excluded + """ + if value is None: + self.name_to_value[key] = None + return + old_val, count = self.name_to_value[key], self.name_to_count[key] + self.name_to_value[key] = old_val * count / (count + 1) + value / (count + 1) + self.name_to_count[key] = count + 1 + self.name_to_excluded[key] = exclude + + def dump(self, step: int = 0) -> None: + """ + Write all of the diagnostics from the current iteration + """ + if self.level == DISABLED: + return + for _format in self.output_formats: + if isinstance(_format, KVWriter): + _format.write(self.name_to_value, self.name_to_excluded, step) + + self.name_to_value.clear() + self.name_to_count.clear() + self.name_to_excluded.clear() + + def log(self, *args, level: int = INFO) -> None: + """ + Write the sequence of args, with no separators, + to the console and output files (if you've configured an output file). + + level: int. (see logger.py docs) If the global logger level is higher than + the level argument here, don't print to stdout. + + :param args: log the arguments + :param level: the logging level (can be DEBUG=10, INFO=20, WARN=30, ERROR=40, DISABLED=50) + """ + if self.level <= level: + self._do_log(args) + + def debug(self, *args) -> None: + """ + Write the sequence of args, with no separators, + to the console and output files (if you've configured an output file). + Using the DEBUG level. + + :param args: log the arguments + """ + self.log(*args, level=DEBUG) + + def info(self, *args) -> None: + """ + Write the sequence of args, with no separators, + to the console and output files (if you've configured an output file). + Using the INFO level. + + :param args: log the arguments + """ + self.log(*args, level=INFO) + + def warn(self, *args) -> None: + """ + Write the sequence of args, with no separators, + to the console and output files (if you've configured an output file). + Using the WARN level. + + :param args: log the arguments + """ + self.log(*args, level=WARN) + + def error(self, *args) -> None: + """ + Write the sequence of args, with no separators, + to the console and output files (if you've configured an output file). + Using the ERROR level. + + :param args: log the arguments + """ + self.log(*args, level=ERROR) + + # Configuration + # ---------------------------------------- + def set_level(self, level: int) -> None: + """ + Set logging threshold on current logger. + + :param level: the logging level (can be DEBUG=10, INFO=20, WARN=30, ERROR=40, DISABLED=50) + """ + self.level = level + + def get_dir(self) -> str: + """ + Get directory that log files are being written to. + will be None if there is no output directory (i.e., if you didn't call start) + + :return: the logging directory + """ + return self.dir + + def close(self) -> None: + """ + closes the file + """ + for _format in self.output_formats: + _format.close() + + # Misc + # ---------------------------------------- + def _do_log(self, args) -> None: + """ + log to the requested format outputs + + :param args: the arguments to log + """ + for _format in self.output_formats: + if isinstance(_format, SeqWriter): + _format.write_sequence(map(str, args)) + + +def configure(folder: Optional[str] = None, format_strings: Optional[List[str]] = None) -> Logger: + """ + Configure the current logger. + + :param folder: the save location + (if None, $SB3_LOGDIR, if still None, tempdir/SB3-[date & time]) + :param format_strings: the output logging format + (if None, $SB3_LOG_FORMAT, if still None, ['stdout', 'log', 'csv']) + :return: The logger object. + """ + if folder is None: + folder = os.getenv("SB3_LOGDIR") + if folder is None: + folder = os.path.join(tempfile.gettempdir(), datetime.datetime.now().strftime("SB3-%Y-%m-%d-%H-%M-%S-%f")) + assert isinstance(folder, str) + os.makedirs(folder, exist_ok=True) + + log_suffix = "" + if format_strings is None: + format_strings = os.getenv("SB3_LOG_FORMAT", "stdout,log,csv").split(",") + + format_strings = list(filter(None, format_strings)) + output_formats = [make_output_format(f, folder, log_suffix) for f in format_strings] + + logger = Logger(folder=folder, output_formats=output_formats) + # Only print when some files will be saved + if len(format_strings) > 0 and format_strings != ["stdout"]: + logger.log(f"Logging to {folder}") + return logger + + +# ================================================================ +# Readers +# ================================================================ + + +def read_json(filename: str) -> pandas.DataFrame: + """ + read a json file using pandas + + :param filename: the file path to read + :return: the data in the json + """ + data = [] + with open(filename) as file_handler: + for line in file_handler: + data.append(json.loads(line)) + return pandas.DataFrame(data) + + +def read_csv(filename: str) -> pandas.DataFrame: + """ + read a csv file using pandas + + :param filename: the file path to read + :return: the data in the csv + """ + return pandas.read_csv(filename, index_col=None, comment="#") diff --git a/dexart-release/stable_baselines3/common/monitor.py b/dexart-release/stable_baselines3/common/monitor.py new file mode 100644 index 0000000000000000000000000000000000000000..a482b72be66f58f5f9776b74d7ce595b400ad6eb --- /dev/null +++ b/dexart-release/stable_baselines3/common/monitor.py @@ -0,0 +1,239 @@ +__all__ = ["Monitor", "ResultsWriter", "get_monitor_files", "load_results"] + +import csv +import json +import os +import time +from glob import glob +from typing import Dict, List, Optional, Tuple, Union + +import gym +import numpy as np +import pandas + +from stable_baselines3.common.type_aliases import GymObs, GymStepReturn + + +class Monitor(gym.Wrapper): + """ + A monitor wrapper for Gym environments, it is used to know the episode reward, length, time and other data. + + :param env: The environment + :param filename: the location to save a log file, can be None for no log + :param allow_early_resets: allows the reset of the environment before it is done + :param reset_keywords: extra keywords for the reset call, + if extra parameters are needed at reset + :param info_keywords: extra information to log, from the information return of env.step() + """ + + EXT = "monitor.csv" + + def __init__( + self, + env: gym.Env, + filename: Optional[str] = None, + allow_early_resets: bool = True, + reset_keywords: Tuple[str, ...] = (), + info_keywords: Tuple[str, ...] = (), + ): + super().__init__(env=env) + self.t_start = time.time() + if filename is not None: + self.results_writer = ResultsWriter( + filename, + header={"t_start": self.t_start, "env_id": env.spec and env.spec.id}, + extra_keys=reset_keywords + info_keywords, + ) + else: + self.results_writer = None + self.reset_keywords = reset_keywords + self.info_keywords = info_keywords + self.allow_early_resets = allow_early_resets + self.rewards = None + self.needs_reset = True + self.episode_returns = [] + self.episode_lengths = [] + self.episode_times = [] + self.total_steps = 0 + self.current_reset_info = {} # extra info about the current episode, that was passed in during reset() + + def reset(self, **kwargs) -> GymObs: + """ + Calls the Gym environment reset. Can only be called if the environment is over, or if allow_early_resets is True + + :param kwargs: Extra keywords saved for the next episode. only if defined by reset_keywords + :return: the first observation of the environment + """ + if not self.allow_early_resets and not self.needs_reset: + raise RuntimeError( + "Tried to reset an environment before done. If you want to allow early resets, " + "wrap your env with Monitor(env, path, allow_early_resets=True)" + ) + self.rewards = [] + self.needs_reset = False + for key in self.reset_keywords: + value = kwargs.get(key) + if value is None: + raise ValueError(f"Expected you to pass keyword argument {key} into reset") + self.current_reset_info[key] = value + return self.env.reset(**kwargs) + + def step(self, action: Union[np.ndarray, int]) -> GymStepReturn: + """ + Step the environment with the given action + + :param action: the action + :return: observation, reward, done, information + """ + if self.needs_reset: + raise RuntimeError("Tried to step environment that needs reset") + observation, reward, done, info = self.env.step(action) + self.rewards.append(reward) + if done: + self.needs_reset = True + ep_rew = sum(self.rewards) + ep_len = len(self.rewards) + ep_info = {"r": round(ep_rew, 6), "l": ep_len, "t": round(time.time() - self.t_start, 6)} + for key in self.info_keywords: + ep_info[key] = info[key] + self.episode_returns.append(ep_rew) + self.episode_lengths.append(ep_len) + self.episode_times.append(time.time() - self.t_start) + ep_info.update(self.current_reset_info) + if self.results_writer: + self.results_writer.write_row(ep_info) + info["episode"] = ep_info + self.total_steps += 1 + return observation, reward, done, info + + def close(self) -> None: + """ + Closes the environment + """ + super().close() + if self.results_writer is not None: + self.results_writer.close() + + def get_total_steps(self) -> int: + """ + Returns the total number of timesteps + + :return: + """ + return self.total_steps + + def get_episode_rewards(self) -> List[float]: + """ + Returns the rewards of all the episodes + + :return: + """ + return self.episode_returns + + def get_episode_lengths(self) -> List[int]: + """ + Returns the number of timesteps of all the episodes + + :return: + """ + return self.episode_lengths + + def get_episode_times(self) -> List[float]: + """ + Returns the runtime in seconds of all the episodes + + :return: + """ + return self.episode_times + + +class LoadMonitorResultsError(Exception): + """ + Raised when loading the monitor log fails. + """ + + pass + + +class ResultsWriter: + """ + A result writer that saves the data from the `Monitor` class + + :param filename: the location to save a log file, can be None for no log + :param header: the header dictionary object of the saved csv + :param reset_keywords: the extra information to log, typically is composed of + ``reset_keywords`` and ``info_keywords`` + """ + + def __init__( + self, + filename: str = "", + header: Optional[Dict[str, Union[float, str]]] = None, + extra_keys: Tuple[str, ...] = (), + ): + if header is None: + header = {} + if not filename.endswith(Monitor.EXT): + if os.path.isdir(filename): + filename = os.path.join(filename, Monitor.EXT) + else: + filename = filename + "." + Monitor.EXT + # Prevent newline issue on Windows, see GH issue #692 + self.file_handler = open(filename, "wt", newline="\n") + self.file_handler.write("#%s\n" % json.dumps(header)) + self.logger = csv.DictWriter(self.file_handler, fieldnames=("r", "l", "t") + extra_keys) + self.logger.writeheader() + self.file_handler.flush() + + def write_row(self, epinfo: Dict[str, Union[float, int]]) -> None: + """ + Close the file handler + + :param epinfo: the information on episodic return, length, and time + """ + if self.logger: + self.logger.writerow(epinfo) + self.file_handler.flush() + + def close(self) -> None: + """ + Close the file handler + """ + self.file_handler.close() + + +def get_monitor_files(path: str) -> List[str]: + """ + get all the monitor files in the given path + + :param path: the logging folder + :return: the log files + """ + return glob(os.path.join(path, "*" + Monitor.EXT)) + + +def load_results(path: str) -> pandas.DataFrame: + """ + Load all Monitor logs from a given directory path matching ``*monitor.csv`` + + :param path: the directory path containing the log file(s) + :return: the logged data + """ + monitor_files = get_monitor_files(path) + if len(monitor_files) == 0: + raise LoadMonitorResultsError(f"No monitor files of the form *{Monitor.EXT} found in {path}") + data_frames, headers = [], [] + for file_name in monitor_files: + with open(file_name) as file_handler: + first_line = file_handler.readline() + assert first_line[0] == "#" + header = json.loads(first_line[1:]) + data_frame = pandas.read_csv(file_handler, index_col=None) + headers.append(header) + data_frame["t"] += header["t_start"] + data_frames.append(data_frame) + data_frame = pandas.concat(data_frames) + data_frame.sort_values("t", inplace=True) + data_frame.reset_index(inplace=True) + data_frame["t"] -= min(header["t_start"] for header in headers) + return data_frame diff --git a/dexart-release/stable_baselines3/common/noise.py b/dexart-release/stable_baselines3/common/noise.py new file mode 100644 index 0000000000000000000000000000000000000000..119ed362ec47363fcfe91343c8d823581b925e17 --- /dev/null +++ b/dexart-release/stable_baselines3/common/noise.py @@ -0,0 +1,167 @@ +import copy +from abc import ABC, abstractmethod +from typing import Iterable, List, Optional + +import numpy as np + + +class ActionNoise(ABC): + """ + The action noise base class + """ + + def __init__(self): + super().__init__() + + def reset(self) -> None: + """ + call end of episode reset for the noise + """ + pass + + @abstractmethod + def __call__(self) -> np.ndarray: + raise NotImplementedError() + + +class NormalActionNoise(ActionNoise): + """ + A Gaussian action noise + + :param mean: the mean value of the noise + :param sigma: the scale of the noise (std here) + """ + + def __init__(self, mean: np.ndarray, sigma: np.ndarray): + self._mu = mean + self._sigma = sigma + super().__init__() + + def __call__(self) -> np.ndarray: + return np.random.normal(self._mu, self._sigma) + + def __repr__(self) -> str: + return f"NormalActionNoise(mu={self._mu}, sigma={self._sigma})" + + +class OrnsteinUhlenbeckActionNoise(ActionNoise): + """ + An Ornstein Uhlenbeck action noise, this is designed to approximate Brownian motion with friction. + + Based on http://math.stackexchange.com/questions/1287634/implementing-ornstein-uhlenbeck-in-matlab + + :param mean: the mean of the noise + :param sigma: the scale of the noise + :param theta: the rate of mean reversion + :param dt: the timestep for the noise + :param initial_noise: the initial value for the noise output, (if None: 0) + """ + + def __init__( + self, + mean: np.ndarray, + sigma: np.ndarray, + theta: float = 0.15, + dt: float = 1e-2, + initial_noise: Optional[np.ndarray] = None, + ): + self._theta = theta + self._mu = mean + self._sigma = sigma + self._dt = dt + self.initial_noise = initial_noise + self.noise_prev = np.zeros_like(self._mu) + self.reset() + super().__init__() + + def __call__(self) -> np.ndarray: + noise = ( + self.noise_prev + + self._theta * (self._mu - self.noise_prev) * self._dt + + self._sigma * np.sqrt(self._dt) * np.random.normal(size=self._mu.shape) + ) + self.noise_prev = noise + return noise + + def reset(self) -> None: + """ + reset the Ornstein Uhlenbeck noise, to the initial position + """ + self.noise_prev = self.initial_noise if self.initial_noise is not None else np.zeros_like(self._mu) + + def __repr__(self) -> str: + return f"OrnsteinUhlenbeckActionNoise(mu={self._mu}, sigma={self._sigma})" + + +class VectorizedActionNoise(ActionNoise): + """ + A Vectorized action noise for parallel environments. + + :param base_noise: ActionNoise The noise generator to use + :param n_envs: The number of parallel environments + """ + + def __init__(self, base_noise: ActionNoise, n_envs: int): + try: + self.n_envs = int(n_envs) + assert self.n_envs > 0 + except (TypeError, AssertionError): + raise ValueError(f"Expected n_envs={n_envs} to be positive integer greater than 0") + + self.base_noise = base_noise + self.noises = [copy.deepcopy(self.base_noise) for _ in range(n_envs)] + + def reset(self, indices: Optional[Iterable[int]] = None) -> None: + """ + Reset all the noise processes, or those listed in indices + + :param indices: Optional[Iterable[int]] The indices to reset. Default: None. + If the parameter is None, then all processes are reset to their initial position. + """ + if indices is None: + indices = range(len(self.noises)) + + for index in indices: + self.noises[index].reset() + + def __repr__(self) -> str: + return f"VecNoise(BaseNoise={repr(self.base_noise)}), n_envs={len(self.noises)})" + + def __call__(self) -> np.ndarray: + """ + Generate and stack the action noise from each noise object + """ + noise = np.stack([noise() for noise in self.noises]) + return noise + + @property + def base_noise(self) -> ActionNoise: + return self._base_noise + + @base_noise.setter + def base_noise(self, base_noise: ActionNoise) -> None: + if base_noise is None: + raise ValueError("Expected base_noise to be an instance of ActionNoise, not None", ActionNoise) + if not isinstance(base_noise, ActionNoise): + raise TypeError("Expected base_noise to be an instance of type ActionNoise", ActionNoise) + self._base_noise = base_noise + + @property + def noises(self) -> List[ActionNoise]: + return self._noises + + @noises.setter + def noises(self, noises: List[ActionNoise]) -> None: + noises = list(noises) # raises TypeError if not iterable + assert len(noises) == self.n_envs, f"Expected a list of {self.n_envs} ActionNoises, found {len(noises)}." + + different_types = [i for i, noise in enumerate(noises) if not isinstance(noise, type(self.base_noise))] + + if len(different_types): + raise ValueError( + f"Noise instances at indices {different_types} don't match the type of base_noise", type(self.base_noise) + ) + + self._noises = noises + for noise in noises: + noise.reset() diff --git a/dexart-release/stable_baselines3/common/on_policy_algorithm.py b/dexart-release/stable_baselines3/common/on_policy_algorithm.py new file mode 100644 index 0000000000000000000000000000000000000000..8b870fc2af6b607af01c04b6dad067fc0d67fbe3 --- /dev/null +++ b/dexart-release/stable_baselines3/common/on_policy_algorithm.py @@ -0,0 +1,320 @@ +import time +from typing import Any, Dict, List, Optional, Tuple, Type, Union + +import gym +import numpy as np +import torch as th + +from stable_baselines3.common.base_class import BaseAlgorithm +from stable_baselines3.common.buffers import DictRolloutBuffer, RolloutBuffer +from stable_baselines3.common.callbacks import BaseCallback +from stable_baselines3.common.policies import ActorCriticPolicy +from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, Schedule +from stable_baselines3.common.utils import obs_as_tensor, safe_mean +from stable_baselines3.common.vec_env import VecEnv + + +class OnPolicyAlgorithm(BaseAlgorithm): + """ + The base for On-Policy algorithms (ex: A2C/PPO). + + :param policy: The policy model to use (MlpPolicy, CnnPolicy, ...) + :param env: The environment to learn from (if registered in Gym, can be str) + :param learning_rate: The learning rate, it can be a function + of the current progress remaining (from 1 to 0) + :param n_steps: The number of steps to run for each environment per update + (i.e. batch size is n_steps * n_env where n_env is number of environment copies running in parallel) + :param gamma: Discount factor + :param gae_lambda: Factor for trade-off of bias vs variance for Generalized Advantage Estimator. + Equivalent to classic advantage when set to 1. + :param ent_coef: Entropy coefficient for the loss calculation + :param vf_coef: Value function coefficient for the loss calculation + :param max_grad_norm: The maximum value for the gradient clipping + :param use_sde: Whether to use generalized State Dependent Exploration (gSDE) + instead of action noise exploration (default: False) + :param sde_sample_freq: Sample a new noise matrix every n steps when using gSDE + Default: -1 (only sample at the beginning of the rollout) + :param tensorboard_log: the log location for tensorboard (if None, no logging) + :param create_eval_env: Whether to create a second environment that will be + used for evaluating the agent periodically. (Only available when passing string for the environment) + :param monitor_wrapper: When creating an environment, whether to wrap it + or not in a Monitor wrapper. + :param policy_kwargs: additional arguments to be passed to the policy on creation + :param verbose: the verbosity level: 0 no output, 1 info, 2 debug + :param seed: Seed for the pseudo random generators + :param device: Device (cpu, cuda, ...) on which the code should be run. + Setting it to auto, the code will be run on the GPU if possible. + :param _init_setup_model: Whether or not to build the network at the creation of the instance + :param supported_action_spaces: The action spaces supported by the algorithm. + """ + + def __init__( + self, + policy: Union[str, Type[ActorCriticPolicy]], + env: Union[GymEnv, str], + learning_rate: Union[float, Schedule], + n_steps: int, + gamma: float, + gae_lambda: float, + ent_coef: float, + vf_coef: float, + max_grad_norm: float, + use_sde: bool, + sde_sample_freq: int, + tensorboard_log: Optional[str] = None, + create_eval_env: bool = False, + monitor_wrapper: bool = True, + policy_kwargs: Optional[Dict[str, Any]] = None, + verbose: int = 0, + seed: Optional[int] = None, + device: Union[th.device, str] = "auto", + _init_setup_model: bool = True, + supported_action_spaces: Optional[Tuple[gym.spaces.Space, ...]] = None, + ): + + super().__init__( + policy=policy, + env=env, + learning_rate=learning_rate, + policy_kwargs=policy_kwargs, + verbose=verbose, + device=device, + use_sde=use_sde, + sde_sample_freq=sde_sample_freq, + create_eval_env=create_eval_env, + support_multi_env=True, + seed=seed, + tensorboard_log=tensorboard_log, + supported_action_spaces=supported_action_spaces, + ) + + self.n_steps = n_steps + self.gamma = gamma + self.gae_lambda = gae_lambda + self.ent_coef = ent_coef + self.vf_coef = vf_coef + self.max_grad_norm = max_grad_norm + self.rollout_buffer = None + + self.last_rollout_reward = -np.inf + self.need_restore = False + self.last_policy_saved: List[Dict] = [{}, {}] + self.current_restore_step = 0 + + if _init_setup_model: + self._setup_model() + + def _setup_model(self) -> None: + self._setup_lr_schedule() + self.set_random_seed(self.seed) + + buffer_cls = DictRolloutBuffer if isinstance(self.observation_space, gym.spaces.Dict) else RolloutBuffer + + self.rollout_buffer = buffer_cls( + self.n_steps, + self.observation_space, + self.action_space, + device=self.device, + gamma=self.gamma, + gae_lambda=self.gae_lambda, + n_envs=self.n_envs, + ) + self.policy = self.policy_class( # pytype:disable=not-instantiable + self.observation_space, + self.action_space, + self.lr_schedule, + use_sde=self.use_sde, + **self.policy_kwargs # pytype:disable=not-instantiable + ) + self.policy = self.policy.to(self.device) + + def collect_rollouts( + self, + env: VecEnv, + callback: BaseCallback, + rollout_buffer: RolloutBuffer, + n_rollout_steps: int, + ) -> bool: + """ + Collect experiences using the current policy and fill a ``RolloutBuffer``. + The term rollout here refers to the model-free notion and should not + be used with the concept of rollout used in model-based RL or planning. + + :param env: The training environment + :param callback: Callback that will be called at each step + (and at the beginning and end of the rollout) + :param rollout_buffer: Buffer to fill with rollouts + :param n_steps: Number of experiences to collect per environment + :return: True if function returned with at least `n_rollout_steps` + collected, False if callback terminated rollout prematurely. + """ + assert self._last_obs is not None, "No previous observation was provided" + # Switch to eval mode (this affects batch norm / dropout) + self.policy.set_training_mode(False) + last_episode_reward = self.last_rollout_reward + self.last_rollout_reward = 0 + num_rollouts = 0 + n_steps = 0 + rollout_buffer.reset() + # Sample new weights for the state dependent exploration + if self.use_sde: + self.policy.reset_noise(env.num_envs) + callback.on_rollout_start() + while n_steps < n_rollout_steps: + if self.use_sde and self.sde_sample_freq > 0 and n_steps % self.sde_sample_freq == 0: + # Sample a new noise matrix + self.policy.reset_noise(env.num_envs) + + with th.no_grad(): + # Convert to pytorch tensor or to TensorDict + obs_tensor = obs_as_tensor(self._last_obs, self.device) + actions, values, log_probs = self.policy(obs_tensor) + actions = actions.cpu().numpy() + + # Rescale and perform action + clipped_actions = actions + # Clip the actions to avoid out of bound error + if isinstance(self.action_space, gym.spaces.Box): + clipped_actions = np.clip(actions, self.action_space.low, self.action_space.high) + + new_obs, rewards, dones, infos = env.step(clipped_actions) + + self.num_timesteps += env.num_envs + + # Give access to local variables + callback.update_locals(locals()) + if callback.on_step() is False: + return False + + self._update_info_buffer(infos) + n_steps += 1 + + if isinstance(self.action_space, gym.spaces.Discrete): + # Reshape in case of discrete action + actions = actions.reshape(-1, 1) + + # Handle timeout by bootstraping with value function + # see GitHub issue #633 + for idx, done in enumerate(dones): + if ( + done + and infos[idx].get("terminal_observation") is not None + and infos[idx].get("TimeLimit.truncated", False) + ): + terminal_obs = self.policy.obs_to_tensor(infos[idx]["terminal_observation"])[0] + with th.no_grad(): + terminal_value = self.policy.predict_values(terminal_obs)[0] + rewards[idx] += self.gamma * terminal_value + + if done: + num_rollouts += 1 + + self.last_rollout_reward += rewards.sum() + + rollout_buffer.add(self._last_obs, actions, rewards, self._last_episode_starts, values, log_probs) + self._last_obs = new_obs + self._last_episode_starts = dones + with th.no_grad(): + # Compute value for the last timestep + values = self.policy.predict_values(obs_as_tensor(new_obs, self.device)) + + rollout_buffer.compute_returns_and_advantage(last_values=values, dones=dones) + + self.last_rollout_reward /= num_rollouts + reward_gap = last_episode_reward - self.last_rollout_reward + self.need_restore = False + self.current_restore_step = 0 + + callback.on_rollout_end() + + return True + + def train(self) -> None: + """ + Consume current rollout data and update policy parameters. + Implemented by individual algorithms. + """ + raise NotImplementedError + + def learn( + self, + total_timesteps: int, + callback: MaybeCallback = None, + log_interval: int = 1, + eval_env: Optional[GymEnv] = None, + eval_freq: int = -1, + n_eval_episodes: int = 5, + tb_log_name: str = "OnPolicyAlgorithm", + eval_log_path: Optional[str] = None, + reset_num_timesteps: bool = True, + iter_start=0, + ) -> "OnPolicyAlgorithm": + iteration = iter_start + + total_timesteps, callback = self._setup_learn( + total_timesteps, eval_env, callback, eval_freq, n_eval_episodes, eval_log_path, reset_num_timesteps, + tb_log_name + ) + + callback.on_training_start(locals(), globals()) + + while self.num_timesteps < total_timesteps: + + x = time.time() + continue_training = self.collect_rollouts(self.env, callback, self.rollout_buffer, + n_rollout_steps=self.n_steps) + print("Rollout time:", time.time() - x) + + if continue_training is False: + break + + if self.need_restore and self.current_restore_step < 5: + print(f"Large performance drop detected. Restore previous model.") + self.set_parameters(self.last_policy_saved[0], exact_match=True, device=self.device) + continue + else: + self.last_policy_saved[0] = self.last_policy_saved[1] + self.last_policy_saved[1] = self.get_parameters() + + iteration += 1 + self._update_current_progress_remaining(self.num_timesteps, total_timesteps) + + # Display training infos + + if log_interval is not None and iteration % log_interval == 0: + fps = int((self.num_timesteps - self._num_timesteps_at_start) / (time.time() - self.start_time)) + self.logger.record("time/iterations", iteration, exclude="wandb") + self.logger.record("rollout/rollout_rew_mean", self.last_rollout_reward) + if len(self.ep_info_buffer) > 0 and len(self.ep_info_buffer[0]) > 0: + self.logger.record("rollout/ep_rew_mean", + safe_mean([ep_info["r"] for ep_info in self.ep_info_buffer])) + self.logger.record("rollout/ep_len_mean", + safe_mean([ep_info["l"] for ep_info in self.ep_info_buffer])) + self.logger.record("time/fps", fps) + self.logger.record("time/time_elapsed", int(time.time() - self.start_time), exclude="tensorboard") + self.logger.record("time/total_timesteps", self.num_timesteps, exclude="tensorboard") + self.logger.dump(step=iteration) + + x = time.time() + self.train() + print("Train time:", time.time() - x) + + callback.on_training_end() + + return self + + def _get_torch_save_params(self) -> Tuple[List[str], List[str]]: + state_dicts = ["policy", "policy.optimizer"] + + return state_dicts, [] + + def _excluded_save_params(self) -> List[str]: + """ + Returns the names of the parameters that should be excluded from being + saved by pickling. E.g. replay buffers are skipped by default + as they take up a lot of space. PyTorch variables should be excluded + with this so they can be stored with ``th.save``. + + :return: List of parameters that should be excluded from being saved with pickle. + """ + return super()._excluded_save_params() + ["last_policy_saved"] diff --git a/dexart-release/stable_baselines3/common/policies.py b/dexart-release/stable_baselines3/common/policies.py new file mode 100644 index 0000000000000000000000000000000000000000..051adc2ed19eab9e91a3fac2c3162e2f6699d16e --- /dev/null +++ b/dexart-release/stable_baselines3/common/policies.py @@ -0,0 +1,898 @@ +"""Policies: abstract base class and concrete implementations.""" + +import collections +import copy +import warnings +from abc import ABC, abstractmethod +from functools import partial +from typing import Any, Dict, List, Optional, Tuple, Type, Union + +import gym +import numpy as np +import torch as th +from torch import nn + +from stable_baselines3.common.distributions import ( + BernoulliDistribution, + CategoricalDistribution, + DiagGaussianDistribution, + Distribution, + MultiCategoricalDistribution, + StateDependentNoiseDistribution, + make_proba_distribution, +) +from stable_baselines3.common.preprocessing import get_action_dim, is_image_space, maybe_transpose, preprocess_obs +from stable_baselines3.common.torch_layers import ( + BaseFeaturesExtractor, + CombinedExtractor, + FlattenExtractor, + MlpExtractor, + NatureCNN, + create_mlp, +) +from stable_baselines3.common.type_aliases import Schedule +from stable_baselines3.common.utils import get_device, is_vectorized_observation, obs_as_tensor + + +class BaseModel(nn.Module, ABC): + """ + The base model object: makes predictions in response to observations. + + In the case of policies, the prediction is an action. In the case of critics, it is the + estimated value of the observation. + + :param observation_space: The observation space of the environment + :param action_space: The action space of the environment + :param features_extractor_class: Features extractor to use. + :param features_extractor_kwargs: Keyword arguments + to pass to the features extractor. + :param features_extractor: Network to extract features + (a CNN when using images, a nn.Flatten() layer otherwise) + :param normalize_images: Whether to normalize images or not, + dividing by 255.0 (True by default) + :param optimizer_class: The optimizer to use, + ``th.optim.Adam`` by default + :param optimizer_kwargs: Additional keyword arguments, + excluding the learning rate, to pass to the optimizer + """ + + def __init__( + self, + observation_space: gym.spaces.Space, + action_space: gym.spaces.Space, + features_extractor_class: Type[BaseFeaturesExtractor] = FlattenExtractor, + features_extractor_kwargs: Optional[Dict[str, Any]] = None, + features_extractor: Optional[nn.Module] = None, + normalize_images: bool = True, + optimizer_class: Type[th.optim.Optimizer] = th.optim.Adam, + optimizer_kwargs: Optional[Dict[str, Any]] = None, + ): + super().__init__() + + if optimizer_kwargs is None: + optimizer_kwargs = {} + + if features_extractor_kwargs is None: + features_extractor_kwargs = {} + + self.observation_space = observation_space + self.action_space = action_space + self.features_extractor = features_extractor + self.normalize_images = normalize_images + + self.optimizer_class = optimizer_class + self.optimizer_kwargs = optimizer_kwargs + self.optimizer = None # type: Optional[th.optim.Optimizer] + + self.features_extractor_class = features_extractor_class + self.features_extractor_kwargs = features_extractor_kwargs + + @abstractmethod + def forward(self, *args, **kwargs): + pass + + def _update_features_extractor( + self, + net_kwargs: Dict[str, Any], + features_extractor: Optional[BaseFeaturesExtractor] = None, + ) -> Dict[str, Any]: + """ + Update the network keyword arguments and create a new features extractor object if needed. + If a ``features_extractor`` object is passed, then it will be shared. + + :param net_kwargs: the base network keyword arguments, without the ones + related to features extractor + :param features_extractor: a features extractor object. + If None, a new object will be created. + :return: The updated keyword arguments + """ + net_kwargs = net_kwargs.copy() + if features_extractor is None: + # The features extractor is not shared, create a new one + features_extractor = self.make_features_extractor() + net_kwargs.update(dict(features_extractor=features_extractor, features_dim=features_extractor.features_dim)) + return net_kwargs + + def make_features_extractor(self) -> BaseFeaturesExtractor: + """Helper method to create a features extractor.""" + return self.features_extractor_class(self.observation_space, **self.features_extractor_kwargs) + + def extract_features(self, obs: th.Tensor) -> th.Tensor: + """ + Preprocess the observation if needed and extract features. + + :param obs: + :return: + """ + assert self.features_extractor is not None, "No features extractor was set" + preprocessed_obs = preprocess_obs(obs, self.observation_space, normalize_images=self.normalize_images) + return self.features_extractor(preprocessed_obs) + + def _get_constructor_parameters(self) -> Dict[str, Any]: + """ + Get data that need to be saved in order to re-create the model when loading it from disk. + + :return: The dictionary to pass to the as kwargs constructor when reconstruction this model. + """ + return dict( + observation_space=self.observation_space, + action_space=self.action_space, + # Passed to the constructor by child class + # squash_output=self.squash_output, + # features_extractor=self.features_extractor + normalize_images=self.normalize_images, + ) + + @property + def device(self) -> th.device: + """Infer which device this policy lives on by inspecting its parameters. + If it has no parameters, the 'cpu' device is used as a fallback. + + :return:""" + for param in self.parameters(): + return param.device + return get_device("cpu") + + def save(self, path: str) -> None: + """ + Save model to a given location. + + :param path: + """ + th.save({"state_dict": self.state_dict(), "data": self._get_constructor_parameters()}, path) + + @classmethod + def load(cls, path: str, device: Union[th.device, str] = "auto") -> "BaseModel": + """ + Load model from path. + + :param path: + :param device: Device on which the policy should be loaded. + :return: + """ + device = get_device(device) + saved_variables = th.load(path, map_location=device) + + # Allow to load policy saved with older version of SB3 + if "sde_net_arch" in saved_variables["data"]: + warnings.warn( + "sde_net_arch is deprecated, please downgrade to SB3 v1.2.0 if you need such parameter.", + DeprecationWarning, + ) + del saved_variables["data"]["sde_net_arch"] + + # Create policy object + model = cls(**saved_variables["data"]) # pytype: disable=not-instantiable + # Load weights + model.load_state_dict(saved_variables["state_dict"]) + model.to(device) + return model + + def load_from_vector(self, vector: np.ndarray) -> None: + """ + Load parameters from a 1D vector. + + :param vector: + """ + th.nn.utils.vector_to_parameters(th.FloatTensor(vector).to(self.device), self.parameters()) + + def parameters_to_vector(self) -> np.ndarray: + """ + Convert the parameters to a 1D vector. + + :return: + """ + return th.nn.utils.parameters_to_vector(self.parameters()).detach().cpu().numpy() + + def set_training_mode(self, mode: bool) -> None: + """ + Put the policy in either training or evaluation mode. + + This affects certain modules, such as batch normalisation and dropout. + + :param mode: if true, set to training mode, else set to evaluation mode + """ + self.train(mode) + + def obs_to_tensor(self, observation: Union[np.ndarray, Dict[str, np.ndarray]]) -> Tuple[th.Tensor, bool]: + """ + Convert an input observation to a PyTorch tensor that can be fed to a model. + Includes sugar-coating to handle different observations (e.g. normalizing images). + + :param observation: the input observation + :return: The observation as PyTorch tensor + and whether the observation is vectorized or not + """ + vectorized_env = False + if isinstance(observation, dict): + # need to copy the dict as the dict in VecFrameStack will become a torch tensor + observation = copy.deepcopy(observation) + for key, obs in observation.items(): + if not self.observation_space.contains(key): + continue + obs_space = self.observation_space.spaces[key] + if is_image_space(obs_space): + obs_ = maybe_transpose(obs, obs_space) + else: + obs_ = np.array(obs) + vectorized_env = vectorized_env or is_vectorized_observation(obs_, obs_space) + # Add batch dimension if needed + observation[key] = obs_.reshape((-1,) + self.observation_space[key].shape) + + elif is_image_space(self.observation_space): + # Handle the different cases for images + # as PyTorch use channel first format + observation = maybe_transpose(observation, self.observation_space) + + else: + observation = np.array(observation) + + if not isinstance(observation, dict): + # Dict obs need to be handled separately + vectorized_env = is_vectorized_observation(observation, self.observation_space) + # Add batch dimension if needed + observation = observation.reshape((-1,) + self.observation_space.shape) + + observation = obs_as_tensor(observation, self.device) + return observation, vectorized_env + + +class BasePolicy(BaseModel): + """The base policy object. + + Parameters are mostly the same as `BaseModel`; additions are documented below. + + :param args: positional arguments passed through to `BaseModel`. + :param kwargs: keyword arguments passed through to `BaseModel`. + :param squash_output: For continuous actions, whether the output is squashed + or not using a ``tanh()`` function. + """ + + def __init__(self, *args, squash_output: bool = False, **kwargs): + super().__init__(*args, **kwargs) + self._squash_output = squash_output + + @staticmethod + def _dummy_schedule(progress_remaining: float) -> float: + """(float) Useful for pickling policy.""" + del progress_remaining + return 0.0 + + @property + def squash_output(self) -> bool: + """(bool) Getter for squash_output.""" + return self._squash_output + + @staticmethod + def init_weights(module: nn.Module, gain: float = 1) -> None: + """ + Orthogonal initialization (used in PPO and A2C) + """ + if isinstance(module, (nn.Linear, nn.Conv2d)): + nn.init.orthogonal_(module.weight, gain=gain) + if module.bias is not None: + module.bias.data.fill_(0.0) + + @abstractmethod + def _predict(self, observation: th.Tensor, deterministic: bool = False) -> th.Tensor: + """ + Get the action according to the policy for a given observation. + + By default provides a dummy implementation -- not all BasePolicy classes + implement this, e.g. if they are a Critic in an Actor-Critic method. + + :param observation: + :param deterministic: Whether to use stochastic or deterministic actions + :return: Taken action according to the policy + """ + + def predict( + self, + observation: Union[np.ndarray, Dict[str, np.ndarray]], + state: Optional[Tuple[np.ndarray, ...]] = None, + episode_start: Optional[np.ndarray] = None, + deterministic: bool = False, + ) -> Tuple[np.ndarray, Optional[Tuple[np.ndarray, ...]]]: + """ + Get the policy action from an observation (and optional hidden state). + Includes sugar-coating to handle different observations (e.g. normalizing images). + + :param observation: the input observation + :param state: The last hidden states (can be None, used in recurrent policies) + :param episode_start: The last masks (can be None, used in recurrent policies) + this correspond to beginning of episodes, + where the hidden states of the RNN must be reset. + :param deterministic: Whether or not to return deterministic actions. + :return: the model's action and the next hidden state + (used in recurrent policies) + """ + # TODO (GH/1): add support for RNN policies + # if state is None: + # state = self.initial_state + # if episode_start is None: + # episode_start = [False for _ in range(self.n_envs)] + # Switch to eval mode (this affects batch norm / dropout) + self.set_training_mode(False) + + observation, vectorized_env = self.obs_to_tensor(observation) + + with th.no_grad(): + actions = self._predict(observation, deterministic=deterministic) + # Convert to numpy + actions = actions.cpu().numpy() + + if isinstance(self.action_space, gym.spaces.Box): + if self.squash_output: + # Rescale to proper domain when using squashing + actions = self.unscale_action(actions) + else: + # Actions could be on arbitrary scale, so clip the actions to avoid + # out of bound error (e.g. if sampling from a Gaussian distribution) + actions = np.clip(actions, self.action_space.low, self.action_space.high) + + # Remove batch dimension if needed + if not vectorized_env: + actions = actions[0] + + return actions, state + + def scale_action(self, action: np.ndarray) -> np.ndarray: + """ + Rescale the action from [low, high] to [-1, 1] + (no need for symmetric action space) + + :param action: Action to scale + :return: Scaled action + """ + low, high = self.action_space.low, self.action_space.high + return 2.0 * ((action - low) / (high - low)) - 1.0 + + def unscale_action(self, scaled_action: np.ndarray) -> np.ndarray: + """ + Rescale the action from [-1, 1] to [low, high] + (no need for symmetric action space) + + :param scaled_action: Action to un-scale + """ + low, high = self.action_space.low, self.action_space.high + return low + (0.5 * (scaled_action + 1.0) * (high - low)) + + +class ActorCriticPolicy(BasePolicy): + """ + Policy class for actor-critic algorithms (has both policy and value prediction). + Used by A2C, PPO and the likes. + + :param observation_space: Observation space + :param action_space: Action space + :param lr_schedule: Learning rate schedule (could be constant) + :param net_arch: The specification of the policy and value networks. + :param activation_fn: Activation function + :param ortho_init: Whether to use or not orthogonal initialization + :param use_sde: Whether to use State Dependent Exploration or not + :param log_std_init: Initial value for the log standard deviation + :param full_std: Whether to use (n_features x n_actions) parameters + for the std instead of only (n_features,) when using gSDE + :param sde_net_arch: Network architecture for extracting features + when using gSDE. If None, the latent features from the policy will be used. + Pass an empty list to use the states as features. + :param use_expln: Use ``expln()`` function instead of ``exp()`` to ensure + a positive standard deviation (cf paper). It allows to keep variance + above zero and prevent it from growing too fast. In practice, ``exp()`` is usually enough. + :param squash_output: Whether to squash the output using a tanh function, + this allows to ensure boundaries when using gSDE. + :param features_extractor_class: Features extractor to use. + :param features_extractor_kwargs: Keyword arguments + to pass to the features extractor. + :param normalize_images: Whether to normalize images or not, + dividing by 255.0 (True by default) + :param optimizer_class: The optimizer to use, + ``th.optim.Adam`` by default + :param optimizer_kwargs: Additional keyword arguments, + excluding the learning rate, to pass to the optimizer + """ + + def __init__( + self, + observation_space: gym.spaces.Space, + action_space: gym.spaces.Space, + lr_schedule: Schedule, + net_arch: Optional[List[Union[int, Dict[str, List[int]]]]] = None, + activation_fn: Type[nn.Module] = nn.Tanh, + ortho_init: bool = True, + use_sde: bool = False, + log_std_init: float = 0.0, + full_std: bool = True, + sde_net_arch: Optional[List[int]] = None, + use_expln: bool = False, + squash_output: bool = False, + features_extractor_class: Type[BaseFeaturesExtractor] = FlattenExtractor, + features_extractor_kwargs: Optional[Dict[str, Any]] = None, + normalize_images: bool = True, + optimizer_class: Type[th.optim.Optimizer] = th.optim.Adam, + optimizer_kwargs: Optional[Dict[str, Any]] = None, + ): + + if optimizer_kwargs is None: + optimizer_kwargs = {} + # Small values to avoid NaN in Adam optimizer + if optimizer_class == th.optim.Adam: + optimizer_kwargs["eps"] = 1e-5 + + super().__init__( + observation_space, + action_space, + features_extractor_class, + features_extractor_kwargs, + optimizer_class=optimizer_class, + optimizer_kwargs=optimizer_kwargs, + squash_output=squash_output, + ) + + # Default network architecture, from stable-baselines + if net_arch is None: + if features_extractor_class == NatureCNN: + net_arch = [] + else: + net_arch = [dict(pi=[64, 64], vf=[64, 64])] + + self.net_arch = net_arch + self.activation_fn = activation_fn + self.ortho_init = ortho_init + + self.features_extractor = features_extractor_class(self.observation_space, **self.features_extractor_kwargs) + self.features_dim = self.features_extractor.features_dim + + self.normalize_images = normalize_images + self.log_std_init = log_std_init + dist_kwargs = None + # Keyword arguments for gSDE distribution + if use_sde: + dist_kwargs = { + "full_std": full_std, + "squash_output": squash_output, + "use_expln": use_expln, + "learn_features": False, + } + + if sde_net_arch is not None: + warnings.warn("sde_net_arch is deprecated and will be removed in SB3 v2.4.0.", DeprecationWarning) + + self.use_sde = use_sde + self.dist_kwargs = dist_kwargs + + # Action distribution + self.action_dist = make_proba_distribution(action_space, use_sde=use_sde, dist_kwargs=dist_kwargs) + + self._build(lr_schedule) + + def _get_constructor_parameters(self) -> Dict[str, Any]: + data = super()._get_constructor_parameters() + + default_none_kwargs = self.dist_kwargs or collections.defaultdict(lambda: None) + + data.update( + dict( + net_arch=self.net_arch, + activation_fn=self.activation_fn, + use_sde=self.use_sde, + log_std_init=self.log_std_init, + squash_output=default_none_kwargs["squash_output"], + full_std=default_none_kwargs["full_std"], + use_expln=default_none_kwargs["use_expln"], + lr_schedule=self._dummy_schedule, # dummy lr schedule, not needed for loading policy alone + ortho_init=self.ortho_init, + optimizer_class=self.optimizer_class, + optimizer_kwargs=self.optimizer_kwargs, + features_extractor_class=self.features_extractor_class, + features_extractor_kwargs=self.features_extractor_kwargs, + ) + ) + return data + + def reset_noise(self, n_envs: int = 1) -> None: + """ + Sample new weights for the exploration matrix. + + :param n_envs: + """ + assert isinstance(self.action_dist, StateDependentNoiseDistribution), "reset_noise() is only available when using gSDE" + self.action_dist.sample_weights(self.log_std, batch_size=n_envs) + + def _build_mlp_extractor(self) -> None: + """ + Create the policy and value networks. + Part of the layers can be shared. + """ + # Note: If net_arch is None and some features extractor is used, + # net_arch here is an empty list and mlp_extractor does not + # really contain any layers (acts like an identity module). + self.mlp_extractor = MlpExtractor( + self.features_dim, + net_arch=self.net_arch, + activation_fn=self.activation_fn, + device=self.device, + ) + + def _build(self, lr_schedule: Schedule) -> None: + """ + Create the networks and the optimizer. + + :param lr_schedule: Learning rate schedule + lr_schedule(1) is the initial learning rate + """ + self._build_mlp_extractor() + + latent_dim_pi = self.mlp_extractor.latent_dim_pi + + if isinstance(self.action_dist, DiagGaussianDistribution): + self.action_net, self.log_std = self.action_dist.proba_distribution_net( + latent_dim=latent_dim_pi, log_std_init=self.log_std_init + ) + elif isinstance(self.action_dist, StateDependentNoiseDistribution): + self.action_net, self.log_std = self.action_dist.proba_distribution_net( + latent_dim=latent_dim_pi, latent_sde_dim=latent_dim_pi, log_std_init=self.log_std_init + ) + elif isinstance(self.action_dist, (CategoricalDistribution, MultiCategoricalDistribution, BernoulliDistribution)): + self.action_net = self.action_dist.proba_distribution_net(latent_dim=latent_dim_pi) + else: + raise NotImplementedError(f"Unsupported distribution '{self.action_dist}'.") + + self.value_net = nn.Linear(self.mlp_extractor.latent_dim_vf, 1) + # Init weights: use orthogonal initialization + # with small initial weight for the output + if self.ortho_init: # TODO: Helin + # TODO: check for features_extractor + # Values from stable-baselines. + # features_extractor/mlp values are + # originally from openai/baselines (default gains/init_scales). + module_gains = { + self.features_extractor: np.sqrt(2), + self.mlp_extractor: np.sqrt(2), + self.action_net: 0.01, + self.value_net: 1, + } + for module, gain in module_gains.items(): + module.apply(partial(self.init_weights, gain=gain)) + + # Setup optimizer with initial learning rate + self.optimizer = self.optimizer_class(filter(lambda p: p.requires_grad, self.parameters()), lr=lr_schedule(1), **self.optimizer_kwargs) + + def forward(self, obs: th.Tensor, deterministic: bool = False) -> Tuple[th.Tensor, th.Tensor, th.Tensor]: + """ + Forward pass in all the networks (actor and critic) + + :param obs: Observation + :param deterministic: Whether to sample or use deterministic actions + :return: action, value and log probability of the action + """ + # Preprocess the observation if needed + features = self.extract_features(obs) + latent_pi, latent_vf = self.mlp_extractor(features) + # Evaluate the values for the given observations + values = self.value_net(latent_vf) + distribution = self._get_action_dist_from_latent(latent_pi) + actions = distribution.get_actions(deterministic=deterministic) + log_prob = distribution.log_prob(actions) + return actions, values, log_prob + + def _get_action_dist_from_latent(self, latent_pi: th.Tensor) -> Distribution: + """ + Retrieve action distribution given the latent codes. + + :param latent_pi: Latent code for the actor + :return: Action distribution + """ + mean_actions = self.action_net(latent_pi) + + if isinstance(self.action_dist, DiagGaussianDistribution): + return self.action_dist.proba_distribution(mean_actions, self.log_std) + elif isinstance(self.action_dist, CategoricalDistribution): + # Here mean_actions are the logits before the softmax + return self.action_dist.proba_distribution(action_logits=mean_actions) + elif isinstance(self.action_dist, MultiCategoricalDistribution): + # Here mean_actions are the flattened logits + return self.action_dist.proba_distribution(action_logits=mean_actions) + elif isinstance(self.action_dist, BernoulliDistribution): + # Here mean_actions are the logits (before rounding to get the binary actions) + return self.action_dist.proba_distribution(action_logits=mean_actions) + elif isinstance(self.action_dist, StateDependentNoiseDistribution): + return self.action_dist.proba_distribution(mean_actions, self.log_std, latent_pi) + else: + raise ValueError("Invalid action distribution") + + def _predict(self, observation: th.Tensor, deterministic: bool = False) -> th.Tensor: + """ + Get the action according to the policy for a given observation. + + :param observation: + :param deterministic: Whether to use stochastic or deterministic actions + :return: Taken action according to the policy + """ + return self.get_distribution(observation).get_actions(deterministic=deterministic) + + def evaluate_actions(self, obs: th.Tensor, actions: th.Tensor) -> Tuple[th.Tensor, th.Tensor, th.Tensor]: + """ + Evaluate actions according to the current policy, + given the observations. + + :param obs: + :param actions: + :return: estimated value, log likelihood of taking those actions + and entropy of the action distribution. + """ + # Preprocess the observation if needed + features = self.extract_features(obs) + latent_pi, latent_vf = self.mlp_extractor(features) + distribution = self._get_action_dist_from_latent(latent_pi) + log_prob = distribution.log_prob(actions) + values = self.value_net(latent_vf) + return values, log_prob, distribution.entropy() + + def get_distribution(self, obs: th.Tensor) -> Distribution: + """ + Get the current policy distribution given the observations. + + :param obs: + :return: the action distribution. + """ + features = self.extract_features(obs) + latent_pi = self.mlp_extractor.forward_actor(features) + return self._get_action_dist_from_latent(latent_pi) + + def predict_values(self, obs: th.Tensor) -> th.Tensor: + """ + Get the estimated values according to the current policy given the observations. + + :param obs: + :return: the estimated values. + """ + features = self.extract_features(obs) + latent_vf = self.mlp_extractor.forward_critic(features) + return self.value_net(latent_vf) + + +class ActorCriticCnnPolicy(ActorCriticPolicy): + """ + CNN policy class for actor-critic algorithms (has both policy and value prediction). + Used by A2C, PPO and the likes. + + :param observation_space: Observation space + :param action_space: Action space + :param lr_schedule: Learning rate schedule (could be constant) + :param net_arch: The specification of the policy and value networks. + :param activation_fn: Activation function + :param ortho_init: Whether to use or not orthogonal initialization + :param use_sde: Whether to use State Dependent Exploration or not + :param log_std_init: Initial value for the log standard deviation + :param full_std: Whether to use (n_features x n_actions) parameters + for the std instead of only (n_features,) when using gSDE + :param sde_net_arch: Network architecture for extracting features + when using gSDE. If None, the latent features from the policy will be used. + Pass an empty list to use the states as features. + :param use_expln: Use ``expln()`` function instead of ``exp()`` to ensure + a positive standard deviation (cf paper). It allows to keep variance + above zero and prevent it from growing too fast. In practice, ``exp()`` is usually enough. + :param squash_output: Whether to squash the output using a tanh function, + this allows to ensure boundaries when using gSDE. + :param features_extractor_class: Features extractor to use. + :param features_extractor_kwargs: Keyword arguments + to pass to the features extractor. + :param normalize_images: Whether to normalize images or not, + dividing by 255.0 (True by default) + :param optimizer_class: The optimizer to use, + ``th.optim.Adam`` by default + :param optimizer_kwargs: Additional keyword arguments, + excluding the learning rate, to pass to the optimizer + """ + + def __init__( + self, + observation_space: gym.spaces.Space, + action_space: gym.spaces.Space, + lr_schedule: Schedule, + net_arch: Optional[List[Union[int, Dict[str, List[int]]]]] = None, + activation_fn: Type[nn.Module] = nn.Tanh, + ortho_init: bool = True, + use_sde: bool = False, + log_std_init: float = 0.0, + full_std: bool = True, + sde_net_arch: Optional[List[int]] = None, + use_expln: bool = False, + squash_output: bool = False, + features_extractor_class: Type[BaseFeaturesExtractor] = NatureCNN, + features_extractor_kwargs: Optional[Dict[str, Any]] = None, + normalize_images: bool = True, + optimizer_class: Type[th.optim.Optimizer] = th.optim.Adam, + optimizer_kwargs: Optional[Dict[str, Any]] = None, + ): + super().__init__( + observation_space, + action_space, + lr_schedule, + net_arch, + activation_fn, + ortho_init, + use_sde, + log_std_init, + full_std, + sde_net_arch, + use_expln, + squash_output, + features_extractor_class, + features_extractor_kwargs, + normalize_images, + optimizer_class, + optimizer_kwargs, + ) + + +class MultiInputActorCriticPolicy(ActorCriticPolicy): + """ + MultiInputActorClass policy class for actor-critic algorithms (has both policy and value prediction). + Used by A2C, PPO and the likes. + + :param observation_space: Observation space (Tuple) + :param action_space: Action space + :param lr_schedule: Learning rate schedule (could be constant) + :param net_arch: The specification of the policy and value networks. + :param activation_fn: Activation function + :param ortho_init: Whether to use or not orthogonal initialization + :param use_sde: Whether to use State Dependent Exploration or not + :param log_std_init: Initial value for the log standard deviation + :param full_std: Whether to use (n_features x n_actions) parameters + for the std instead of only (n_features,) when using gSDE + :param sde_net_arch: Network architecture for extracting features + when using gSDE. If None, the latent features from the policy will be used. + Pass an empty list to use the states as features. + :param use_expln: Use ``expln()`` function instead of ``exp()`` to ensure + a positive standard deviation (cf paper). It allows to keep variance + above zero and prevent it from growing too fast. In practice, ``exp()`` is usually enough. + :param squash_output: Whether to squash the output using a tanh function, + this allows to ensure boundaries when using gSDE. + :param features_extractor_class: Uses the CombinedExtractor + :param features_extractor_kwargs: Keyword arguments + to pass to the feature extractor. + :param normalize_images: Whether to normalize images or not, + dividing by 255.0 (True by default) + :param optimizer_class: The optimizer to use, + ``th.optim.Adam`` by default + :param optimizer_kwargs: Additional keyword arguments, + excluding the learning rate, to pass to the optimizer + """ + + def __init__( + self, + observation_space: gym.spaces.Dict, + action_space: gym.spaces.Space, + lr_schedule: Schedule, + net_arch: Optional[List[Union[int, Dict[str, List[int]]]]] = None, + activation_fn: Type[nn.Module] = nn.Tanh, + ortho_init: bool = True, + use_sde: bool = False, + log_std_init: float = 0.0, + full_std: bool = True, + sde_net_arch: Optional[List[int]] = None, + use_expln: bool = False, + squash_output: bool = False, + features_extractor_class: Type[BaseFeaturesExtractor] = CombinedExtractor, + features_extractor_kwargs: Optional[Dict[str, Any]] = None, + normalize_images: bool = True, + optimizer_class: Type[th.optim.Optimizer] = th.optim.Adam, + optimizer_kwargs: Optional[Dict[str, Any]] = None, + ): + super().__init__( + observation_space, + action_space, + lr_schedule, + net_arch, + activation_fn, + ortho_init, + use_sde, + log_std_init, + full_std, + sde_net_arch, + use_expln, + squash_output, + features_extractor_class, + features_extractor_kwargs, + normalize_images, + optimizer_class, + optimizer_kwargs, + ) + + +class ContinuousCritic(BaseModel): + """ + Critic network(s) for DDPG/SAC/TD3. + It represents the action-state value function (Q-value function). + Compared to A2C/PPO critics, this one represents the Q-value + and takes the continuous action as input. It is concatenated with the state + and then fed to the network which outputs a single value: Q(s, a). + For more recent algorithms like SAC/TD3, multiple networks + are created to give different estimates. + + By default, it creates two critic networks used to reduce overestimation + thanks to clipped Q-learning (cf TD3 paper). + + :param observation_space: Obervation space + :param action_space: Action space + :param net_arch: Network architecture + :param features_extractor: Network to extract features + (a CNN when using images, a nn.Flatten() layer otherwise) + :param features_dim: Number of features + :param activation_fn: Activation function + :param normalize_images: Whether to normalize images or not, + dividing by 255.0 (True by default) + :param n_critics: Number of critic networks to create. + :param share_features_extractor: Whether the features extractor is shared or not + between the actor and the critic (this saves computation time) + """ + + def __init__( + self, + observation_space: gym.spaces.Space, + action_space: gym.spaces.Space, + net_arch: List[int], + features_extractor: nn.Module, + features_dim: int, + activation_fn: Type[nn.Module] = nn.ReLU, + normalize_images: bool = True, + n_critics: int = 2, + share_features_extractor: bool = True, + ): + super().__init__( + observation_space, + action_space, + features_extractor=features_extractor, + normalize_images=normalize_images, + ) + + action_dim = get_action_dim(self.action_space) + + self.share_features_extractor = share_features_extractor + self.n_critics = n_critics + self.q_networks = [] + for idx in range(n_critics): + q_net = create_mlp(features_dim + action_dim, 1, net_arch, activation_fn) + q_net = nn.Sequential(*q_net) + self.add_module(f"qf{idx}", q_net) + self.q_networks.append(q_net) + + def forward(self, obs: th.Tensor, actions: th.Tensor) -> Tuple[th.Tensor, ...]: + # Learn the features extractor using the policy loss only + # when the features_extractor is shared with the actor + with th.set_grad_enabled(not self.share_features_extractor): + features = self.extract_features(obs) + qvalue_input = th.cat([features, actions], dim=1) + return tuple(q_net(qvalue_input) for q_net in self.q_networks) + + def q1_forward(self, obs: th.Tensor, actions: th.Tensor) -> th.Tensor: + """ + Only predict the Q-value using the first network. + This allows to reduce computation when all the estimates are not needed + (e.g. when updating the policy in TD3). + """ + with th.no_grad(): + features = self.extract_features(obs) + return self.q_networks[0](th.cat([features, actions], dim=1)) diff --git a/dexart-release/stable_baselines3/common/preprocessing.py b/dexart-release/stable_baselines3/common/preprocessing.py new file mode 100644 index 0000000000000000000000000000000000000000..a2a2f9466fc26ce3b288bc2fac275aeb3d6b336c --- /dev/null +++ b/dexart-release/stable_baselines3/common/preprocessing.py @@ -0,0 +1,217 @@ +import warnings +from typing import Dict, Tuple, Union + +import numpy as np +import torch as th +from gym import spaces +from torch.nn import functional as F + + +def is_image_space_channels_first(observation_space: spaces.Box) -> bool: + """ + Check if an image observation space (see ``is_image_space``) + is channels-first (CxHxW, True) or channels-last (HxWxC, False). + + Use a heuristic that channel dimension is the smallest of the three. + If second dimension is smallest, raise an exception (no support). + + :param observation_space: + :return: True if observation space is channels-first image, False if channels-last. + """ + smallest_dimension = np.argmin(observation_space.shape).item() + if smallest_dimension == 1: + warnings.warn("Treating image space as channels-last, while second dimension was smallest of the three.") + return smallest_dimension == 0 + + +def is_image_space( + observation_space: spaces.Space, + check_channels: bool = False, +) -> bool: + """ + Check if a observation space has the shape, limits and dtype + of a valid image. + The check is conservative, so that it returns False if there is a doubt. + + Valid images: RGB, RGBD, GrayScale with values in [0, 255] + + :param observation_space: + :param check_channels: Whether to do or not the check for the number of channels. + e.g., with frame-stacking, the observation space may have more channels than expected. + :return: + """ + if isinstance(observation_space, spaces.Box) and len(observation_space.shape) == 3: + # Check the type + if observation_space.dtype != np.uint8: + return False + + # Check the value range + if np.any(observation_space.low != 0) or np.any(observation_space.high != 255): + return False + + # Skip channels check + if not check_channels: + return True + # Check the number of channels + if is_image_space_channels_first(observation_space): + n_channels = observation_space.shape[0] + else: + n_channels = observation_space.shape[-1] + # RGB, RGBD, GrayScale + return n_channels in [1, 3, 4] + return False + + +def maybe_transpose(observation: np.ndarray, observation_space: spaces.Space) -> np.ndarray: + """ + Handle the different cases for images as PyTorch use channel first format. + + :param observation: + :param observation_space: + :return: channel first observation if observation is an image + """ + # Avoid circular import + from stable_baselines3.common.vec_env import VecTransposeImage + + if is_image_space(observation_space): + if not (observation.shape == observation_space.shape or observation.shape[1:] == observation_space.shape): + # Try to re-order the channels + transpose_obs = VecTransposeImage.transpose_image(observation) + if transpose_obs.shape == observation_space.shape or transpose_obs.shape[1:] == observation_space.shape: + observation = transpose_obs + return observation + + +def preprocess_obs( + obs: th.Tensor, + observation_space: spaces.Space, + normalize_images: bool = True +) -> Union[th.Tensor, Dict[str, th.Tensor]]: + """ + Preprocess observation to be to a neural network. + For images, it normalizes the values by dividing them by 255 (to have values in [0, 1]) + For discrete observations, it create a one hot vector. + + :param obs: Observation + :param observation_space: + :param normalize_images: Whether to normalize images or not + (True by default) + :return: + """ + if isinstance(observation_space, spaces.Box): + if is_image_space(observation_space) and normalize_images: + return obs.float() / 255.0 + return obs.float() + + elif isinstance(observation_space, spaces.Discrete): + # One hot encoding and convert to float to avoid errors + return F.one_hot(obs.long(), num_classes=observation_space.n).float() + + elif isinstance(observation_space, spaces.MultiDiscrete): + # Tensor concatenation of one hot encodings of each Categorical sub-space + return th.cat( + [ + F.one_hot(obs_.long(), num_classes=int(observation_space.nvec[idx])).float() + for idx, obs_ in enumerate(th.split(obs.long(), 1, dim=1)) + ], + dim=-1, + ).view(obs.shape[0], sum(observation_space.nvec)) + + elif isinstance(observation_space, spaces.MultiBinary): + return obs.float() + + elif isinstance(observation_space, spaces.Dict): + # Do not modify by reference the original observation + preprocessed_obs = {} + for key, _obs in obs.items(): + if observation_space.spaces.__contains__(key): + preprocessed_obs[key] = preprocess_obs(_obs, observation_space[key], normalize_images=normalize_images) + return preprocessed_obs + + else: + raise NotImplementedError(f"Preprocessing not implemented for {observation_space}") + + +def get_obs_shape( + observation_space: spaces.Space, +) -> Union[Tuple[int, ...], Dict[str, Tuple[int, ...]]]: + """ + Get the shape of the observation (useful for the buffers). + + :param observation_space: + :return: + """ + if isinstance(observation_space, spaces.Box): + return observation_space.shape + elif isinstance(observation_space, spaces.Discrete): + # Observation is an int + return (1,) + elif isinstance(observation_space, spaces.MultiDiscrete): + # Number of discrete features + return (int(len(observation_space.nvec)),) + elif isinstance(observation_space, spaces.MultiBinary): + # Number of binary features + return (int(observation_space.n),) + elif isinstance(observation_space, spaces.Dict): + return {key: get_obs_shape(subspace) for (key, subspace) in observation_space.spaces.items()} + + else: + raise NotImplementedError(f"{observation_space} observation space is not supported") + + +def get_flattened_obs_dim(observation_space: spaces.Space) -> int: + """ + Get the dimension of the observation space when flattened. + It does not apply to image observation space. + + Used by the ``FlattenExtractor`` to compute the input shape. + + :param observation_space: + :return: + """ + # See issue https://github.com/openai/gym/issues/1915 + # it may be a problem for Dict/Tuple spaces too... + if isinstance(observation_space, spaces.MultiDiscrete): + return sum(observation_space.nvec) + else: + # Use Gym internal method + return spaces.utils.flatdim(observation_space) + + +def get_action_dim(action_space: spaces.Space) -> int: + """ + Get the dimension of the action space. + + :param action_space: + :return: + """ + if isinstance(action_space, spaces.Box): + return int(np.prod(action_space.shape)) + elif isinstance(action_space, spaces.Discrete): + # Action is an int + return 1 + elif isinstance(action_space, spaces.MultiDiscrete): + # Number of discrete actions + return int(len(action_space.nvec)) + elif isinstance(action_space, spaces.MultiBinary): + # Number of binary actions + return int(action_space.n) + else: + raise NotImplementedError(f"{action_space} action space is not supported") + + +def check_for_nested_spaces(obs_space: spaces.Space): + """ + Make sure the observation space does not have nested spaces (Dicts/Tuples inside Dicts/Tuples). + If so, raise an Exception informing that there is no support for this. + + :param obs_space: an observation space + :return: + """ + if isinstance(obs_space, (spaces.Dict, spaces.Tuple)): + sub_spaces = obs_space.spaces.values() if isinstance(obs_space, spaces.Dict) else obs_space.spaces + for sub_space in sub_spaces: + if isinstance(sub_space, (spaces.Dict, spaces.Tuple)): + raise NotImplementedError( + "Nested observation spaces are not supported (Tuple/Dict space inside Tuple/Dict space)." + ) diff --git a/dexart-release/stable_baselines3/common/running_mean_std.py b/dexart-release/stable_baselines3/common/running_mean_std.py new file mode 100644 index 0000000000000000000000000000000000000000..b48f9223c9a1a6814ab5db0fe416e30522374797 --- /dev/null +++ b/dexart-release/stable_baselines3/common/running_mean_std.py @@ -0,0 +1,57 @@ +from typing import Tuple, Union + +import numpy as np + + +class RunningMeanStd: + def __init__(self, epsilon: float = 1e-4, shape: Tuple[int, ...] = ()): + """ + Calulates the running mean and std of a data stream + https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Parallel_algorithm + + :param epsilon: helps with arithmetic issues + :param shape: the shape of the data stream's output + """ + self.mean = np.zeros(shape, np.float64) + self.var = np.ones(shape, np.float64) + self.count = epsilon + + def copy(self) -> "RunningMeanStd": + """ + :return: Return a copy of the current object. + """ + new_object = RunningMeanStd(shape=self.mean.shape) + new_object.mean = self.mean.copy() + new_object.var = self.var.copy() + new_object.count = float(self.count) + return new_object + + def combine(self, other: "RunningMeanStd") -> None: + """ + Combine stats from another ``RunningMeanStd`` object. + + :param other: The other object to combine with. + """ + self.update_from_moments(other.mean, other.var, other.count) + + def update(self, arr: np.ndarray) -> None: + batch_mean = np.mean(arr, axis=0) + batch_var = np.var(arr, axis=0) + batch_count = arr.shape[0] + self.update_from_moments(batch_mean, batch_var, batch_count) + + def update_from_moments(self, batch_mean: np.ndarray, batch_var: np.ndarray, batch_count: Union[int, float]) -> None: + delta = batch_mean - self.mean + tot_count = self.count + batch_count + + new_mean = self.mean + delta * batch_count / tot_count + m_a = self.var * self.count + m_b = batch_var * batch_count + m_2 = m_a + m_b + np.square(delta) * self.count * batch_count / (self.count + batch_count) + new_var = m_2 / (self.count + batch_count) + + new_count = batch_count + self.count + + self.mean = new_mean + self.var = new_var + self.count = new_count diff --git a/dexart-release/stable_baselines3/common/save_util.py b/dexart-release/stable_baselines3/common/save_util.py new file mode 100644 index 0000000000000000000000000000000000000000..a3e3acd0b84d9427ab1f4cf5ffe0b58ea3885f02 --- /dev/null +++ b/dexart-release/stable_baselines3/common/save_util.py @@ -0,0 +1,447 @@ +""" +Save util taken from stable_baselines +used to serialize data (class parameters) of model classes +""" +import base64 +import functools +import io +import json +import os +import pathlib +import pickle +import warnings +import zipfile +from typing import Any, Dict, Optional, Tuple, Union + +import cloudpickle +import torch as th + +import stable_baselines3 as sb3 +from stable_baselines3.common.type_aliases import TensorDict +from stable_baselines3.common.utils import get_device, get_system_info + + +def recursive_getattr(obj: Any, attr: str, *args) -> Any: + """ + Recursive version of getattr + taken from https://stackoverflow.com/questions/31174295 + + Ex: + > MyObject.sub_object = SubObject(name='test') + > recursive_getattr(MyObject, 'sub_object.name') # return test + :param obj: + :param attr: Attribute to retrieve + :return: The attribute + """ + + def _getattr(obj: Any, attr: str) -> Any: + return getattr(obj, attr, *args) + + return functools.reduce(_getattr, [obj] + attr.split(".")) + + +def recursive_setattr(obj: Any, attr: str, val: Any) -> None: + """ + Recursive version of setattr + taken from https://stackoverflow.com/questions/31174295 + + Ex: + > MyObject.sub_object = SubObject(name='test') + > recursive_setattr(MyObject, 'sub_object.name', 'hello') + :param obj: + :param attr: Attribute to set + :param val: New value of the attribute + """ + pre, _, post = attr.rpartition(".") + return setattr(recursive_getattr(obj, pre) if pre else obj, post, val) + + +def is_json_serializable(item: Any) -> bool: + """ + Test if an object is serializable into JSON + + :param item: The object to be tested for JSON serialization. + :return: True if object is JSON serializable, false otherwise. + """ + # Try with try-except struct. + json_serializable = True + try: + _ = json.dumps(item) + except TypeError: + json_serializable = False + return json_serializable + + +def data_to_json(data: Dict[str, Any]) -> str: + """ + Turn data (class parameters) into a JSON string for storing + + :param data: Dictionary of class parameters to be + stored. Items that are not JSON serializable will be + pickled with Cloudpickle and stored as bytearray in + the JSON file + :return: JSON string of the data serialized. + """ + # First, check what elements can not be JSONfied, + # and turn them into byte-strings + serializable_data = {} + for data_key, data_item in data.items(): + # See if object is JSON serializable + if is_json_serializable(data_item): + # All good, store as it is + serializable_data[data_key] = data_item + else: + # Not serializable, cloudpickle it into + # bytes and convert to base64 string for storing. + # Also store type of the class for consumption + # from other languages/humans, so we have an + # idea what was being stored. + base64_encoded = base64.b64encode(cloudpickle.dumps(data_item)).decode() + + # Use ":" to make sure we do + # not override these keys + # when we include variables of the object later + cloudpickle_serialization = { + ":type:": str(type(data_item)), + ":serialized:": base64_encoded, + } + + # Add first-level JSON-serializable items of the + # object for further details (but not deeper than this to + # avoid deep nesting). + # First we check that object has attributes (not all do, + # e.g. numpy scalars) + if hasattr(data_item, "__dict__") or isinstance(data_item, dict): + # Take elements from __dict__ for custom classes + item_generator = data_item.items if isinstance(data_item, dict) else data_item.__dict__.items + for variable_name, variable_item in item_generator(): + # Check if serializable. If not, just include the + # string-representation of the object. + if is_json_serializable(variable_item): + cloudpickle_serialization[variable_name] = variable_item + else: + cloudpickle_serialization[variable_name] = str(variable_item) + + serializable_data[data_key] = cloudpickle_serialization + json_string = json.dumps(serializable_data, indent=4) + return json_string + + +def json_to_data(json_string: str, custom_objects: Optional[Dict[str, Any]] = None) -> Dict[str, Any]: + """ + Turn JSON serialization of class-parameters back into dictionary. + + :param json_string: JSON serialization of the class-parameters + that should be loaded. + :param custom_objects: Dictionary of objects to replace + upon loading. If a variable is present in this dictionary as a + key, it will not be deserialized and the corresponding item + will be used instead. Similar to custom_objects in + ``keras.models.load_model``. Useful when you have an object in + file that can not be deserialized. + :return: Loaded class parameters. + """ + if custom_objects is not None and not isinstance(custom_objects, dict): + raise ValueError("custom_objects argument must be a dict or None") + + json_dict = json.loads(json_string) + # This will be filled with deserialized data + return_data = {} + for data_key, data_item in json_dict.items(): + if custom_objects is not None and data_key in custom_objects.keys(): + # If item is provided in custom_objects, replace + # the one from JSON with the one in custom_objects + return_data[data_key] = custom_objects[data_key] + elif isinstance(data_item, dict) and ":serialized:" in data_item.keys(): + # If item is dictionary with ":serialized:" + # key, this means it is serialized with cloudpickle. + serialization = data_item[":serialized:"] + # Try-except deserialization in case we run into + # errors. If so, we can tell bit more information to + # user. + try: + base64_object = base64.b64decode(serialization.encode()) + deserialized_object = cloudpickle.loads(base64_object) + except (RuntimeError, TypeError): + warnings.warn( + f"Could not deserialize object {data_key}. " + + "Consider using `custom_objects` argument to replace " + + "this object." + ) + return_data[data_key] = deserialized_object + else: + # Read as it is + return_data[data_key] = data_item + return return_data + + +@functools.singledispatch +def open_path(path: Union[str, pathlib.Path, io.BufferedIOBase], mode: str, verbose: int = 0, suffix: Optional[str] = None): + """ + Opens a path for reading or writing with a preferred suffix and raises debug information. + If the provided path is a derivative of io.BufferedIOBase it ensures that the file + matches the provided mode, i.e. If the mode is read ("r", "read") it checks that the path is readable. + If the mode is write ("w", "write") it checks that the file is writable. + + If the provided path is a string or a pathlib.Path, it ensures that it exists. If the mode is "read" + it checks that it exists, if it doesn't exist it attempts to read path.suffix if a suffix is provided. + If the mode is "write" and the path does not exist, it creates all the parent folders. If the path + points to a folder, it changes the path to path_2. If the path already exists and verbose == 2, + it raises a warning. + + :param path: the path to open. + if save_path is a str or pathlib.Path and mode is "w", single dispatch ensures that the + path actually exists. If path is a io.BufferedIOBase the path exists. + :param mode: how to open the file. "w"|"write" for writing, "r"|"read" for reading. + :param verbose: Verbosity level, 0 means only warnings, 2 means debug information. + :param suffix: The preferred suffix. If mode is "w" then the opened file has the suffix. + If mode is "r" then we attempt to open the path. If an error is raised and the suffix + is not None, we attempt to open the path with the suffix. + :return: + """ + if not isinstance(path, io.BufferedIOBase): + raise TypeError("Path parameter has invalid type.", io.BufferedIOBase) + if path.closed: + raise ValueError("File stream is closed.") + mode = mode.lower() + try: + mode = {"write": "w", "read": "r", "w": "w", "r": "r"}[mode] + except KeyError: + raise ValueError("Expected mode to be either 'w' or 'r'.") + if ("w" == mode) and not path.writable() or ("r" == mode) and not path.readable(): + e1 = "writable" if "w" == mode else "readable" + raise ValueError(f"Expected a {e1} file.") + return path + + +@open_path.register(str) +def open_path_str(path: str, mode: str, verbose: int = 0, suffix: Optional[str] = None) -> io.BufferedIOBase: + """ + Open a path given by a string. If writing to the path, the function ensures + that the path exists. + + :param path: the path to open. If mode is "w" then it ensures that the path exists + by creating the necessary folders and renaming path if it points to a folder. + :param mode: how to open the file. "w" for writing, "r" for reading. + :param verbose: Verbosity level, 0 means only warnings, 2 means debug information. + :param suffix: The preferred suffix. If mode is "w" then the opened file has the suffix. + If mode is "r" then we attempt to open the path. If an error is raised and the suffix + is not None, we attempt to open the path with the suffix. + :return: + """ + return open_path(pathlib.Path(path), mode, verbose, suffix) + + +@open_path.register(pathlib.Path) +def open_path_pathlib(path: pathlib.Path, mode: str, verbose: int = 0, suffix: Optional[str] = None) -> io.BufferedIOBase: + """ + Open a path given by a string. If writing to the path, the function ensures + that the path exists. + + :param path: the path to check. If mode is "w" then it + ensures that the path exists by creating the necessary folders and + renaming path if it points to a folder. + :param mode: how to open the file. "w" for writing, "r" for reading. + :param verbose: Verbosity level, 0 means only warnings, 2 means debug information. + :param suffix: The preferred suffix. If mode is "w" then the opened file has the suffix. + If mode is "r" then we attempt to open the path. If an error is raised and the suffix + is not None, we attempt to open the path with the suffix. + :return: + """ + if mode not in ("w", "r"): + raise ValueError("Expected mode to be either 'w' or 'r'.") + + if mode == "r": + try: + path = path.open("rb") + except FileNotFoundError as error: + if suffix is not None and suffix != "": + newpath = pathlib.Path(f"{path}.{suffix}") + if verbose == 2: + warnings.warn(f"Path '{path}' not found. Attempting {newpath}.") + path, suffix = newpath, None + else: + raise error + else: + try: + if path.suffix == "" and suffix is not None and suffix != "": + path = pathlib.Path(f"{path}.{suffix}") + if path.exists() and path.is_file() and verbose == 2: + warnings.warn(f"Path '{path}' exists, will overwrite it.") + path = path.open("wb") + except IsADirectoryError: + warnings.warn(f"Path '{path}' is a folder. Will save instead to {path}_2") + path = pathlib.Path(f"{path}_2") + except FileNotFoundError: # Occurs when the parent folder doesn't exist + warnings.warn(f"Path '{path.parent}' does not exist. Will create it.") + path.parent.mkdir(exist_ok=True, parents=True) + + # if opening was successful uses the identity function + # if opening failed with IsADirectory|FileNotFound, calls open_path_pathlib + # with corrections + # if reading failed with FileNotFoundError, calls open_path_pathlib with suffix + + return open_path(path, mode, verbose, suffix) + + +def save_to_zip_file( + save_path: Union[str, pathlib.Path, io.BufferedIOBase], + data: Optional[Dict[str, Any]] = None, + params: Optional[Dict[str, Any]] = None, + pytorch_variables: Optional[Dict[str, Any]] = None, + verbose: int = 0, +) -> None: + """ + Save model data to a zip archive. + + :param save_path: Where to store the model. + if save_path is a str or pathlib.Path ensures that the path actually exists. + :param data: Class parameters being stored (non-PyTorch variables) + :param params: Model parameters being stored expected to contain an entry for every + state_dict with its name and the state_dict. + :param pytorch_variables: Other PyTorch variables expected to contain name and value of the variable. + :param verbose: Verbosity level, 0 means only warnings, 2 means debug information + """ + save_path = open_path(save_path, "w", verbose=0, suffix="zip") + # data/params can be None, so do not + # try to serialize them blindly + if data is not None: + serialized_data = data_to_json(data) + + # Create a zip-archive and write our objects there. + with zipfile.ZipFile(save_path, mode="w") as archive: + # Do not try to save "None" elements + if data is not None: + archive.writestr("data", serialized_data) + if pytorch_variables is not None: + with archive.open("pytorch_variables.pth", mode="w", force_zip64=True) as pytorch_variables_file: + th.save(pytorch_variables, pytorch_variables_file) + if params is not None: + for file_name, dict_ in params.items(): + with archive.open(file_name + ".pth", mode="w", force_zip64=True) as param_file: + th.save(dict_, param_file) + # Save metadata: library version when file was saved + archive.writestr("_stable_baselines3_version", sb3.__version__) + # Save system info about the current python env + # archive.writestr("system_info.txt", get_system_info(print_info=False)[1]) + + +def save_to_pkl(path: Union[str, pathlib.Path, io.BufferedIOBase], obj: Any, verbose: int = 0) -> None: + """ + Save an object to path creating the necessary folders along the way. + If the path exists and is a directory, it will raise a warning and rename the path. + If a suffix is provided in the path, it will use that suffix, otherwise, it will use '.pkl'. + + :param path: the path to open. + if save_path is a str or pathlib.Path and mode is "w", single dispatch ensures that the + path actually exists. If path is a io.BufferedIOBase the path exists. + :param obj: The object to save. + :param verbose: Verbosity level, 0 means only warnings, 2 means debug information. + """ + with open_path(path, "w", verbose=verbose, suffix="pkl") as file_handler: + # Use protocol>=4 to support saving replay buffers >= 4Gb + # See https://docs.python.org/3/library/pickle.html + pickle.dump(obj, file_handler, protocol=pickle.HIGHEST_PROTOCOL) + + +def load_from_pkl(path: Union[str, pathlib.Path, io.BufferedIOBase], verbose: int = 0) -> Any: + """ + Load an object from the path. If a suffix is provided in the path, it will use that suffix. + If the path does not exist, it will attempt to load using the .pkl suffix. + + :param path: the path to open. + if save_path is a str or pathlib.Path and mode is "w", single dispatch ensures that the + path actually exists. If path is a io.BufferedIOBase the path exists. + :param verbose: Verbosity level, 0 means only warnings, 2 means debug information. + """ + with open_path(path, "r", verbose=verbose, suffix="pkl") as file_handler: + return pickle.load(file_handler) + + +def load_from_zip_file( + load_path: Union[str, pathlib.Path, io.BufferedIOBase], + load_data: bool = True, + custom_objects: Optional[Dict[str, Any]] = None, + device: Union[th.device, str] = "auto", + verbose: int = 0, + print_system_info: bool = False, +) -> (Tuple[Optional[Dict[str, Any]], Optional[TensorDict], Optional[TensorDict]]): + """ + Load model data from a .zip archive + + :param load_path: Where to load the model from + :param load_data: Whether we should load and return data + (class parameters). Mainly used by 'load_parameters' to only load model parameters (weights) + :param custom_objects: Dictionary of objects to replace + upon loading. If a variable is present in this dictionary as a + key, it will not be deserialized and the corresponding item + will be used instead. Similar to custom_objects in + ``keras.models.load_model``. Useful when you have an object in + file that can not be deserialized. + :param device: Device on which the code should run. + :param verbose: Verbosity level, 0 means only warnings, 2 means debug information. + :param print_system_info: Whether to print or not the system info + about the saved model. + :return: Class parameters, model state_dicts (aka "params", dict of state_dict) + and dict of pytorch variables + """ + load_path = open_path(load_path, "r", verbose=verbose, suffix="zip") + + # set device to cpu if cuda is not available + device = get_device(device=device) + + # Open the zip archive and load data + try: + with zipfile.ZipFile(load_path) as archive: + namelist = archive.namelist() + # If data or parameters is not in the + # zip archive, assume they were stored + # as None (_save_to_file_zip allows this). + data = None + pytorch_variables = None + params = {} + + # Debug system info first + if print_system_info: + if "system_info.txt" in namelist: + print("== SAVED MODEL SYSTEM INFO ==") + print(archive.read("system_info.txt").decode()) + else: + warnings.warn( + "The model was saved with SB3 <= 1.2.0 and thus cannot print system information.", + UserWarning, + ) + + if "data" in namelist and load_data: + # Load class parameters that are stored + # with either JSON or pickle (not PyTorch variables). + json_data = archive.read("data").decode() + data = json_to_data(json_data, custom_objects=custom_objects) + + # Check for all .pth files and load them using th.load. + # "pytorch_variables.pth" stores PyTorch variables, and any other .pth + # files store state_dicts of variables with custom names (e.g. policy, policy.optimizer) + pth_files = [file_name for file_name in namelist if os.path.splitext(file_name)[1] == ".pth"] + for file_path in pth_files: + with archive.open(file_path, mode="r") as param_file: + # File has to be seekable, but param_file is not, so load in BytesIO first + # fixed in python >= 3.7 + file_content = io.BytesIO() + file_content.write(param_file.read()) + # go to start of file + file_content.seek(0) + # Load the parameters with the right ``map_location``. + # Remove ".pth" ending with splitext + th_object = th.load(file_content, map_location=device) + # "tensors.pth" was renamed "pytorch_variables.pth" in v0.9.0, see PR #138 + if file_path == "pytorch_variables.pth" or file_path == "tensors.pth": + # PyTorch variables (not state_dicts) + pytorch_variables = th_object + else: + # State dicts. Store into params dictionary + # with same name as in .zip file (without .pth) + params[os.path.splitext(file_path)[0]] = th_object + except zipfile.BadZipFile: + # load_path wasn't a zip file + raise ValueError(f"Error: the file {load_path} wasn't a zip-file") + return data, params, pytorch_variables diff --git a/dexart-release/stable_baselines3/common/torch_layers.py b/dexart-release/stable_baselines3/common/torch_layers.py new file mode 100644 index 0000000000000000000000000000000000000000..c558ef63e3def98cdadd1938f0a93a857be96e44 --- /dev/null +++ b/dexart-release/stable_baselines3/common/torch_layers.py @@ -0,0 +1,401 @@ +from itertools import zip_longest +from typing import Dict, List, Tuple, Type, Union, Optional + +import gym +import torch +import torch as th +from torch import nn + +from stable_baselines3.common.preprocessing import get_flattened_obs_dim, is_image_space +from stable_baselines3.common.type_aliases import TensorDict +from stable_baselines3.common.utils import get_device + + +class BaseFeaturesExtractor(nn.Module): + """ + Base class that represents a features extractor. + + :param observation_space: + :param features_dim: Number of features extracted. + """ + + def __init__(self, observation_space: gym.Space, features_dim: int = 0): + super().__init__() + assert features_dim > 0 + self._observation_space = observation_space + self._features_dim = features_dim + + @property + def features_dim(self) -> int: + return self._features_dim + + def forward(self, observations: th.Tensor) -> th.Tensor: + raise NotImplementedError() + + +class FlattenExtractor(BaseFeaturesExtractor): + """ + Feature extract that flatten the input. + Used as a placeholder when feature extraction is not needed. + + :param observation_space: + """ + + def __init__(self, observation_space: gym.Space): + super().__init__(observation_space, get_flattened_obs_dim(observation_space)) + self.flatten = nn.Flatten() + + def forward(self, observations: th.Tensor) -> th.Tensor: + return self.flatten(observations) + + +class NatureCNN(BaseFeaturesExtractor): + """ + CNN from DQN nature paper: + Mnih, Volodymyr, et al. + "Human-level control through deep reinforcement learning." + Nature 518.7540 (2015): 529-533. + + :param observation_space: + :param features_dim: Number of features extracted. + This corresponds to the number of unit for the last layer. + """ + + def __init__(self, observation_space: gym.spaces.Box, features_dim: int = 512): + super().__init__(observation_space, features_dim) + # We assume CxHxW images (channels first) + # Re-ordering will be done by pre-preprocessing or wrapper + assert is_image_space(observation_space, check_channels=False), ( + "You should use NatureCNN " + f"only with images not with {observation_space}\n" + "(you are probably using `CnnPolicy` instead of `MlpPolicy` or `MultiInputPolicy`)\n" + "If you are using a custom environment,\n" + "please check it using our env checker:\n" + "https://stable-baselines3.readthedocs.io/en/master/common/env_checker.html" + ) + n_input_channels = observation_space.shape[0] + self.cnn = nn.Sequential( + nn.Conv2d(n_input_channels, 32, kernel_size=8, stride=4, padding=0), + nn.ReLU(), + nn.Conv2d(32, 64, kernel_size=4, stride=2, padding=0), + nn.ReLU(), + nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=0), + nn.ReLU(), + nn.Flatten(), + ) + + # Compute shape by doing one forward pass + with th.no_grad(): + n_flatten = self.cnn(th.as_tensor(observation_space.sample()[None]).float()).shape[1] + + self.linear = nn.Sequential(nn.Linear(n_flatten, features_dim), nn.ReLU()) + + def forward(self, observations: th.Tensor) -> th.Tensor: + return self.linear(self.cnn(observations)) + + +def create_mlp( + input_dim: int, + output_dim: int, + net_arch: List[int], + activation_fn: Type[nn.Module] = nn.ReLU, + squash_output: bool = False, +) -> List[nn.Module]: + """ + Create a multi layer perceptron (MLP), which is + a collection of fully-connected layers each followed by an activation function. + + :param input_dim: Dimension of the input vector + :param output_dim: + :param net_arch: Architecture of the neural net + It represents the number of units per layer. + The length of this list is the number of layers. + :param activation_fn: The activation function + to use after each layer. + :param squash_output: Whether to squash the output using a Tanh + activation function + :return: + """ + + if len(net_arch) > 0: + modules = [nn.Linear(input_dim, net_arch[0]), activation_fn()] + else: + modules = [] + + for idx in range(len(net_arch) - 1): + modules.append(nn.Linear(net_arch[idx], net_arch[idx + 1])) + modules.append(activation_fn()) + + if output_dim > 0: + last_layer_dim = net_arch[-1] if len(net_arch) > 0 else input_dim + modules.append(nn.Linear(last_layer_dim, output_dim)) + if squash_output: + modules.append(nn.Tanh()) + return modules + + +class MlpExtractor(nn.Module): + """ + Constructs an MLP that receives the output from a previous feature extractor (i.e. a CNN) or directly + the observations (if no feature extractor is applied) as an input and outputs a latent representation + for the policy and a value network. + The ``net_arch`` parameter allows to specify the amount and size of the hidden layers and how many + of them are shared between the policy network and the value network. It is assumed to be a list with the following + structure: + + 1. An arbitrary length (zero allowed) number of integers each specifying the number of units in a shared layer. + If the number of ints is zero, there will be no shared layers. + 2. An optional dict, to specify the following non-shared layers for the value network and the policy network. + It is formatted like ``dict(vf=[], pi=[])``. + If it is missing any of the keys (pi or vf), no non-shared layers (empty list) is assumed. + + For example to construct a network with one shared layer of size 55 followed by two non-shared layers for the value + network of size 255 and a single non-shared layer of size 128 for the policy network, the following layers_spec + would be used: ``[55, dict(vf=[255, 255], pi=[128])]``. A simple shared network topology with two layers of size 128 + would be specified as [128, 128]. + + Adapted from Stable Baselines. + + :param feature_dim: Dimension of the feature vector (can be the output of a CNN) + :param net_arch: The specification of the policy and value networks. + See above for details on its formatting. + :param activation_fn: The activation function to use for the networks. + :param device: + """ + + def __init__( + self, + feature_dim: int, + net_arch: List[Union[int, Dict[str, List[int]]]], + activation_fn: Type[nn.Module], + device: Union[th.device, str] = "auto", + ): + super().__init__() + device = get_device(device) + shared_net, policy_net, value_net = [], [], [] + policy_only_layers = [] # Layer sizes of the network that only belongs to the policy network + value_only_layers = [] # Layer sizes of the network that only belongs to the value network + last_layer_dim_shared = feature_dim + + # Iterate through the shared layers and build the shared parts of the network + for layer in net_arch: + if isinstance(layer, int): # Check that this is a shared layer + # TODO: give layer a meaningful name + shared_net.append(nn.Linear(last_layer_dim_shared, layer)) # add linear of size layer + shared_net.append(activation_fn()) + last_layer_dim_shared = layer + else: + assert isinstance(layer, dict), "Error: the net_arch list can only contain ints and dicts" + if "pi" in layer: + assert isinstance(layer["pi"], list), "Error: net_arch[-1]['pi'] must contain a list of integers." + policy_only_layers = layer["pi"] + + if "vf" in layer: + assert isinstance(layer["vf"], list), "Error: net_arch[-1]['vf'] must contain a list of integers." + value_only_layers = layer["vf"] + break # From here on the network splits up in policy and value network + + last_layer_dim_pi = last_layer_dim_shared + last_layer_dim_vf = last_layer_dim_shared + + # Build the non-shared part of the network + for pi_layer_size, vf_layer_size in zip_longest(policy_only_layers, value_only_layers): + if pi_layer_size is not None: + assert isinstance(pi_layer_size, int), "Error: net_arch[-1]['pi'] must only contain integers." + policy_net.append(nn.Linear(last_layer_dim_pi, pi_layer_size)) + policy_net.append(activation_fn()) + last_layer_dim_pi = pi_layer_size + + if vf_layer_size is not None: + assert isinstance(vf_layer_size, int), "Error: net_arch[-1]['vf'] must only contain integers." + value_net.append(nn.Linear(last_layer_dim_vf, vf_layer_size)) + value_net.append(activation_fn()) + last_layer_dim_vf = vf_layer_size + + # Save dim, used to create the distributions + self.latent_dim_pi = last_layer_dim_pi + self.latent_dim_vf = last_layer_dim_vf + + # Create networks + # If the list of layers is empty, the network will just act as an Identity module + self.shared_net = nn.Sequential(*shared_net).to(device) + self.policy_net = nn.Sequential(*policy_net).to(device) + self.value_net = nn.Sequential(*value_net).to(device) + + def forward(self, features: th.Tensor) -> Tuple[th.Tensor, th.Tensor]: + """ + :return: latent_policy, latent_value of the specified network. + If all layers are shared, then ``latent_policy == latent_value`` + """ + shared_latent = self.shared_net(features) + return self.policy_net(shared_latent), self.value_net(shared_latent) + + def forward_actor(self, features: th.Tensor) -> th.Tensor: + return self.policy_net(self.shared_net(features)) + + def forward_critic(self, features: th.Tensor) -> th.Tensor: + return self.value_net(self.shared_net(features)) + + +class CombinedExtractor(BaseFeaturesExtractor): + """ + Combined feature extractor for Dict observation spaces. + Builds a feature extractor for each key of the space. Input from each space + is fed through a separate submodule (CNN or MLP, depending on input shape), + the output features are concatenated and fed through additional MLP network ("combined"). + + :param observation_space: + :param cnn_output_dim: Number of features to output from each CNN submodule(s). Defaults to + 256 to avoid exploding network sizes. + """ + + def __init__(self, observation_space: gym.spaces.Dict, cnn_output_dim: int = 256): + # TODO we do not know features-dim here before going over all the items, so put something there. This is dirty! + super().__init__(observation_space, features_dim=1) + + extractors = {} + + total_concat_size = 0 + for key, subspace in observation_space.spaces.items(): + if is_image_space(subspace): + extractors[key] = NatureCNN(subspace, features_dim=cnn_output_dim) + total_concat_size += cnn_output_dim + else: + # The observation key is a vector, flatten it if needed + extractors[key] = nn.Flatten() + total_concat_size += get_flattened_obs_dim(subspace) + + self.extractors = nn.ModuleDict(extractors) + + # Update the features dim manually + self._features_dim = total_concat_size + + def forward(self, observations: TensorDict) -> th.Tensor: + encoded_tensor_list = [] + + for key, extractor in self.extractors.items(): + encoded_tensor_list.append(extractor(observations[key])) + return th.cat(encoded_tensor_list, dim=1) + + +def get_actor_critic_arch(net_arch: Union[List[int], Dict[str, List[int]]]) -> Tuple[List[int], List[int]]: + """ + Get the actor and critic network architectures for off-policy actor-critic algorithms (SAC, TD3, DDPG). + + The ``net_arch`` parameter allows to specify the amount and size of the hidden layers, + which can be different for the actor and the critic. + It is assumed to be a list of ints or a dict. + + 1. If it is a list, actor and critic networks will have the same architecture. + The architecture is represented by a list of integers (of arbitrary length (zero allowed)) + each specifying the number of units per layer. + If the number of ints is zero, the network will be linear. + 2. If it is a dict, it should have the following structure: + ``dict(qf=[], pi=[])``. + where the network architecture is a list as described in 1. + + For example, to have actor and critic that share the same network architecture, + you only need to specify ``net_arch=[256, 256]`` (here, two hidden layers of 256 units each). + + If you want a different architecture for the actor and the critic, + then you can specify ``net_arch=dict(qf=[400, 300], pi=[64, 64])``. + + .. note:: + Compared to their on-policy counterparts, no shared layers (other than the features extractor) + between the actor and the critic are allowed (to prevent issues with target networks). + + :param net_arch: The specification of the actor and critic networks. + See above for details on its formatting. + :return: The network architectures for the actor and the critic + """ + if isinstance(net_arch, list): + actor_arch, critic_arch = net_arch, net_arch + else: + assert isinstance(net_arch, dict), "Error: the net_arch can only contain be a list of ints or a dict" + assert "pi" in net_arch, "Error: no key 'pi' was provided in net_arch for the actor network" + assert "qf" in net_arch, "Error: no key 'qf' was provided in net_arch for the critic network" + actor_arch, critic_arch = net_arch["pi"], net_arch["qf"] + return actor_arch, critic_arch + + +class PointNetImaginationExtractorGP(BaseFeaturesExtractor): + def __init__(self, observation_space: gym.spaces.Dict, pc_key: str, feat_key: Optional[str] = None, + out_channel=256, extractor_name="smallpn", + gt_key: Optional[str] = None, imagination_keys=("imagination_robot",), state_key="state", + state_mlp_size=(64, 64), state_mlp_activation_fn=nn.ReLU, *kwargs): + self.imagination_key = imagination_keys + # Init state representation + self.use_state = state_key is not None + self.state_key = state_key + + print(f"extractor use state = {self.use_state}") + if self.use_state: + if state_key not in observation_space.spaces.keys(): + raise RuntimeError(f"State key {state_key} not in observation space: {observation_space}") + self.state_space = observation_space[self.state_key] + if feat_key is not None: + if feat_key not in list(observation_space.keys()): + raise RuntimeError(f"Feature key {feat_key} not in observation space.") + if pc_key not in list(observation_space.keys()): + raise RuntimeError(f"Point cloud key {pc_key} not in observation space.") + + super().__init__(observation_space, out_channel) + # Point cloud input should have size (n, 3), spec size (n, 3), feat size (n, m) + self.pc_key = pc_key + self.has_feat = feat_key is not None + self.feat_key = feat_key + self.gt_key = gt_key + + if extractor_name == "smallpn": + from stable_baselines3.networks.pretrain_nets import PointNet + self.extractor = PointNet() + elif extractor_name == "mediumpn": + from stable_baselines3.networks.pretrain_nets import PointNetMedium + self.extractor = PointNetMedium() + elif extractor_name == "largepn": + from stable_baselines3.networks.pretrain_nets import PointNetLarge + self.extractor = PointNetLarge() + else: + raise NotImplementedError(f"Extractor {extractor_name} not implemented. Available:\ + smallpn, mediumpn, largepn") + + # self.n_input_channels = n_input_channels + self.n_output_channels = out_channel + assert self.n_output_channels == 256 + + if self.use_state: + self.state_dim = self.state_space.shape[0] + if len(state_mlp_size) == 0: + raise RuntimeError(f"State mlp size is empty") + elif len(state_mlp_size) == 1: + net_arch = [] + else: + net_arch = state_mlp_size[:-1] + output_dim = state_mlp_size[-1] + + self.n_output_channels = out_channel + output_dim + self._features_dim = self.n_output_channels + self.state_mlp = nn.Sequential(*create_mlp(self.state_dim, output_dim, net_arch, state_mlp_activation_fn)) + + def forward(self, observations: TensorDict) -> th.Tensor: + # get raw point cloud segmentation mask + points = observations[self.pc_key][..., :3] # B * N * 3 + + b, _, _ = points.shape + if len(self.imagination_key) > 0: + for key in self.imagination_key: + obs = observations[key] + if len(obs.shape) == 2: + obs = obs.unsqueeze(0) + img_points = obs[:, :, :3] + points = torch.concat([points, img_points], dim=1) + + # points = torch.transpose(points, 1, 2) # B * 3 * N + # points: B * 3 * (N + sum(Ni)) + pn_feat = self.extractor(points) # B * 256 + if self.use_state: + state_feat = self.state_mlp(observations[self.state_key]) + return torch.cat([pn_feat, state_feat], dim=-1) + else: + return pn_feat + diff --git a/dexart-release/stable_baselines3/common/type_aliases.py b/dexart-release/stable_baselines3/common/type_aliases.py new file mode 100644 index 0000000000000000000000000000000000000000..8f110d6d24f56845b4dc92a420a33bc54ccb7e3a --- /dev/null +++ b/dexart-release/stable_baselines3/common/type_aliases.py @@ -0,0 +1,82 @@ +"""Common aliases for type hints""" + +from enum import Enum +from typing import Any, Callable, Dict, List, NamedTuple, Tuple, Union + +import gym +import numpy as np +import torch as th + +from stable_baselines3.common import callbacks, vec_env + +GymEnv = Union[gym.Env, vec_env.VecEnv] +GymObs = Union[Tuple, Dict[str, Any], np.ndarray, int] +GymStepReturn = Tuple[GymObs, float, bool, Dict] +TensorDict = Dict[Union[str, int], th.Tensor] +OptimizerStateDict = Dict[str, Any] +MaybeCallback = Union[None, Callable, List[callbacks.BaseCallback], callbacks.BaseCallback] + +# A schedule takes the remaining progress as input +# and ouputs a scalar (e.g. learning rate, clip range, ...) +Schedule = Callable[[float], float] + + +class RolloutBufferSamples(NamedTuple): + observations: th.Tensor + actions: th.Tensor + old_values: th.Tensor + old_log_prob: th.Tensor + advantages: th.Tensor + returns: th.Tensor + + +class DictRolloutBufferSamples(RolloutBufferSamples): + observations: TensorDict + actions: th.Tensor + old_values: th.Tensor + old_log_prob: th.Tensor + advantages: th.Tensor + returns: th.Tensor + + +class DictSSLRolloutBufferSamples(NamedTuple): + observations: TensorDict + next_observations: List[TensorDict] + actions: th.Tensor + next_actions: List[th.Tensor] + old_values: th.Tensor + old_log_prob: th.Tensor + advantages: th.Tensor + returns: th.Tensor + + +class ReplayBufferSamples(NamedTuple): + observations: th.Tensor + actions: th.Tensor + next_observations: th.Tensor + dones: th.Tensor + rewards: th.Tensor + + +class DictReplayBufferSamples(ReplayBufferSamples): + observations: TensorDict + actions: th.Tensor + next_observations: th.Tensor + dones: th.Tensor + rewards: th.Tensor + + +class RolloutReturn(NamedTuple): + episode_timesteps: int + n_episodes: int + continue_training: bool + + +class TrainFrequencyUnit(Enum): + STEP = "step" + EPISODE = "episode" + + +class TrainFreq(NamedTuple): + frequency: int + unit: TrainFrequencyUnit # either "step" or "episode" diff --git a/dexart-release/stable_baselines3/common/utils.py b/dexart-release/stable_baselines3/common/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..e73ffe4efdfa88868d0b67c6c65a9f0553fcb50b --- /dev/null +++ b/dexart-release/stable_baselines3/common/utils.py @@ -0,0 +1,508 @@ +import glob +import os +import platform +import random +from collections import deque +from itertools import zip_longest +from typing import Dict, Iterable, Optional, Tuple, Union + +import gym +import numpy as np +import torch as th + +import stable_baselines3 as sb3 + +# Check if tensorboard is available for pytorch +try: + from torch.utils.tensorboard import SummaryWriter +except ImportError: + SummaryWriter = None + +from stable_baselines3.common.logger import Logger, configure +from stable_baselines3.common.type_aliases import GymEnv, Schedule, TensorDict, TrainFreq, TrainFrequencyUnit + + +def set_random_seed(seed: int, using_cuda: bool = False) -> None: + """ + Seed the different random generators. + + :param seed: + :param using_cuda: + """ + # Seed python RNG + random.seed(seed) + # Seed numpy RNG + np.random.seed(seed) + # seed the RNG for all devices (both CPU and CUDA) + th.manual_seed(seed) + + if using_cuda: + # Deterministic operations for CuDNN, it may impact performances + th.backends.cudnn.deterministic = True + th.backends.cudnn.benchmark = False + + +# From stable baselines +def explained_variance(y_pred: np.ndarray, y_true: np.ndarray) -> np.ndarray: + """ + Computes fraction of variance that ypred explains about y. + Returns 1 - Var[y-ypred] / Var[y] + + interpretation: + ev=0 => might as well have predicted zero + ev=1 => perfect prediction + ev<0 => worse than just predicting zero + + :param y_pred: the prediction + :param y_true: the expected value + :return: explained variance of ypred and y + """ + assert y_true.ndim == 1 and y_pred.ndim == 1 + var_y = np.var(y_true) + return np.nan if var_y == 0 else 1 - np.var(y_true - y_pred) / var_y + + +def update_learning_rate(optimizer: th.optim.Optimizer, learning_rate: float) -> None: + """ + Update the learning rate for a given optimizer. + Useful when doing linear schedule. + + :param optimizer: + :param learning_rate: + """ + for param_group in optimizer.param_groups: + param_group["lr"] = learning_rate + + +def get_schedule_fn(value_schedule: Union[Schedule, float, int]) -> Schedule: + """ + Transform (if needed) learning rate and clip range (for PPO) + to callable. + + :param value_schedule: + :return: + """ + # If the passed schedule is a float + # create a constant function + if isinstance(value_schedule, (float, int)): + # Cast to float to avoid errors + value_schedule = constant_fn(float(value_schedule)) + else: + assert callable(value_schedule) + return value_schedule + + +def get_linear_fn(start: float, end: float, end_fraction: float) -> Schedule: + """ + Create a function that interpolates linearly between start and end + between ``progress_remaining`` = 1 and ``progress_remaining`` = ``end_fraction``. + This is used in DQN for linearly annealing the exploration fraction + (epsilon for the epsilon-greedy strategy). + + :params start: value to start with if ``progress_remaining`` = 1 + :params end: value to end with if ``progress_remaining`` = 0 + :params end_fraction: fraction of ``progress_remaining`` + where end is reached e.g 0.1 then end is reached after 10% + of the complete training process. + :return: + """ + + def func(progress_remaining: float) -> float: + if (1 - progress_remaining) > end_fraction: + return end + else: + return start + (1 - progress_remaining) * (end - start) / end_fraction + + return func + + +def constant_fn(val: float) -> Schedule: + """ + Create a function that returns a constant + It is useful for learning rate schedule (to avoid code duplication) + + :param val: + :return: + """ + + def func(_): + return val + + return func + + +def get_device(device: Union[th.device, str] = "auto") -> th.device: + """ + Retrieve PyTorch device. + It checks that the requested device is available first. + For now, it supports only cpu and cuda. + By default, it tries to use the gpu. + + :param device: One for 'auto', 'cuda', 'cpu' + :return: + """ + # Cuda by default + if device == "auto": + device = "cuda" + # Force conversion to th.device + device = th.device(device) + + # Cuda not available + if device.type == th.device("cuda").type and not th.cuda.is_available(): + return th.device("cpu") + + return device + + +def get_latest_run_id(log_path: str = "", log_name: str = "") -> int: + """ + Returns the latest run number for the given log name and log path, + by finding the greatest number in the directories. + + :param log_path: Path to the log folder containing several runs. + :param log_name: Name of the experiment. Each run is stored + in a folder named ``log_name_1``, ``log_name_2``, ... + :return: latest run number + """ + max_run_id = 0 + for path in glob.glob(os.path.join(log_path, f"{glob.escape(log_name)}_[0-9]*")): + file_name = path.split(os.sep)[-1] + ext = file_name.split("_")[-1] + if log_name == "_".join(file_name.split("_")[:-1]) and ext.isdigit() and int(ext) > max_run_id: + max_run_id = int(ext) + return max_run_id + + +def configure_logger( + verbose: int = 0, + tensorboard_log: Optional[str] = None, + tb_log_name: str = "", + reset_num_timesteps: bool = True, +) -> Logger: + """ + Configure the logger's outputs. + + :param verbose: the verbosity level: 0 no output, 1 info, 2 debug + :param tensorboard_log: the log location for tensorboard (if None, no logging) + :param tb_log_name: tensorboard log + :param reset_num_timesteps: Whether the ``num_timesteps`` attribute is reset or not. + It allows to continue a previous learning curve (``reset_num_timesteps=False``) + or start from t=0 (``reset_num_timesteps=True``, the default). + :return: The logger object + """ + save_path, format_strings = None, ["stdout"] + + if tensorboard_log is not None and SummaryWriter is None: + raise ImportError("Trying to log data to tensorboard but tensorboard is not installed.") + + if tensorboard_log is not None and SummaryWriter is not None: + latest_run_id = get_latest_run_id(tensorboard_log, tb_log_name) + if not reset_num_timesteps: + # Continue training in the same directory + latest_run_id -= 1 + save_path = os.path.join(tensorboard_log, f"{tb_log_name}_{latest_run_id + 1}") + if verbose >= 1: + format_strings = ["stdout", "tensorboard"] + else: + format_strings = ["tensorboard"] + elif verbose == 0: + format_strings = [""] + return configure(save_path, format_strings=format_strings) + + +def check_for_correct_spaces(env: GymEnv, observation_space: gym.spaces.Space, action_space: gym.spaces.Space) -> None: + """ + Checks that the environment has same spaces as provided ones. Used by BaseAlgorithm to check if + spaces match after loading the model with given env. + Checked parameters: + - observation_space + - action_space + + :param env: Environment to check for valid spaces + :param observation_space: Observation space to check against + :param action_space: Action space to check against + """ + if observation_space != env.observation_space: + raise ValueError(f"Observation spaces do not match: {observation_space} != {env.observation_space}") + if action_space != env.action_space: + raise ValueError(f"Action spaces do not match: {action_space} != {env.action_space}") + + +def is_vectorized_box_observation(observation: np.ndarray, observation_space: gym.spaces.Box) -> bool: + """ + For box observation type, detects and validates the shape, + then returns whether or not the observation is vectorized. + + :param observation: the input observation to validate + :param observation_space: the observation space + :return: whether the given observation is vectorized or not + """ + if observation.shape == observation_space.shape: + return False + elif observation.shape[1:] == observation_space.shape: + return True + else: + raise ValueError( + f"Error: Unexpected observation shape {observation.shape} for " + + f"Box environment, please use {observation_space.shape} " + + "or (n_env, {}) for the observation shape.".format(", ".join(map(str, observation_space.shape))) + ) + + +def is_vectorized_discrete_observation(observation: Union[int, np.ndarray], observation_space: gym.spaces.Discrete) -> bool: + """ + For discrete observation type, detects and validates the shape, + then returns whether or not the observation is vectorized. + + :param observation: the input observation to validate + :param observation_space: the observation space + :return: whether the given observation is vectorized or not + """ + if isinstance(observation, int) or observation.shape == (): # A numpy array of a number, has shape empty tuple '()' + return False + elif len(observation.shape) == 1: + return True + else: + raise ValueError( + f"Error: Unexpected observation shape {observation.shape} for " + + "Discrete environment, please use () or (n_env,) for the observation shape." + ) + + +def is_vectorized_multidiscrete_observation(observation: np.ndarray, observation_space: gym.spaces.MultiDiscrete) -> bool: + """ + For multidiscrete observation type, detects and validates the shape, + then returns whether or not the observation is vectorized. + + :param observation: the input observation to validate + :param observation_space: the observation space + :return: whether the given observation is vectorized or not + """ + if observation.shape == (len(observation_space.nvec),): + return False + elif len(observation.shape) == 2 and observation.shape[1] == len(observation_space.nvec): + return True + else: + raise ValueError( + f"Error: Unexpected observation shape {observation.shape} for MultiDiscrete " + + f"environment, please use ({len(observation_space.nvec)},) or " + + f"(n_env, {len(observation_space.nvec)}) for the observation shape." + ) + + +def is_vectorized_multibinary_observation(observation: np.ndarray, observation_space: gym.spaces.MultiBinary) -> bool: + """ + For multibinary observation type, detects and validates the shape, + then returns whether or not the observation is vectorized. + + :param observation: the input observation to validate + :param observation_space: the observation space + :return: whether the given observation is vectorized or not + """ + if observation.shape == (observation_space.n,): + return False + elif len(observation.shape) == 2 and observation.shape[1] == observation_space.n: + return True + else: + raise ValueError( + f"Error: Unexpected observation shape {observation.shape} for MultiBinary " + + f"environment, please use ({observation_space.n},) or " + + f"(n_env, {observation_space.n}) for the observation shape." + ) + + +def is_vectorized_dict_observation(observation: np.ndarray, observation_space: gym.spaces.Dict) -> bool: + """ + For dict observation type, detects and validates the shape, + then returns whether or not the observation is vectorized. + + :param observation: the input observation to validate + :param observation_space: the observation space + :return: whether the given observation is vectorized or not + """ + # We first assume that all observations are not vectorized + all_non_vectorized = True + for key, subspace in observation_space.spaces.items(): + # This fails when the observation is not vectorized + # or when it has the wrong shape + if observation[key].shape != subspace.shape: + all_non_vectorized = False + break + + if all_non_vectorized: + return False + + all_vectorized = True + # Now we check that all observation are vectorized and have the correct shape + for key, subspace in observation_space.spaces.items(): + if observation[key].shape[1:] != subspace.shape: + all_vectorized = False + break + + if all_vectorized: + return True + else: + # Retrieve error message + error_msg = "" + try: + is_vectorized_observation(observation[key], observation_space.spaces[key]) + except ValueError as e: + error_msg = f"{e}" + raise ValueError( + f"There seems to be a mix of vectorized and non-vectorized observations. " + f"Unexpected observation shape {observation[key].shape} for key {key} " + f"of type {observation_space.spaces[key]}. {error_msg}" + ) + + +def is_vectorized_observation(observation: Union[int, np.ndarray], observation_space: gym.spaces.Space) -> bool: + """ + For every observation type, detects and validates the shape, + then returns whether or not the observation is vectorized. + + :param observation: the input observation to validate + :param observation_space: the observation space + :return: whether the given observation is vectorized or not + """ + + is_vec_obs_func_dict = { + gym.spaces.Box: is_vectorized_box_observation, + gym.spaces.Discrete: is_vectorized_discrete_observation, + gym.spaces.MultiDiscrete: is_vectorized_multidiscrete_observation, + gym.spaces.MultiBinary: is_vectorized_multibinary_observation, + gym.spaces.Dict: is_vectorized_dict_observation, + } + + for space_type, is_vec_obs_func in is_vec_obs_func_dict.items(): + if isinstance(observation_space, space_type): + return is_vec_obs_func(observation, observation_space) + else: + # for-else happens if no break is called + raise ValueError(f"Error: Cannot determine if the observation is vectorized with the space type {observation_space}.") + + +def safe_mean(arr: Union[np.ndarray, list, deque]) -> np.ndarray: + """ + Compute the mean of an array if there is at least one element. + For empty array, return NaN. It is used for logging only. + + :param arr: + :return: + """ + return np.nan if len(arr) == 0 else np.mean(arr) + + +def zip_strict(*iterables: Iterable) -> Iterable: + r""" + ``zip()`` function but enforces that iterables are of equal length. + Raises ``ValueError`` if iterables not of equal length. + Code inspired by Stackoverflow answer for question #32954486. + + :param \*iterables: iterables to ``zip()`` + """ + # As in Stackoverflow #32954486, use + # new object for "empty" in case we have + # Nones in iterable. + sentinel = object() + for combo in zip_longest(*iterables, fillvalue=sentinel): + if sentinel in combo: + raise ValueError("Iterables have different lengths") + yield combo + + +def polyak_update( + params: Iterable[th.nn.Parameter], + target_params: Iterable[th.nn.Parameter], + tau: float, +) -> None: + """ + Perform a Polyak average update on ``target_params`` using ``params``: + target parameters are slowly updated towards the src parameters. + ``tau``, the soft update coefficient controls the interpolation: + ``tau=1`` corresponds to copying the parameters to the target ones whereas nothing happens when ``tau=0``. + The Polyak update is done in place, with ``no_grad``, and therefore does not create intermediate tensors, + or a computation graph, reducing memory cost and improving performance. We scale the target params + by ``1-tau`` (in-place), add the new weights, scaled by ``tau`` and store the result of the sum in the target + params (in place). + See https://github.com/DLR-RM/stable-baselines3/issues/93 + + :param params: parameters to use to update the target params + :param target_params: parameters to update + :param tau: the soft update coefficient ("Polyak update", between 0 and 1) + """ + with th.no_grad(): + # zip does not raise an exception if length of parameters does not match. + for param, target_param in zip_strict(params, target_params): + target_param.data.mul_(1 - tau) + th.add(target_param.data, param.data, alpha=tau, out=target_param.data) + + +def obs_as_tensor( + obs: Union[np.ndarray, Dict[Union[str, int], np.ndarray]], device: th.device +) -> Union[th.Tensor, TensorDict]: + """ + Moves the observation to the given device. + + :param obs: + :param device: PyTorch device + :return: PyTorch tensor of the observation on a desired device. + """ + if isinstance(obs, np.ndarray): + return th.as_tensor(obs).to(device) + elif isinstance(obs, dict): + return {key: th.as_tensor(_obs).to(device) for (key, _obs) in obs.items()} + else: + raise Exception(f"Unrecognized type of observation {type(obs)}") + + +def should_collect_more_steps( + train_freq: TrainFreq, + num_collected_steps: int, + num_collected_episodes: int, +) -> bool: + """ + Helper used in ``collect_rollouts()`` of off-policy algorithms + to determine the termination condition. + + :param train_freq: How much experience should be collected before updating the policy. + :param num_collected_steps: The number of already collected steps. + :param num_collected_episodes: The number of already collected episodes. + :return: Whether to continue or not collecting experience + by doing rollouts of the current policy. + """ + if train_freq.unit == TrainFrequencyUnit.STEP: + return num_collected_steps < train_freq.frequency + + elif train_freq.unit == TrainFrequencyUnit.EPISODE: + return num_collected_episodes < train_freq.frequency + + else: + raise ValueError( + "The unit of the `train_freq` must be either TrainFrequencyUnit.STEP " + f"or TrainFrequencyUnit.EPISODE not '{train_freq.unit}'!" + ) + + +def get_system_info(print_info: bool = True) -> Tuple[Dict[str, str], str]: + """ + Retrieve system and python env info for the current system. + + :param print_info: Whether to print or not those infos + :return: Dictionary summing up the version for each relevant package + and a formatted string. + """ + env_info = { + "OS": f"{platform.platform()} {platform.version()}", + "Python": platform.python_version(), + "Stable-Baselines3": sb3.__version__, + "PyTorch": th.__version__, + "GPU Enabled": str(th.cuda.is_available()), + "Numpy": np.__version__, + "Gym": gym.__version__, + } + env_info_str = "" + for key, value in env_info.items(): + env_info_str += f"{key}: {value}\n" + if print_info: + print(env_info_str) + return env_info, env_info_str diff --git a/dexart-release/stable_baselines3/common/vec_env/__init__.py b/dexart-release/stable_baselines3/common/vec_env/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..3880fbd53d9f7ed4683428149de2538042385877 --- /dev/null +++ b/dexart-release/stable_baselines3/common/vec_env/__init__.py @@ -0,0 +1,74 @@ +# flake8: noqa F401 +import typing +from copy import deepcopy +from typing import Optional, Type, Union + +from stable_baselines3.common.vec_env.base_vec_env import CloudpickleWrapper, VecEnv, VecEnvWrapper +from stable_baselines3.common.vec_env.dummy_vec_env import DummyVecEnv +from stable_baselines3.common.vec_env.stacked_observations import StackedDictObservations, StackedObservations +from stable_baselines3.common.vec_env.subproc_vec_env import SubprocVecEnv +from stable_baselines3.common.vec_env.vec_check_nan import VecCheckNan +from stable_baselines3.common.vec_env.vec_extract_dict_obs import VecExtractDictObs +from stable_baselines3.common.vec_env.vec_frame_stack import VecFrameStack +from stable_baselines3.common.vec_env.vec_monitor import VecMonitor +from stable_baselines3.common.vec_env.vec_normalize import VecNormalize +from stable_baselines3.common.vec_env.vec_transpose import VecTransposeImage +from stable_baselines3.common.vec_env.vec_video_recorder import VecVideoRecorder + +# Avoid circular import +if typing.TYPE_CHECKING: + from stable_baselines3.common.type_aliases import GymEnv + + +def unwrap_vec_wrapper(env: Union["GymEnv", VecEnv], vec_wrapper_class: Type[VecEnvWrapper]) -> Optional[VecEnvWrapper]: + """ + Retrieve a ``VecEnvWrapper`` object by recursively searching. + + :param env: + :param vec_wrapper_class: + :return: + """ + env_tmp = env + while isinstance(env_tmp, VecEnvWrapper): + if isinstance(env_tmp, vec_wrapper_class): + return env_tmp + env_tmp = env_tmp.venv + return None + + +def unwrap_vec_normalize(env: Union["GymEnv", VecEnv]) -> Optional[VecNormalize]: + """ + :param env: + :return: + """ + return unwrap_vec_wrapper(env, VecNormalize) # pytype:disable=bad-return-type + + +def is_vecenv_wrapped(env: Union["GymEnv", VecEnv], vec_wrapper_class: Type[VecEnvWrapper]) -> bool: + """ + Check if an environment is already wrapped by a given ``VecEnvWrapper``. + + :param env: + :param vec_wrapper_class: + :return: + """ + return unwrap_vec_wrapper(env, vec_wrapper_class) is not None + + +# Define here to avoid circular import +def sync_envs_normalization(env: "GymEnv", eval_env: "GymEnv") -> None: + """ + Sync eval env and train env when using VecNormalize + + :param env: + :param eval_env: + """ + env_tmp, eval_env_tmp = env, eval_env + while isinstance(env_tmp, VecEnvWrapper): + if isinstance(env_tmp, VecNormalize): + # Only synchronize if observation normalization exists + if hasattr(env_tmp, "obs_rms"): + eval_env_tmp.obs_rms = deepcopy(env_tmp.obs_rms) + eval_env_tmp.ret_rms = deepcopy(env_tmp.ret_rms) + env_tmp = env_tmp.venv + eval_env_tmp = eval_env_tmp.venv diff --git a/dexart-release/stable_baselines3/common/vec_env/base_vec_env.py b/dexart-release/stable_baselines3/common/vec_env/base_vec_env.py new file mode 100644 index 0000000000000000000000000000000000000000..98706050c7c48684b623de6637a95b5bde40bf51 --- /dev/null +++ b/dexart-release/stable_baselines3/common/vec_env/base_vec_env.py @@ -0,0 +1,374 @@ +import inspect +import warnings +from abc import ABC, abstractmethod +from typing import Any, Dict, Iterable, List, Optional, Sequence, Tuple, Type, Union + +import cloudpickle +import gym +import numpy as np + +# Define type aliases here to avoid circular import +# Used when we want to access one or more VecEnv +VecEnvIndices = Union[None, int, Iterable[int]] +# VecEnvObs is what is returned by the reset() method +# it contains the observation for each env +VecEnvObs = Union[np.ndarray, Dict[str, np.ndarray], Tuple[np.ndarray, ...]] +# VecEnvStepReturn is what is returned by the step() method +# it contains the observation, reward, done, info for each env +VecEnvStepReturn = Tuple[VecEnvObs, np.ndarray, np.ndarray, List[Dict]] + + +def tile_images(img_nhwc: Sequence[np.ndarray]) -> np.ndarray: # pragma: no cover + """ + Tile N images into one big PxQ image + (P,Q) are chosen to be as close as possible, and if N + is square, then P=Q. + + :param img_nhwc: list or array of images, ndim=4 once turned into array. img nhwc + n = batch index, h = height, w = width, c = channel + :return: img_HWc, ndim=3 + """ + img_nhwc = np.asarray(img_nhwc) + n_images, height, width, n_channels = img_nhwc.shape + # new_height was named H before + new_height = int(np.ceil(np.sqrt(n_images))) + # new_width was named W before + new_width = int(np.ceil(float(n_images) / new_height)) + img_nhwc = np.array(list(img_nhwc) + [img_nhwc[0] * 0 for _ in range(n_images, new_height * new_width)]) + # img_HWhwc + out_image = img_nhwc.reshape((new_height, new_width, height, width, n_channels)) + # img_HhWwc + out_image = out_image.transpose(0, 2, 1, 3, 4) + # img_Hh_Ww_c + out_image = out_image.reshape((new_height * height, new_width * width, n_channels)) + return out_image + + +class VecEnv(ABC): + """ + An abstract asynchronous, vectorized environment. + + :param num_envs: the number of environments + :param observation_space: the observation space + :param action_space: the action space + """ + + metadata = {"render.modes": ["human", "rgb_array"]} + + def __init__(self, num_envs: int, observation_space: gym.spaces.Space, action_space: gym.spaces.Space): + self.num_envs = num_envs + self.observation_space = observation_space + self.action_space = action_space + + @abstractmethod + def reset(self) -> VecEnvObs: + """ + Reset all the environments and return an array of + observations, or a tuple of observation arrays. + + If step_async is still doing work, that work will + be cancelled and step_wait() should not be called + until step_async() is invoked again. + + :return: observation + """ + raise NotImplementedError() + + @abstractmethod + def step_async(self, actions: np.ndarray) -> None: + """ + Tell all the environments to start taking a step + with the given actions. + Call step_wait() to get the results of the step. + + You should not call this if a step_async run is + already pending. + """ + raise NotImplementedError() + + @abstractmethod + def step_wait(self) -> VecEnvStepReturn: + """ + Wait for the step taken with step_async(). + + :return: observation, reward, done, information + """ + raise NotImplementedError() + + @abstractmethod + def close(self) -> None: + """ + Clean up the environment's resources. + """ + raise NotImplementedError() + + @abstractmethod + def get_attr(self, attr_name: str, indices: VecEnvIndices = None) -> List[Any]: + """ + Return attribute from vectorized environment. + + :param attr_name: The name of the attribute whose value to return + :param indices: Indices of envs to get attribute from + :return: List of values of 'attr_name' in all environments + """ + raise NotImplementedError() + + @abstractmethod + def set_attr(self, attr_name: str, value: Any, indices: VecEnvIndices = None) -> None: + """ + Set attribute inside vectorized environments. + + :param attr_name: The name of attribute to assign new value + :param value: Value to assign to `attr_name` + :param indices: Indices of envs to assign value + :return: + """ + raise NotImplementedError() + + @abstractmethod + def env_method(self, method_name: str, *method_args, indices: VecEnvIndices = None, **method_kwargs) -> List[Any]: + """ + Call instance methods of vectorized environments. + + :param method_name: The name of the environment method to invoke. + :param indices: Indices of envs whose method to call + :param method_args: Any positional arguments to provide in the call + :param method_kwargs: Any keyword arguments to provide in the call + :return: List of items returned by the environment's method call + """ + raise NotImplementedError() + + @abstractmethod + def env_is_wrapped(self, wrapper_class: Type[gym.Wrapper], indices: VecEnvIndices = None) -> List[bool]: + """ + Check if environments are wrapped with a given wrapper. + + :param method_name: The name of the environment method to invoke. + :param indices: Indices of envs whose method to call + :param method_args: Any positional arguments to provide in the call + :param method_kwargs: Any keyword arguments to provide in the call + :return: True if the env is wrapped, False otherwise, for each env queried. + """ + raise NotImplementedError() + + def step(self, actions: np.ndarray) -> VecEnvStepReturn: + """ + Step the environments with the given action + + :param actions: the action + :return: observation, reward, done, information + """ + self.step_async(actions) + return self.step_wait() + + def get_images(self) -> Sequence[np.ndarray]: + """ + Return RGB images from each environment + """ + raise NotImplementedError + + def render(self, mode: str = "human") -> Optional[np.ndarray]: + """ + Gym environment rendering + + :param mode: the rendering type + """ + try: + imgs = self.get_images() + except NotImplementedError: + warnings.warn(f"Render not defined for {self}") + return + + # Create a big image by tiling images from subprocesses + bigimg = tile_images(imgs) + if mode == "human": + import cv2 # pytype:disable=import-error + + cv2.imshow("vecenv", bigimg[:, :, ::-1]) + cv2.waitKey(1) + elif mode == "rgb_array": + return bigimg + else: + raise NotImplementedError(f"Render mode {mode} is not supported by VecEnvs") + + @abstractmethod + def seed(self, seed: Optional[int] = None) -> List[Union[None, int]]: + """ + Sets the random seeds for all environments, based on a given seed. + Each individual environment will still get its own seed, by incrementing the given seed. + + :param seed: The random seed. May be None for completely random seeding. + :return: Returns a list containing the seeds for each individual env. + Note that all list elements may be None, if the env does not return anything when being seeded. + """ + pass + + @property + def unwrapped(self) -> "VecEnv": + if isinstance(self, VecEnvWrapper): + return self.venv.unwrapped + else: + return self + + def getattr_depth_check(self, name: str, already_found: bool) -> Optional[str]: + """Check if an attribute reference is being hidden in a recursive call to __getattr__ + + :param name: name of attribute to check for + :param already_found: whether this attribute has already been found in a wrapper + :return: name of module whose attribute is being shadowed, if any. + """ + if hasattr(self, name) and already_found: + return f"{type(self).__module__}.{type(self).__name__}" + else: + return None + + def _get_indices(self, indices: VecEnvIndices) -> Iterable[int]: + """ + Convert a flexibly-typed reference to environment indices to an implied list of indices. + + :param indices: refers to indices of envs. + :return: the implied list of indices. + """ + if indices is None: + indices = range(self.num_envs) + elif isinstance(indices, int): + indices = [indices] + return indices + + +class VecEnvWrapper(VecEnv): + """ + Vectorized environment base class + + :param venv: the vectorized environment to wrap + :param observation_space: the observation space (can be None to load from venv) + :param action_space: the action space (can be None to load from venv) + """ + + def __init__( + self, + venv: VecEnv, + observation_space: Optional[gym.spaces.Space] = None, + action_space: Optional[gym.spaces.Space] = None, + ): + self.venv = venv + VecEnv.__init__( + self, + num_envs=venv.num_envs, + observation_space=observation_space or venv.observation_space, + action_space=action_space or venv.action_space, + ) + self.class_attributes = dict(inspect.getmembers(self.__class__)) + + def step_async(self, actions: np.ndarray) -> None: + self.venv.step_async(actions) + + @abstractmethod + def reset(self) -> VecEnvObs: + pass + + @abstractmethod + def step_wait(self) -> VecEnvStepReturn: + pass + + def seed(self, seed: Optional[int] = None) -> List[Union[None, int]]: + return self.venv.seed(seed) + + def close(self) -> None: + return self.venv.close() + + def render(self, mode: str = "human") -> Optional[np.ndarray]: + return self.venv.render(mode=mode) + + def get_images(self) -> Sequence[np.ndarray]: + return self.venv.get_images() + + def get_attr(self, attr_name: str, indices: VecEnvIndices = None) -> List[Any]: + return self.venv.get_attr(attr_name, indices) + + def set_attr(self, attr_name: str, value: Any, indices: VecEnvIndices = None) -> None: + return self.venv.set_attr(attr_name, value, indices) + + def env_method(self, method_name: str, *method_args, indices: VecEnvIndices = None, **method_kwargs) -> List[Any]: + return self.venv.env_method(method_name, *method_args, indices=indices, **method_kwargs) + + def env_is_wrapped(self, wrapper_class: Type[gym.Wrapper], indices: VecEnvIndices = None) -> List[bool]: + return self.venv.env_is_wrapped(wrapper_class, indices=indices) + + def __getattr__(self, name: str) -> Any: + """Find attribute from wrapped venv(s) if this wrapper does not have it. + Useful for accessing attributes from venvs which are wrapped with multiple wrappers + which have unique attributes of interest. + """ + blocked_class = self.getattr_depth_check(name, already_found=False) + if blocked_class is not None: + own_class = f"{type(self).__module__}.{type(self).__name__}" + error_str = ( + f"Error: Recursive attribute lookup for {name} from {own_class} is " + f"ambiguous and hides attribute from {blocked_class}" + ) + raise AttributeError(error_str) + + return self.getattr_recursive(name) + + def _get_all_attributes(self) -> Dict[str, Any]: + """Get all (inherited) instance and class attributes + + :return: all_attributes + """ + all_attributes = self.__dict__.copy() + all_attributes.update(self.class_attributes) + return all_attributes + + def getattr_recursive(self, name: str) -> Any: + """Recursively check wrappers to find attribute. + + :param name: name of attribute to look for + :return: attribute + """ + all_attributes = self._get_all_attributes() + if name in all_attributes: # attribute is present in this wrapper + attr = getattr(self, name) + elif hasattr(self.venv, "getattr_recursive"): + # Attribute not present, child is wrapper. Call getattr_recursive rather than getattr + # to avoid a duplicate call to getattr_depth_check. + attr = self.venv.getattr_recursive(name) + else: # attribute not present, child is an unwrapped VecEnv + attr = getattr(self.venv, name) + + return attr + + def getattr_depth_check(self, name: str, already_found: bool) -> str: + """See base class. + + :return: name of module whose attribute is being shadowed, if any. + """ + all_attributes = self._get_all_attributes() + if name in all_attributes and already_found: + # this venv's attribute is being hidden because of a higher venv. + shadowed_wrapper_class = f"{type(self).__module__}.{type(self).__name__}" + elif name in all_attributes and not already_found: + # we have found the first reference to the attribute. Now check for duplicates. + shadowed_wrapper_class = self.venv.getattr_depth_check(name, True) + else: + # this wrapper does not have the attribute. Keep searching. + shadowed_wrapper_class = self.venv.getattr_depth_check(name, already_found) + + return shadowed_wrapper_class + + +class CloudpickleWrapper: + """ + Uses cloudpickle to serialize contents (otherwise multiprocessing tries to use pickle) + + :param var: the variable you wish to wrap for pickling with cloudpickle + """ + + def __init__(self, var: Any): + self.var = var + + def __getstate__(self) -> Any: + return cloudpickle.dumps(self.var) + + def __setstate__(self, var: Any) -> None: + self.var = cloudpickle.loads(var) diff --git a/dexart-release/stable_baselines3/common/vec_env/dummy_vec_env.py b/dexart-release/stable_baselines3/common/vec_env/dummy_vec_env.py new file mode 100644 index 0000000000000000000000000000000000000000..c0efc8cafcdaca9596e21d3014b3db32ac92cdd3 --- /dev/null +++ b/dexart-release/stable_baselines3/common/vec_env/dummy_vec_env.py @@ -0,0 +1,127 @@ +from collections import OrderedDict +from copy import deepcopy +from typing import Any, Callable, List, Optional, Sequence, Type, Union + +import gym +import numpy as np + +from stable_baselines3.common.vec_env.base_vec_env import VecEnv, VecEnvIndices, VecEnvObs, VecEnvStepReturn +from stable_baselines3.common.vec_env.util import copy_obs_dict, dict_to_obs, obs_space_info + + +class DummyVecEnv(VecEnv): + """ + Creates a simple vectorized wrapper for multiple environments, calling each environment in sequence on the current + Python process. This is useful for computationally simple environment such as ``cartpole-v1``, + as the overhead of multiprocess or multithread outweighs the environment computation time. + This can also be used for RL methods that + require a vectorized environment, but that you want a single environments to train with. + + :param env_fns: a list of functions + that return environments to vectorize + """ + + def __init__(self, env_fns: List[Callable[[], gym.Env]]): + self.envs = [fn() for fn in env_fns] + env = self.envs[0] + VecEnv.__init__(self, len(env_fns), env.observation_space, env.action_space) + obs_space = env.observation_space + self.keys, shapes, dtypes = obs_space_info(obs_space) + + self.buf_obs = OrderedDict([(k, np.zeros((self.num_envs,) + tuple(shapes[k]), dtype=dtypes[k])) for k in self.keys]) + self.buf_dones = np.zeros((self.num_envs,), dtype=bool) + self.buf_rews = np.zeros((self.num_envs,), dtype=np.float32) + self.buf_infos = [{} for _ in range(self.num_envs)] + self.actions = None + self.metadata = env.metadata + + def step_async(self, actions: np.ndarray) -> None: + self.actions = actions + + def step_wait(self) -> VecEnvStepReturn: + for env_idx in range(self.num_envs): + obs, self.buf_rews[env_idx], self.buf_dones[env_idx], self.buf_infos[env_idx] = self.envs[env_idx].step( + self.actions[env_idx] + ) + if self.buf_dones[env_idx]: + # save final observation where user can get it, then reset + self.buf_infos[env_idx]["terminal_observation"] = obs + obs = self.envs[env_idx].reset() + self._save_obs(env_idx, obs) + return (self._obs_from_buf(), np.copy(self.buf_rews), np.copy(self.buf_dones), deepcopy(self.buf_infos)) + + def seed(self, seed: Optional[int] = None) -> List[Union[None, int]]: + if seed is None: + seed = np.random.randint(0, 2**32 - 1) + seeds = [] + for idx, env in enumerate(self.envs): + seeds.append(env.seed(seed + idx)) + return seeds + + def reset(self) -> VecEnvObs: + for env_idx in range(self.num_envs): + obs = self.envs[env_idx].reset() + self._save_obs(env_idx, obs) + return self._obs_from_buf() + + def close(self) -> None: + for env in self.envs: + env.close() + + def get_images(self) -> Sequence[np.ndarray]: + return [env.render(mode="rgb_array") for env in self.envs] + + def render(self, mode: str = "human") -> Optional[np.ndarray]: + """ + Gym environment rendering. If there are multiple environments then + they are tiled together in one image via ``BaseVecEnv.render()``. + Otherwise (if ``self.num_envs == 1``), we pass the render call directly to the + underlying environment. + + Therefore, some arguments such as ``mode`` will have values that are valid + only when ``num_envs == 1``. + + :param mode: The rendering type. + """ + if self.num_envs == 1: + return self.envs[0].render(mode=mode) + else: + return super().render(mode=mode) + + def _save_obs(self, env_idx: int, obs: VecEnvObs) -> None: + for key in self.keys: + if key is None: + self.buf_obs[key][env_idx] = obs + else: + self.buf_obs[key][env_idx] = obs[key] + + def _obs_from_buf(self) -> VecEnvObs: + return dict_to_obs(self.observation_space, copy_obs_dict(self.buf_obs)) + + def get_attr(self, attr_name: str, indices: VecEnvIndices = None) -> List[Any]: + """Return attribute from vectorized environment (see base class).""" + target_envs = self._get_target_envs(indices) + return [getattr(env_i, attr_name) for env_i in target_envs] + + def set_attr(self, attr_name: str, value: Any, indices: VecEnvIndices = None) -> None: + """Set attribute inside vectorized environments (see base class).""" + target_envs = self._get_target_envs(indices) + for env_i in target_envs: + setattr(env_i, attr_name, value) + + def env_method(self, method_name: str, *method_args, indices: VecEnvIndices = None, **method_kwargs) -> List[Any]: + """Call instance methods of vectorized environments.""" + target_envs = self._get_target_envs(indices) + return [getattr(env_i, method_name)(*method_args, **method_kwargs) for env_i in target_envs] + + def env_is_wrapped(self, wrapper_class: Type[gym.Wrapper], indices: VecEnvIndices = None) -> List[bool]: + """Check if worker environments are wrapped with a given wrapper""" + target_envs = self._get_target_envs(indices) + # Import here to avoid a circular import + from stable_baselines3.common import env_util + + return [env_util.is_wrapped(env_i, wrapper_class) for env_i in target_envs] + + def _get_target_envs(self, indices: VecEnvIndices) -> List[gym.Env]: + indices = self._get_indices(indices) + return [self.envs[i] for i in indices] diff --git a/dexart-release/stable_baselines3/common/vec_env/maniskill2_utils_common.py b/dexart-release/stable_baselines3/common/vec_env/maniskill2_utils_common.py new file mode 100644 index 0000000000000000000000000000000000000000..b6b90c0de331a975f73a6ab4cf5e5bfcbdc835b8 --- /dev/null +++ b/dexart-release/stable_baselines3/common/vec_env/maniskill2_utils_common.py @@ -0,0 +1,233 @@ +from collections import defaultdict, OrderedDict +from typing import Dict, Sequence + +import gym +import numpy as np +from gym import spaces + + + +# -------------------------------------------------------------------------- # +# Basic +# -------------------------------------------------------------------------- # +def merge_dicts(ds: Sequence[Dict], asarray=False): + """Merge multiple dicts with the same keys to a single one.""" + # NOTE(jigu): To be compatible with generator, we only iterate once. + ret = defaultdict(list) + for d in ds: + for k in d: + ret[k].append(d[k]) + ret = dict(ret) + # Sanity check (length) + assert len(set(len(v) for v in ret.values())) == 1, "Keys are not same." + if asarray: + ret = {k: np.concatenate(v) for k, v in ret.items()} + return ret + + +# -------------------------------------------------------------------------- # +# Numpy +# -------------------------------------------------------------------------- # +def normalize_vector(x, eps=1e-6): + x = np.asarray(x) + assert x.ndim == 1, x.ndim + norm = np.linalg.norm(x) + return np.zeros_like(x) if norm < eps else (x / norm) + + +def compute_angle_between(x1, x2): + """Compute angle (radian) between two vectors.""" + x1, x2 = normalize_vector(x1), normalize_vector(x2) + dot_prod = np.clip(np.dot(x1, x2), -1, 1) + return np.arccos(dot_prod).item() + + +class np_random: + """Context manager for numpy random state""" + + def __init__(self, seed): + self.seed = seed + self.state = None + + def __enter__(self): + self.state = np.random.get_state() + np.random.seed(self.seed) + return self.state + + def __exit__(self, exc_type, exc_val, exc_tb): + np.random.set_state(self.state) + + +def random_choice(x: Sequence, rng: np.random.RandomState = np.random): + assert len(x) > 0 + if len(x) == 1: + return x[0] + else: + return x[rng.randint(len(x))] + + +def get_dtype_bounds(dtype: np.dtype): + if np.issubdtype(dtype, np.floating): + info = np.finfo(dtype) + return info.min, info.max + elif np.issubdtype(dtype, np.integer): + info = np.iinfo(dtype) + return info.min, info.max + elif np.issubdtype(dtype, np.bool_): + return 0, 1 + else: + raise TypeError(dtype) + + +# ---------------------------------------------------------------------------- # +# OpenAI gym +# ---------------------------------------------------------------------------- # +def convert_observation_to_space(observation, prefix=""): + """Convert observation to OpenAI gym observation space (recursively). + Modified from `gym.envs.mujoco_env` + """ + if isinstance(observation, (dict)): + space = spaces.Dict( + { + k: convert_observation_to_space(v, prefix + "/" + k) + for k, v in observation.items() + } + ) + elif isinstance(observation, np.ndarray): + shape = observation.shape + dtype = observation.dtype + low, high = get_dtype_bounds(dtype) + if np.issubdtype(dtype, np.floating): + low, high = -np.inf, np.inf + space = spaces.Box(low, high, shape=shape, dtype=dtype) + elif isinstance(observation, (float, np.float32, np.float64)): + print(f"The observation ({prefix}) is a (float) scalar") + space = spaces.Box(-np.inf, np.inf, shape=[1], dtype=np.float32) + elif isinstance(observation, (int, np.int32, np.int64)): + print(f"The observation ({prefix}) is a (integer) scalar") + space = spaces.Box(-np.inf, np.inf, shape=[1], dtype=int) + elif isinstance(observation, (bool, np.bool_)): + print(f"The observation ({prefix}) is a (bool) scalar") + space = spaces.Box(0, 1, shape=[1], dtype=np.bool_) + else: + raise NotImplementedError(type(observation), observation) + + return space + + +def normalize_action_space(action_space: spaces.Box): + assert isinstance(action_space, spaces.Box), type(action_space) + return spaces.Box(-1, 1, shape=action_space.shape, dtype=action_space.dtype) + + +def clip_and_scale_action(action, low, high): + """Clip action to [-1, 1] and scale according to a range [low, high].""" + low, high = np.asarray(low), np.asarray(high) + action = np.clip(action, -1, 1) + return 0.5 * (high + low) + 0.5 * (high - low) * action + + +def inv_clip_and_scale_action(action, low, high): + """Inverse of `clip_and_scale_action`.""" + low, high = np.asarray(low), np.asarray(high) + action = (action - 0.5 * (high + low)) / (0.5 * (high - low)) + return np.clip(action, -1.0, 1.0) + + +def inv_scale_action(action, low, high): + """Inverse of `clip_and_scale_action` without clipping.""" + low, high = np.asarray(low), np.asarray(high) + return (action - 0.5 * (high + low)) / (0.5 * (high - low)) + + +def flatten_state_dict(state_dict: dict) -> np.ndarray: + """Flatten a dictionary containing states recursively. + Args: + state_dict: a dictionary containing scalars or 1-dim vectors. + Raises: + AssertionError: If a value of @state_dict is an ndarray with ndim > 2. + Returns: + np.ndarray: flattened states. + Notes: + The input is recommended to be ordered (e.g. OrderedDict). + However, since python 3.7, dictionary order is guaranteed to be insertion order. + """ + states = [] + for key, value in state_dict.items(): + if isinstance(value, dict): + state = flatten_state_dict(value) + if state.size == 0: + state = None + elif isinstance(value, (tuple, list)): + state = None if len(value) == 0 else value + elif isinstance(value, (bool, np.bool_, int, np.int32, np.int64)): + # x = np.array(1) > 0 is np.bool_ instead of ndarray + state = int(value) + elif isinstance(value, (float, np.float32, np.float64)): + state = np.float32(value) + elif isinstance(value, np.ndarray): + if value.ndim > 2: + raise AssertionError( + "The dimension of {} should not be more than 2.".format(key) + ) + state = value if value.size > 0 else None + else: + raise TypeError("Unsupported type: {}".format(type(value))) + if state is not None: + states.append(state) + if len(states) == 0: + return np.empty(0) + else: + return np.hstack(states) + + +def flatten_dict_keys(d: dict, prefix=""): + """Flatten a dict by expanding its keys recursively.""" + out = dict() + for k, v in d.items(): + if isinstance(v, dict): + out.update(flatten_dict_keys(v, prefix + k + "/")) + else: + out[prefix + k] = v + return out + + +def extract_scalars_from_info(info: dict, blacklist=()) -> Dict[str, float]: + """Recursively extract scalar metrics from info dict. + Args: + info (dict): info dict + blacklist (tuple, optional): keys to exclude. + Returns: + Dict[str, float]: scalar metrics + """ + ret = {} + for k, v in info.items(): + if k in blacklist: + continue + + # Ignore placeholder + if v is None: + continue + + # Recursively extract scalars + elif isinstance(v, dict): + ret2 = extract_scalars_from_info(v, blacklist=blacklist) + ret2 = {f"{k}.{k2}": v2 for k2, v2 in ret2.items()} + ret2 = {k2: v2 for k2, v2 in ret2.items() if k2 not in blacklist} + + # Things that are scalar-like will have an np.size of 1. + # Strings also have an np.size of 1, so explicitly ban those + elif np.size(v) == 1 and not isinstance(v, str): + ret[k] = float(v) + return ret + + +def flatten_dict_space_keys(space: spaces.Dict, prefix="") -> spaces.Dict: + """Flatten a dict of spaces by expanding its keys recursively.""" + out = OrderedDict() + for k, v in space.spaces.items(): + if isinstance(v, spaces.Dict): + out.update(flatten_dict_space_keys(v, prefix + k + "/").spaces) + else: + out[prefix + k] = v + return spaces.Dict(out) \ No newline at end of file diff --git a/dexart-release/stable_baselines3/common/vec_env/maniskill2_utils_wrappers_obs.py b/dexart-release/stable_baselines3/common/vec_env/maniskill2_utils_wrappers_obs.py new file mode 100644 index 0000000000000000000000000000000000000000..c989378ac6684aa9cbfaf96e195b42fe4e8587a0 --- /dev/null +++ b/dexart-release/stable_baselines3/common/vec_env/maniskill2_utils_wrappers_obs.py @@ -0,0 +1,233 @@ +from collections import OrderedDict +from copy import deepcopy +from typing import Sequence + +import gym +import numpy as np +from gym import spaces + +from maniskill2_utils_common import (flatten_dict_keys, flatten_dict_space_keys, merge_dicts) + + +class RGBDObservationWrapper(gym.ObservationWrapper): + """Map raw textures (Color and Position) to rgb and depth.""" + + def __init__(self, env): + super().__init__(env) + self.observation_space = deepcopy(env.observation_space) + self.update_observation_space(self.observation_space) + + @staticmethod + def update_observation_space(space: spaces.Dict): + # Update image observation space + image_space: spaces.Dict = space.spaces["image"] + for cam_uid in image_space: + ori_cam_space = image_space[cam_uid] + new_cam_space = OrderedDict() + for key in ori_cam_space: + if key == "Color": + height, width = ori_cam_space[key].shape[:2] + new_cam_space["rgb"] = spaces.Box( + low=0, high=255, shape=(height, width, 3), dtype=np.uint8 + ) + elif key == "Position": + height, width = ori_cam_space[key].shape[:2] + new_cam_space["depth"] = spaces.Box( + low=0, high=np.inf, shape=(height, width, 1), dtype=np.float32 + ) + else: + new_cam_space[key] = ori_cam_space[key] + image_space.spaces[cam_uid] = spaces.Dict(new_cam_space) + + def observation(self, observation: dict): + image_obs = observation["image"] + for cam_uid, ori_images in image_obs.items(): + new_images = OrderedDict() + for key in ori_images: + if key == "Color": + rgb = ori_images[key][..., :3] # [H, W, 4] + rgb = np.clip(rgb * 255, 0, 255).astype(np.uint8) + new_images["rgb"] = rgb # [H, W, 4] + elif key == "Position": + depth = -ori_images[key][..., [2]] # [H, W, 1] + new_images["depth"] = depth + else: + new_images[key] = ori_images[key] + image_obs[cam_uid] = new_images + return observation + + +def merge_dict_spaces(dict_spaces: Sequence[spaces.Dict]): + reverse_spaces = merge_dicts([x.spaces for x in dict_spaces]) + for key in reverse_spaces: + low, high = [], [] + for x in reverse_spaces[key]: + assert isinstance(x, spaces.Box), type(x) + low.append(x.low) + high.append(x.high) + low = np.concatenate(low) + high = np.concatenate(high) + new_space = spaces.Box(low=low, high=high, dtype=low.dtype) + reverse_spaces[key] = new_space + return spaces.Dict(OrderedDict(reverse_spaces)) + + +class PointCloudObservationWrapper(gym.ObservationWrapper): + """Convert Position textures to world-space point cloud.""" + + def __init__(self, env): + super().__init__(env) + self.observation_space = deepcopy(env.observation_space) + self.update_observation_space(self.observation_space) + self._buffer = {} + + @staticmethod + def update_observation_space(space: spaces.Dict): + # Replace image observation spaces with point cloud ones + image_space: spaces.Dict = space.spaces.pop("image") + space.spaces.pop("camera_param") + pcd_space = OrderedDict() + + for cam_uid in image_space: + cam_image_space = image_space[cam_uid] + cam_pcd_space = OrderedDict() + + h, w = cam_image_space["Position"].shape[:2] + cam_pcd_space["xyzw"] = spaces.Box( + low=-np.inf, high=np.inf, shape=(h * w, 4), dtype=np.float32 + ) + + # Extra keys + if "Color" in cam_image_space.spaces: + cam_pcd_space["rgb"] = spaces.Box( + low=0, high=255, shape=(h * w, 3), dtype=np.uint8 + ) + if "Segmentation" in cam_image_space.spaces: + cam_pcd_space["Segmentation"] = spaces.Box( + low=0, high=(2 ** 32 - 1), shape=(h * w, 4), dtype=np.uint32 + ) + + pcd_space[cam_uid] = spaces.Dict(cam_pcd_space) + + pcd_space = merge_dict_spaces(pcd_space.values()) + space.spaces["pointcloud"] = pcd_space + + def observation(self, observation: dict): + image_obs = observation.pop("image") + camera_params = observation.pop("camera_param") + pointcloud_obs = OrderedDict() + + for cam_uid, images in image_obs.items(): + cam_pcd = {} + + # Each pixel is (x, y, z, z_buffer_depth) in OpenGL camera space + position = images["Position"] + # position[..., 3] = position[..., 3] < 1 + position[..., 3] = position[..., 2] < 0 + + # Convert to world space + cam2world = camera_params[cam_uid]["cam2world_gl"] + xyzw = position.reshape(-1, 4) @ cam2world.T + cam_pcd["xyzw"] = xyzw + + # Extra keys + if "Color" in images: + rgb = images["Color"][..., :3] + rgb = np.clip(rgb * 255, 0, 255).astype(np.uint8) + cam_pcd["rgb"] = rgb.reshape(-1, 3) + if "Segmentation" in images: + cam_pcd["Segmentation"] = images["Segmentation"].reshape(-1, 4) + + pointcloud_obs[cam_uid] = cam_pcd + + pointcloud_obs = merge_dicts(pointcloud_obs.values()) + for key, value in pointcloud_obs.items(): + buffer = self._buffer.get(key, None) + pointcloud_obs[key] = np.concatenate(value, out=buffer) + self._buffer[key] = pointcloud_obs[key] + + observation["pointcloud"] = pointcloud_obs + return observation + + +class RobotSegmentationObservationWrapper(gym.ObservationWrapper): + """Add a binary mask for robot links.""" + + def __init__(self, env, replace=True): + super().__init__(env) + self.observation_space = deepcopy(env.observation_space) + self.init_observation_space(self.observation_space, replace=replace) + self.replace = replace + # Cache robot link ids + self.robot_link_ids = self.env.robot_link_ids + + @staticmethod + def init_observation_space(space: spaces.Dict, replace: bool): + # Update image observation spaces + if "image" in space.spaces: + image_space = space["image"] + for cam_uid in image_space: + cam_space = image_space[cam_uid] + if "Segmentation" not in cam_space.spaces: + continue + height, width = cam_space["Segmentation"].shape[:2] + new_space = spaces.Box( + low=0, high=1, shape=(height, width, 1), dtype="bool" + ) + if replace: + cam_space.spaces.pop("Segmentation") + cam_space.spaces["robot_seg"] = new_space + + # Update pointcloud observation spaces + if "pointcloud" in space.spaces: + pcd_space = space["pointcloud"] + if "Segmentation" in pcd_space.spaces: + n = pcd_space["Segmentation"].shape[0] + new_space = spaces.Box(low=0, high=1, shape=(n, 1), dtype="bool") + if replace: + pcd_space.spaces.pop("Segmentation") + pcd_space.spaces["robot_seg"] = new_space + + def reset(self, **kwargs): + observation = self.env.reset(**kwargs) + self.robot_link_ids = self.env.robot_link_ids + return self.observation(observation) + + def observation_image(self, observation: dict): + image_obs = observation["image"] + for cam_images in image_obs.values(): + if "Segmentation" not in cam_images: + continue + seg = cam_images["Segmentation"] + robot_seg = np.isin(seg[..., 1:2], self.robot_link_ids) + if self.replace: + cam_images.pop("Segmentation") + cam_images["robot_seg"] = robot_seg + return observation + + def observation_pointcloud(self, observation: dict): + pointcloud_obs = observation["pointcloud"] + if "Segmentation" not in pointcloud_obs: + return observation + seg = pointcloud_obs["Segmentation"] + robot_seg = np.isin(seg[..., 1:2], self.robot_link_ids) + if self.replace: + pointcloud_obs.pop("Segmentation") + pointcloud_obs["robot_seg"] = robot_seg + return observation + + def observation(self, observation: dict): + if "image" in observation: + observation = self.observation_image(observation) + if "pointcloud" in observation: + observation = self.observation_pointcloud(observation) + return observation + + +class FlattenObservationWrapper(gym.ObservationWrapper): + def __init__(self, env) -> None: + super().__init__(env) + self.observation_space = flatten_dict_space_keys(self.observation_space) + + def observation(self, observation): + return flatten_dict_keys(observation) diff --git a/dexart-release/stable_baselines3/common/vec_env/maniskill2_vec_env.py b/dexart-release/stable_baselines3/common/vec_env/maniskill2_vec_env.py new file mode 100644 index 0000000000000000000000000000000000000000..dee1c7165fca39dd996c68faed9d65e3b348a5aa --- /dev/null +++ b/dexart-release/stable_baselines3/common/vec_env/maniskill2_vec_env.py @@ -0,0 +1,607 @@ +"""ManiSkill2 vectorized environment. + +See also: + https://github.com/DLR-RM/stable-baselines3/blob/master/stable_baselines3/common/vec_env/subproc_vec_env.py +""" + +import multiprocessing as mp +import os +from collections import defaultdict +from copy import deepcopy +from functools import partial +from multiprocessing.connection import Connection +from typing import Callable, Dict, List, Optional, Sequence, Type, Union + +import gym +import numpy as np +import sapien.core as sapien +from gym import spaces +from gym.vector.utils.shared_memory import * + +try: + import torch +except ImportError: + raise ImportError("To use ManiSkill2 VecEnv, please install PyTorch first.") + +# from mani_skill2 import logger +# from mani_skill2.envs.sapien_env import BaseEnv +from gym import Env as BaseEnv + +def find_available_port(): + # https://stackoverflow.com/questions/1365265/on-localhost-how-do-i-pick-a-free-port-number + import socket + + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: + s.bind(("", 0)) + port = s.getsockname()[1] + server_address = f"localhost:{port}" + return server_address + + +def _worker( + rank: int, + remote: Connection, + parent_remote: Connection, + env_fn: Callable[..., BaseEnv], +): + # NOTE(jigu): Set environment variables for ManiSkill2 + os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" + os.environ["MKL_NUM_THREADS"] = "1" + os.environ["NUMEXPR_NUM_THREADS"] = "1" + os.environ["OMP_NUM_THREADS"] = "1" + + parent_remote.close() + + try: + env = env_fn() + while True: + cmd, data = remote.recv() + if cmd == "step": + obs, reward, done, info = env.step(data) + remote.send((obs, reward, done, info)) + elif cmd == "reset": + obs = env.reset() + remote.send(obs) + elif cmd == "close": + remote.close() + break + elif cmd == "env_method": + method = getattr(env, data[0]) + remote.send(method(*data[1], **data[2])) + elif cmd == "get_attr": + remote.send(getattr(env, data)) + elif cmd == "set_attr": + remote.send(setattr(env, data[0], data[1])) + elif cmd == "handshake": + remote.send(None) + else: + raise NotImplementedError(f"`{cmd}` is not implemented in the worker") + except KeyboardInterrupt: + print("Worker KeyboardInterrupt") + except EOFError: + print("Worker EOF") + except Exception as err: + print(err) + finally: + env.close() + + +class VecEnv: + """Vectorized environment modified from Stable Baselines3 for ManiSkill2. + Image observations can stay on GPU to avoid unnecessary data transfer. + + Creates a multiprocess vectorized wrapper for multiple environments, distributing each environment to its own + process, allowing significant speed up when the environment is computationally complex. + + For performance reasons, if your environment is not IO bound, the number of environments should not exceed the + number of logical cores on your CPU. + + .. warning:: + + Only 'forkserver' and 'spawn' start methods are thread-safe, + which is important when TensorFlow sessions or other non thread-safe + libraries are used in the parent (see issue #217). However, compared to + 'fork' they incur a small start-up cost and have restrictions on + global variables. With those methods, users must wrap the code in an + ``if __name__ == "__main__":`` block. + For more information, see the multiprocessing documentation. + + .. warning:: + The tensor observations are buffered. Make a copy to avoid from overwriting them. + + :param env_fns: Environments to run in subprocesses + :param start_method: method used to start the subprocesses. + Must be one of the methods returned by multiprocessing.get_all_start_methods(). + Defaults to 'forkserver' on available platforms, and 'spawn' otherwise. + :param server_address: The network address of the SAPIEN RenderServer. + If "auto", the server will be created automatically at an avaiable port. + Otherwise, it should be a networkd address, e.g. "localhost:12345". + :param server_kwargs: keyword arguments for sapien.RenderServer + """ + + device: torch.device + + def __init__( + self, + env_fns: List[Callable[[], BaseEnv]], + start_method: Optional[str] = None, + server_address: str = "auto", + server_kwargs: dict = None, + ): + self.waiting = False + self.closed = False + + if start_method is None: + # Fork is not a thread safe method (see issue #217) + # but is more user friendly (does not require to wrap the code in + # a `if __name__ == "__main__":`) + forkserver_available = "forkserver" in mp.get_all_start_methods() + start_method = "forkserver" if forkserver_available else "spawn" + ctx = mp.get_context(start_method) + + # ---------------------------------------------------------------------------- # + # Acquire observation space to construct buffer + # NOTE(jigu): Use a separate process to avoid creating sapien resources in the src process + remote, work_remote = ctx.Pipe() + args = (0, work_remote, remote, env_fns[0]) + process = ctx.Process(target=_worker, args=args, daemon=True) + process.start() + work_remote.close() + remote.send(("get_attr", "observation_space")) + self.observation_space: spaces.Dict = remote.recv() + remote.send(("get_attr", "action_space")) + self.action_space: spaces.Space = remote.recv() + remote.send(("close", None)) + remote.close() + process.join() + # ---------------------------------------------------------------------------- # + + n_envs = len(env_fns) + self.num_envs = n_envs + + # Allocate numpy buffers + self.non_image_obs_space = deepcopy(self.observation_space) + self.image_obs_space = self.non_image_obs_space.spaces.pop("image") + self._last_obs_np = [None for _ in range(n_envs)] + self._obs_np_buffer = create_np_buffer(self.non_image_obs_space, n=n_envs) + + # Start RenderServer + if server_address == "auto": + server_address = find_available_port() + self.server_address = server_address + server_kwargs = {} if server_kwargs is None else server_kwargs + self.server = sapien.RenderServer(**server_kwargs) + self.server.start(self.server_address) + print(f"RenderServer is running at: {server_address}") + + # Wrap env_fn + for i, env_fn in enumerate(env_fns): + client_kwargs = {"address": self.server_address, "process_index": i} + env_fns[i] = partial( + env_fn, renderer="client", renderer_kwargs=client_kwargs + ) + + # Initialize workers + self.remotes, self.work_remotes = zip(*[ctx.Pipe() for _ in range(n_envs)]) + self.processes = [] + for rank in range(n_envs): + work_remote = self.work_remotes[rank] + remote = self.remotes[rank] + env_fn = env_fns[rank] + args = (rank, work_remote, remote, env_fn) + # daemon=True: if the src process crashes, we should not cause things to hang + process = ctx.Process( + target=_worker, args=args, daemon=True + ) # pytype:disable=attribute-error + process.start() + self.processes.append(process) + work_remote.close() + + # To make sure environments are initialized in all workers + for remote in self.remotes: + remote.send(("handshake", None)) + for remote in self.remotes: + remote.recv() + + # Infer texture names + texture_names = set() + for cam_space in self.image_obs_space.spaces.values(): + texture_names.update(cam_space.spaces.keys()) + self.texture_names = tuple(texture_names) + + # Allocate torch buffers + # A list of [n_envs, n_cams, H, W, C] tensors + self._obs_torch_buffer: List[ + torch.Tensor + ] = self.server.auto_allocate_torch_tensors(self.texture_names) + self.device = self._obs_torch_buffer[0].device + + # ---------------------------------------------------------------------------- # + # Observations + # ---------------------------------------------------------------------------- # + def _update_np_buffer(self, obs_list, indices=None): + indices = self._get_indices(indices) + for i, obs in zip(indices, obs_list): + self._last_obs_np[i] = obs + return stack_obs( + self._last_obs_np, self.non_image_obs_space, self._obs_np_buffer + ) + + @torch.no_grad() + def _get_torch_observations(self): + self.server.wait_all() + + tensor_dict = {} + for i, name in enumerate(self.texture_names): + tensor_dict[name] = self._obs_torch_buffer[i] + + # NOTE(jigu): Efficiency might not be optimized when using more cameras + image_obs = {} + for cam_idx, cam_uid in enumerate(self.image_obs_space.spaces.keys()): + image_obs[cam_uid] = {} + cam_space = self.image_obs_space[cam_uid] + for tex_name in cam_space: + tensor = tensor_dict[tex_name][:, cam_idx] # [B, H, W, C] + if tensor.shape[1:3] != cam_space[tex_name].shape[0:2]: + h, w = cam_space[tex_name].shape[0:2] + tensor = tensor[:, :h, :w] + image_obs[cam_uid][tex_name] = tensor + + return dict(image=image_obs) + + # ---------------------------------------------------------------------------- # + # Interfaces + # ---------------------------------------------------------------------------- # + def seed(self, seed: Optional[int] = None) -> List[Union[None, int]]: + if seed is None: + seed = np.random.randint(0, 2**32) + for idx, remote in enumerate(self.remotes): + remote.send(("env_method", ("seed", [seed + idx], {}))) + return [remote.recv() for remote in self.remotes] + + def reset_async(self, indices=None): + remotes = self._get_target_remotes(indices) + for remote in remotes: + remote.send(("reset", None)) + self.waiting = True + + def reset_wait(self, indices=None): + remotes = self._get_target_remotes(indices) + results = [remote.recv() for remote in remotes] + self.waiting = False + vec_obs = self._get_torch_observations() + self._update_np_buffer(results, indices) + vec_obs.update(deepcopy(self._obs_np_buffer)) + return vec_obs + + def reset(self, indices=None): + self.reset_async(indices=indices) + return self.reset_wait(indices=indices) + + def step_async(self, actions: np.ndarray) -> None: + for remote, action in zip(self.remotes, actions): + remote.send(("step", action)) + self.waiting = True + + def step_wait(self): + results = [remote.recv() for remote in self.remotes] + self.waiting = False + obs_list, rews, dones, infos = zip(*results) + vec_obs = self._get_torch_observations() + self._update_np_buffer(obs_list) + vec_obs.update(deepcopy(self._obs_np_buffer)) + return vec_obs, np.array(rews), np.array(dones), infos + + def step(self, actions): + self.step_async(actions) + return self.step_wait() + + def close(self) -> None: + if self.closed: + return + if self.waiting: + for remote in self.remotes: + remote.recv() + for remote in self.remotes: + remote.send(("close", None)) + for process in self.processes: + process.join() + self.closed = True + + def render(self, mode=""): + raise NotImplementedError + + def get_attr(self, attr_name: str, indices=None) -> List: + """Return attribute from vectorized environment (see base class).""" + target_remotes = self._get_target_remotes(indices) + for remote in target_remotes: + remote.send(("get_attr", attr_name)) + return [remote.recv() for remote in target_remotes] + + def set_attr(self, attr_name: str, value, indices=None) -> None: + """Set attribute inside vectorized environments (see base class).""" + target_remotes = self._get_target_remotes(indices) + for remote in target_remotes: + remote.send(("set_attr", (attr_name, value))) + for remote in target_remotes: + remote.recv() + + def env_method( + self, + method_name: str, + *method_args, + indices=None, + **method_kwargs, + ) -> List: + """Call instance methods of vectorized environments.""" + target_remotes = self._get_target_remotes(indices) + for remote in target_remotes: + remote.send(("env_method", (method_name, method_args, method_kwargs))) + return [remote.recv() for remote in target_remotes] + + def env_is_wrapped( + self, wrapper_class: Type[gym.Wrapper], indices=None + ) -> List[bool]: + """Check if environments are wrapped with a given wrapper.""" + target_remotes = self._get_target_remotes(indices) + for remote in target_remotes: + remote.send(("is_wrapped", wrapper_class)) + return [remote.recv() for remote in target_remotes] + + @property + def unwrapped(self) -> "VecEnv": + if isinstance(self, VecEnvWrapper): + return self.venv.unwrapped + else: + return self + + def _get_indices(self, indices) -> List[int]: + """ + Convert a flexibly-typed reference to environment indices to an implied list of indices. + + :param indices: refers to indices of envs. + :return: the implied list of indices. + """ + if indices is None: + indices = list(range(self.num_envs)) + elif isinstance(indices, int): + indices = [indices] + return indices + + def _get_target_remotes(self, indices) -> List[Connection]: + """ + Get the connection object needed to communicate with the wanted + envs that are in subprocesses. + + :param indices: refers to indices of envs. + :return: Connection object to communicate between processes. + """ + indices = self._get_indices(indices) + return [self.remotes[i] for i in indices] + + def __repr__(self): + return "{}({})".format( + self.__class__.__name__, self.env_method("__repr__", indices=0)[0] + ) + + +def stack_observation_space(space: spaces.Space, n: int): + if isinstance(space, spaces.Dict): + sub_spaces = [ + (key, stack_observation_space(subspace, n)) + for key, subspace in space.spaces.items() + ] + return spaces.Dict(sub_spaces) + elif isinstance(space, spaces.Box): + shape = (n,) + space.shape + low = np.broadcast_to(space.low, shape) + high = np.broadcast_to(space.high, shape) + return spaces.Box(low=low, high=high, shape=shape, dtype=space.dtype) + else: + raise NotImplementedError( + "Unsupported observation space: {}".format(type(space)) + ) + + +def create_np_buffer(space: spaces.Space, n: int): + if isinstance(space, spaces.Dict): + return { + key: create_np_buffer(subspace, n) for key, subspace in space.spaces.items() + } + elif isinstance(space, spaces.Box): + return np.zeros((n,) + space.shape, dtype=space.dtype) + else: + raise NotImplementedError( + "Unsupported observation space: {}".format(type(space)) + ) + + +def stack_obs(obs: Sequence, space: spaces.Space, buffer: Optional[np.ndarray] = None): + if isinstance(space, spaces.Dict): + ret = {} + for key in space: + _obs = [o[key] for o in obs] + _buffer = None if buffer is None else buffer[key] + ret[key] = stack_obs(_obs, space[key], buffer=_buffer) + return ret + elif isinstance(space, spaces.Box): + return np.stack(obs, out=buffer) + else: + raise NotImplementedError(type(space)) + + +class RGBDVecEnv(VecEnv): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + from maniskill2_utils_wrappers_obs import RGBDObservationWrapper + + RGBDObservationWrapper.update_observation_space(self.observation_space) + + def _get_torch_observations(self): + observation = super()._get_torch_observations() + + image_obs = observation["image"] + for cam_uid, ori_images in image_obs.items(): + new_images = {} + for key in ori_images: + if key == "Color": + rgb = ori_images[key][..., :3] + rgb = torch.clamp(rgb * 255, 0, 255).to(dtype=torch.uint8) + new_images["rgb"] = rgb + elif key == "Position": + depth = -ori_images[key][..., [2]] + new_images["depth"] = depth + else: + new_images[key] = ori_images[key] + image_obs[cam_uid] = new_images + return observation + + +class PointCloudVecEnv(VecEnv): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + from maniskill2_utils_wrappers_obs import PointCloudObservationWrapper + + PointCloudObservationWrapper.update_observation_space(self.observation_space) + self._buffer = {} + + def _get_torch_observations(self): + observation = super()._get_torch_observations() + + image_obs = observation.pop("image") + pointcloud_obs = {} + + for cam_uid, cam_images in image_obs.items(): + cam_pcd = {} + + # Each pixel is (x, y, z, z_buffer_depth) in OpenGL camera space + position = cam_images["Position"] + bs = position.size(0) + + # Homogeneous coordinates + xyzw = torch.cat([position[..., :3], position[..., [3]] < 1], dim=-1) + cam_pcd["xyzw"] = xyzw.reshape(bs, -1, 4) + + if "Color" in cam_images: + rgb = cam_images["Color"][..., :3] + rgb = torch.clamp(rgb * 255, 0, 255).to(torch.uint8) + cam_pcd["rgb"] = rgb.reshape(bs, -1, 3) + + if "Segmentation" in cam_images: + seg = cam_images["Segmentation"] + cam_pcd["Segmentation"] = seg.reshape(bs, -1, 4) + + pointcloud_obs[cam_uid] = cam_pcd + + observation["pointcloud"] = pointcloud_obs + return observation + + @torch.no_grad() + def observation(self, observation: dict): + # Move camera parameters to device + camera_params = observation.pop("camera_param") + camera_params2 = {} + for cam_uid in camera_params: + cam2world = camera_params[cam_uid]["cam2world_gl"] + cam2world = torch.from_numpy(cam2world).to( + device=self.device, non_blocking=True + ) + camera_params2[cam_uid] = cam2world + + pointcloud_obs = observation["pointcloud"] + pointcloud_obs2 = defaultdict(list) + for cam_uid, cam_pcd in pointcloud_obs.items(): + # Transform coordinates to world space + xyzw = cam_pcd["xyzw"] # [B, H*W, 4] + cam2world = camera_params2[cam_uid] # [B, 4, 4] + cam_pcd["xyzw"] = torch.bmm(xyzw, cam2world.transpose(1, 2)) + for k, v in cam_pcd.items(): + pointcloud_obs2[k].append(v) + + for key, value in pointcloud_obs2.items(): + buffer = self._buffer.get(key, None) + self._buffer[key] = torch.cat(value, dim=1, out=buffer) + pointcloud_obs2[key] = self._buffer[key] + + observation["pointcloud"] = pointcloud_obs2 + return observation + + def reset_wait(self, *args, **kargs): + obs = super().reset_wait(*args, **kargs) + return self.observation(obs) + + def step_wait(self): + obs, rews, dones, infos = super().step_wait() + return self.observation(obs), rews, dones, infos + + +class VecEnvWrapper(VecEnv): + def __init__(self, venv: VecEnv): + self.venv = venv + self.num_envs = venv.num_envs + self.observation_space = venv.observation_space + self.action_space = venv.action_space + + def seed(self, seed: Optional[int] = None): + return self.venv.seed(seed) + + def reset_async(self, *args, **kwargs): + self.venv.reset_async(*args, **kwargs) + + def reset_wait(self, *args, **kwargs): + return self.venv.reset_wait(*args, **kwargs) + + def step_async(self, actions: np.ndarray): + self.venv.step_async(actions) + + def step_wait(self): + return self.venv.step_wait() + + def close(self): + return self.venv.close() + + def render(self, mode=""): + return self.venv.render(mode) + + def get_attr(self, attr_name: str, indices=None) -> List: + return self.venv.get_attr(attr_name, indices) + + def set_attr(self, attr_name: str, value, indices=None) -> None: + return self.venv.set_attr(attr_name, value, indices) + + def env_method( + self, + method_name: str, + *method_args, + indices=None, + **method_kwargs, + ) -> List: + return self.venv.env_method( + method_name, *method_args, indices=indices, **method_kwargs + ) + + def env_is_wrapped( + self, wrapper_class: Type[gym.Wrapper], indices=None + ) -> List[bool]: + return self.venv.env_is_wrapped(wrapper_class, indices) + + def __getattr__(self, name): + if name in self.__dict__: + return self.__dict__[name] + else: + return getattr(self.venv, name) + + +class VecEnvObservationWrapper(VecEnvWrapper): + def reset_wait(self, **kwargs): + observation = self.venv.reset_wait(**kwargs) + return self.observation(observation) + + def step_wait(self): + observation, reward, done, info = self.venv.step_wait() + return self.observation(observation), reward, done, info + + def observation(self, observation): + raise NotImplementedError \ No newline at end of file diff --git a/dexart-release/stable_baselines3/common/vec_env/maniskill2_wrapper_obs.py b/dexart-release/stable_baselines3/common/vec_env/maniskill2_wrapper_obs.py new file mode 100644 index 0000000000000000000000000000000000000000..5607dc49acaa50f67aaa024a284cffee2369fe4d --- /dev/null +++ b/dexart-release/stable_baselines3/common/vec_env/maniskill2_wrapper_obs.py @@ -0,0 +1,102 @@ +from copy import deepcopy + +import torch + +from maniskill2_vec_env import VecEnv, VecEnvObservationWrapper + + +def batch_isin(x: torch.Tensor, inds: torch.Tensor): + """A batch version of `torch.isin`. + Args: + x (torch.Tensor): [B, ...], integer + inds (torch.Tensor): [B, N], integer + Returns: + torch.Tensor: [B, ...], boolean + """ + + # # For-loop version + # out = [] + # for x_i, inds_i in zip(x.unbind(0), inds.unbind(0)): + # out.append(torch.isin(x_i, inds_i)) + # out = torch.stack(out, dim=0) + + bs = x.size(0) + max_ind, _ = torch.max(x.reshape(bs, -1), dim=-1) # [B] + # Acquire the maximum index of each sample in batch + offset = torch.cumsum(max_ind + 1, dim=0) + offset = torch.nn.functional.pad(offset[:-1], (1, 0), value=0) + # Add offset to avoid indexing collision + _shape = (bs,) + (1,) * (x.dim() - 1) + # Remap indices + _x = x + offset.view(_shape) + _inds = inds + offset.view(bs, 1) + return torch.isin(_x, _inds) + + +class VecRobotSegmentationObservationWrapper(VecEnvObservationWrapper): + """Add a binary mask for robot links.""" + + def __init__(self, venv: VecEnv, replace=True): + super().__init__(venv) + + from maniskill2_utils_wrappers_obs import ( + RobotSegmentationObservationWrapper, + ) + + self.observation_space = deepcopy(venv.observation_space) + RobotSegmentationObservationWrapper.init_observation_space( + self.observation_space, replace=replace + ) + self.replace = replace + + # Cache robot link ids + # NOTE(jigu): Assume robots are the same and thus can be batched + robot_link_ids = self.get_attr("robot_link_ids") + self.robot_link_ids = torch.tensor( + robot_link_ids, dtype=torch.int32, device=self.device + ) + + @torch.no_grad() + def update_robot_link_ids(self, indices=None): + robot_link_ids = self.get_attr("robot_link_ids", indices=indices) + robot_link_ids = torch.tensor(robot_link_ids, dtype=torch.int32) + robot_link_ids = robot_link_ids.to(device=self.device, non_blocking=True) + indices = self._get_indices(indices) + self.robot_link_ids[indices] = robot_link_ids + + def observation_image(self, observation: dict): + image_obs = observation["image"] + for cam_images in image_obs.values(): + if "Segmentation" not in cam_images: + continue + seg = cam_images["Segmentation"] # [B, H, W, 4] + # [B, H, W, 1] + robot_seg = batch_isin(seg[..., 1:2], self.robot_link_ids) + if self.replace: + cam_images.pop("Segmentation") + cam_images["robot_seg"] = robot_seg + return observation + + def observation_pointcloud(self, observation: dict): + pointcloud_obs = observation["pointcloud"] + if "Segmentation" not in pointcloud_obs: + return observation + seg = pointcloud_obs["Segmentation"] # [N, 4] + robot_seg = batch_isin(seg[..., 1:2], self.robot_link_ids) # [N, 1] + if self.replace: + pointcloud_obs.pop("Segmentation") + pointcloud_obs["robot_seg"] = robot_seg + return observation + + @torch.no_grad() + def observation(self, observation: dict): + if "image" in observation: + observation = self.observation_image(observation) + if "pointcloud" in observation: + observation = self.observation_pointcloud(observation) + return observation + + def reset_wait(self, indices=None, **kwargs): + obs = super().reset_wait(indices=indices, **kwargs) + self.update_robot_link_ids(indices=indices) + return self.observation(obs) \ No newline at end of file diff --git a/dexart-release/stable_baselines3/common/vec_env/stacked_observations.py b/dexart-release/stable_baselines3/common/vec_env/stacked_observations.py new file mode 100644 index 0000000000000000000000000000000000000000..733b72833ea275eff50f09cd2a5a009e1884eab3 --- /dev/null +++ b/dexart-release/stable_baselines3/common/vec_env/stacked_observations.py @@ -0,0 +1,266 @@ +import warnings +from typing import Any, Dict, List, Optional, Tuple, Union + +import numpy as np +from gym import spaces + +from stable_baselines3.common.preprocessing import is_image_space, is_image_space_channels_first + + +class StackedObservations: + """ + Frame stacking wrapper for data. + + Dimension to stack over is either first (channels-first) or + last (channels-last), which is detected automatically using + ``common.preprocessing.is_image_space_channels_first`` if + observation is an image space. + + :param num_envs: number of environments + :param n_stack: Number of frames to stack + :param observation_space: Environment observation space. + :param channels_order: If "first", stack on first image dimension. If "last", stack on last dimension. + If None, automatically detect channel to stack over in case of image observation or default to "last" (default). + """ + + def __init__( + self, + num_envs: int, + n_stack: int, + observation_space: spaces.Space, + channels_order: Optional[str] = None, + ): + + self.n_stack = n_stack + ( + self.channels_first, + self.stack_dimension, + self.stackedobs, + self.repeat_axis, + ) = self.compute_stacking(num_envs, n_stack, observation_space, channels_order) + super().__init__() + + @staticmethod + def compute_stacking( + num_envs: int, + n_stack: int, + observation_space: spaces.Box, + channels_order: Optional[str] = None, + ) -> Tuple[bool, int, np.ndarray, int]: + """ + Calculates the parameters in order to stack observations + + :param num_envs: Number of environments in the stack + :param n_stack: The number of observations to stack + :param observation_space: The observation space + :param channels_order: The order of the channels + :return: tuple of channels_first, stack_dimension, stackedobs, repeat_axis + """ + channels_first = False + if channels_order is None: + # Detect channel location automatically for images + if is_image_space(observation_space): + channels_first = is_image_space_channels_first(observation_space) + else: + # Default behavior for non-image space, stack on the last axis + channels_first = False + else: + assert channels_order in { + "last", + "first", + }, "`channels_order` must be one of following: 'last', 'first'" + + channels_first = channels_order == "first" + + # This includes the vec-env dimension (first) + stack_dimension = 1 if channels_first else -1 + repeat_axis = 0 if channels_first else -1 + low = np.repeat(observation_space.low, n_stack, axis=repeat_axis) + stackedobs = np.zeros((num_envs,) + low.shape, low.dtype) + return channels_first, stack_dimension, stackedobs, repeat_axis + + def stack_observation_space(self, observation_space: spaces.Box) -> spaces.Box: + """ + Given an observation space, returns a new observation space with stacked observations + + :return: New observation space with stacked dimensions + """ + low = np.repeat(observation_space.low, self.n_stack, axis=self.repeat_axis) + high = np.repeat(observation_space.high, self.n_stack, axis=self.repeat_axis) + return spaces.Box(low=low, high=high, dtype=observation_space.dtype) + + def reset(self, observation: np.ndarray) -> np.ndarray: + """ + Resets the stackedobs, adds the reset observation to the stack, and returns the stack + + :param observation: Reset observation + :return: The stacked reset observation + """ + self.stackedobs[...] = 0 + if self.channels_first: + self.stackedobs[:, -observation.shape[self.stack_dimension] :, ...] = observation + else: + self.stackedobs[..., -observation.shape[self.stack_dimension] :] = observation + return self.stackedobs + + def update( + self, + observations: np.ndarray, + dones: np.ndarray, + infos: List[Dict[str, Any]], + ) -> Tuple[np.ndarray, List[Dict[str, Any]]]: + """ + Adds the observations to the stack and uses the dones to update the infos. + + :param observations: numpy array of observations + :param dones: numpy array of done info + :param infos: numpy array of info dicts + :return: tuple of the stacked observations and the updated infos + """ + stack_ax_size = observations.shape[self.stack_dimension] + self.stackedobs = np.roll(self.stackedobs, shift=-stack_ax_size, axis=self.stack_dimension) + for i, done in enumerate(dones): + if done: + if "terminal_observation" in infos[i]: + old_terminal = infos[i]["terminal_observation"] + if self.channels_first: + new_terminal = np.concatenate( + (self.stackedobs[i, :-stack_ax_size, ...], old_terminal), + axis=0, # self.stack_dimension - 1, as there is not batch dim + ) + else: + new_terminal = np.concatenate( + (self.stackedobs[i, ..., :-stack_ax_size], old_terminal), + axis=self.stack_dimension, + ) + infos[i]["terminal_observation"] = new_terminal + else: + warnings.warn("VecFrameStack wrapping a VecEnv without terminal_observation info") + self.stackedobs[i] = 0 + if self.channels_first: + self.stackedobs[:, -observations.shape[self.stack_dimension] :, ...] = observations + else: + self.stackedobs[..., -observations.shape[self.stack_dimension] :] = observations + return self.stackedobs, infos + + +class StackedDictObservations(StackedObservations): + """ + Frame stacking wrapper for dictionary data. + + Dimension to stack over is either first (channels-first) or + last (channels-last), which is detected automatically using + ``common.preprocessing.is_image_space_channels_first`` if + observation is an image space. + + :param num_envs: number of environments + :param n_stack: Number of frames to stack + :param channels_order: If "first", stack on first image dimension. If "last", stack on last dimension. + If None, automatically detect channel to stack over in case of image observation or default to "last" (default). + """ + + def __init__( + self, + num_envs: int, + n_stack: int, + observation_space: spaces.Dict, + channels_order: Optional[Union[str, Dict[str, str]]] = None, + ): + self.n_stack = n_stack + self.channels_first = {} + self.stack_dimension = {} + self.stackedobs = {} + self.repeat_axis = {} + + for key, subspace in observation_space.spaces.items(): + assert isinstance(subspace, spaces.Box), "StackedDictObservations only works with nested gym.spaces.Box" + if isinstance(channels_order, str) or channels_order is None: + subspace_channel_order = channels_order + else: + subspace_channel_order = channels_order[key] + ( + self.channels_first[key], + self.stack_dimension[key], + self.stackedobs[key], + self.repeat_axis[key], + ) = self.compute_stacking(num_envs, n_stack, subspace, subspace_channel_order) + + def stack_observation_space(self, observation_space: spaces.Dict) -> spaces.Dict: + """ + Returns the stacked verson of a Dict observation space + + :param observation_space: Dict observation space to stack + :return: stacked observation space + """ + spaces_dict = {} + for key, subspace in observation_space.spaces.items(): + low = np.repeat(subspace.low, self.n_stack, axis=self.repeat_axis[key]) + high = np.repeat(subspace.high, self.n_stack, axis=self.repeat_axis[key]) + spaces_dict[key] = spaces.Box(low=low, high=high, dtype=subspace.dtype) + return spaces.Dict(spaces=spaces_dict) + + def reset(self, observation: Dict[str, np.ndarray]) -> Dict[str, np.ndarray]: + """ + Resets the stacked observations, adds the reset observation to the stack, and returns the stack + + :param observation: Reset observation + :return: Stacked reset observations + """ + for key, obs in observation.items(): + self.stackedobs[key][...] = 0 + if self.channels_first[key]: + self.stackedobs[key][:, -obs.shape[self.stack_dimension[key]] :, ...] = obs + else: + self.stackedobs[key][..., -obs.shape[self.stack_dimension[key]] :] = obs + return self.stackedobs + + def update( + self, + observations: Dict[str, np.ndarray], + dones: np.ndarray, + infos: List[Dict[str, Any]], + ) -> Tuple[Dict[str, np.ndarray], List[Dict[str, Any]]]: + """ + Adds the observations to the stack and uses the dones to update the infos. + + :param observations: Dict of numpy arrays of observations + :param dones: numpy array of dones + :param infos: dict of infos + :return: tuple of the stacked observations and the updated infos + """ + for key in self.stackedobs.keys(): + stack_ax_size = observations[key].shape[self.stack_dimension[key]] + self.stackedobs[key] = np.roll( + self.stackedobs[key], + shift=-stack_ax_size, + axis=self.stack_dimension[key], + ) + + for i, done in enumerate(dones): + if done: + if "terminal_observation" in infos[i]: + old_terminal = infos[i]["terminal_observation"][key] + if self.channels_first[key]: + new_terminal = np.vstack( + ( + self.stackedobs[key][i, :-stack_ax_size, ...], + old_terminal, + ) + ) + else: + new_terminal = np.concatenate( + ( + self.stackedobs[key][i, ..., :-stack_ax_size], + old_terminal, + ), + axis=self.stack_dimension[key], + ) + infos[i]["terminal_observation"][key] = new_terminal + else: + warnings.warn("VecFrameStack wrapping a VecEnv without terminal_observation info") + self.stackedobs[key][i] = 0 + if self.channels_first[key]: + self.stackedobs[key][:, -stack_ax_size:, ...] = observations[key] + else: + self.stackedobs[key][..., -stack_ax_size:] = observations[key] + return self.stackedobs, infos diff --git a/dexart-release/stable_baselines3/common/vec_env/subproc_vec_env.py b/dexart-release/stable_baselines3/common/vec_env/subproc_vec_env.py new file mode 100644 index 0000000000000000000000000000000000000000..a8a5f27964ce848ebd6ff96e4ecdb9d301c5d17d --- /dev/null +++ b/dexart-release/stable_baselines3/common/vec_env/subproc_vec_env.py @@ -0,0 +1,221 @@ +import multiprocessing as mp +from collections import OrderedDict +from typing import Any, Callable, List, Optional, Sequence, Tuple, Type, Union + +import gym +import numpy as np + +from stable_baselines3.common.vec_env.base_vec_env import ( + CloudpickleWrapper, + VecEnv, + VecEnvIndices, + VecEnvObs, + VecEnvStepReturn, +) + + +def _worker( + remote: mp.connection.Connection, parent_remote: mp.connection.Connection, env_fn_wrapper: CloudpickleWrapper +) -> None: + # Import here to avoid a circular import + from stable_baselines3.common.env_util import is_wrapped + + parent_remote.close() + env = env_fn_wrapper.var() + while True: + try: + cmd, data = remote.recv() + if cmd == "step": + observation, reward, done, info = env.step(data) + if done: + # save final observation where user can get it, then reset + info["terminal_observation"] = observation + observation = env.reset() + remote.send((observation, reward, done, info)) + elif cmd == "seed": + remote.send(env.seed(data)) + elif cmd == "reset": + observation = env.reset() + remote.send(observation) + elif cmd == "render": + remote.send(env.render(data)) + elif cmd == "close": + env.close() + remote.close() + break + elif cmd == "get_spaces": + remote.send((env.observation_space, env.action_space)) + elif cmd == "env_method": + method = getattr(env, data[0]) + remote.send(method(*data[1], **data[2])) + elif cmd == "get_attr": + remote.send(getattr(env, data)) + elif cmd == "set_attr": + remote.send(setattr(env, data[0], data[1])) + elif cmd == "is_wrapped": + remote.send(is_wrapped(env, data)) + else: + raise NotImplementedError(f"`{cmd}` is not implemented in the worker") + except EOFError: + break + +class SubprocVecEnv(VecEnv): + """ + Creates a multiprocess vectorized wrapper for multiple environments, distributing each environment to its own + process, allowing significant speed up when the environment is computationally complex. + + For performance reasons, if your environment is not IO bound, the number of environments should not exceed the + number of logical cores on your CPU. + + .. warning:: + + Only 'forkserver' and 'spawn' start methods are thread-safe, + which is important when TensorFlow sessions or other non thread-safe + libraries are used in the parent (see issue #217). However, compared to + 'fork' they incur a small start-up cost and have restrictions on + global variables. With those methods, users must wrap the code in an + ``if __name__ == "__main__":`` block. + For more information, see the multiprocessing documentation. + + :param env_fns: Environments to run in subprocesses + :param start_method: method used to start the subprocesses. + Must be one of the methods returned by multiprocessing.get_all_start_methods(). + Defaults to 'forkserver' on available platforms, and 'spawn' otherwise. + """ + + def __init__(self, env_fns: List[Callable[[], gym.Env]], start_method: Optional[str] = None): + self.waiting = False + self.closed = False + n_envs = len(env_fns) + + if start_method is None: + # Fork is not a thread safe method (see issue #217) + # but is more user friendly (does not require to wrap the code in + # a `if __name__ == "__main__":`) + forkserver_available = "forkserver" in mp.get_all_start_methods() + start_method = "forkserver" if forkserver_available else "spawn" + ctx = mp.get_context(start_method) + + self.remotes, self.work_remotes = zip(*[ctx.Pipe() for _ in range(n_envs)]) + self.processes = [] + for work_remote, remote, env_fn in zip(self.work_remotes, self.remotes, env_fns): + args = (work_remote, remote, CloudpickleWrapper(env_fn)) + # daemon=True: if the src process crashes, we should not cause things to hang + process = ctx.Process(target=_worker, args=args, daemon=True) # pytype:disable=attribute-error + process.start() + self.processes.append(process) + work_remote.close() + + self.remotes[0].send(("get_spaces", None)) + observation_space, action_space = self.remotes[0].recv() + VecEnv.__init__(self, len(env_fns), observation_space, action_space) + + def step_async(self, actions: np.ndarray) -> None: + for remote, action in zip(self.remotes, actions): + remote.send(("step", action)) + self.waiting = True + + def step_wait(self) -> VecEnvStepReturn: + results = [remote.recv() for remote in self.remotes] + self.waiting = False + obs, rews, dones, infos = zip(*results) + return _flatten_obs(obs, self.observation_space), np.stack(rews), np.stack(dones), infos + + def seed(self, seed: Optional[int] = None) -> List[Union[None, int]]: + if seed is None: + seed = np.random.randint(0, 2**32 - 1) + for idx, remote in enumerate(self.remotes): + remote.send(("seed", seed + idx)) + return [remote.recv() for remote in self.remotes] + + def reset(self) -> VecEnvObs: + for remote in self.remotes: + remote.send(("reset", None)) + obs = [remote.recv() for remote in self.remotes] + return _flatten_obs(obs, self.observation_space) + + def close(self) -> None: + if self.closed: + return + if self.waiting: + for remote in self.remotes: + remote.recv() + for remote in self.remotes: + remote.send(("close", None)) + for process in self.processes: + process.join() + self.closed = True + + def get_images(self) -> Sequence[np.ndarray]: + for pipe in self.remotes: + # gather images from subprocesses + # `mode` will be taken into account later + pipe.send(("render", "rgb_array")) + imgs = [pipe.recv() for pipe in self.remotes] + return imgs + + def get_attr(self, attr_name: str, indices: VecEnvIndices = None) -> List[Any]: + """Return attribute from vectorized environment (see base class).""" + target_remotes = self._get_target_remotes(indices) + for remote in target_remotes: + remote.send(("get_attr", attr_name)) + return [remote.recv() for remote in target_remotes] + + def set_attr(self, attr_name: str, value: Any, indices: VecEnvIndices = None) -> None: + """Set attribute inside vectorized environments (see base class).""" + target_remotes = self._get_target_remotes(indices) + for remote in target_remotes: + remote.send(("set_attr", (attr_name, value))) + for remote in target_remotes: + remote.recv() + + def env_method(self, method_name: str, *method_args, indices: VecEnvIndices = None, **method_kwargs) -> List[Any]: + """Call instance methods of vectorized environments.""" + target_remotes = self._get_target_remotes(indices) + for remote in target_remotes: + remote.send(("env_method", (method_name, method_args, method_kwargs))) + return [remote.recv() for remote in target_remotes] + + def env_is_wrapped(self, wrapper_class: Type[gym.Wrapper], indices: VecEnvIndices = None) -> List[bool]: + """Check if worker environments are wrapped with a given wrapper""" + target_remotes = self._get_target_remotes(indices) + for remote in target_remotes: + remote.send(("is_wrapped", wrapper_class)) + return [remote.recv() for remote in target_remotes] + + def _get_target_remotes(self, indices: VecEnvIndices) -> List[Any]: + """ + Get the connection object needed to communicate with the wanted + envs that are in subprocesses. + + :param indices: refers to indices of envs. + :return: Connection object to communicate between processes. + """ + indices = self._get_indices(indices) + return [self.remotes[i] for i in indices] + + +def _flatten_obs(obs: Union[List[VecEnvObs], Tuple[VecEnvObs]], space: gym.spaces.Space) -> VecEnvObs: + """ + Flatten observations, depending on the observation space. + + :param obs: observations. + A list or tuple of observations, one per environment. + Each environment observation may be a NumPy array, or a dict or tuple of NumPy arrays. + :return: flattened observations. + A flattened NumPy array or an OrderedDict or tuple of flattened numpy arrays. + Each NumPy array has the environment index as its first axis. + """ + assert isinstance(obs, (list, tuple)), "expected list or tuple of observations per environment" + assert len(obs) > 0, "need observations from at least one environment" + + if isinstance(space, gym.spaces.Dict): + assert isinstance(space.spaces, OrderedDict), "Dict space must have ordered subspaces" + assert isinstance(obs[0], dict), "non-dict observation for environment with Dict observation space" + return OrderedDict([(k, np.stack([o[k] for o in obs])) for k in space.spaces.keys()]) + elif isinstance(space, gym.spaces.Tuple): + assert isinstance(obs[0], tuple), "non-tuple observation for environment with Tuple observation space" + obs_len = len(space.spaces) + return tuple(np.stack([o[i] for o in obs]) for i in range(obs_len)) + else: + return np.stack(obs) diff --git a/dexart-release/stable_baselines3/common/vec_env/util.py b/dexart-release/stable_baselines3/common/vec_env/util.py new file mode 100644 index 0000000000000000000000000000000000000000..ca590cb1c81221cc2860490b50e74ef59c9736ba --- /dev/null +++ b/dexart-release/stable_baselines3/common/vec_env/util.py @@ -0,0 +1,76 @@ +""" +Helpers for dealing with vectorized environments. +""" +from collections import OrderedDict +from typing import Any, Dict, List, Tuple + +import gym +import numpy as np + +from stable_baselines3.common.preprocessing import check_for_nested_spaces +from stable_baselines3.common.vec_env.base_vec_env import VecEnvObs + + +def copy_obs_dict(obs: Dict[str, np.ndarray]) -> Dict[str, np.ndarray]: + """ + Deep-copy a dict of numpy arrays. + + :param obs: a dict of numpy arrays. + :return: a dict of copied numpy arrays. + """ + assert isinstance(obs, OrderedDict), f"unexpected type for observations '{type(obs)}'" + return OrderedDict([(k, np.copy(v)) for k, v in obs.items()]) + + +def dict_to_obs(obs_space: gym.spaces.Space, obs_dict: Dict[Any, np.ndarray]) -> VecEnvObs: + """ + Convert an internal representation raw_obs into the appropriate type + specified by space. + + :param obs_space: an observation space. + :param obs_dict: a dict of numpy arrays. + :return: returns an observation of the same type as space. + If space is Dict, function is identity; if space is Tuple, converts dict to Tuple; + otherwise, space is unstructured and returns the value raw_obs[None]. + """ + if isinstance(obs_space, gym.spaces.Dict): + return obs_dict + elif isinstance(obs_space, gym.spaces.Tuple): + assert len(obs_dict) == len(obs_space.spaces), "size of observation does not match size of observation space" + return tuple(obs_dict[i] for i in range(len(obs_space.spaces))) + else: + assert set(obs_dict.keys()) == {None}, "multiple observation keys for unstructured observation space" + return obs_dict[None] + + +def obs_space_info(obs_space: gym.spaces.Space) -> Tuple[List[str], Dict[Any, Tuple[int, ...]], Dict[Any, np.dtype]]: + """ + Get dict-structured information about a gym.Space. + + Dict spaces are represented directly by their dict of subspaces. + Tuple spaces are converted into a dict with keys indexing into the tuple. + Unstructured spaces are represented by {None: obs_space}. + + :param obs_space: an observation space + :return: A tuple (keys, shapes, dtypes): + keys: a list of dict keys. + shapes: a dict mapping keys to shapes. + dtypes: a dict mapping keys to dtypes. + """ + check_for_nested_spaces(obs_space) + if isinstance(obs_space, gym.spaces.Dict): + assert isinstance(obs_space.spaces, OrderedDict), "Dict space must have ordered subspaces" + subspaces = obs_space.spaces + elif isinstance(obs_space, gym.spaces.Tuple): + subspaces = {i: space for i, space in enumerate(obs_space.spaces)} + else: + assert not hasattr(obs_space, "spaces"), f"Unsupported structured space '{type(obs_space)}'" + subspaces = {None: obs_space} + keys = [] + shapes = {} + dtypes = {} + for key, box in subspaces.items(): + keys.append(key) + shapes[key] = box.shape + dtypes[key] = box.dtype + return keys, shapes, dtypes diff --git a/dexart-release/stable_baselines3/common/vec_env/vec_check_nan.py b/dexart-release/stable_baselines3/common/vec_env/vec_check_nan.py new file mode 100644 index 0000000000000000000000000000000000000000..258f9c26bef5e355d4db3ec7e08410a676e3e709 --- /dev/null +++ b/dexart-release/stable_baselines3/common/vec_env/vec_check_nan.py @@ -0,0 +1,86 @@ +import warnings + +import numpy as np + +from stable_baselines3.common.vec_env.base_vec_env import VecEnv, VecEnvObs, VecEnvStepReturn, VecEnvWrapper + + +class VecCheckNan(VecEnvWrapper): + """ + NaN and inf checking wrapper for vectorized environment, will raise a warning by default, + allowing you to know from what the NaN of inf originated from. + + :param venv: the vectorized environment to wrap + :param raise_exception: Whether or not to raise a ValueError, instead of a UserWarning + :param warn_once: Whether or not to only warn once. + :param check_inf: Whether or not to check for +inf or -inf as well + """ + + def __init__(self, venv: VecEnv, raise_exception: bool = False, warn_once: bool = True, check_inf: bool = True): + VecEnvWrapper.__init__(self, venv) + self.raise_exception = raise_exception + self.warn_once = warn_once + self.check_inf = check_inf + self._actions = None + self._observations = None + self._user_warned = False + + def step_async(self, actions: np.ndarray) -> None: + self._check_val(async_step=True, actions=actions) + + self._actions = actions + self.venv.step_async(actions) + + def step_wait(self) -> VecEnvStepReturn: + observations, rewards, news, infos = self.venv.step_wait() + + self._check_val(async_step=False, observations=observations, rewards=rewards, news=news) + + self._observations = observations + return observations, rewards, news, infos + + def reset(self) -> VecEnvObs: + observations = self.venv.reset() + self._actions = None + + self._check_val(async_step=False, observations=observations) + + self._observations = observations + return observations + + def _check_val(self, *, async_step: bool, **kwargs) -> None: + # if warn and warn once and have warned once: then stop checking + if not self.raise_exception and self.warn_once and self._user_warned: + return + + found = [] + for name, val in kwargs.items(): + has_nan = np.any(np.isnan(val)) + has_inf = self.check_inf and np.any(np.isinf(val)) + if has_inf: + found.append((name, "inf")) + if has_nan: + found.append((name, "nan")) + + if found: + self._user_warned = True + msg = "" + for i, (name, type_val) in enumerate(found): + msg += f"found {type_val} in {name}" + if i != len(found) - 1: + msg += ", " + + msg += ".\r\nOriginated from the " + + if not async_step: + if self._actions is None: + msg += "environment observation (at reset)" + else: + msg += f"environment, Last given value was: \r\n\taction={self._actions}" + else: + msg += f"RL model, Last given value was: \r\n\tobservations={self._observations}" + + if self.raise_exception: + raise ValueError(msg) + else: + warnings.warn(msg, UserWarning) diff --git a/dexart-release/stable_baselines3/common/vec_env/vec_extract_dict_obs.py b/dexart-release/stable_baselines3/common/vec_env/vec_extract_dict_obs.py new file mode 100644 index 0000000000000000000000000000000000000000..8582b7a308c7d8695b8b93061e4119d5cdcf1f5b --- /dev/null +++ b/dexart-release/stable_baselines3/common/vec_env/vec_extract_dict_obs.py @@ -0,0 +1,24 @@ +import numpy as np + +from stable_baselines3.common.vec_env.base_vec_env import VecEnv, VecEnvStepReturn, VecEnvWrapper + + +class VecExtractDictObs(VecEnvWrapper): + """ + A vectorized wrapper for extracting dictionary observations. + + :param venv: The vectorized environment + :param key: The key of the dictionary observation + """ + + def __init__(self, venv: VecEnv, key: str): + self.key = key + super().__init__(venv=venv, observation_space=venv.observation_space.spaces[self.key]) + + def reset(self) -> np.ndarray: + obs = self.venv.reset() + return obs[self.key] + + def step_wait(self) -> VecEnvStepReturn: + obs, reward, done, info = self.venv.step_wait() + return obs[self.key], reward, done, info diff --git a/dexart-release/stable_baselines3/common/vec_env/vec_frame_stack.py b/dexart-release/stable_baselines3/common/vec_env/vec_frame_stack.py new file mode 100644 index 0000000000000000000000000000000000000000..e06d5125e0f4b666af20a1b86ef53c63b635c1b9 --- /dev/null +++ b/dexart-release/stable_baselines3/common/vec_env/vec_frame_stack.py @@ -0,0 +1,64 @@ +from typing import Any, Dict, List, Optional, Tuple, Union + +import numpy as np +from gym import spaces + +from stable_baselines3.common.vec_env.base_vec_env import VecEnv, VecEnvWrapper +from stable_baselines3.common.vec_env.stacked_observations import StackedDictObservations, StackedObservations + + +class VecFrameStack(VecEnvWrapper): + """ + Frame stacking wrapper for vectorized environment. Designed for image observations. + + Uses the StackedObservations class, or StackedDictObservations depending on the observations space + + :param venv: the vectorized environment to wrap + :param n_stack: Number of frames to stack + :param channels_order: If "first", stack on first image dimension. If "last", stack on last dimension. + If None, automatically detect channel to stack over in case of image observation or default to "last" (default). + Alternatively channels_order can be a dictionary which can be used with environments with Dict observation spaces + """ + + def __init__(self, venv: VecEnv, n_stack: int, channels_order: Optional[Union[str, Dict[str, str]]] = None): + self.venv = venv + self.n_stack = n_stack + + wrapped_obs_space = venv.observation_space + + if isinstance(wrapped_obs_space, spaces.Box): + assert not isinstance( + channels_order, dict + ), f"Expected None or string for channels_order but received {channels_order}" + self.stackedobs = StackedObservations(venv.num_envs, n_stack, wrapped_obs_space, channels_order) + + elif isinstance(wrapped_obs_space, spaces.Dict): + self.stackedobs = StackedDictObservations(venv.num_envs, n_stack, wrapped_obs_space, channels_order) + + else: + raise Exception("VecFrameStack only works with gym.spaces.Box and gym.spaces.Dict observation spaces") + + observation_space = self.stackedobs.stack_observation_space(wrapped_obs_space) + VecEnvWrapper.__init__(self, venv, observation_space=observation_space) + + def step_wait( + self, + ) -> Tuple[Union[np.ndarray, Dict[str, np.ndarray]], np.ndarray, np.ndarray, List[Dict[str, Any]],]: + + observations, rewards, dones, infos = self.venv.step_wait() + + observations, infos = self.stackedobs.update(observations, dones, infos) + + return observations, rewards, dones, infos + + def reset(self) -> Union[np.ndarray, Dict[str, np.ndarray]]: + """ + Reset all environments + """ + observation = self.venv.reset() # pytype:disable=annotation-type-mismatch + + observation = self.stackedobs.reset(observation) + return observation + + def close(self) -> None: + self.venv.close() diff --git a/dexart-release/stable_baselines3/common/vec_env/vec_monitor.py b/dexart-release/stable_baselines3/common/vec_env/vec_monitor.py new file mode 100644 index 0000000000000000000000000000000000000000..ddc099a12d048eb9c4658f76569939f565d38c81 --- /dev/null +++ b/dexart-release/stable_baselines3/common/vec_env/vec_monitor.py @@ -0,0 +1,100 @@ +import time +import warnings +from typing import Optional, Tuple + +import numpy as np + +from stable_baselines3.common.vec_env.base_vec_env import VecEnv, VecEnvObs, VecEnvStepReturn, VecEnvWrapper + + +class VecMonitor(VecEnvWrapper): + """ + A vectorized monitor wrapper for *vectorized* Gym environments, + it is used to record the episode reward, length, time and other data. + + Some environments like `openai/procgen `_ + or `gym3 `_ directly initialize the + vectorized environments, without giving us a chance to use the ``Monitor`` + wrapper. So this class simply does the job of the ``Monitor`` wrapper on + a vectorized level. + + :param venv: The vectorized environment + :param filename: the location to save a log file, can be None for no log + :param info_keywords: extra information to log, from the information return of env.step() + """ + + def __init__( + self, + venv: VecEnv, + filename: Optional[str] = None, + info_keywords: Tuple[str, ...] = (), + ): + # Avoid circular import + from stable_baselines3.common.monitor import Monitor, ResultsWriter + + # This check is not valid for special `VecEnv` + # like the ones created by Procgen, that does follow completely + # the `VecEnv` interface + try: + is_wrapped_with_monitor = venv.env_is_wrapped(Monitor)[0] + except AttributeError: + is_wrapped_with_monitor = False + + if is_wrapped_with_monitor: + warnings.warn( + "The environment is already wrapped with a `Monitor` wrapper" + "but you are wrapping it with a `VecMonitor` wrapper, the `Monitor` statistics will be" + "overwritten by the `VecMonitor` ones.", + UserWarning, + ) + + VecEnvWrapper.__init__(self, venv) + self.episode_returns = None + self.episode_lengths = None + self.episode_count = 0 + self.t_start = time.time() + + env_id = None + if hasattr(venv, "spec") and venv.spec is not None: + env_id = venv.spec.id + + if filename: + self.results_writer = ResultsWriter( + filename, header={"t_start": self.t_start, "env_id": env_id}, extra_keys=info_keywords + ) + else: + self.results_writer = None + self.info_keywords = info_keywords + + def reset(self) -> VecEnvObs: + obs = self.venv.reset() + self.episode_returns = np.zeros(self.num_envs, dtype=np.float32) + self.episode_lengths = np.zeros(self.num_envs, dtype=np.int32) + return obs + + def step_wait(self) -> VecEnvStepReturn: + obs, rewards, dones, infos = self.venv.step_wait() + self.episode_returns += rewards + self.episode_lengths += 1 + new_infos = list(infos[:]) + for i in range(len(dones)): + if dones[i]: + info = infos[i].copy() + episode_return = self.episode_returns[i] + episode_length = self.episode_lengths[i] + episode_info = {"r": episode_return, "l": episode_length, "t": round(time.time() - self.t_start, 6)} + for key in self.info_keywords: + episode_info[key] = info[key] + info["episode"] = episode_info + self.episode_count += 1 + self.episode_returns[i] = 0 + self.episode_lengths[i] = 0 + if self.results_writer: + self.results_writer.write_row(episode_info) + new_infos[i] = info + return obs, rewards, dones, new_infos + + def close(self) -> None: + if self.results_writer: + self.results_writer.close() + return self.venv.close() diff --git a/dexart-release/stable_baselines3/common/vec_env/vec_normalize.py b/dexart-release/stable_baselines3/common/vec_env/vec_normalize.py new file mode 100644 index 0000000000000000000000000000000000000000..f3ee588aba47e49691d5fb9ae51e097e6bc1bb3e --- /dev/null +++ b/dexart-release/stable_baselines3/common/vec_env/vec_normalize.py @@ -0,0 +1,296 @@ +import pickle +import warnings +from copy import deepcopy +from typing import Any, Dict, List, Optional, Union + +import gym +import numpy as np + +from stable_baselines3.common import utils +from stable_baselines3.common.running_mean_std import RunningMeanStd +from stable_baselines3.common.vec_env.base_vec_env import VecEnv, VecEnvStepReturn, VecEnvWrapper + + +class VecNormalize(VecEnvWrapper): + """ + A moving average, normalizing wrapper for vectorized environment. + has support for saving/loading moving average, + + :param venv: the vectorized environment to wrap + :param training: Whether to update or not the moving average + :param norm_obs: Whether to normalize observation or not (default: True) + :param norm_reward: Whether to normalize rewards or not (default: True) + :param clip_obs: Max absolute value for observation + :param clip_reward: Max value absolute for discounted reward + :param gamma: discount factor + :param epsilon: To avoid division by zero + :param norm_obs_keys: Which keys from observation dict to normalize. + If not specified, all keys will be normalized. + """ + + def __init__( + self, + venv: VecEnv, + training: bool = True, + norm_obs: bool = True, + norm_reward: bool = True, + clip_obs: float = 10.0, + clip_reward: float = 10.0, + gamma: float = 0.99, + epsilon: float = 1e-8, + norm_obs_keys: Optional[List[str]] = None, + ): + VecEnvWrapper.__init__(self, venv) + + self.norm_obs = norm_obs + self.norm_obs_keys = norm_obs_keys + # Check observation spaces + if self.norm_obs: + self._sanity_checks() + + if isinstance(self.observation_space, gym.spaces.Dict): + self.obs_spaces = self.observation_space.spaces + self.obs_rms = {key: RunningMeanStd(shape=self.obs_spaces[key].shape) for key in self.norm_obs_keys} + else: + self.obs_spaces = None + self.obs_rms = RunningMeanStd(shape=self.observation_space.shape) + + self.ret_rms = RunningMeanStd(shape=()) + self.clip_obs = clip_obs + self.clip_reward = clip_reward + # Returns: discounted rewards + self.returns = np.zeros(self.num_envs) + self.gamma = gamma + self.epsilon = epsilon + self.training = training + self.norm_obs = norm_obs + self.norm_reward = norm_reward + self.old_obs = np.array([]) + self.old_reward = np.array([]) + + def _sanity_checks(self) -> None: + """ + Check the observations that are going to be normalized are of the correct type (spaces.Box). + """ + if isinstance(self.observation_space, gym.spaces.Dict): + # By default, we normalize all keys + if self.norm_obs_keys is None: + self.norm_obs_keys = list(self.observation_space.spaces.keys()) + # Check that all keys are of type Box + for obs_key in self.norm_obs_keys: + if not isinstance(self.observation_space.spaces[obs_key], gym.spaces.Box): + raise ValueError( + f"VecNormalize only supports `gym.spaces.Box` observation spaces but {obs_key} " + f"is of type {self.observation_space.spaces[obs_key]}. " + "You should probably explicitely pass the observation keys " + " that should be normalized via the `norm_obs_keys` parameter." + ) + + elif isinstance(self.observation_space, gym.spaces.Box): + if self.norm_obs_keys is not None: + raise ValueError("`norm_obs_keys` param is applicable only with `gym.spaces.Dict` observation spaces") + + else: + raise ValueError( + "VecNormalize only supports `gym.spaces.Box` and `gym.spaces.Dict` observation spaces, " + f"not {self.observation_space}" + ) + + def __getstate__(self) -> Dict[str, Any]: + """ + Gets state for pickling. + + Excludes self.venv, as in general VecEnv's may not be pickleable.""" + state = self.__dict__.copy() + # these attributes are not pickleable + del state["venv"] + del state["class_attributes"] + # these attributes depend on the above and so we would prefer not to pickle + del state["returns"] + return state + + def __setstate__(self, state: Dict[str, Any]) -> None: + """ + Restores pickled state. + + User must call set_venv() after unpickling before using. + + :param state:""" + # Backward compatibility + if "norm_obs_keys" not in state and isinstance(state["observation_space"], gym.spaces.Dict): + state["norm_obs_keys"] = list(state["observation_space"].spaces.keys()) + self.__dict__.update(state) + assert "venv" not in state + self.venv = None + + def set_venv(self, venv: VecEnv) -> None: + """ + Sets the vector environment to wrap to venv. + + Also sets attributes derived from this such as `num_env`. + + :param venv: + """ + if self.venv is not None: + raise ValueError("Trying to set venv of already initialized VecNormalize wrapper.") + VecEnvWrapper.__init__(self, venv) + + # Check only that the observation_space match + utils.check_for_correct_spaces(venv, self.observation_space, venv.action_space) + self.returns = np.zeros(self.num_envs) + + def step_wait(self) -> VecEnvStepReturn: + """ + Apply sequence of actions to sequence of environments + actions -> (observations, rewards, dones) + + where ``dones`` is a boolean vector indicating whether each element is new. + """ + obs, rewards, dones, infos = self.venv.step_wait() + self.old_obs = obs + self.old_reward = rewards + + if self.training and self.norm_obs: + if isinstance(obs, dict) and isinstance(self.obs_rms, dict): + for key in self.obs_rms.keys(): + self.obs_rms[key].update(obs[key]) + else: + self.obs_rms.update(obs) + + obs = self.normalize_obs(obs) + + if self.training: + self._update_reward(rewards) + rewards = self.normalize_reward(rewards) + + # Normalize the terminal observations + for idx, done in enumerate(dones): + if not done: + continue + if "terminal_observation" in infos[idx]: + infos[idx]["terminal_observation"] = self.normalize_obs(infos[idx]["terminal_observation"]) + + self.returns[dones] = 0 + return obs, rewards, dones, infos + + def _update_reward(self, reward: np.ndarray) -> None: + """Update reward normalization statistics.""" + self.returns = self.returns * self.gamma + reward + self.ret_rms.update(self.returns) + + def _normalize_obs(self, obs: np.ndarray, obs_rms: RunningMeanStd) -> np.ndarray: + """ + Helper to normalize observation. + :param obs: + :param obs_rms: associated statistics + :return: normalized observation + """ + return np.clip((obs - obs_rms.mean) / np.sqrt(obs_rms.var + self.epsilon), -self.clip_obs, self.clip_obs) + + def _unnormalize_obs(self, obs: np.ndarray, obs_rms: RunningMeanStd) -> np.ndarray: + """ + Helper to unnormalize observation. + :param obs: + :param obs_rms: associated statistics + :return: unnormalized observation + """ + return (obs * np.sqrt(obs_rms.var + self.epsilon)) + obs_rms.mean + + def normalize_obs(self, obs: Union[np.ndarray, Dict[str, np.ndarray]]) -> Union[np.ndarray, Dict[str, np.ndarray]]: + """ + Normalize observations using this VecNormalize's observations statistics. + Calling this method does not update statistics. + """ + # Avoid modifying by reference the original object + obs_ = deepcopy(obs) + if self.norm_obs: + if isinstance(obs, dict) and isinstance(self.obs_rms, dict): + # Only normalize the specified keys + for key in self.norm_obs_keys: + obs_[key] = self._normalize_obs(obs[key], self.obs_rms[key]).astype(np.float32) + else: + obs_ = self._normalize_obs(obs, self.obs_rms).astype(np.float32) + return obs_ + + def normalize_reward(self, reward: np.ndarray) -> np.ndarray: + """ + Normalize rewards using this VecNormalize's rewards statistics. + Calling this method does not update statistics. + """ + if self.norm_reward: + reward = np.clip(reward / np.sqrt(self.ret_rms.var + self.epsilon), -self.clip_reward, self.clip_reward) + return reward + + def unnormalize_obs(self, obs: Union[np.ndarray, Dict[str, np.ndarray]]) -> Union[np.ndarray, Dict[str, np.ndarray]]: + # Avoid modifying by reference the original object + obs_ = deepcopy(obs) + if self.norm_obs: + if isinstance(obs, dict) and isinstance(self.obs_rms, dict): + for key in self.norm_obs_keys: + obs_[key] = self._unnormalize_obs(obs[key], self.obs_rms[key]) + else: + obs_ = self._unnormalize_obs(obs, self.obs_rms) + return obs_ + + def unnormalize_reward(self, reward: np.ndarray) -> np.ndarray: + if self.norm_reward: + return reward * np.sqrt(self.ret_rms.var + self.epsilon) + return reward + + def get_original_obs(self) -> Union[np.ndarray, Dict[str, np.ndarray]]: + """ + Returns an unnormalized version of the observations from the most recent + step or reset. + """ + return deepcopy(self.old_obs) + + def get_original_reward(self) -> np.ndarray: + """ + Returns an unnormalized version of the rewards from the most recent step. + """ + return self.old_reward.copy() + + def reset(self) -> Union[np.ndarray, Dict[str, np.ndarray]]: + """ + Reset all environments + :return: first observation of the episode + """ + obs = self.venv.reset() + self.old_obs = obs + self.returns = np.zeros(self.num_envs) + if self.training and self.norm_obs: + if isinstance(obs, dict) and isinstance(self.obs_rms, dict): + for key in self.obs_rms.keys(): + self.obs_rms[key].update(obs[key]) + else: + self.obs_rms.update(obs) + return self.normalize_obs(obs) + + @staticmethod + def load(load_path: str, venv: VecEnv) -> "VecNormalize": + """ + Loads a saved VecNormalize object. + + :param load_path: the path to load from. + :param venv: the VecEnv to wrap. + :return: + """ + with open(load_path, "rb") as file_handler: + vec_normalize = pickle.load(file_handler) + vec_normalize.set_venv(venv) + return vec_normalize + + def save(self, save_path: str) -> None: + """ + Save current VecNormalize object with + all running statistics and settings (e.g. clip_obs) + + :param save_path: The path to save to + """ + with open(save_path, "wb") as file_handler: + pickle.dump(self, file_handler) + + @property + def ret(self) -> np.ndarray: + warnings.warn("`VecNormalize` `ret` attribute is deprecated. Please use `returns` instead.", DeprecationWarning) + return self.returns diff --git a/dexart-release/stable_baselines3/common/vec_env/vec_transpose.py b/dexart-release/stable_baselines3/common/vec_env/vec_transpose.py new file mode 100644 index 0000000000000000000000000000000000000000..b6b0ad832f71a876628aa022645e1078940dc30e --- /dev/null +++ b/dexart-release/stable_baselines3/common/vec_env/vec_transpose.py @@ -0,0 +1,113 @@ +from copy import deepcopy +from typing import Dict, Union + +import numpy as np +from gym import spaces + +from stable_baselines3.common.preprocessing import is_image_space, is_image_space_channels_first +from stable_baselines3.common.vec_env.base_vec_env import VecEnv, VecEnvStepReturn, VecEnvWrapper + + +class VecTransposeImage(VecEnvWrapper): + """ + Re-order channels, from HxWxC to CxHxW. + It is required for PyTorch convolution layers. + + :param venv: + :param skip: Skip this wrapper if needed as we rely on heuristic to apply it or not, + which may result in unwanted behavior, see GH issue #671. + """ + + def __init__(self, venv: VecEnv, skip: bool = False): + assert is_image_space(venv.observation_space) or isinstance( + venv.observation_space, spaces.dict.Dict + ), "The observation space must be an image or dictionary observation space" + + self.skip = skip + # Do nothing + if skip: + super().__init__(venv) + return + + if isinstance(venv.observation_space, spaces.dict.Dict): + self.image_space_keys = [] + observation_space = deepcopy(venv.observation_space) + for key, space in observation_space.spaces.items(): + if is_image_space(space): + # Keep track of which keys should be transposed later + self.image_space_keys.append(key) + observation_space.spaces[key] = self.transpose_space(space, key) + else: + observation_space = self.transpose_space(venv.observation_space) + super().__init__(venv, observation_space=observation_space) + + @staticmethod + def transpose_space(observation_space: spaces.Box, key: str = "") -> spaces.Box: + """ + Transpose an observation space (re-order channels). + + :param observation_space: + :param key: In case of dictionary space, the key of the observation space. + :return: + """ + # Sanity checks + assert is_image_space(observation_space), "The observation space must be an image" + assert not is_image_space_channels_first( + observation_space + ), f"The observation space {key} must follow the channel last convention" + height, width, channels = observation_space.shape + new_shape = (channels, height, width) + return spaces.Box(low=0, high=255, shape=new_shape, dtype=observation_space.dtype) + + @staticmethod + def transpose_image(image: np.ndarray) -> np.ndarray: + """ + Transpose an image or batch of images (re-order channels). + + :param image: + :return: + """ + if len(image.shape) == 3: + return np.transpose(image, (2, 0, 1)) + return np.transpose(image, (0, 3, 1, 2)) + + def transpose_observations(self, observations: Union[np.ndarray, Dict]) -> Union[np.ndarray, Dict]: + """ + Transpose (if needed) and return new observations. + + :param observations: + :return: Transposed observations + """ + # Do nothing + if self.skip: + return observations + + if isinstance(observations, dict): + # Avoid modifying the original object in place + observations = deepcopy(observations) + for k in self.image_space_keys: + observations[k] = self.transpose_image(observations[k]) + else: + observations = self.transpose_image(observations) + return observations + + def step_wait(self) -> VecEnvStepReturn: + observations, rewards, dones, infos = self.venv.step_wait() + + # Transpose the terminal observations + for idx, done in enumerate(dones): + if not done: + continue + if "terminal_observation" in infos[idx]: + infos[idx]["terminal_observation"] = self.transpose_observations(infos[idx]["terminal_observation"]) + + return self.transpose_observations(observations), rewards, dones, infos + + def reset(self) -> Union[np.ndarray, Dict]: + """ + Reset all environments + """ + return self.transpose_observations(self.venv.reset()) + + def close(self) -> None: + self.venv.close() diff --git a/dexart-release/stable_baselines3/common/vec_env/vec_video_recorder.py b/dexart-release/stable_baselines3/common/vec_env/vec_video_recorder.py new file mode 100644 index 0000000000000000000000000000000000000000..70d74ebe4c981fd0c1182bee97ff175802bfadb3 --- /dev/null +++ b/dexart-release/stable_baselines3/common/vec_env/vec_video_recorder.py @@ -0,0 +1,113 @@ +import os +from typing import Callable + +from gym.wrappers.monitoring import video_recorder + +from stable_baselines3.common.vec_env.base_vec_env import VecEnv, VecEnvObs, VecEnvStepReturn, VecEnvWrapper +from stable_baselines3.common.vec_env.dummy_vec_env import DummyVecEnv +from stable_baselines3.common.vec_env.subproc_vec_env import SubprocVecEnv + + +class VecVideoRecorder(VecEnvWrapper): + """ + Wraps a VecEnv or VecEnvWrapper object to record rendered image as mp4 video. + It requires ffmpeg or avconv to be installed on the machine. + + :param venv: + :param video_folder: Where to save videos + :param record_video_trigger: Function that defines when to start recording. + The function takes the current number of step, + and returns whether we should start recording or not. + :param video_length: Length of recorded videos + :param name_prefix: Prefix to the video name + """ + + def __init__( + self, + venv: VecEnv, + video_folder: str, + record_video_trigger: Callable[[int], bool], + video_length: int = 200, + name_prefix: str = "rl-video", + ): + + VecEnvWrapper.__init__(self, venv) + + self.env = venv + # Temp variable to retrieve metadata + temp_env = venv + + # Unwrap to retrieve metadata dict + # that will be used by gym recorder + while isinstance(temp_env, VecEnvWrapper): + temp_env = temp_env.venv + + if isinstance(temp_env, DummyVecEnv) or isinstance(temp_env, SubprocVecEnv): + metadata = temp_env.get_attr("metadata")[0] + else: + metadata = temp_env.metadata + + self.env.metadata = metadata + + self.record_video_trigger = record_video_trigger + self.video_recorder = None + + self.video_folder = os.path.abspath(video_folder) + # Create output folder if needed + os.makedirs(self.video_folder, exist_ok=True) + + self.name_prefix = name_prefix + self.step_id = 0 + self.video_length = video_length + + self.recording = False + self.recorded_frames = 0 + + def reset(self) -> VecEnvObs: + obs = self.venv.reset() + self.start_video_recorder() + return obs + + def start_video_recorder(self) -> None: + self.close_video_recorder() + + video_name = f"{self.name_prefix}-step-{self.step_id}-to-step-{self.step_id + self.video_length}" + base_path = os.path.join(self.video_folder, video_name) + self.video_recorder = video_recorder.VideoRecorder( + env=self.env, base_path=base_path, metadata={"step_id": self.step_id} + ) + + self.video_recorder.capture_frame() + self.recorded_frames = 1 + self.recording = True + + def _video_enabled(self) -> bool: + return self.record_video_trigger(self.step_id) + + def step_wait(self) -> VecEnvStepReturn: + obs, rews, dones, infos = self.venv.step_wait() + + self.step_id += 1 + if self.recording: + self.video_recorder.capture_frame() + self.recorded_frames += 1 + if self.recorded_frames > self.video_length: + print(f"Saving video to {self.video_recorder.path}") + self.close_video_recorder() + elif self._video_enabled(): + self.start_video_recorder() + + return obs, rews, dones, infos + + def close_video_recorder(self) -> None: + if self.recording: + self.video_recorder.close() + self.recording = False + self.recorded_frames = 1 + + def close(self) -> None: + VecEnvWrapper.close(self) + self.close_video_recorder() + + def __del__(self): + self.close() diff --git a/dexart-release/stable_baselines3/networks/pretrain_nets.py b/dexart-release/stable_baselines3/networks/pretrain_nets.py new file mode 100644 index 0000000000000000000000000000000000000000..ff286a3ea19b0e2f96ba9404bc16062bb404c8fb --- /dev/null +++ b/dexart-release/stable_baselines3/networks/pretrain_nets.py @@ -0,0 +1,118 @@ +import torch +import torch.nn as nn + + +class PointNet(nn.Module): # actually pointnet + def __init__(self, point_channel=3, output_dim=256): + # NOTE: we require the output dim to be 256, in order to match the pretrained weights + super(PointNet, self).__init__() + + print(f'PointNetSmall') + + in_channel = point_channel + mlp_out_dim = 256 + self.local_mlp = nn.Sequential( + nn.Linear(in_channel, 64), + nn.GELU(), + nn.Linear(64, mlp_out_dim), + ) + self.reset_parameters_() + + def reset_parameters_(self): + for m in self.modules(): + if isinstance(m, nn.Linear): + nn.init.trunc_normal_(m.weight, std=.02) + if m.bias is not None: + nn.init.zeros_(m.bias) + + def forward(self, x): + ''' + x: [B, N, 3] + ''' + # pc = x[0].cpu().detach().numpy() + # Local + x = self.local_mlp(x) + # gloabal max pooling + x = torch.max(x, dim=1)[0] + return x + + +class PointNetMedium(nn.Module): # actually pointnet + def __init__(self, point_channel=3, output_dim=256): + # NOTE: we require the output dim to be 256, in order to match the pretrained weights + super(PointNetMedium, self).__init__() + + print(f'PointNetMedium') + + in_channel = point_channel + mlp_out_dim = 256 + self.local_mlp = nn.Sequential( + nn.Linear(in_channel, 64), + nn.GELU(), + nn.Linear(64, 64), + nn.GELU(), + nn.Linear(64, 128), + nn.GELU(), + nn.Linear(128, mlp_out_dim), + ) + self.reset_parameters_() + + def reset_parameters_(self): + for m in self.modules(): + if isinstance(m, nn.Linear): + nn.init.trunc_normal_(m.weight, std=.02) + if m.bias is not None: + nn.init.zeros_(m.bias) + + def forward(self, x): + ''' + x: [B, N, 3] + ''' + # Local + x = self.local_mlp(x) + # gloabal max pooling + x = torch.max(x, dim=1)[0] + return x + + +class PointNetLarge(nn.Module): # actually pointnet + def __init__(self, point_channel=3, output_dim=256): + # NOTE: we require the output dim to be 256, in order to match the pretrained weights + super(PointNetLarge, self).__init__() + + print(f'PointNetLarge') + + in_channel = point_channel + mlp_out_dim = 256 + self.local_mlp = nn.Sequential( + nn.Linear(in_channel, 64), + nn.GELU(), + nn.Linear(64, 64), + nn.GELU(), + nn.Linear(64, 128), + nn.GELU(), + nn.Linear(128, 128), + nn.GELU(), + nn.Linear(128, 256), + nn.GELU(), + nn.Linear(256, mlp_out_dim), + ) + + self.reset_parameters_() + + def reset_parameters_(self): + for m in self.modules(): + if isinstance(m, nn.Linear): + nn.init.trunc_normal_(m.weight, std=.02) + if m.bias is not None: + nn.init.zeros_(m.bias) + + def forward(self, x): + ''' + x: [B, N, 3] + ''' + # Local + x = self.local_mlp(x) + # gloabal max pooling + x = torch.max(x, dim=1)[0] + return x diff --git a/dexart-release/stable_baselines3/ppo/__init__.py b/dexart-release/stable_baselines3/ppo/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e5c23fc9c6d2f327befeb1b1b2df66847f3bff62 --- /dev/null +++ b/dexart-release/stable_baselines3/ppo/__init__.py @@ -0,0 +1,2 @@ +from stable_baselines3.ppo.policies import CnnPolicy, MlpPolicy, MultiInputPolicy +from stable_baselines3.ppo.ppo import PPO diff --git a/dexart-release/stable_baselines3/ppo/policies.py b/dexart-release/stable_baselines3/ppo/policies.py new file mode 100644 index 0000000000000000000000000000000000000000..fb7afaef13be27ca853b7ef4087864f469a1e3dc --- /dev/null +++ b/dexart-release/stable_baselines3/ppo/policies.py @@ -0,0 +1,7 @@ +# This file is here just to define MlpPolicy/CnnPolicy +# that work for PPO +from stable_baselines3.common.policies import ActorCriticCnnPolicy, ActorCriticPolicy, MultiInputActorCriticPolicy + +MlpPolicy = ActorCriticPolicy +CnnPolicy = ActorCriticCnnPolicy +MultiInputPolicy = MultiInputActorCriticPolicy diff --git a/dexart-release/stable_baselines3/ppo/ppo.py b/dexart-release/stable_baselines3/ppo/ppo.py new file mode 100644 index 0000000000000000000000000000000000000000..4fac910ccf3eba684015ec379a844fb4ffc3df09 --- /dev/null +++ b/dexart-release/stable_baselines3/ppo/ppo.py @@ -0,0 +1,358 @@ +import warnings +from typing import Any, Dict, Optional, Type, Union + +import numpy as np +import torch as th +from gym import spaces +from torch.nn import functional as F + +from stable_baselines3.common.on_policy_algorithm import OnPolicyAlgorithm +from stable_baselines3.common.policies import ActorCriticCnnPolicy, ActorCriticPolicy, BasePolicy, \ + MultiInputActorCriticPolicy +from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, Schedule +from stable_baselines3.common.utils import explained_variance, get_schedule_fn, update_learning_rate + + +class AdaptiveScheduler: + def __init__(self, kl_threshold, min_lr, max_lr, init_lr): + super().__init__() + self.min_lr = min_lr + self.max_lr = max_lr + self.kl_threshold = kl_threshold + self.current_lr = init_lr + + def update(self, kl_dist): + lr = self.current_lr + if kl_dist > (2.0 * self.kl_threshold): + lr = max(self.current_lr / 1.5, self.min_lr) + if kl_dist < (0.5 * self.kl_threshold): + lr = min(self.current_lr * 1.5, self.max_lr) + self.current_lr = lr + return lr + + +class PPO(OnPolicyAlgorithm): + """ + Proximal Policy Optimization algorithm (PPO) (clip version) + + Paper: https://arxiv.org/abs/1707.06347 + Code: This implementation borrows code from OpenAI Spinning Up (https://github.com/openai/spinningup/) + https://github.com/ikostrikov/pytorch-a2c-ppo-acktr-gail and + Stable Baselines (PPO2 from https://github.com/hill-a/stable-baselines) + + Introduction to PPO: https://spinningup.openai.com/en/latest/algorithms/ppo.html + + :param policy: The policy model to use (MlpPolicy, CnnPolicy, ...) + :param env: The environment to learn from (if registered in Gym, can be str) + :param learning_rate: The learning rate, it can be a function + of the current progress remaining (from 1 to 0) + :param n_steps: The number of steps to run for each environment per update + (i.e. rollout buffer size is n_steps * n_envs where n_envs is number of environment copies running in parallel) + NOTE: n_steps * n_envs must be greater than 1 (because of the advantage normalization) + See https://github.com/pytorch/pytorch/issues/29372 + :param batch_size: Minibatch size + :param n_epochs: Number of epoch when optimizing the surrogate loss + :param gamma: Discount factor + :param gae_lambda: Factor for trade-off of bias vs variance for Generalized Advantage Estimator + :param clip_range: Clipping parameter, it can be a function of the current progress + remaining (from 1 to 0). + :param clip_range_vf: Clipping parameter for the value function, + it can be a function of the current progress remaining (from 1 to 0). + This is a parameter specific to the OpenAI implementation. If None is passed (default), + no clipping will be done on the value function. + IMPORTANT: this clipping depends on the reward scaling. + :param normalize_advantage: Whether to normalize or not the advantage + :param ent_coef: Entropy coefficient for the loss calculation + :param vf_coef: Value function coefficient for the loss calculation + :param max_grad_norm: The maximum value for the gradient clipping + :param use_sde: Whether to use generalized State Dependent Exploration (gSDE) + instead of action noise exploration (default: False) + :param sde_sample_freq: Sample a new noise matrix every n steps when using gSDE + Default: -1 (only sample at the beginning of the rollout) + :param target_kl: Limit the KL divergence between updates, + because the clipping is not enough to prevent large update + see issue #213 (cf https://github.com/hill-a/stable-baselines/issues/213) + By default, there is no limit on the kl div. + :param tensorboard_log: the log location for tensorboard (if None, no logging) + :param create_eval_env: Whether to create a second environment that will be + used for evaluating the agent periodically. (Only available when passing string for the environment) + :param policy_kwargs: additional arguments to be passed to the policy on creation + :param verbose: the verbosity level: 0 no output, 1 info, 2 debug + :param seed: Seed for the pseudo random generators + :param device: Device (cpu, cuda, ...) on which the code should be run. + Setting it to auto, the code will be run on the GPU if possible. + :param _init_setup_model: Whether or not to build the network at the creation of the instance + """ + + policy_aliases: Dict[str, Type[BasePolicy]] = { + "MlpPolicy": ActorCriticPolicy, + "PointCloudPolicy": ActorCriticPolicy, + "CnnPolicy": ActorCriticCnnPolicy, + "MultiInputPolicy": MultiInputActorCriticPolicy, + } + + def __init__( + self, + policy: Union[str, Type[ActorCriticPolicy]], + env: Union[GymEnv, str], + learning_rate: Union[float, Schedule] = 3e-4, + n_steps: int = 2048, + batch_size: int = 64, + n_epochs: int = 10, + gamma: float = 0.99, + gae_lambda: float = 0.95, + clip_range: Union[float, Schedule] = 0.2, + clip_range_vf: Union[None, float, Schedule] = None, + normalize_advantage: bool = True, + ent_coef: float = 0.0, + vf_coef: float = 0.5, + max_grad_norm: float = 0.5, + use_sde: bool = False, + sde_sample_freq: int = -1, + target_kl: Optional[float] = None, + tensorboard_log: Optional[str] = None, + create_eval_env: bool = False, + policy_kwargs: Optional[Dict[str, Any]] = None, + verbose: int = 0, + seed: Optional[int] = None, + device: Union[th.device, str] = "auto", + _init_setup_model: bool = True, + adaptive_kl: float = 0.02, + min_lr=1e-4, + max_lr=1e-3, + ): + + super().__init__( + policy, + env, + learning_rate=learning_rate, + n_steps=n_steps, + gamma=gamma, + gae_lambda=gae_lambda, + ent_coef=ent_coef, + vf_coef=vf_coef, + max_grad_norm=max_grad_norm, + use_sde=use_sde, + sde_sample_freq=sde_sample_freq, + tensorboard_log=tensorboard_log, + policy_kwargs=policy_kwargs, + verbose=verbose, + device=device, + create_eval_env=create_eval_env, + seed=seed, + _init_setup_model=False, + supported_action_spaces=( + spaces.Box, + spaces.Discrete, + spaces.MultiDiscrete, + spaces.MultiBinary, + ), + ) + + # Sanity check, otherwise it will lead to noisy gradient and NaN + # because of the advantage normalization + if normalize_advantage: + assert ( + batch_size > 1 + ), "`batch_size` must be greater than 1. See https://github.com/DLR-RM/stable-baselines3/issues/440" + + if self.env is not None: + # Check that `n_steps * n_envs > 1` to avoid NaN + # when doing advantage normalization + buffer_size = self.env.num_envs * self.n_steps + assert ( + buffer_size > 1 + ), f"`n_steps * n_envs` must be greater than 1. Currently n_steps={self.n_steps} and n_envs={self.env.num_envs}" + # Check that the rollout buffer size is a multiple of the mini-batch size + untruncated_batches = buffer_size // batch_size + if buffer_size % batch_size > 0: + warnings.warn( + f"You have specified a mini-batch size of {batch_size}," + f" but because the `RolloutBuffer` is of size `n_steps * n_envs = {buffer_size}`," + f" after every {untruncated_batches} untruncated mini-batches," + f" there will be a truncated mini-batch of size {buffer_size % batch_size}\n" + f"We recommend using a `batch_size` that is a factor of `n_steps * n_envs`.\n" + f"Info: (n_steps={self.n_steps} and n_envs={self.env.num_envs})" + ) + self.batch_size = batch_size + self.n_epochs = n_epochs + self.clip_range = clip_range + self.clip_range_vf = clip_range_vf + self.normalize_advantage = normalize_advantage + self.target_kl = target_kl + self.kl_scheduler = AdaptiveScheduler(kl_threshold=adaptive_kl, min_lr=min_lr, max_lr=max_lr, + init_lr=learning_rate) + + if _init_setup_model: + self._setup_model() + + def _setup_model(self) -> None: + super()._setup_model() + + # Initialize schedules for policy/value clipping + self.clip_range = get_schedule_fn(self.clip_range) + if self.clip_range_vf is not None: + if isinstance(self.clip_range_vf, (float, int)): + assert self.clip_range_vf > 0, "`clip_range_vf` must be positive, " "pass `None` to deactivate vf clipping" + + self.clip_range_vf = get_schedule_fn(self.clip_range_vf) + + def train(self) -> None: + """ + Update policy using the currently gathered rollout buffer. + """ + # Switch to train mode (this affects batch norm / dropout) + self.policy.set_training_mode(True) + # Update optimizer learning rate + self.logger.record("train/learning_rate", self.kl_scheduler.current_lr) + # Compute current clip range + clip_range = self.clip_range(self._current_progress_remaining) + # Optional: clip range for the value function + if self.clip_range_vf is not None: + clip_range_vf = self.clip_range_vf(self._current_progress_remaining) + + entropy_losses = [] + pg_losses, value_losses = [], [] + clip_fractions = [] + + continue_training = True + + # train for n_epochs epochs + num_early_stopping = 0 + num_batch_update = 0 + for epoch in range(self.n_epochs): + approx_kl_divs = [] + # Do a complete pass on the rollout buffer + for rollout_data in self.rollout_buffer.get(self.batch_size): + num_batch_update += 1 + actions = rollout_data.actions + if isinstance(self.action_space, spaces.Discrete): + # Convert discrete action from float to long + actions = rollout_data.actions.long().flatten() + + # Re-sample the noise matrix because the log_std has changed + if self.use_sde: + self.policy.reset_noise(self.batch_size) + + values, log_prob, entropy = self.policy.evaluate_actions(rollout_data.observations, actions) + values = values.flatten() + + # Calculate approximate form of reverse KL Divergence for early stopping + # see issue #417: https://github.com/DLR-RM/stable-baselines3/issues/417 + # and discussion in PR #419: https://github.com/DLR-RM/stable-baselines3/pull/419 + # and Schulman blog: http://joschu.net/blog/kl-approx.html + with th.no_grad(): + log_ratio = log_prob - rollout_data.old_log_prob + approx_kl_div = th.mean((th.exp(log_ratio) - 1) - log_ratio).cpu().numpy() + approx_kl_divs.append(approx_kl_div) + + n_iter = self._n_updates // self.n_epochs + scheduling = max(0.99 ** n_iter, 0.05) * 10 + if self.target_kl is not None and approx_kl_div > 1.5 * self.target_kl * scheduling: + # continue_training = False + num_early_stopping += 1 + # if self.verbose >= 1: + # print(f"Early stopping at step {epoch} due to reaching max kl: {approx_kl_div:.2f}") + continue + + # Normalize advantage + advantages = rollout_data.advantages + if self.normalize_advantage: + advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8) + + # ratio between old and new policy, should be one at the first iteration + ratio = th.exp(log_prob - rollout_data.old_log_prob) + + # clipped surrogate loss + policy_loss_1 = advantages * ratio + policy_loss_2 = advantages * th.clamp(ratio, 1 - clip_range, 1 + clip_range) + policy_loss = -th.min(policy_loss_1, policy_loss_2).mean() + + # Logging + pg_losses.append(policy_loss.item()) + clip_fraction = th.mean((th.abs(ratio - 1) > clip_range).float()).item() + clip_fractions.append(clip_fraction) + + if self.clip_range_vf is None: + # No clipping + values_pred = values + else: + # Clip the different between old and new value + # NOTE: this depends on the reward scaling + values_pred = rollout_data.old_values + th.clamp( + values - rollout_data.old_values, -clip_range_vf, clip_range_vf + ) + # Value loss using the TD(gae_lambda) target + value_loss = F.mse_loss(rollout_data.returns, values_pred) + value_losses.append(value_loss.item()) + + # Entropy loss favor exploration + if entropy is None: + # Approximate entropy when no analytical form + entropy_loss = -th.mean(-log_prob) + else: + entropy_loss = -th.mean(entropy) + + entropy_losses.append(entropy_loss.item()) + + loss = policy_loss + self.ent_coef * entropy_loss + self.vf_coef * value_loss + + lr = self.kl_scheduler.update(approx_kl_div) + update_learning_rate(self.policy.optimizer, lr) + + # Optimization step + self.policy.optimizer.zero_grad() + loss.backward() + # Clip grad norm + th.nn.utils.clip_grad_norm_(filter(lambda p: p.requires_grad, self.policy.parameters()), self.max_grad_norm) + self.policy.optimizer.step() + + if not continue_training: + break + if self.verbose >= 1: + print(f"Early stopping / mini batch update: {num_early_stopping} / {num_batch_update}") + self._n_updates += self.n_epochs + explained_var = explained_variance(self.rollout_buffer.values.flatten(), self.rollout_buffer.returns.flatten()) + # Logs + self.logger.record("train/entropy_loss", np.mean(entropy_losses)) + self.logger.record("train/policy_gradient_loss", np.mean(pg_losses)) + self.logger.record("train/value_loss", np.mean(value_losses)) + self.logger.record("train/approx_kl", np.mean(approx_kl_divs)) + self.logger.record("train/clip_fraction", np.mean(clip_fractions)) + self.logger.record("train/explained_variance", explained_var) + self.logger.record("rollout/skipped_minibatch", num_early_stopping / num_batch_update) + if hasattr(self.policy, "log_std"): + self.logger.record("train/std", th.exp(self.policy.log_std).mean().item()) + + self.logger.record("train/n_updates", self._n_updates, exclude="tensorboard") + self.logger.record("train/clip_range", clip_range) + if self.clip_range_vf is not None: + self.logger.record("train/clip_range_vf", clip_range_vf) + + def learn( + self, + total_timesteps: int, + callback: MaybeCallback = None, + log_interval: int = 1, + eval_env: Optional[GymEnv] = None, + eval_freq: int = -1, + n_eval_episodes: int = 5, + tb_log_name: str = "PPO", + eval_log_path: Optional[str] = None, + reset_num_timesteps: bool = True, + iter_start=0, + **kwargs + ) -> "PPO": + + return super().learn( + total_timesteps=total_timesteps, + callback=callback, + log_interval=log_interval, + eval_env=eval_env, + eval_freq=eval_freq, + n_eval_episodes=n_eval_episodes, + tb_log_name=tb_log_name, + eval_log_path=eval_log_path, + reset_num_timesteps=reset_num_timesteps, + iter_start=iter_start, + )