Upload folder using huggingface_hub
Browse files- .gitattributes +1 -0
- .summary/0/events.out.tfevents.1718465455.koa03 +3 -0
- README.md +56 -0
- checkpoint_p0/best_000820672_1680736256_reward_1255.320.pth +3 -0
- checkpoint_p0/checkpoint_000976608_2000093184.pth +3 -0
- checkpoint_p0/checkpoint_000976624_2000125952.pth +3 -0
- checkpoint_p0/milestones/checkpoint_000025328_51871744.pth +3 -0
- checkpoint_p0/milestones/checkpoint_000052832_108199936.pth +3 -0
- checkpoint_p0/milestones/checkpoint_000080576_165019648.pth +3 -0
- checkpoint_p0/milestones/checkpoint_000108128_221446144.pth +3 -0
- checkpoint_p0/milestones/checkpoint_000136192_278921216.pth +3 -0
- checkpoint_p0/milestones/checkpoint_000164160_336199680.pth +3 -0
- checkpoint_p0/milestones/checkpoint_000191744_392691712.pth +3 -0
- checkpoint_p0/milestones/checkpoint_000220096_450756608.pth +3 -0
- checkpoint_p0/milestones/checkpoint_000248000_507904000.pth +3 -0
- checkpoint_p0/milestones/checkpoint_000276096_565444608.pth +3 -0
- checkpoint_p0/milestones/checkpoint_000303856_622297088.pth +3 -0
- checkpoint_p0/milestones/checkpoint_000331888_679706624.pth +3 -0
- checkpoint_p0/milestones/checkpoint_000359664_736591872.pth +3 -0
- checkpoint_p0/milestones/checkpoint_000387616_793837568.pth +3 -0
- checkpoint_p0/milestones/checkpoint_000415744_851443712.pth +3 -0
- checkpoint_p0/milestones/checkpoint_000444096_909508608.pth +3 -0
- checkpoint_p0/milestones/checkpoint_000472512_967704576.pth +3 -0
- checkpoint_p0/milestones/checkpoint_000500608_1025245184.pth +3 -0
- checkpoint_p0/milestones/checkpoint_000528672_1082720256.pth +3 -0
- checkpoint_p0/milestones/checkpoint_000556928_1140588544.pth +3 -0
- checkpoint_p0/milestones/checkpoint_000585408_1198915584.pth +3 -0
- checkpoint_p0/milestones/checkpoint_000614144_1257766912.pth +3 -0
- checkpoint_p0/milestones/checkpoint_000642240_1315307520.pth +3 -0
- checkpoint_p0/milestones/checkpoint_000670720_1373634560.pth +3 -0
- checkpoint_p0/milestones/checkpoint_000698816_1431175168.pth +3 -0
- checkpoint_p0/milestones/checkpoint_000726976_1488846848.pth +3 -0
- checkpoint_p0/milestones/checkpoint_000755328_1546911744.pth +3 -0
- checkpoint_p0/milestones/checkpoint_000783136_1603862528.pth +3 -0
- checkpoint_p0/milestones/checkpoint_000811584_1662124032.pth +3 -0
- checkpoint_p0/milestones/checkpoint_000839936_1720188928.pth +3 -0
- checkpoint_p0/milestones/checkpoint_000868144_1777958912.pth +3 -0
- checkpoint_p0/milestones/checkpoint_000896240_1835499520.pth +3 -0
- checkpoint_p0/milestones/checkpoint_000924672_1893728256.pth +3 -0
- checkpoint_p0/milestones/checkpoint_000953008_1951760384.pth +3 -0
- config.json +167 -0
- git.diff +712 -0
- replay.mp4 +3 -0
- sf_log.txt +0 -0
.gitattributes
CHANGED
|
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 36 |
+
replay.mp4 filter=lfs diff=lfs merge=lfs -text
|
.summary/0/events.out.tfevents.1718465455.koa03
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:e94cf53b8f6360348849b631b4c38a3b6f7f36c0e9a04db91a35f5b7b9ccd542
|
| 3 |
+
size 18285797
|
README.md
ADDED
|
@@ -0,0 +1,56 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
library_name: sample-factory
|
| 3 |
+
tags:
|
| 4 |
+
- deep-reinforcement-learning
|
| 5 |
+
- reinforcement-learning
|
| 6 |
+
- sample-factory
|
| 7 |
+
model-index:
|
| 8 |
+
- name: APPO
|
| 9 |
+
results:
|
| 10 |
+
- task:
|
| 11 |
+
type: reinforcement-learning
|
| 12 |
+
name: reinforcement-learning
|
| 13 |
+
dataset:
|
| 14 |
+
name: atari_airraid
|
| 15 |
+
type: atari_airraid
|
| 16 |
+
metrics:
|
| 17 |
+
- type: mean_reward
|
| 18 |
+
value: 465.00 +/- 182.76
|
| 19 |
+
name: mean_reward
|
| 20 |
+
verified: false
|
| 21 |
+
---
|
| 22 |
+
|
| 23 |
+
A(n) **APPO** model trained on the **atari_airraid** environment.
|
| 24 |
+
|
| 25 |
+
This model was trained using Sample-Factory 2.0: https://github.com/alex-petrenko/sample-factory.
|
| 26 |
+
Documentation for how to use Sample-Factory can be found at https://www.samplefactory.dev/
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
## Downloading the model
|
| 30 |
+
|
| 31 |
+
After installing Sample-Factory, download the model with:
|
| 32 |
+
```
|
| 33 |
+
python -m sample_factory.huggingface.load_from_hub -r ksridhar/atari_2B_atari_airraid_1111
|
| 34 |
+
```
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
## Using the model
|
| 38 |
+
|
| 39 |
+
To run the model after download, use the `enjoy` script corresponding to this environment:
|
| 40 |
+
```
|
| 41 |
+
python -m <path.to.enjoy.module> --algo=APPO --env=atari_airraid --train_dir=./train_dir --experiment=atari_2B_atari_airraid_1111
|
| 42 |
+
```
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
You can also upload models to the Hugging Face Hub using the same script with the `--push_to_hub` flag.
|
| 46 |
+
See https://www.samplefactory.dev/10-huggingface/huggingface/ for more details
|
| 47 |
+
|
| 48 |
+
## Training with this model
|
| 49 |
+
|
| 50 |
+
To continue training with this model, use the `train` script corresponding to this environment:
|
| 51 |
+
```
|
| 52 |
+
python -m <path.to.train.module> --algo=APPO --env=atari_airraid --train_dir=./train_dir --experiment=atari_2B_atari_airraid_1111 --restart_behavior=resume --train_for_env_steps=10000000000
|
| 53 |
+
```
|
| 54 |
+
|
| 55 |
+
Note, you may have to adjust `--train_for_env_steps` to a suitably high number as the experiment will resume at the number of steps it concluded at.
|
| 56 |
+
|
checkpoint_p0/best_000820672_1680736256_reward_1255.320.pth
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:04ea5891f52bccc613ae420f2383c694c8a4583d64cce037fe061a08240982b3
|
| 3 |
+
size 20722280
|
checkpoint_p0/checkpoint_000976608_2000093184.pth
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:3f5b7f2c104719944467b95f32e263dff2734fcfce0a7940f053de310ca23fb9
|
| 3 |
+
size 20722628
|
checkpoint_p0/checkpoint_000976624_2000125952.pth
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:e83c5aba7700b73f5f86d2763f21eadcaee8e1daabfd945d07def7a62c2fdfab
|
| 3 |
+
size 20722628
|
checkpoint_p0/milestones/checkpoint_000025328_51871744.pth
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:c70ea6c746b1099395d74f65b51a553989d2373949cda55bdb936f6d339f163a
|
| 3 |
+
size 20723568
|
checkpoint_p0/milestones/checkpoint_000052832_108199936.pth
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:90153c065eaf8a16774217a7ce5a67685dd5fc0af1d6f8ce55b475a50deea4fe
|
| 3 |
+
size 20723626
|
checkpoint_p0/milestones/checkpoint_000080576_165019648.pth
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:48af7f9d07dc3e13cd2d01cff53a5d4c142afb6e5d5cb46c3f8fcbf78c529f52
|
| 3 |
+
size 20723626
|
checkpoint_p0/milestones/checkpoint_000108128_221446144.pth
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:ed2dbe5193dcea57fbc0d9dce3e6d3bd959bb251e232176daa539bfc93a24530
|
| 3 |
+
size 20723626
|
checkpoint_p0/milestones/checkpoint_000136192_278921216.pth
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:ba07de6c1b32632354f3d77b63a47c28e412bfed43ca5ef82c3e33fa2616785e
|
| 3 |
+
size 20723626
|
checkpoint_p0/milestones/checkpoint_000164160_336199680.pth
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:cf9e64c7241e64da17332f5779f522e2efd38d87cf5847eeb914f1b86bb90eeb
|
| 3 |
+
size 20723626
|
checkpoint_p0/milestones/checkpoint_000191744_392691712.pth
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:fcce6b301593984a3c055b3fbe8f21f3823b1b3233c9a405a5c2c62168f8d45f
|
| 3 |
+
size 20723626
|
checkpoint_p0/milestones/checkpoint_000220096_450756608.pth
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:60e5d3544717f7983b4abe5a622035a4949d20a3a07b0a6d2732d54e15004567
|
| 3 |
+
size 20723626
|
checkpoint_p0/milestones/checkpoint_000248000_507904000.pth
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:04bef3f4845d5afb30cb49f9bc48dd421b1ae398e2e249b56dfeda5b1f089906
|
| 3 |
+
size 20723626
|
checkpoint_p0/milestones/checkpoint_000276096_565444608.pth
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:d4eb1a15333413a30b47229e6e8e1236096c8fb60dc20560f0762f5d5a7ee491
|
| 3 |
+
size 20723626
|
checkpoint_p0/milestones/checkpoint_000303856_622297088.pth
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:8ed4b042c87e40ee5edcc31499c946d83b3206d68b572009c7f0b9b1bb9e51e8
|
| 3 |
+
size 20723626
|
checkpoint_p0/milestones/checkpoint_000331888_679706624.pth
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:506f2656ecd9f241a5175470d7502dea2df70e4b5f9c0fb911219fe041c906e8
|
| 3 |
+
size 20723626
|
checkpoint_p0/milestones/checkpoint_000359664_736591872.pth
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:e56a0ed00e07cb870607d3ac6727ae2d1c952cfff4156fbabf41954a816271fe
|
| 3 |
+
size 20723626
|
checkpoint_p0/milestones/checkpoint_000387616_793837568.pth
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:c4206c3d4735f6aada46266d53b372afef0ecf8aee1516538fdde4128e73ca35
|
| 3 |
+
size 20723626
|
checkpoint_p0/milestones/checkpoint_000415744_851443712.pth
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:cdbeb1c1d4facca0691907133f028f7a395d7a4631e024bb962fd5eae9e0473e
|
| 3 |
+
size 20723626
|
checkpoint_p0/milestones/checkpoint_000444096_909508608.pth
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:908e6f59f7f70604da8a49c0effa07c915a4c8181b2229bbf5060448eab34237
|
| 3 |
+
size 20723626
|
checkpoint_p0/milestones/checkpoint_000472512_967704576.pth
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:3acfe6effa0eb87d9be6c64ded5e46f73548ab35f2b9324282a9fab72952cdd9
|
| 3 |
+
size 20723626
|
checkpoint_p0/milestones/checkpoint_000500608_1025245184.pth
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:c0562aa2dd28df15165bbc10ac172196e667e69694225bb419790a0204872b89
|
| 3 |
+
size 20723684
|
checkpoint_p0/milestones/checkpoint_000528672_1082720256.pth
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:0a359cb85026316b46cec52c2b3072bb3f4c6d0c8f8c04b81652d53c320962ab
|
| 3 |
+
size 20723684
|
checkpoint_p0/milestones/checkpoint_000556928_1140588544.pth
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:67e509d0cba38cab4cb931ce1b06b3c0dcd997802b1cd0e2cf62bf42bba8c964
|
| 3 |
+
size 20723684
|
checkpoint_p0/milestones/checkpoint_000585408_1198915584.pth
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:123f1825e68d6007a227a2809108a82997c6140e78cd4b441124d5de77909c72
|
| 3 |
+
size 20723684
|
checkpoint_p0/milestones/checkpoint_000614144_1257766912.pth
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:58fed7f6964ad4469525aa4840aab5182a39bd3fa981c5d0c85131980dff948b
|
| 3 |
+
size 20723684
|
checkpoint_p0/milestones/checkpoint_000642240_1315307520.pth
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:2c2b24a591716ea07f740d151b7b912eb8da99d438b7421c1c90e458950ff079
|
| 3 |
+
size 20723684
|
checkpoint_p0/milestones/checkpoint_000670720_1373634560.pth
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:06d76f4a0427517572114f56ad297b0820be092421150f56f0b045289b88fba0
|
| 3 |
+
size 20723684
|
checkpoint_p0/milestones/checkpoint_000698816_1431175168.pth
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:67eedd32ef45079fb6063fa17b3573d4cfa3756ff2b850d2ed9a7cfa3ff2e9f8
|
| 3 |
+
size 20723684
|
checkpoint_p0/milestones/checkpoint_000726976_1488846848.pth
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:f799ebc4f68d484712629e50e5e50b61a5cf1c228487f35565bef8fad96877ef
|
| 3 |
+
size 20723684
|
checkpoint_p0/milestones/checkpoint_000755328_1546911744.pth
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:acf482fd0d35a9e9466747e0902a0113179aa3153d6457d9069b62d5006a9af1
|
| 3 |
+
size 20723684
|
checkpoint_p0/milestones/checkpoint_000783136_1603862528.pth
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:a72d7d19a447c00346e260ca9efe1737505c1bdd124bc713d2fb5f600d758db2
|
| 3 |
+
size 20723684
|
checkpoint_p0/milestones/checkpoint_000811584_1662124032.pth
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:2d129fb12044a642bcf9cce81c2a72773e839a07e7d83b400c28ffa6f6e77dda
|
| 3 |
+
size 20723684
|
checkpoint_p0/milestones/checkpoint_000839936_1720188928.pth
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:04aa797855e355aa77fc037b122b93c36ef03add6458a5964950c787ec16b547
|
| 3 |
+
size 20723684
|
checkpoint_p0/milestones/checkpoint_000868144_1777958912.pth
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:6f2e77ad1f2d084ca4b8a0560788be4a2c00b2215895a53f0b1ce94ee419dc2f
|
| 3 |
+
size 20723684
|
checkpoint_p0/milestones/checkpoint_000896240_1835499520.pth
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:8d45adf6d55ca09b5f89e4178e6eae4c7844bf2e42f0c19985040c7e348a4f6a
|
| 3 |
+
size 20723684
|
checkpoint_p0/milestones/checkpoint_000924672_1893728256.pth
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:a8057f3255e306fda08291c8e367d5adc5bf1baea5f0da09fb14f76e3d035666
|
| 3 |
+
size 20723684
|
checkpoint_p0/milestones/checkpoint_000953008_1951760384.pth
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:0a07fc48118e729c1125f70c78b65c102fff6c9714e8cc37b8455fca3da50fa3
|
| 3 |
+
size 20723684
|
config.json
ADDED
|
@@ -0,0 +1,167 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"help": false,
|
| 3 |
+
"algo": "APPO",
|
| 4 |
+
"env": "atari_airraid",
|
| 5 |
+
"experiment": "atari_2B_atari_airraid_1111",
|
| 6 |
+
"train_dir": "train_dir",
|
| 7 |
+
"restart_behavior": "resume",
|
| 8 |
+
"device": "gpu",
|
| 9 |
+
"seed": 1111,
|
| 10 |
+
"num_policies": 1,
|
| 11 |
+
"async_rl": true,
|
| 12 |
+
"serial_mode": false,
|
| 13 |
+
"batched_sampling": true,
|
| 14 |
+
"num_batches_to_accumulate": 2,
|
| 15 |
+
"worker_num_splits": 1,
|
| 16 |
+
"policy_workers_per_policy": 1,
|
| 17 |
+
"max_policy_lag": 1000,
|
| 18 |
+
"num_workers": 4,
|
| 19 |
+
"num_envs_per_worker": 1,
|
| 20 |
+
"batch_size": 1024,
|
| 21 |
+
"num_batches_per_epoch": 8,
|
| 22 |
+
"num_epochs": 2,
|
| 23 |
+
"rollout": 64,
|
| 24 |
+
"recurrence": 1,
|
| 25 |
+
"shuffle_minibatches": false,
|
| 26 |
+
"gamma": 0.99,
|
| 27 |
+
"reward_scale": 1.0,
|
| 28 |
+
"reward_clip": 1000.0,
|
| 29 |
+
"value_bootstrap": false,
|
| 30 |
+
"normalize_returns": true,
|
| 31 |
+
"exploration_loss_coeff": 0.0004677351413,
|
| 32 |
+
"value_loss_coeff": 0.5,
|
| 33 |
+
"kl_loss_coeff": 0.0,
|
| 34 |
+
"exploration_loss": "entropy",
|
| 35 |
+
"gae_lambda": 0.95,
|
| 36 |
+
"ppo_clip_ratio": 0.1,
|
| 37 |
+
"ppo_clip_value": 1.0,
|
| 38 |
+
"with_vtrace": false,
|
| 39 |
+
"vtrace_rho": 1.0,
|
| 40 |
+
"vtrace_c": 1.0,
|
| 41 |
+
"optimizer": "adam",
|
| 42 |
+
"adam_eps": 1e-05,
|
| 43 |
+
"adam_beta1": 0.9,
|
| 44 |
+
"adam_beta2": 0.999,
|
| 45 |
+
"max_grad_norm": 0.0,
|
| 46 |
+
"learning_rate": 0.0003033891184,
|
| 47 |
+
"lr_schedule": "linear_decay",
|
| 48 |
+
"lr_schedule_kl_threshold": 0.008,
|
| 49 |
+
"lr_adaptive_min": 1e-06,
|
| 50 |
+
"lr_adaptive_max": 0.01,
|
| 51 |
+
"obs_subtract_mean": 0.0,
|
| 52 |
+
"obs_scale": 255.0,
|
| 53 |
+
"normalize_input": true,
|
| 54 |
+
"normalize_input_keys": [
|
| 55 |
+
"obs"
|
| 56 |
+
],
|
| 57 |
+
"decorrelate_experience_max_seconds": 1,
|
| 58 |
+
"decorrelate_envs_on_one_worker": true,
|
| 59 |
+
"actor_worker_gpus": [],
|
| 60 |
+
"set_workers_cpu_affinity": true,
|
| 61 |
+
"force_envs_single_thread": false,
|
| 62 |
+
"default_niceness": 0,
|
| 63 |
+
"log_to_file": true,
|
| 64 |
+
"experiment_summaries_interval": 3,
|
| 65 |
+
"flush_summaries_interval": 30,
|
| 66 |
+
"stats_avg": 100,
|
| 67 |
+
"summaries_use_frameskip": true,
|
| 68 |
+
"heartbeat_interval": 20,
|
| 69 |
+
"heartbeat_reporting_interval": 180,
|
| 70 |
+
"train_for_env_steps": 2000000000,
|
| 71 |
+
"train_for_seconds": 3600000,
|
| 72 |
+
"save_every_sec": 120,
|
| 73 |
+
"keep_checkpoints": 2,
|
| 74 |
+
"load_checkpoint_kind": "latest",
|
| 75 |
+
"save_milestones_sec": 1200,
|
| 76 |
+
"save_best_every_sec": 5,
|
| 77 |
+
"save_best_metric": "reward",
|
| 78 |
+
"save_best_after": 100000,
|
| 79 |
+
"benchmark": false,
|
| 80 |
+
"encoder_mlp_layers": [
|
| 81 |
+
512,
|
| 82 |
+
512
|
| 83 |
+
],
|
| 84 |
+
"encoder_conv_architecture": "convnet_atari",
|
| 85 |
+
"encoder_conv_mlp_layers": [
|
| 86 |
+
512
|
| 87 |
+
],
|
| 88 |
+
"use_rnn": false,
|
| 89 |
+
"rnn_size": 512,
|
| 90 |
+
"rnn_type": "gru",
|
| 91 |
+
"rnn_num_layers": 1,
|
| 92 |
+
"decoder_mlp_layers": [],
|
| 93 |
+
"nonlinearity": "relu",
|
| 94 |
+
"policy_initialization": "orthogonal",
|
| 95 |
+
"policy_init_gain": 1.0,
|
| 96 |
+
"actor_critic_share_weights": true,
|
| 97 |
+
"adaptive_stddev": false,
|
| 98 |
+
"continuous_tanh_scale": 0.0,
|
| 99 |
+
"initial_stddev": 1.0,
|
| 100 |
+
"use_env_info_cache": false,
|
| 101 |
+
"env_gpu_actions": false,
|
| 102 |
+
"env_gpu_observations": true,
|
| 103 |
+
"env_frameskip": 4,
|
| 104 |
+
"env_framestack": 4,
|
| 105 |
+
"pixel_format": "CHW",
|
| 106 |
+
"use_record_episode_statistics": true,
|
| 107 |
+
"episode_counter": false,
|
| 108 |
+
"with_wandb": false,
|
| 109 |
+
"wandb_user": null,
|
| 110 |
+
"wandb_project": "sample_factory",
|
| 111 |
+
"wandb_group": null,
|
| 112 |
+
"wandb_job_type": "SF",
|
| 113 |
+
"wandb_tags": [],
|
| 114 |
+
"with_pbt": false,
|
| 115 |
+
"pbt_mix_policies_in_one_env": true,
|
| 116 |
+
"pbt_period_env_steps": 5000000,
|
| 117 |
+
"pbt_start_mutation": 20000000,
|
| 118 |
+
"pbt_replace_fraction": 0.3,
|
| 119 |
+
"pbt_mutation_rate": 0.15,
|
| 120 |
+
"pbt_replace_reward_gap": 0.1,
|
| 121 |
+
"pbt_replace_reward_gap_absolute": 1e-06,
|
| 122 |
+
"pbt_optimize_gamma": false,
|
| 123 |
+
"pbt_target_objective": "true_objective",
|
| 124 |
+
"pbt_perturb_min": 1.1,
|
| 125 |
+
"pbt_perturb_max": 1.5,
|
| 126 |
+
"env_agents": 512,
|
| 127 |
+
"command_line": "--seed=1111 --experiment=atari_2B_atari_airraid_1111 --env=atari_airraid --train_for_seconds=3600000 --algo=APPO --gamma=0.99 --num_workers=4 --num_envs_per_worker=1 --worker_num_splits=1 --env_agents=512 --benchmark=False --max_grad_norm=0.0 --decorrelate_experience_max_seconds=1 --encoder_conv_architecture=convnet_atari --encoder_conv_mlp_layers 512 --nonlinearity=relu --num_policies=1 --normalize_input=True --normalize_input_keys obs --normalize_returns=True --async_rl=True --batched_sampling=True --train_for_env_steps=2000000000 --save_milestones_sec=1200 --train_dir train_dir --rollout 64 --exploration_loss_coeff 0.0004677351413 --num_epochs 2 --batch_size 1024 --num_batches_per_epoch 8 --learning_rate 0.0003033891184",
|
| 128 |
+
"cli_args": {
|
| 129 |
+
"algo": "APPO",
|
| 130 |
+
"env": "atari_airraid",
|
| 131 |
+
"experiment": "atari_2B_atari_airraid_1111",
|
| 132 |
+
"train_dir": "train_dir",
|
| 133 |
+
"seed": 1111,
|
| 134 |
+
"num_policies": 1,
|
| 135 |
+
"async_rl": true,
|
| 136 |
+
"batched_sampling": true,
|
| 137 |
+
"worker_num_splits": 1,
|
| 138 |
+
"num_workers": 4,
|
| 139 |
+
"num_envs_per_worker": 1,
|
| 140 |
+
"batch_size": 1024,
|
| 141 |
+
"num_batches_per_epoch": 8,
|
| 142 |
+
"num_epochs": 2,
|
| 143 |
+
"rollout": 64,
|
| 144 |
+
"gamma": 0.99,
|
| 145 |
+
"normalize_returns": true,
|
| 146 |
+
"exploration_loss_coeff": 0.0004677351413,
|
| 147 |
+
"max_grad_norm": 0.0,
|
| 148 |
+
"learning_rate": 0.0003033891184,
|
| 149 |
+
"normalize_input": true,
|
| 150 |
+
"normalize_input_keys": [
|
| 151 |
+
"obs"
|
| 152 |
+
],
|
| 153 |
+
"decorrelate_experience_max_seconds": 1,
|
| 154 |
+
"train_for_env_steps": 2000000000,
|
| 155 |
+
"train_for_seconds": 3600000,
|
| 156 |
+
"save_milestones_sec": 1200,
|
| 157 |
+
"benchmark": false,
|
| 158 |
+
"encoder_conv_architecture": "convnet_atari",
|
| 159 |
+
"encoder_conv_mlp_layers": [
|
| 160 |
+
512
|
| 161 |
+
],
|
| 162 |
+
"nonlinearity": "relu",
|
| 163 |
+
"env_agents": 512
|
| 164 |
+
},
|
| 165 |
+
"git_hash": "e259c57b8c7aa9c7f541e9efd1316f8e6f97a6db",
|
| 166 |
+
"git_repo_name": "https://github.com/kaustubhsridhar/jat_regent.git"
|
| 167 |
+
}
|
git.diff
ADDED
|
@@ -0,0 +1,712 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
diff --git a/README.md b/README.md
|
| 2 |
+
index e51a12b..a6e1ca1 100644
|
| 3 |
+
--- a/README.md
|
| 4 |
+
+++ b/README.md
|
| 5 |
+
@@ -21,6 +21,21 @@ conda activate jat
|
| 6 |
+
pip install -e .[dev]
|
| 7 |
+
```
|
| 8 |
+
|
| 9 |
+
+## REGENT fork of sample-factory: Installation
|
| 10 |
+
+Following [this install ink](https://www.samplefactory.dev/01-get-started/installation/) but for the fork:
|
| 11 |
+
+```shell
|
| 12 |
+
+git clone https://github.com/kaustubhsridhar/sample-factory.git
|
| 13 |
+
+cd sample-factory
|
| 14 |
+
+pip install -e .[dev,mujoco,atari,envpool,vizdoom]
|
| 15 |
+
+```
|
| 16 |
+
+
|
| 17 |
+
+# Regent fork of sample-factory: Train Unseen Env Policies and Generate Datasets
|
| 18 |
+
+Train policies using envpool's atari:
|
| 19 |
+
+```shell
|
| 20 |
+
+bash scripts_sample-factory/train_unseen_atari.sh
|
| 21 |
+
+```
|
| 22 |
+
+Note that the training command inside the above script was obtained from the config files of Ed Beeching's Atari 57 models on Huggingface. An example is [here](https://huggingface.co/edbeeching/atari_2B_atari_mspacman_1111/blob/main/cfg.json#L124). See my discussion [here](https://huggingface.co/edbeeching/atari_2B_atari_mspacman_1111/discussions/2).
|
| 23 |
+
+
|
| 24 |
+
## PREV Installation
|
| 25 |
+
|
| 26 |
+
To get started with JAT, follow these steps:
|
| 27 |
+
@@ -155,12 +170,21 @@ python -u scripts_jat_regent/eval_RandP.py --task ${TASK} &> outputs/RandP/${TAS
|
| 28 |
+
```
|
| 29 |
+
|
| 30 |
+
### REGENT Analyze data
|
| 31 |
+
+Necessary:
|
| 32 |
+
```shell
|
| 33 |
+
-python -u examples_regent/compare_datasets.py &> examples_regent/compare_datasets.txt &
|
| 34 |
+
-
|
| 35 |
+
python -u examples_regent/analyze_rows_tokenized.py &> examples_regent/analyze_rows_tokenized.txt &
|
| 36 |
+
+```
|
| 37 |
+
|
| 38 |
+
+Already ran and output dict in code:
|
| 39 |
+
+```shell
|
| 40 |
+
python -u examples_regent/get_dim_all_vector_tasks.py &> examples_regent/get_dim_all_vector_tasks.txt &
|
| 41 |
+
+
|
| 42 |
+
+python -u examples_regent/count_rows_to_consider.py &> examples_regent/count_rows_to_consider.txt &
|
| 43 |
+
+```
|
| 44 |
+
+
|
| 45 |
+
+Optional:
|
| 46 |
+
+```shell
|
| 47 |
+
+python -u examples_regent/compare_datasets.py &> examples_regent/compare_datasets.txt &
|
| 48 |
+
```
|
| 49 |
+
|
| 50 |
+
## PREV Dataset
|
| 51 |
+
diff --git a/jat_regent/RandP.py b/jat_regent/RandP.py
|
| 52 |
+
deleted file mode 100644
|
| 53 |
+
index b2bd8bf..0000000
|
| 54 |
+
--- a/jat_regent/RandP.py
|
| 55 |
+
+++ /dev/null
|
| 56 |
+
@@ -1,38 +0,0 @@
|
| 57 |
+
-import warnings
|
| 58 |
+
-from dataclasses import dataclass
|
| 59 |
+
-from typing import List, Optional, Tuple, Union
|
| 60 |
+
-
|
| 61 |
+
-import numpy as np
|
| 62 |
+
-import torch
|
| 63 |
+
-import torch.nn.functional as F
|
| 64 |
+
-from gymnasium import spaces
|
| 65 |
+
-from torch import BoolTensor, FloatTensor, LongTensor, Tensor, nn
|
| 66 |
+
-from transformers import GPTNeoModel, GPTNeoPreTrainedModel
|
| 67 |
+
-from transformers.modeling_outputs import ModelOutput
|
| 68 |
+
-from transformers.models.vit.modeling_vit import ViTPatchEmbeddings
|
| 69 |
+
-
|
| 70 |
+
-from jat.configuration_jat import JatConfig
|
| 71 |
+
-from jat.processing_jat import JatProcessor
|
| 72 |
+
-
|
| 73 |
+
-
|
| 74 |
+
-class RandP():
|
| 75 |
+
- def __init__(self, dataset) -> None:
|
| 76 |
+
- self.steps = 0
|
| 77 |
+
- # create an index for retrieval in vector obs envs (OR) collect all images in Atari
|
| 78 |
+
-
|
| 79 |
+
- def reset_rl(self):
|
| 80 |
+
- self.steps = 0
|
| 81 |
+
-
|
| 82 |
+
- def get_next_action(
|
| 83 |
+
- self,
|
| 84 |
+
- processor: JatProcessor,
|
| 85 |
+
- continuous_observation: Optional[List[float]] = None,
|
| 86 |
+
- discrete_observation: Optional[List[int]] = None,
|
| 87 |
+
- text_observation: Optional[str] = None,
|
| 88 |
+
- image_observation: Optional[np.ndarray] = None,
|
| 89 |
+
- action_space: Union[spaces.Box, spaces.Discrete] = None,
|
| 90 |
+
- reward: Optional[float] = None,
|
| 91 |
+
- deterministic: bool = False,
|
| 92 |
+
- context_window: Optional[int] = None,
|
| 93 |
+
- ):
|
| 94 |
+
- pass
|
| 95 |
+
|
| 96 |
+
diff --git a/jat_regent/modelling_jat_regent.py b/jat_regent/modelling_jat_regent.py
|
| 97 |
+
deleted file mode 100644
|
| 98 |
+
index e69de29..0000000
|
| 99 |
+
diff --git a/jat_regent/utils.py b/jat_regent/utils.py
|
| 100 |
+
index 56bfb44..36f6cca 100644
|
| 101 |
+
--- a/jat_regent/utils.py
|
| 102 |
+
+++ b/jat_regent/utils.py
|
| 103 |
+
@@ -8,23 +8,35 @@ from tqdm import tqdm
|
| 104 |
+
from autofaiss import build_index
|
| 105 |
+
|
| 106 |
+
|
| 107 |
+
+UNSEEN_TASK_NAMES = { # Total -- atari: 57, metaworld: 50, babyai: 39, mujoco: 11
|
| 108 |
+
+
|
| 109 |
+
+}
|
| 110 |
+
+
|
| 111 |
+
def myprint(str):
|
| 112 |
+
- # check if first character of string is a newline character
|
| 113 |
+
- if str[0] == '\n':
|
| 114 |
+
- str_without_newline = str[1:]
|
| 115 |
+
+ # check if first characters of string are newline character
|
| 116 |
+
+ num_newlines = 0
|
| 117 |
+
+ while str[num_newlines] == '\n':
|
| 118 |
+
print()
|
| 119 |
+
- else:
|
| 120 |
+
- str_without_newline = str
|
| 121 |
+
+ num_newlines += 1
|
| 122 |
+
+ str_without_newline = str[num_newlines:]
|
| 123 |
+
print(f'{datetime.now().strftime("%Y-%m-%d %H:%M:%S")}: {str_without_newline}')
|
| 124 |
+
|
| 125 |
+
def is_png_img(item):
|
| 126 |
+
return isinstance(item, PngImagePlugin.PngImageFile)
|
| 127 |
+
|
| 128 |
+
+def get_last_row_for_1M_states(task):
|
| 129 |
+
+ last_row_idx = {'atari-alien': 14134, 'atari-amidar': 14319, 'atari-assault': 14427, 'atari-asterix': 14456, 'atari-asteroids': 14348, 'atari-atlantis': 14325, 'atari-bankheist': 14167, 'atari-battlezone': 13981, 'atari-beamrider': 13442, 'atari-berzerk': 13534, 'atari-bowling': 14110, 'atari-boxing': 14542, 'atari-breakout': 13474, 'atari-centipede': 14196, 'atari-choppercommand': 13397, 'atari-crazyclimber': 14026, 'atari-defender': 13504, 'atari-demonattack': 13499, 'atari-doubledunk': 14292, 'atari-enduro': 13260, 'atari-fishingderby': 14073, 'atari-freeway': 14016, 'atari-frostbite': 14075, 'atari-gopher': 13143, 'atari-gravitar': 14405, 'atari-hero': 14044, 'atari-icehockey': 14017, 'atari-jamesbond': 12678, 'atari-kangaroo': 14248, 'atari-krull': 14204, 'atari-kungfumaster': 14030, 'atari-montezumarevenge': 14219, 'atari-mspacman': 14120, 'atari-namethisgame': 13575, 'atari-phoenix': 13539, 'atari-pitfall': 14287, 'atari-pong': 14151, 'atari-privateeye': 14105, 'atari-qbert': 14026, 'atari-riverraid': 14275, 'atari-roadrunner': 14127, 'atari-robotank': 14079, 'atari-seaquest': 14097, 'atari-skiing': 14708, 'atari-solaris': 14199, 'atari-spaceinvaders': 12652, 'atari-stargunner': 13822, 'atari-surround': 13840, 'atari-tennis': 14062, 'atari-timepilot': 13896, 'atari-tutankham': 13121, 'atari-upndown': 13504, 'atari-venture': 14260, 'atari-videopinball': 14272, 'atari-wizardofwor': 13920, 'atari-yarsrevenge': 13981, 'atari-zaxxon': 13833, 'babyai-action-obj-door': 95000, 'babyai-blocked-unlock-pickup': 29279, 'babyai-boss-level-no-unlock': 12087, 'babyai-boss-level': 12101, 'babyai-find-obj-s5': 32974, 'babyai-go-to-door': 95000, 'babyai-go-to-imp-unlock': 9286, 'babyai-go-to-local': 95000, 'babyai-go-to-obj-door': 95000, 'babyai-go-to-obj': 95000, 'babyai-go-to-red-ball-grey': 95000, 'babyai-go-to-red-ball-no-dists': 95000, 'babyai-go-to-red-ball': 95000, 'babyai-go-to-red-blue-ball': 95000, 'babyai-go-to-seq': 13744, 'babyai-go-to': 18974, 'babyai-key-corridor': 9014, 'babyai-mini-boss-level': 38119, 'babyai-move-two-across-s8n9': 24505, 'babyai-one-room-s8': 95000, 'babyai-open-door': 95000, 'babyai-open-doors-order-n4': 95000, 'babyai-open-red-door': 95000, 'babyai-open-two-doors': 73291, 'babyai-open': 32559, 'babyai-pickup-above': 34084, 'babyai-pickup-dist': 89640, 'babyai-pickup-loc': 95000, 'babyai-pickup': 18670, 'babyai-put-next-local': 83187, 'babyai-put-next': 56986, 'babyai-synth-loc': 21605, 'babyai-synth-seq': 13049, 'babyai-synth': 19409, 'babyai-unblock-pickup': 17881, 'babyai-unlock-local': 71186, 'babyai-unlock-pickup': 50883, 'babyai-unlock-to-unlock': 23062, 'babyai-unlock': 11734, 'metaworld-assembly': 10000, 'metaworld-basketball': 10000, 'metaworld-bin-picking': 10000, 'metaworld-box-close': 10000, 'metaworld-button-press-topdown-wall': 10000, 'metaworld-button-press-topdown': 10000, 'metaworld-button-press-wall': 10000, 'metaworld-button-press': 10000, 'metaworld-coffee-button': 10000, 'metaworld-coffee-pull': 10000, 'metaworld-coffee-push': 10000, 'metaworld-dial-turn': 10000, 'metaworld-disassemble': 10000, 'metaworld-door-close': 10000, 'metaworld-door-lock': 10000, 'metaworld-door-open': 10000, 'metaworld-door-unlock': 10000, 'metaworld-drawer-close': 10000, 'metaworld-drawer-open': 10000, 'metaworld-faucet-close': 10000, 'metaworld-faucet-open': 10000, 'metaworld-hammer': 10000, 'metaworld-hand-insert': 10000, 'metaworld-handle-press-side': 10000, 'metaworld-handle-press': 10000, 'metaworld-handle-pull-side': 10000, 'metaworld-handle-pull': 10000, 'metaworld-lever-pull': 10000, 'metaworld-peg-insert-side': 10000, 'metaworld-peg-unplug-side': 10000, 'metaworld-pick-out-of-hole': 10000, 'metaworld-pick-place-wall': 10000, 'metaworld-pick-place': 10000, 'metaworld-plate-slide-back-side': 10000, 'metaworld-plate-slide-back': 10000, 'metaworld-plate-slide-side': 10000, 'metaworld-plate-slide': 10000, 'metaworld-push-back': 10000, 'metaworld-push-wall': 10000, 'metaworld-push': 10000, 'metaworld-reach-wall': 10000, 'metaworld-reach': 10000, 'metaworld-shelf-place': 10000, 'metaworld-soccer': 10000, 'metaworld-stick-pull': 10000, 'metaworld-stick-push': 10000, 'metaworld-sweep-into': 10000, 'metaworld-sweep': 10000, 'metaworld-window-close': 10000, 'metaworld-window-open': 10000, 'mujoco-ant': 4023, 'mujoco-doublependulum': 4002, 'mujoco-halfcheetah': 4000, 'mujoco-hopper': 4931, 'mujoco-humanoid': 4119, 'mujoco-pendulum': 4959, 'mujoco-pusher': 9000, 'mujoco-reacher': 9000, 'mujoco-standup': 4000, 'mujoco-swimmer': 4000, 'mujoco-walker': 4101}
|
| 130 |
+
+ return last_row_idx[task]
|
| 131 |
+
+
|
| 132 |
+
+def get_last_row_for_100k_states(task):
|
| 133 |
+
+ last_row_idx = {'atari-alien': 3135, 'atari-amidar': 3142, 'atari-assault': 3132, 'atari-asterix': 3181, 'atari-asteroids': 3127, 'atari-atlantis': 3128, 'atari-bankheist': 3156, 'atari-battlezone': 3136, 'atari-beamrider': 3131, 'atari-berzerk': 3127, 'atari-bowling': 3148, 'atari-boxing': 3227, 'atari-breakout': 3128, 'atari-centipede': 3176, 'atari-choppercommand': 3144, 'atari-crazyclimber': 3134, 'atari-defender': 3127, 'atari-demonattack': 3127, 'atari-doubledunk': 3175, 'atari-enduro': 3126, 'atari-fishingderby': 3155, 'atari-freeway': 3131, 'atari-frostbite': 3146, 'atari-gopher': 3128, 'atari-gravitar': 3202, 'atari-hero': 3144, 'atari-icehockey': 3138, 'atari-jamesbond': 3131, 'atari-kangaroo': 3160, 'atari-krull': 3162, 'atari-kungfumaster': 3143, 'atari-montezumarevenge': 3168, 'atari-mspacman': 3143, 'atari-namethisgame': 3131, 'atari-phoenix': 3127, 'atari-pitfall': 3131, 'atari-pong': 3160, 'atari-privateeye': 3158, 'atari-qbert': 3136, 'atari-riverraid': 3157, 'atari-roadrunner': 3150, 'atari-robotank': 3133, 'atari-seaquest': 3138, 'atari-skiing': 3271, 'atari-solaris': 3129, 'atari-spaceinvaders': 3128, 'atari-stargunner': 3129, 'atari-surround': 3143, 'atari-tennis': 3129, 'atari-timepilot': 3132, 'atari-tutankham': 3127, 'atari-upndown': 3127, 'atari-venture': 3148, 'atari-videopinball': 3130, 'atari-wizardofwor': 3138, 'atari-yarsrevenge': 3129, 'atari-zaxxon': 3133, 'babyai-action-obj-door': 15923, 'babyai-blocked-unlock-pickup': 2919, 'babyai-boss-level-no-unlock': 1217, 'babyai-boss-level': 1159, 'babyai-find-obj-s5': 3345, 'babyai-go-to-door': 18875, 'babyai-go-to-imp-unlock': 923, 'babyai-go-to-local': 18724, 'babyai-go-to-obj-door': 16472, 'babyai-go-to-obj': 20197, 'babyai-go-to-red-ball-grey': 16953, 'babyai-go-to-red-ball-no-dists': 20165, 'babyai-go-to-red-ball': 18730, 'babyai-go-to-red-blue-ball': 16934, 'babyai-go-to-seq': 1439, 'babyai-go-to': 1964, 'babyai-key-corridor': 900, 'babyai-mini-boss-level': 3789, 'babyai-move-two-across-s8n9': 2462, 'babyai-one-room-s8': 16994, 'babyai-open-door': 13565, 'babyai-open-doors-order-n4': 9706, 'babyai-open-red-door': 21185, 'babyai-open-two-doors': 7348, 'babyai-open': 3331, 'babyai-pickup-above': 3392, 'babyai-pickup-dist': 19693, 'babyai-pickup-loc': 16405, 'babyai-pickup': 1806, 'babyai-put-next-local': 8303, 'babyai-put-next': 5703, 'babyai-synth-loc': 2183, 'babyai-synth-seq': 1316, 'babyai-synth': 1964, 'babyai-unblock-pickup': 1886, 'babyai-unlock-local': 7118, 'babyai-unlock-pickup': 5107, 'babyai-unlock-to-unlock': 2309, 'babyai-unlock': 1177, 'metaworld-assembly': 1000, 'metaworld-basketball': 1000, 'metaworld-bin-picking': 1000, 'metaworld-box-close': 1000, 'metaworld-button-press-topdown-wall': 1000, 'metaworld-button-press-topdown': 1000, 'metaworld-button-press-wall': 1000, 'metaworld-button-press': 1000, 'metaworld-coffee-button': 1000, 'metaworld-coffee-pull': 1000, 'metaworld-coffee-push': 1000, 'metaworld-dial-turn': 1000, 'metaworld-disassemble': 1000, 'metaworld-door-close': 1000, 'metaworld-door-lock': 1000, 'metaworld-door-open': 1000, 'metaworld-door-unlock': 1000, 'metaworld-drawer-close': 1000, 'metaworld-drawer-open': 1000, 'metaworld-faucet-close': 1000, 'metaworld-faucet-open': 1000, 'metaworld-hammer': 1000, 'metaworld-hand-insert': 1000, 'metaworld-handle-press-side': 1000, 'metaworld-handle-press': 1000, 'metaworld-handle-pull-side': 1000, 'metaworld-handle-pull': 1000, 'metaworld-lever-pull': 1000, 'metaworld-peg-insert-side': 1000, 'metaworld-peg-unplug-side': 1000, 'metaworld-pick-out-of-hole': 1000, 'metaworld-pick-place-wall': 1000, 'metaworld-pick-place': 1000, 'metaworld-plate-slide-back-side': 1000, 'metaworld-plate-slide-back': 1000, 'metaworld-plate-slide-side': 1000, 'metaworld-plate-slide': 1000, 'metaworld-push-back': 1000, 'metaworld-push-wall': 1000, 'metaworld-push': 1000, 'metaworld-reach-wall': 1000, 'metaworld-reach': 1000, 'metaworld-shelf-place': 1000, 'metaworld-soccer': 1000, 'metaworld-stick-pull': 1000, 'metaworld-stick-push': 1000, 'metaworld-sweep-into': 1000, 'metaworld-sweep': 1000, 'metaworld-window-close': 1000, 'metaworld-window-open': 1000, 'mujoco-ant': 401, 'mujoco-doublependulum': 401, 'mujoco-halfcheetah': 400, 'mujoco-hopper': 491, 'mujoco-humanoid': 415, 'mujoco-pendulum': 495, 'mujoco-pusher': 1000, 'mujoco-reacher': 2000, 'mujoco-standup': 400, 'mujoco-swimmer': 400, 'mujoco-walker': 407}
|
| 134 |
+
+ return last_row_idx[task]
|
| 135 |
+
+
|
| 136 |
+
def get_obs_dim(task):
|
| 137 |
+
assert task.startswith("babyai") or task.startswith("metaworld") or task.startswith("mujoco")
|
| 138 |
+
|
| 139 |
+
all_obs_dims={'babyai-action-obj-door': 212, 'babyai-blocked-unlock-pickup': 212, 'babyai-boss-level-no-unlock': 212, 'babyai-boss-level': 212, 'babyai-find-obj-s5': 212, 'babyai-go-to-door': 212, 'babyai-go-to-imp-unlock': 212, 'babyai-go-to-local': 212, 'babyai-go-to-obj-door': 212, 'babyai-go-to-obj': 212, 'babyai-go-to-red-ball-grey': 212, 'babyai-go-to-red-ball-no-dists': 212, 'babyai-go-to-red-ball': 212, 'babyai-go-to-red-blue-ball': 212, 'babyai-go-to-seq': 212, 'babyai-go-to': 212, 'babyai-key-corridor': 212, 'babyai-mini-boss-level': 212, 'babyai-move-two-across-s8n9': 212, 'babyai-one-room-s8': 212, 'babyai-open-door': 212, 'babyai-open-doors-order-n4': 212, 'babyai-open-red-door': 212, 'babyai-open-two-doors': 212, 'babyai-open': 212, 'babyai-pickup-above': 212, 'babyai-pickup-dist': 212, 'babyai-pickup-loc': 212, 'babyai-pickup': 212, 'babyai-put-next-local': 212, 'babyai-put-next': 212, 'babyai-synth-loc': 212, 'babyai-synth-seq': 212, 'babyai-synth': 212, 'babyai-unblock-pickup': 212, 'babyai-unlock-local': 212, 'babyai-unlock-pickup': 212, 'babyai-unlock-to-unlock': 212, 'babyai-unlock': 212, 'metaworld-assembly': 39, 'metaworld-basketball': 39, 'metaworld-bin-picking': 39, 'metaworld-box-close': 39, 'metaworld-button-press-topdown-wall': 39, 'metaworld-button-press-topdown': 39, 'metaworld-button-press-wall': 39, 'metaworld-button-press': 39, 'metaworld-coffee-button': 39, 'metaworld-coffee-pull': 39, 'metaworld-coffee-push': 39, 'metaworld-dial-turn': 39, 'metaworld-disassemble': 39, 'metaworld-door-close': 39, 'metaworld-door-lock': 39, 'metaworld-door-open': 39, 'metaworld-door-unlock': 39, 'metaworld-drawer-close': 39, 'metaworld-drawer-open': 39, 'metaworld-faucet-close': 39, 'metaworld-faucet-open': 39, 'metaworld-hammer': 39, 'metaworld-hand-insert': 39, 'metaworld-handle-press-side': 39, 'metaworld-handle-press': 39, 'metaworld-handle-pull-side': 39, 'metaworld-handle-pull': 39, 'metaworld-lever-pull': 39, 'metaworld-peg-insert-side': 39, 'metaworld-peg-unplug-side': 39, 'metaworld-pick-out-of-hole': 39, 'metaworld-pick-place-wall': 39, 'metaworld-pick-place': 39, 'metaworld-plate-slide-back-side': 39, 'metaworld-plate-slide-back': 39, 'metaworld-plate-slide-side': 39, 'metaworld-plate-slide': 39, 'metaworld-push-back': 39, 'metaworld-push-wall': 39, 'metaworld-push': 39, 'metaworld-reach-wall': 39, 'metaworld-reach': 39, 'metaworld-shelf-place': 39, 'metaworld-soccer': 39, 'metaworld-stick-pull': 39, 'metaworld-stick-push': 39, 'metaworld-sweep-into': 39, 'metaworld-sweep': 39, 'metaworld-window-close': 39, 'metaworld-window-open': 39, 'mujoco-ant': 27, 'mujoco-doublependulum': 11, 'mujoco-halfcheetah': 17, 'mujoco-hopper': 11, 'mujoco-humanoid': 376, 'mujoco-pendulum': 4, 'mujoco-pusher': 23, 'mujoco-reacher': 11, 'mujoco-standup': 376, 'mujoco-swimmer': 8, 'mujoco-walker': 17}
|
| 140 |
+
- return all_obs_dims[task]
|
| 141 |
+
+ return (all_obs_dims[task],)
|
| 142 |
+
|
| 143 |
+
def get_act_dim(task):
|
| 144 |
+
assert task.startswith("babyai") or task.startswith("metaworld") or task.startswith("mujoco")
|
| 145 |
+
@@ -36,141 +48,188 @@ def get_act_dim(task):
|
| 146 |
+
elif task.startswith("mujoco"):
|
| 147 |
+
all_act_dims={'mujoco-ant': 8, 'mujoco-doublependulum': 1, 'mujoco-halfcheetah': 6, 'mujoco-hopper': 3, 'mujoco-humanoid': 17, 'mujoco-pendulum': 1, 'mujoco-pusher': 7, 'mujoco-reacher': 2, 'mujoco-standup': 17, 'mujoco-swimmer': 2, 'mujoco-walker': 6}
|
| 148 |
+
return all_act_dims[task]
|
| 149 |
+
-
|
| 150 |
+
-def process_row_atari(attn_mask, row_of_obs, task):
|
| 151 |
+
- """
|
| 152 |
+
- Example for selection with bools:
|
| 153 |
+
- >>> a = np.array([0,1,2,3,4,5])
|
| 154 |
+
- >>> b = np.array([1,0,0,0,0,1]).astype(bool)
|
| 155 |
+
- >>> a[b]
|
| 156 |
+
- array([0, 5])
|
| 157 |
+
- """
|
| 158 |
+
- attn_mask = np.array(attn_mask).astype(bool)
|
| 159 |
+
|
| 160 |
+
- row_of_obs = torch.stack([to_tensor(np.array(img)) for img in row_of_obs])
|
| 161 |
+
- row_of_obs = row_of_obs[attn_mask]
|
| 162 |
+
+def get_task_info(task):
|
| 163 |
+
+ rew_key = 'rewards'
|
| 164 |
+
+ attn_key = 'attention_mask'
|
| 165 |
+
+ if task.startswith("atari"):
|
| 166 |
+
+ obs_key = 'image_observations'
|
| 167 |
+
+ act_key = 'discrete_actions'
|
| 168 |
+
+ B = 32 # half of 54
|
| 169 |
+
+ obs_dim = (3, 4*84, 84)
|
| 170 |
+
+ elif task.startswith("babyai"):
|
| 171 |
+
+ obs_key = 'discrete_observations' # also has 'text_observations' only for raw dataset not for tokenized dataset (as it is combined into discrete_observation in tokenized dataset)
|
| 172 |
+
+ act_key = 'discrete_actions'
|
| 173 |
+
+ B = 256 # half of 512
|
| 174 |
+
+ obs_dim = get_obs_dim(task)
|
| 175 |
+
+ elif task.startswith("metaworld") or task.startswith("mujoco"):
|
| 176 |
+
+ obs_key = 'continuous_observations'
|
| 177 |
+
+ act_key = 'continuous_actions'
|
| 178 |
+
+ B = 256
|
| 179 |
+
+ obs_dim = get_obs_dim(task)
|
| 180 |
+
+
|
| 181 |
+
+ return rew_key, attn_key, obs_key, act_key, B, obs_dim
|
| 182 |
+
+
|
| 183 |
+
+def process_row_of_obs_atari_full_without_mask(row_of_obs):
|
| 184 |
+
+
|
| 185 |
+
+ if not isinstance(row_of_obs, torch.Tensor):
|
| 186 |
+
+ row_of_obs = torch.stack([to_tensor(np.array(img)) for img in row_of_obs])
|
| 187 |
+
row_of_obs = row_of_obs * 0.5 + 0.5 # denormalize from [-1, 1] to [0, 1]
|
| 188 |
+
- assert row_of_obs.shape == (sum(attn_mask), 84, 4, 84)
|
| 189 |
+
+ assert row_of_obs.shape == (len(row_of_obs), 84, 4, 84)
|
| 190 |
+
row_of_obs = row_of_obs.permute(0, 2, 1, 3) # (*, 4, 84, 84)
|
| 191 |
+
- row_of_obs = row_of_obs.reshape(sum(attn_mask), 4*84, 84) # put side-by-side
|
| 192 |
+
+ row_of_obs = row_of_obs.reshape(len(row_of_obs), 4*84, 84) # put side-by-side
|
| 193 |
+
row_of_obs = row_of_obs.unsqueeze(1).repeat(1, 3, 1, 1) # repeat for 3 channels
|
| 194 |
+
- assert row_of_obs.shape == (sum(attn_mask), 3, 4*84, 84) # sum(attn_mask) is the batch size dimension
|
| 195 |
+
-
|
| 196 |
+
- return attn_mask, row_of_obs
|
| 197 |
+
+ assert row_of_obs.shape == (len(row_of_obs), 3, 4*84, 84) # sum(attn_mask) is the batch size dimension
|
| 198 |
+
+
|
| 199 |
+
+ return row_of_obs
|
| 200 |
+
|
| 201 |
+
-def process_row_vector(attn_mask, row_of_obs, task, return_numpy=False):
|
| 202 |
+
- attn_mask = np.array(attn_mask).astype(bool)
|
| 203 |
+
+def collect_all_atari_data(dataset, all_row_idxs=None):
|
| 204 |
+
+ if all_row_idxs is None:
|
| 205 |
+
+ all_row_idxs = list(range(len(dataset['train'])))
|
| 206 |
+
|
| 207 |
+
- row_of_obs = np.array(row_of_obs)
|
| 208 |
+
- if not return_numpy:
|
| 209 |
+
- row_of_obs = torch.tensor(row_of_obs)
|
| 210 |
+
- row_of_obs = row_of_obs[attn_mask]
|
| 211 |
+
- assert row_of_obs.shape == (sum(attn_mask), get_obs_dim(task))
|
| 212 |
+
-
|
| 213 |
+
- return attn_mask, row_of_obs
|
| 214 |
+
-
|
| 215 |
+
-def retrieve_atari(row_of_obs, # query: (row_B, 3, 4*84, 84)
|
| 216 |
+
- dataset, # to retrieve from
|
| 217 |
+
- all_rows_to_consider, # rows to consider
|
| 218 |
+
- num_to_retrieve, # top-k
|
| 219 |
+
+ all_rows_of_obs = []
|
| 220 |
+
+ all_attn_masks = []
|
| 221 |
+
+ for row_idx in tqdm(all_row_idxs):
|
| 222 |
+
+ datarow = dataset['train'][row_idx]
|
| 223 |
+
+ row_of_obs = process_row_of_obs_atari_full_without_mask(datarow['image_observations'])
|
| 224 |
+
+ attn_mask = np.array(datarow['attention_mask']).astype(bool)
|
| 225 |
+
+ all_rows_of_obs.append(row_of_obs) # appending tensor
|
| 226 |
+
+ all_attn_masks.append(attn_mask) # appending np array
|
| 227 |
+
+ all_rows_of_obs = torch.stack(all_rows_of_obs, dim=0) # stacking tensors
|
| 228 |
+
+ all_attn_masks = np.stack(all_attn_masks, axis=0) # concatenating np arrays
|
| 229 |
+
+ assert (all_rows_of_obs.shape == (len(all_row_idxs), 32, 3, 4*84, 84) and
|
| 230 |
+
+ all_attn_masks.shape == (len(all_row_idxs), 32))
|
| 231 |
+
+ return all_attn_masks, all_rows_of_obs
|
| 232 |
+
+
|
| 233 |
+
+def collect_all_data(dataset, task, obs_key):
|
| 234 |
+
+ last_row_idx = get_last_row_for_100k_states(task)
|
| 235 |
+
+ all_row_idxs = list(range(last_row_idx))
|
| 236 |
+
+ if task.startswith("atari"):
|
| 237 |
+
+ myprint("Collecting all Atari images and Atari attention masks...")
|
| 238 |
+
+ all_attn_masks_OG, all_rows_of_obs_OG = collect_all_atari_data(dataset, all_row_idxs)
|
| 239 |
+
+ else:
|
| 240 |
+
+ datarows = dataset['train'][all_row_idxs]
|
| 241 |
+
+ all_rows_of_obs_OG = np.array(datarows[obs_key])
|
| 242 |
+
+ all_attn_masks_OG = np.array(datarows['attention_mask']).astype(bool)
|
| 243 |
+
+ return all_rows_of_obs_OG, all_attn_masks_OG, all_row_idxs
|
| 244 |
+
+
|
| 245 |
+
+def collect_subset(all_rows_of_obs_OG,
|
| 246 |
+
+ all_attn_masks_OG,
|
| 247 |
+
+ all_rows_to_consider,
|
| 248 |
+
+ kwargs
|
| 249 |
+
+ ):
|
| 250 |
+
+ """
|
| 251 |
+
+ Function to collect subset of data given all_rows_to_consider, reshape it, create all_indices and return.
|
| 252 |
+
+ Used in both retrieve_atari() and retrieve_vector() --> build_index_vector().
|
| 253 |
+
+ """
|
| 254 |
+
+ myprint(f'\n\n\n' + ('-'*100) + f'Collecting subset...')
|
| 255 |
+
+ # read kwargs
|
| 256 |
+
+ B, task, obs_dim = kwargs['B'], kwargs['task'], kwargs['obs_dim']
|
| 257 |
+
+
|
| 258 |
+
+ # take subset based on all_rows_to_consider
|
| 259 |
+
+ myprint(f'Taking subset of data based on all_rows_to_consider...')
|
| 260 |
+
+ all_processed_rows_of_obs = all_rows_of_obs_OG[all_rows_to_consider]
|
| 261 |
+
+ all_attn_masks = all_attn_masks_OG[all_rows_to_consider]
|
| 262 |
+
+ assert (all_processed_rows_of_obs.shape == (len(all_rows_to_consider), B, *obs_dim) and
|
| 263 |
+
+ all_attn_masks.shape == (len(all_rows_to_consider), B))
|
| 264 |
+
+
|
| 265 |
+
+ # reshape
|
| 266 |
+
+ myprint(f'Reshaping data...')
|
| 267 |
+
+ all_attn_masks = all_attn_masks.reshape(-1)
|
| 268 |
+
+ all_processed_rows_of_obs = all_processed_rows_of_obs.reshape(-1, *obs_dim)
|
| 269 |
+
+ all_processed_rows_of_obs = all_processed_rows_of_obs[all_attn_masks]
|
| 270 |
+
+ assert (all_attn_masks.shape == (len(all_rows_to_consider) * B,) and
|
| 271 |
+
+ all_processed_rows_of_obs.shape == (np.sum(all_attn_masks), *obs_dim))
|
| 272 |
+
+
|
| 273 |
+
+ # collect indices of data
|
| 274 |
+
+ myprint(f'Collecting indices of data...')
|
| 275 |
+
+ all_indices = np.array([[row_idx, i] for row_idx in all_rows_to_consider for i in range(B)])
|
| 276 |
+
+ all_indices = all_indices[all_attn_masks] # this is fine because all attn masks have 0s that only come after 1s
|
| 277 |
+
+ assert all_indices.shape == (np.sum(all_attn_masks), 2)
|
| 278 |
+
+
|
| 279 |
+
+ myprint(f'{all_indices.shape=}, {all_processed_rows_of_obs.shape=}')
|
| 280 |
+
+ myprint(('-'*100) + '\n\n\n')
|
| 281 |
+
+ return all_indices, all_processed_rows_of_obs
|
| 282 |
+
+
|
| 283 |
+
+def retrieve_atari(row_of_obs, # query: (xbdim, 3, 4*84, 84) / (xdim *obs_dim)
|
| 284 |
+
+ all_processed_rows_of_obs,
|
| 285 |
+
+ all_indices,
|
| 286 |
+
+ num_to_retrieve,
|
| 287 |
+
kwargs
|
| 288 |
+
- ):
|
| 289 |
+
+ ):
|
| 290 |
+
+ """
|
| 291 |
+
+ Retrieval for Atari with images, ssim distance, and on GPU.
|
| 292 |
+
+ """
|
| 293 |
+
assert isinstance(row_of_obs, torch.Tensor)
|
| 294 |
+
|
| 295 |
+
# read kwargs # Note: B = len of row
|
| 296 |
+
- B, attn_key, obs_key, device, task, batch_size_retrieval = kwargs['B'], kwargs['attn_key'], kwargs['obs_key'], kwargs['device'], kwargs['task'], kwargs['batch_size_retrieval']
|
| 297 |
+
+ B, device, batch_size_retrieval = kwargs['B'], kwargs['device'], kwargs['batch_size_retrieval']
|
| 298 |
+
|
| 299 |
+
# batch size of row_of_obs which can be <= B since we process before calling this function
|
| 300 |
+
- row_B = row_of_obs.shape[0]
|
| 301 |
+
-
|
| 302 |
+
+ xbdim = row_of_obs.shape[0]
|
| 303 |
+
+
|
| 304 |
+
+ # collect subset of data that we can retrieve from
|
| 305 |
+
+ ydim = all_processed_rows_of_obs.shape[0]
|
| 306 |
+
+
|
| 307 |
+
# first argument for ssim
|
| 308 |
+
- repeated_row_og = row_of_obs.repeat_interleave(B, dim=0).to(device)
|
| 309 |
+
- assert repeated_row_og.shape == (row_B*B, 3, 4*84, 84)
|
| 310 |
+
+ xbatch = row_of_obs.repeat_interleave(batch_size_retrieval, dim=0).to(device)
|
| 311 |
+
+ assert xbatch.shape == (xbdim * batch_size_retrieval, 3, 4*84, 84)
|
| 312 |
+
|
| 313 |
+
- # iterate over all other rows
|
| 314 |
+
+ # iterate over data that we can retrieve from in batches
|
| 315 |
+
all_ssim = []
|
| 316 |
+
- all_indices = []
|
| 317 |
+
- total = 0
|
| 318 |
+
- for other_row_idx in tqdm(all_rows_to_consider):
|
| 319 |
+
- other_attn_mask, other_row_of_obs = process_row_atari(dataset['train'][other_row_idx][attn_key], dataset['train'][other_row_idx][obs_key])
|
| 320 |
+
-
|
| 321 |
+
- # batch size of other_row_of_obs
|
| 322 |
+
- other_row_B = other_row_of_obs.shape[0]
|
| 323 |
+
- total += other_row_B
|
| 324 |
+
-
|
| 325 |
+
- # first argument for ssim: RECHECK
|
| 326 |
+
- if other_row_B < B: # when other row has less observations than expected
|
| 327 |
+
- repeated_row = row_of_obs.repeat_interleave(other_row_B, dim=0).to(device)
|
| 328 |
+
- elif other_row_B == B: # otherwise just use the one created before the for loop
|
| 329 |
+
- repeated_row = repeated_row_og
|
| 330 |
+
- assert repeated_row.shape == (row_B*other_row_B, 3, 4*84, 84)
|
| 331 |
+
-
|
| 332 |
+
+ for j in range(0, ydim, batch_size_retrieval):
|
| 333 |
+
# second argument for ssim
|
| 334 |
+
- repeated_other_row = other_row_of_obs.repeat(row_B, 1, 1, 1).to(device)
|
| 335 |
+
- assert repeated_other_row.shape == (row_B*other_row_B, 3, 4*84, 84)
|
| 336 |
+
+ ybatch = all_processed_rows_of_obs[j:j+batch_size_retrieval]
|
| 337 |
+
+ ybdim = ybatch.shape[0]
|
| 338 |
+
+ ybatch = ybatch.repeat(xbdim, 1, 1, 1).to(device)
|
| 339 |
+
+ assert ybatch.shape == (ybdim * xbdim, 3, 4*84, 84)
|
| 340 |
+
+
|
| 341 |
+
+ if ybdim < batch_size_retrieval: # for last batch
|
| 342 |
+
+ xbatch = row_of_obs.repeat_interleave(ybdim, dim=0).to(device)
|
| 343 |
+
+ assert xbatch.shape == (xbdim * ybdim, 3, 4*84, 84)
|
| 344 |
+
|
| 345 |
+
# compare via ssim and updated all_ssim
|
| 346 |
+
- ssim_score = ssim(repeated_row, repeated_other_row, data_range=1.0, size_average=False)
|
| 347 |
+
- ssim_score = ssim_score.reshape(row_B, other_row_B)
|
| 348 |
+
+ ssim_score = ssim(xbatch, ybatch, data_range=1.0, size_average=False)
|
| 349 |
+
+ ssim_score = ssim_score.reshape(xbdim, ybdim)
|
| 350 |
+
all_ssim.append(ssim_score)
|
| 351 |
+
|
| 352 |
+
- # update all_indices
|
| 353 |
+
- all_indices.extend([[other_row_idx, i] for i in range(other_row_B)])
|
| 354 |
+
-
|
| 355 |
+
# concat
|
| 356 |
+
all_ssim = torch.cat(all_ssim, dim=1)
|
| 357 |
+
- assert all_ssim.shape == (row_B, total)
|
| 358 |
+
+ assert all_ssim.shape == (xbdim, ydim)
|
| 359 |
+
|
| 360 |
+
- all_indices = np.array(all_indices)
|
| 361 |
+
- assert all_indices.shape == (total, 2)
|
| 362 |
+
+ assert all_indices.shape == (ydim, 2)
|
| 363 |
+
|
| 364 |
+
# get top-k indices
|
| 365 |
+
topk_values, topk_indices = torch.topk(all_ssim, num_to_retrieve, dim=1, largest=True)
|
| 366 |
+
topk_indices = topk_indices.cpu().numpy()
|
| 367 |
+
- assert topk_indices.shape == (row_B, num_to_retrieve)
|
| 368 |
+
+ assert topk_indices.shape == (xbdim, num_to_retrieve)
|
| 369 |
+
|
| 370 |
+
# convert topk indices to indices in the dataset
|
| 371 |
+
- retrieved_indices = np.array(all_indices[topk_indices])
|
| 372 |
+
- assert retrieved_indices.shape == (row_B, num_to_retrieve, 2)
|
| 373 |
+
-
|
| 374 |
+
- # pad the above to expected B
|
| 375 |
+
- if row_B < B:
|
| 376 |
+
- retrieved_indices = np.concatenate([retrieved_indices, np.zeros((B-row_B, num_to_retrieve, 2), dtype=int)], axis=0)
|
| 377 |
+
- assert retrieved_indices.shape == (B, num_to_retrieve, 2)
|
| 378 |
+
+ retrieved_indices = all_indices[topk_indices]
|
| 379 |
+
+ assert retrieved_indices.shape == (xbdim, num_to_retrieve, 2)
|
| 380 |
+
|
| 381 |
+
return retrieved_indices
|
| 382 |
+
|
| 383 |
+
-def build_index_vector(all_rows_of_obs_og,
|
| 384 |
+
- all_attn_masks_og,
|
| 385 |
+
+def build_index_vector(all_rows_of_obs_OG,
|
| 386 |
+
+ all_attn_masks_OG,
|
| 387 |
+
all_rows_to_consider,
|
| 388 |
+
kwargs
|
| 389 |
+
- ):
|
| 390 |
+
+ ):
|
| 391 |
+
+ """
|
| 392 |
+
+ Builds FAISS index for vector observation environments.
|
| 393 |
+
+ """
|
| 394 |
+
# read kwargs # Note: B = len of row
|
| 395 |
+
- B, attn_key, obs_key, device, task, batch_size_retrieval, nb_cores_autofaiss = kwargs['B'], kwargs['attn_key'], kwargs['obs_key'], kwargs['device'], kwargs['task'], kwargs['batch_size_retrieval'], kwargs['nb_cores_autofaiss']
|
| 396 |
+
- obs_dim = get_obs_dim(task)
|
| 397 |
+
+ nb_cores_autofaiss = kwargs['nb_cores_autofaiss']
|
| 398 |
+
|
| 399 |
+
- # take subset based on all_rows_to_consider
|
| 400 |
+
- myprint(f'Taking subset')
|
| 401 |
+
- all_rows_of_obs = all_rows_of_obs_og[all_rows_to_consider]
|
| 402 |
+
- all_attn_masks = all_attn_masks_og[all_rows_to_consider]
|
| 403 |
+
- assert (all_rows_of_obs.shape == (len(all_rows_to_consider), B, obs_dim) and
|
| 404 |
+
- all_attn_masks.shape == (len(all_rows_to_consider), B))
|
| 405 |
+
-
|
| 406 |
+
- # reshape
|
| 407 |
+
- all_attn_masks = all_attn_masks.reshape(-1)
|
| 408 |
+
- all_rows_of_obs = all_rows_of_obs.reshape(-1, obs_dim)
|
| 409 |
+
- all_rows_of_obs = all_rows_of_obs[all_attn_masks]
|
| 410 |
+
- assert all_rows_of_obs.shape == (np.sum(all_attn_masks), obs_dim)
|
| 411 |
+
+ # take subset based on all_rows_to_consider, reshape, and save indices of data
|
| 412 |
+
+ all_indices, all_processed_rows_of_obs = collect_subset(all_rows_of_obs_OG, all_attn_masks_OG, all_rows_to_consider, kwargs)
|
| 413 |
+
|
| 414 |
+
- # save indices of data to retrieve from
|
| 415 |
+
- myprint(f'Saving indices of data to retrieve from')
|
| 416 |
+
- all_indices = np.array([[row_idx, i] for row_idx in all_rows_to_consider for i in range(B)])
|
| 417 |
+
- all_indices = all_indices[all_attn_masks] # this is fine because all attn masks have 0s that only come after 1s
|
| 418 |
+
- assert all_indices.shape == (np.sum(all_attn_masks), 2)
|
| 419 |
+
+ # make sure input to build_index is float, otherwise you will get reading temp file error
|
| 420 |
+
+ all_processed_rows_of_obs = all_processed_rows_of_obs.astype(float)
|
| 421 |
+
|
| 422 |
+
# build index
|
| 423 |
+
- myprint(f'Building index...')
|
| 424 |
+
- knn_index, knn_index_infos = build_index(embeddings=all_rows_of_obs, # Note: embeddings have to be float to avoid errors in autofaiss / embedding_reader!
|
| 425 |
+
+ myprint(('-'*100) + 'Building index...')
|
| 426 |
+
+ knn_index, knn_index_infos = build_index(embeddings=all_processed_rows_of_obs, # Note: embeddings have to be float to avoid errors in autofaiss / embedding_reader!
|
| 427 |
+
save_on_disk=False,
|
| 428 |
+
min_nearest_neighbors_to_retrieve=20, # default: 20
|
| 429 |
+
max_index_query_time_ms=10, # default: 10
|
| 430 |
+
@@ -179,34 +238,32 @@ def build_index_vector(all_rows_of_obs_og,
|
| 431 |
+
metric_type='l2',
|
| 432 |
+
nb_cores=nb_cores_autofaiss, # default: None # "The number of cores to use, by default will use all cores" as seen in https://criteo.github.io/autofaiss/getting_started/quantization.html#the-build-index-command
|
| 433 |
+
)
|
| 434 |
+
+ myprint(('-'*100) + '\n\n\n')
|
| 435 |
+
|
| 436 |
+
- return knn_index, all_indices
|
| 437 |
+
+ return all_indices, knn_index
|
| 438 |
+
|
| 439 |
+
-def retrieve_vector(row_of_obs, # query: (row_B, dim)
|
| 440 |
+
- dataset, # to retrieve from
|
| 441 |
+
- all_rows_to_consider, # rows to consider
|
| 442 |
+
- num_to_retrieve, # top-k
|
| 443 |
+
+def retrieve_vector(row_of_obs, # query: (xbdim, *obs_dim)
|
| 444 |
+
+ knn_index,
|
| 445 |
+
+ all_indices,
|
| 446 |
+
+ num_to_retrieve,
|
| 447 |
+
kwargs
|
| 448 |
+
- ):
|
| 449 |
+
+ ):
|
| 450 |
+
+ """
|
| 451 |
+
+ Retrieval for vector observation environments.
|
| 452 |
+
+ """
|
| 453 |
+
assert isinstance(row_of_obs, np.ndarray)
|
| 454 |
+
|
| 455 |
+
# read few kwargs
|
| 456 |
+
B = kwargs['B']
|
| 457 |
+
|
| 458 |
+
# batch size of row_of_obs which can be <= B since we process before calling this function
|
| 459 |
+
- row_B = row_of_obs.shape[0]
|
| 460 |
+
+ xbdim = row_of_obs.shape[0]
|
| 461 |
+
|
| 462 |
+
- # read dataset_tuple
|
| 463 |
+
- all_rows_of_obs, all_attn_masks = dataset
|
| 464 |
+
-
|
| 465 |
+
- # create index and all_indices
|
| 466 |
+
- knn_index, all_indices = build_index_vector(all_rows_of_obs, all_attn_masks, all_rows_to_consider, kwargs)
|
| 467 |
+
-
|
| 468 |
+
# retrieve
|
| 469 |
+
myprint(f'Retrieving...')
|
| 470 |
+
topk_indices, _ = knn_index.search(row_of_obs, 10 * num_to_retrieve)
|
| 471 |
+
topk_indices = topk_indices.astype(int)
|
| 472 |
+
- assert topk_indices.shape == (row_B, 10 * num_to_retrieve)
|
| 473 |
+
+ assert topk_indices.shape == (xbdim, 10 * num_to_retrieve)
|
| 474 |
+
|
| 475 |
+
# remove -1s and crop to num_to_retrieve
|
| 476 |
+
try:
|
| 477 |
+
@@ -219,16 +276,10 @@ def retrieve_vector(row_of_obs, # query: (row_B, dim)
|
| 478 |
+
print(f'-------------------------------------------------------------------------------------------------------------------------------------------')
|
| 479 |
+
print(f'Leaving some -1s in topk_indices and continuing')
|
| 480 |
+
topk_indices = np.array([indices[:num_to_retrieve] for indices in topk_indices])
|
| 481 |
+
- assert topk_indices.shape == (row_B, num_to_retrieve)
|
| 482 |
+
+ assert topk_indices.shape == (xbdim, num_to_retrieve)
|
| 483 |
+
|
| 484 |
+
# convert topk indices to indices in the dataset
|
| 485 |
+
retrieved_indices = all_indices[topk_indices]
|
| 486 |
+
- assert retrieved_indices.shape == (row_B, num_to_retrieve, 2)
|
| 487 |
+
-
|
| 488 |
+
- # pad the above to expected B
|
| 489 |
+
- if row_B < B:
|
| 490 |
+
- retrieved_indices = np.concatenate([retrieved_indices, np.zeros((B-row_B, num_to_retrieve, 2), dtype=int)], axis=0)
|
| 491 |
+
- assert retrieved_indices.shape == (B, num_to_retrieve, 2)
|
| 492 |
+
+ assert retrieved_indices.shape == (xbdim, num_to_retrieve, 2)
|
| 493 |
+
|
| 494 |
+
- myprint(f'Returning')
|
| 495 |
+
return retrieved_indices
|
| 496 |
+
|
| 497 |
+
diff --git a/scripts_regent/eval_RandP.py b/scripts_regent/eval_RandP.py
|
| 498 |
+
index 07e545c..146b347 100755
|
| 499 |
+
--- a/scripts_regent/eval_RandP.py
|
| 500 |
+
+++ b/scripts_regent/eval_RandP.py
|
| 501 |
+
@@ -15,9 +15,10 @@ from transformers import AutoModelForCausalLM, AutoProcessor, HfArgumentParser
|
| 502 |
+
|
| 503 |
+
from jat.eval.rl import TASK_NAME_TO_ENV_ID, make
|
| 504 |
+
from jat.utils import normalize, push_to_hub, save_video_grid
|
| 505 |
+
-from jat_regent.RandP import RandP
|
| 506 |
+
+from jat_regent.modeling_RandP import RandP
|
| 507 |
+
from datasets import load_from_disk
|
| 508 |
+
from datasets.config import HF_DATASETS_CACHE
|
| 509 |
+
+from jat_regent.utils import myprint
|
| 510 |
+
|
| 511 |
+
|
| 512 |
+
@dataclass
|
| 513 |
+
@@ -70,6 +71,7 @@ def eval_rl(model, processor, task, eval_args):
|
| 514 |
+
scores = []
|
| 515 |
+
frames = []
|
| 516 |
+
for episode in tqdm(range(eval_args.num_episodes), desc=task, unit="episode", leave=False):
|
| 517 |
+
+ myprint(('-'*100) + f'{episode=}')
|
| 518 |
+
observation, _ = env.reset()
|
| 519 |
+
reward = None
|
| 520 |
+
rewards = []
|
| 521 |
+
@@ -96,6 +98,7 @@ def eval_rl(model, processor, task, eval_args):
|
| 522 |
+
frames.append(np.array(env.render(), dtype=np.uint8))
|
| 523 |
+
|
| 524 |
+
scores.append(sum(rewards))
|
| 525 |
+
+ myprint(('-'*100) + '\n\n\n')
|
| 526 |
+
env.close()
|
| 527 |
+
|
| 528 |
+
raw_mean, raw_std = np.mean(scores), np.std(scores)
|
| 529 |
+
@@ -145,7 +148,9 @@ def main():
|
| 530 |
+
tasks.extend([env_id for env_id in TASK_NAME_TO_ENV_ID.keys() if env_id.startswith(domain)])
|
| 531 |
+
|
| 532 |
+
device = torch.device("cpu") if eval_args.use_cpu else get_default_device()
|
| 533 |
+
- processor = None
|
| 534 |
+
+ processor = AutoProcessor.from_pretrained(
|
| 535 |
+
+ 'jat-project/jat', cache_dir=None, trust_remote_code=True
|
| 536 |
+
+ )
|
| 537 |
+
|
| 538 |
+
evaluations = {}
|
| 539 |
+
video_list = []
|
| 540 |
+
@@ -153,14 +158,18 @@ def main():
|
| 541 |
+
|
| 542 |
+
for task in tqdm(tasks, desc="Evaluation", unit="task", leave=True):
|
| 543 |
+
if task in TASK_NAME_TO_ENV_ID.keys():
|
| 544 |
+
+ myprint(('-'*100) + f'{task=}')
|
| 545 |
+
dataset = load_from_disk(f'{HF_DATASETS_CACHE}/jat-project/jat-dataset-tokenized/{task}')
|
| 546 |
+
- model = RandP(dataset)
|
| 547 |
+
+ model = RandP(task,
|
| 548 |
+
+ dataset,
|
| 549 |
+
+ device,)
|
| 550 |
+
scores, frames, fps = eval_rl(model, processor, task, eval_args)
|
| 551 |
+
evaluations[task] = scores
|
| 552 |
+
# Save the video
|
| 553 |
+
if eval_args.save_video:
|
| 554 |
+
video_list.append(frames)
|
| 555 |
+
input_fps.append(fps)
|
| 556 |
+
+ myprint(('-'*100) + '\n\n\n')
|
| 557 |
+
else:
|
| 558 |
+
warnings.warn(f"Task {task} is not supported.")
|
| 559 |
+
|
| 560 |
+
diff --git a/scripts_regent/offline_retrieval_jat_regent.py b/scripts_regent/offline_retrieval_jat_regent.py
|
| 561 |
+
index c83d259..aad678a 100644
|
| 562 |
+
--- a/scripts_regent/offline_retrieval_jat_regent.py
|
| 563 |
+
+++ b/scripts_regent/offline_retrieval_jat_regent.py
|
| 564 |
+
@@ -8,7 +8,7 @@ import time
|
| 565 |
+
from datetime import datetime
|
| 566 |
+
from datasets import load_from_disk
|
| 567 |
+
from datasets.config import HF_DATASETS_CACHE
|
| 568 |
+
-from jat_regent.utils import myprint, process_row_atari, process_row_vector, retrieve_atari, retrieve_vector
|
| 569 |
+
+from jat_regent.utils import myprint, get_task_info, collect_all_data, process_row_of_obs_atari_full_without_mask, retrieve_atari, retrieve_vector, collect_subset, build_index_vector
|
| 570 |
+
import logging
|
| 571 |
+
logging.basicConfig(level=logging.DEBUG)
|
| 572 |
+
|
| 573 |
+
@@ -17,7 +17,8 @@ def main():
|
| 574 |
+
parser = argparse.ArgumentParser(description='Build RAAGENT sequence indices')
|
| 575 |
+
parser.add_argument('--task', type=str, default='atari-alien', help='Task name')
|
| 576 |
+
parser.add_argument('--num_to_retrieve', type=int, default=100, help='Number of states/windows to retrieve')
|
| 577 |
+
- parser.add_argument('--nb_cores_autofaiss', type=int, default=16, help='Number of cores to use for faiss in vector observation environments')
|
| 578 |
+
+ parser.add_argument('--nb_cores_autofaiss', type=int, default=16, help='Number of cores to use for faiss in vector obs envs')
|
| 579 |
+
+ parser.add_argument('--batch_size_retrieval', type=int, default=1024, help='Batch size for retrieval in atari')
|
| 580 |
+
args = parser.parse_args()
|
| 581 |
+
|
| 582 |
+
# load dataset, map, device, for task
|
| 583 |
+
@@ -25,77 +26,83 @@ def main():
|
| 584 |
+
dataset_path = f"{HF_DATASETS_CACHE}/jat-project/jat-dataset-tokenized/{task}"
|
| 585 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 586 |
+
|
| 587 |
+
- rew_key = 'rewards'
|
| 588 |
+
- attn_key = 'attention_mask'
|
| 589 |
+
- if task.startswith("atari"):
|
| 590 |
+
- obs_key = 'image_observations'
|
| 591 |
+
- act_key = 'discrete_actions'
|
| 592 |
+
- len_row_tokenized_known = 32 # half of 54
|
| 593 |
+
- process_row_fn = process_row_atari
|
| 594 |
+
- retrieve_fn = retrieve_atari
|
| 595 |
+
- elif task.startswith("babyai"):
|
| 596 |
+
- obs_key = 'discrete_observations'# also has 'text_observations' only for raw dataset not for tokenized dataset (as it is combined into discrete_observation in tokenized dataset)
|
| 597 |
+
- act_key = 'discrete_actions'
|
| 598 |
+
- len_row_tokenized_known = 256 # half of 512
|
| 599 |
+
- process_row_fn = lambda attn_mask, row_of_obs, task: process_row_vector(attn_mask, row_of_obs, task, return_numpy=True)
|
| 600 |
+
- retrieve_fn = retrieve_vector
|
| 601 |
+
- elif task.startswith("metaworld") or task.startswith("mujoco"):
|
| 602 |
+
- obs_key = 'continuous_observations'
|
| 603 |
+
- act_key = 'continuous_actions'
|
| 604 |
+
- len_row_tokenized_known = 256
|
| 605 |
+
- process_row_fn = lambda attn_mask, row_of_obs, task: process_row_vector(attn_mask, row_of_obs, task, return_numpy=True)
|
| 606 |
+
- retrieve_fn = retrieve_vector
|
| 607 |
+
+ rew_key, attn_key, obs_key, act_key, B, obs_dim = get_task_info(task)
|
| 608 |
+
|
| 609 |
+
dataset = load_from_disk(dataset_path)
|
| 610 |
+
with open(f"{dataset_path}/map_from_rows_to_episodes_for_tokenized.json", 'r') as f:
|
| 611 |
+
map_from_rows_to_episodes_for_tokenized = json.load(f)
|
| 612 |
+
|
| 613 |
+
# setup kwargs
|
| 614 |
+
- len_dataset = len(dataset['train'])
|
| 615 |
+
- B = len_row_tokenized_known
|
| 616 |
+
kwargs = {'B': B,
|
| 617 |
+
- 'attn_key':attn_key,
|
| 618 |
+
- 'obs_key':obs_key,
|
| 619 |
+
- 'device':device,
|
| 620 |
+
- 'task':task,
|
| 621 |
+
- 'batch_size_retrieval':None,
|
| 622 |
+
- 'nb_cores_autofaiss':None if task.startswith("atari") else args.nb_cores_autofaiss,
|
| 623 |
+
- }
|
| 624 |
+
+ 'obs_dim': obs_dim,
|
| 625 |
+
+ 'attn_key': attn_key,
|
| 626 |
+
+ 'obs_key': obs_key,
|
| 627 |
+
+ 'device': device,
|
| 628 |
+
+ 'task': task,
|
| 629 |
+
+ 'batch_size_retrieval': args.batch_size_retrieval,
|
| 630 |
+
+ 'nb_cores_autofaiss': None if task.startswith("atari") else args.nb_cores_autofaiss,
|
| 631 |
+
+ }
|
| 632 |
+
|
| 633 |
+
# collect all observations in a single array (this takes some time) for vector observation environments
|
| 634 |
+
- if not task.startswith("atari"):
|
| 635 |
+
- myprint("Collecting all observations/attn_masks in a single array")
|
| 636 |
+
- all_rows_of_obs = np.array(dataset['train'][obs_key])
|
| 637 |
+
- all_attn_masks = np.array(dataset['train'][attn_key]).astype(bool)
|
| 638 |
+
+ myprint("Collecting all observations/attn_masks")
|
| 639 |
+
+ all_rows_of_obs_OG, all_attn_masks_OG, all_row_idxs = collect_all_data(dataset, task, obs_key)
|
| 640 |
+
|
| 641 |
+
# iterate over rows
|
| 642 |
+
all_retrieved_indices = []
|
| 643 |
+
- for row_idx in range(len_dataset):
|
| 644 |
+
- myprint(f"\nProcessing row {row_idx}/{len_dataset}")
|
| 645 |
+
+ for row_idx in all_row_idxs:
|
| 646 |
+
+ myprint(f"\nProcessing row {row_idx}/{len(all_row_idxs)}")
|
| 647 |
+
current_ep = map_from_rows_to_episodes_for_tokenized[str(row_idx)]
|
| 648 |
+
|
| 649 |
+
- attn_mask, row_of_obs = process_row_fn(dataset['train'][row_idx][attn_key], dataset['train'][row_idx][obs_key], task)
|
| 650 |
+
+ # get row_of_obs and attn_mask
|
| 651 |
+
+ datarow = dataset['train'][row_idx]
|
| 652 |
+
+ attn_mask = np.array(datarow[attn_key]).astype(bool)
|
| 653 |
+
+ if task.startswith("atari"):
|
| 654 |
+
+ row_of_obs = process_row_of_obs_atari_full_without_mask(datarow[obs_key])
|
| 655 |
+
+ else:
|
| 656 |
+
+ row_of_obs = np.array(datarow[obs_key])
|
| 657 |
+
+ row_of_obs = row_of_obs[attn_mask]
|
| 658 |
+
+ assert row_of_obs.shape == (np.sum(attn_mask), *obs_dim)
|
| 659 |
+
|
| 660 |
+
# compare with rows from all but the current episode
|
| 661 |
+
- all_other_rows = [idx for idx in range(len_dataset) if map_from_rows_to_episodes_for_tokenized[str(idx)] != current_ep]
|
| 662 |
+
+ all_other_row_idxs = [idx for idx in all_row_idxs if map_from_rows_to_episodes_for_tokenized[str(idx)] != current_ep]
|
| 663 |
+
|
| 664 |
+
# do the retrieval
|
| 665 |
+
- retrieved_indices = retrieve_fn(row_of_obs=row_of_obs,
|
| 666 |
+
- dataset=dataset if task.startswith("atari") else (all_rows_of_obs, all_attn_masks),
|
| 667 |
+
- all_rows_to_consider=all_other_rows,
|
| 668 |
+
- num_to_retrieve=args.num_to_retrieve,
|
| 669 |
+
- kwargs=kwargs,
|
| 670 |
+
- )
|
| 671 |
+
+ if task.startswith("atari"):
|
| 672 |
+
+ all_indices, all_processed_rows_of_obs = collect_subset(all_rows_of_obs_OG=all_rows_of_obs_OG,
|
| 673 |
+
+ all_attn_masks_OG=all_attn_masks_OG,
|
| 674 |
+
+ all_rows_to_consider=all_row_idxs,
|
| 675 |
+
+ kwargs=kwargs)
|
| 676 |
+
+ retrieved_indices = retrieve_atari(row_of_obs=row_of_obs,
|
| 677 |
+
+ all_processed_rows_of_obs=all_processed_rows_of_obs,
|
| 678 |
+
+ all_indices=all_indices,
|
| 679 |
+
+ num_to_retrieve=args.num_to_retrieve,
|
| 680 |
+
+ kwargs=kwargs)
|
| 681 |
+
+ else:
|
| 682 |
+
+ all_indices, knn_index = build_index_vector(all_rows_of_obs_OG=all_rows_of_obs_OG,
|
| 683 |
+
+ all_attn_masks_OG=all_attn_masks_OG,
|
| 684 |
+
+ all_rows_to_consider=all_other_row_idxs,
|
| 685 |
+
+ kwargs=kwargs)
|
| 686 |
+
+ retrieved_indices = retrieve_vector(row_of_obs=row_of_obs,
|
| 687 |
+
+ knn_index=knn_index,
|
| 688 |
+
+ all_indices=all_indices,
|
| 689 |
+
+ num_to_retrieve=args.num_to_retrieve,
|
| 690 |
+
+ kwargs=kwargs)
|
| 691 |
+
+
|
| 692 |
+
+ # pad the above to expected B
|
| 693 |
+
+ xbdim = row_of_obs.shape[0]
|
| 694 |
+
+ if xbdim < B:
|
| 695 |
+
+ retrieved_indices = np.concatenate([retrieved_indices, np.zeros((B-xbdim, args.num_to_retrieve, 2), dtype=int)], axis=0)
|
| 696 |
+
+ assert retrieved_indices.shape == (B, args.num_to_retrieve, 2)
|
| 697 |
+
|
| 698 |
+
# collect retrieved indices
|
| 699 |
+
all_retrieved_indices.append(retrieved_indices)
|
| 700 |
+
|
| 701 |
+
# concat
|
| 702 |
+
all_retrieved_indices = np.stack(all_retrieved_indices, axis=0)
|
| 703 |
+
- assert all_retrieved_indices.shape == (len_dataset, B, args.num_to_retrieve, 2)
|
| 704 |
+
+ assert all_retrieved_indices.shape == (len(all_row_idxs), B, args.num_to_retrieve, 2)
|
| 705 |
+
|
| 706 |
+
# save arrays as bin for easy memmap access and faster loading
|
| 707 |
+
- all_retrieved_indices.tofile(f"{dataset_path}/retrieved_indices_{len_dataset}_{B}_{args.num_to_retrieve}_2.bin")
|
| 708 |
+
+ all_retrieved_indices.tofile(f"{dataset_path}/retrieved_indices_{len(all_row_idxs)}_{B}_{args.num_to_retrieve}_2.bin")
|
| 709 |
+
|
| 710 |
+
if __name__ == "__main__":
|
| 711 |
+
main()
|
| 712 |
+
|
replay.mp4
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:07be059a9429a473b5b17baa3708722b914b49c1e81c2c57e350ea5acb4339b7
|
| 3 |
+
size 1295354
|
sf_log.txt
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|