Cohom commited on
Commit
1552a43
·
verified ·
1 Parent(s): 5cdd150

Add files using upload-large-folder tool

Browse files
openvla-7b-prismatic/.gitattributes ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tar filter=lfs diff=lfs merge=lfs -text
29
+ *.tflite filter=lfs diff=lfs merge=lfs -text
30
+ *.tgz filter=lfs diff=lfs merge=lfs -text
31
+ *.wasm filter=lfs diff=lfs merge=lfs -text
32
+ *.xz filter=lfs diff=lfs merge=lfs -text
33
+ *.zip filter=lfs diff=lfs merge=lfs -text
34
+ *.zst filter=lfs diff=lfs merge=lfs -text
35
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
openvla-7b-prismatic/README.md ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ library_name: transformers
3
+ tags:
4
+ - robotics
5
+ - vla
6
+ - image-text-to-text
7
+ - multimodal
8
+ - pretraining
9
+ license: mit
10
+ language:
11
+ - en
12
+ pipeline_tag: image-text-to-text
13
+ ---
14
+
15
+ # OpenVLA 7B (Prismatic-Compatible Version)
16
+
17
+ <b>This is the same model as the [OpenVLA 7B model](https://huggingface.co/openvla/openvla-7b), except that this checkpoint is in a format that is
18
+ compatible with the training script from the original [Prismatic VLMs project codebase](https://github.com/TRI-ML/prismatic-vlms), which the OpenVLA
19
+ team built on top of to develop the OpenVLA model. See details for the OpenVLA 7B model here: https://huggingface.co/openvla/openvla-7b</b>
20
+
21
+ This Prismatic-compatible checkpoint may be useful if you wish to <b>fully fine-tune</b> OpenVLA (all 7.5 billion parameters) via native PyTorch Fully
22
+ Sharded Data Parallel (FSDP) using the Prismatic VLMs training script. If you instead wish to do Parameter-Efficient Fine-Tuning via LoRA, you
23
+ can use the OpenVLA checkpoint linked above, which is compatible with the Hugging Face `transformers` library. We recommend fine-tuning via LoRA if
24
+ you do not have sufficient compute to fully fine-tune a 7B-parameter model (e.g., multiple A100/H100 GPUs).
25
+
26
+ ## Usage Instructions
27
+
28
+ See the [OpenVLA GitHub README](https://github.com/openvla/openvla/blob/main/README.md) for instructions on how to use this checkpoint for full fine-tuning.
29
+
30
+ ## Citation
31
+
32
+ **BibTeX:**
33
+
34
+ ```bibtex
35
+ @article{kim24openvla,
36
+ title={OpenVLA: An Open-Source Vision-Language-Action Model},
37
+ author={{Moo Jin} Kim and Karl Pertsch and Siddharth Karamcheti and Ted Xiao and Ashwin Balakrishna and Suraj Nair and Rafael Rafailov and Ethan Foster and Grace Lam and Pannag Sanketi and Quan Vuong and Thomas Kollar and Benjamin Burchfiel and Russ Tedrake and Dorsa Sadigh and Sergey Levine and Percy Liang and Chelsea Finn},
38
+ journal = {arXiv preprint arXiv:2406.09246},
39
+ year={2024}
40
+ }
41
+ ```
openvla-7b-prismatic/checkpoints/step-295000-epoch-40-loss=0.2200.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:2c2497cd9e0ecced65e54b0771172f22e3ed64d0c0af339e094349715d3b3602
3
+ size 30165309772
openvla-7b-prismatic/config.json ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "data_root_dir": "/scr/user/data",
3
+ "hf_token": ".hf_token",
4
+ "pretrained_checkpoint": "",
5
+ "resume_epoch": null,
6
+ "resume_step": null,
7
+ "run_id": "prism-dinosiglip-224px+mx-oxe-magic-soup-plus+n8+b32+x7",
8
+ "run_id_note": null,
9
+ "run_root_dir": "./runs",
10
+ "save_interval": 2500,
11
+ "seed": 7,
12
+ "stage": "vla-full-train",
13
+ "trackers": [
14
+ "jsonl",
15
+ "wandb"
16
+ ],
17
+ "vla": {
18
+ "base_vlm": "prism-dinosiglip-224px+7b",
19
+ "data_mix": "oxe_magic_soup_plus_minus",
20
+ "enable_gradient_checkpointing": true,
21
+ "enable_mixed_precision_training": true,
22
+ "epochs": 1000,
23
+ "expected_world_size": 64,
24
+ "freeze_vision_backbone": false,
25
+ "global_batch_size": 2048,
26
+ "learning_rate": 2e-05,
27
+ "lr_scheduler_type": "constant",
28
+ "max_grad_norm": 1.0,
29
+ "max_steps": null,
30
+ "per_device_batch_size": 32,
31
+ "reduce_in_full_precision": true,
32
+ "shuffle_buffer_size": 256000,
33
+ "train_strategy": "fsdp-full-shard",
34
+ "type": "prism-dinosiglip-224px+mx-oxe-magic-soup-plus",
35
+ "vla_id": "prism-dinosiglip-224px+mx-oxe-magic-soup-plus",
36
+ "warmup_ratio": 0.0,
37
+ "weight_decay": 0.0
38
+ },
39
+ "wandb_entity": "",
40
+ "wandb_project": ""
41
+ }
openvla-7b-prismatic/config.yml ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ data_root_dir: /scr/user/data
2
+ hf_token: .hf_token
3
+ pretrained_checkpoint: null
4
+ resume_epoch: null
5
+ resume_step: null
6
+ run_id: prism-dinosiglip-224px+mx-oxe-magic-soup-plus+n8+b32+x7
7
+ run_id_note: null
8
+ run_root_dir: ./runs
9
+ save_interval: 2500
10
+ seed: 7
11
+ stage: vla-full-train
12
+ trackers:
13
+ - jsonl
14
+ - wandb
15
+ vla:
16
+ base_vlm: prism-dinosiglip-224px+7b
17
+ data_mix: oxe_magic_soup_plus_minus
18
+ enable_gradient_checkpointing: true
19
+ enable_mixed_precision_training: true
20
+ epochs: 1000
21
+ expected_world_size: 64
22
+ freeze_vision_backbone: false
23
+ global_batch_size: 2048
24
+ learning_rate: 2.0e-05
25
+ lr_scheduler_type: constant
26
+ max_grad_norm: 1.0
27
+ max_steps: null
28
+ per_device_batch_size: 32
29
+ reduce_in_full_precision: true
30
+ shuffle_buffer_size: 256000
31
+ train_strategy: fsdp-full-shard
32
+ type: prism-dinosiglip-224px+mx-oxe-magic-soup-plus
33
+ vla_id: prism-dinosiglip-224px+mx-oxe-magic-soup-plus
34
+ warmup_ratio: 0.0
35
+ weight_decay: 0.0
36
+ wandb_entity: null
37
+ wandb_project: null
openvla-7b-prismatic/dataset_statistics.json ADDED
@@ -0,0 +1,3127 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "fractal20220817_data": {
3
+ "action": {
4
+ "mean": [
5
+ 0.006987582892179489,
6
+ 0.006265917327255011,
7
+ -0.01262515690177679,
8
+ 0.04333311319351196,
9
+ -0.005756212864071131,
10
+ 0.0009130256366916001,
11
+ 0.5354204773902893
12
+ ],
13
+ "std": [
14
+ 0.0692116990685463,
15
+ 0.05970962345600128,
16
+ 0.07353084534406662,
17
+ 0.15610496699810028,
18
+ 0.13164450228214264,
19
+ 0.14593800902366638,
20
+ 0.497110515832901
21
+ ],
22
+ "max": [
23
+ 2.9984593391418457,
24
+ 22.09052848815918,
25
+ 2.7507524490356445,
26
+ 1.570636510848999,
27
+ 1.5321086645126343,
28
+ 1.5691522359848022,
29
+ 1.0
30
+ ],
31
+ "min": [
32
+ -2.0204520225524902,
33
+ -5.497899532318115,
34
+ -2.031663417816162,
35
+ -1.569917917251587,
36
+ -1.569892168045044,
37
+ -1.570419430732727,
38
+ 0.0
39
+ ],
40
+ "q01": [
41
+ -0.22453527510166169,
42
+ -0.14820013284683228,
43
+ -0.231589707583189,
44
+ -0.3517994859814644,
45
+ -0.4193011274933815,
46
+ -0.43643461108207704,
47
+ 0.0
48
+ ],
49
+ "q99": [
50
+ 0.17824687153100965,
51
+ 0.14938379630446405,
52
+ 0.21842354819178575,
53
+ 0.5892666035890578,
54
+ 0.35272657424211445,
55
+ 0.44796681255102094,
56
+ 1.0
57
+ ],
58
+ "mask": [
59
+ true,
60
+ true,
61
+ true,
62
+ true,
63
+ true,
64
+ true,
65
+ false
66
+ ]
67
+ },
68
+ "proprio": {
69
+ "mean": [
70
+ 0.0,
71
+ 0.0,
72
+ 0.0,
73
+ 0.0,
74
+ 0.0,
75
+ 0.0,
76
+ 0.0
77
+ ],
78
+ "std": [
79
+ 0.0,
80
+ 0.0,
81
+ 0.0,
82
+ 0.0,
83
+ 0.0,
84
+ 0.0,
85
+ 0.0
86
+ ],
87
+ "max": [
88
+ 0.0,
89
+ 0.0,
90
+ 0.0,
91
+ 0.0,
92
+ 0.0,
93
+ 0.0,
94
+ 0.0
95
+ ],
96
+ "min": [
97
+ 0.0,
98
+ 0.0,
99
+ 0.0,
100
+ 0.0,
101
+ 0.0,
102
+ 0.0,
103
+ 0.0
104
+ ],
105
+ "q01": [
106
+ 0.0,
107
+ 0.0,
108
+ 0.0,
109
+ 0.0,
110
+ 0.0,
111
+ 0.0,
112
+ 0.0
113
+ ],
114
+ "q99": [
115
+ 0.0,
116
+ 0.0,
117
+ 0.0,
118
+ 0.0,
119
+ 0.0,
120
+ 0.0,
121
+ 0.0
122
+ ]
123
+ },
124
+ "num_transitions": 3786400,
125
+ "num_trajectories": 87212
126
+ },
127
+ "kuka": {
128
+ "action": {
129
+ "mean": [
130
+ -0.0004668905457947403,
131
+ 0.00040138536132872105,
132
+ -0.001280792523175478,
133
+ 0.0,
134
+ 0.0,
135
+ -0.03722453489899635,
136
+ 0.4131543040275574
137
+ ],
138
+ "std": [
139
+ 0.02083250693976879,
140
+ 0.02915887162089348,
141
+ 0.06422865390777588,
142
+ 0.0,
143
+ 0.0,
144
+ 0.14224295318126678,
145
+ 0.49086448550224304
146
+ ],
147
+ "max": [
148
+ 0.1697135865688324,
149
+ 0.2777623236179352,
150
+ 0.43710532784461975,
151
+ 0.0,
152
+ 0.0,
153
+ 1.9684287309646606,
154
+ 1.0
155
+ ],
156
+ "min": [
157
+ -0.159867063164711,
158
+ -0.2892282009124756,
159
+ -0.2795473635196686,
160
+ 0.0,
161
+ 0.0,
162
+ -1.9875637292861938,
163
+ 0.0
164
+ ],
165
+ "q01": [
166
+ -0.06619441494345665,
167
+ -0.08713878810405731,
168
+ -0.15083016991615295,
169
+ 0.0,
170
+ 0.0,
171
+ -0.5415697038173676,
172
+ 0.0
173
+ ],
174
+ "q99": [
175
+ 0.06601839080452929,
176
+ 0.08732476785779003,
177
+ 0.18168179214000715,
178
+ 0.0,
179
+ 0.0,
180
+ 0.2923380345106127,
181
+ 1.0
182
+ ],
183
+ "mask": [
184
+ true,
185
+ true,
186
+ true,
187
+ true,
188
+ true,
189
+ true,
190
+ false
191
+ ]
192
+ },
193
+ "proprio": {
194
+ "mean": [
195
+ 0.0,
196
+ 0.0,
197
+ 0.0,
198
+ 0.0,
199
+ 0.0,
200
+ 0.0,
201
+ 0.0
202
+ ],
203
+ "std": [
204
+ 0.0,
205
+ 0.0,
206
+ 0.0,
207
+ 0.0,
208
+ 0.0,
209
+ 0.0,
210
+ 0.0
211
+ ],
212
+ "max": [
213
+ 0.0,
214
+ 0.0,
215
+ 0.0,
216
+ 0.0,
217
+ 0.0,
218
+ 0.0,
219
+ 0.0
220
+ ],
221
+ "min": [
222
+ 0.0,
223
+ 0.0,
224
+ 0.0,
225
+ 0.0,
226
+ 0.0,
227
+ 0.0,
228
+ 0.0
229
+ ],
230
+ "q01": [
231
+ 0.0,
232
+ 0.0,
233
+ 0.0,
234
+ 0.0,
235
+ 0.0,
236
+ 0.0,
237
+ 0.0
238
+ ],
239
+ "q99": [
240
+ 0.0,
241
+ 0.0,
242
+ 0.0,
243
+ 0.0,
244
+ 0.0,
245
+ 0.0,
246
+ 0.0
247
+ ]
248
+ },
249
+ "num_transitions": 2455879,
250
+ "num_trajectories": 209880
251
+ },
252
+ "bridge_orig": {
253
+ "action": {
254
+ "mean": [
255
+ 0.0002334194869035855,
256
+ 0.00013004911306779832,
257
+ -0.00012762474943883717,
258
+ -0.0001556558854645118,
259
+ -0.0004039328487124294,
260
+ 0.00023557482927571982,
261
+ 0.5764579176902771
262
+ ],
263
+ "std": [
264
+ 0.009765930473804474,
265
+ 0.013689135201275349,
266
+ 0.012667362578213215,
267
+ 0.028534092009067535,
268
+ 0.030637972056865692,
269
+ 0.07691419124603271,
270
+ 0.4973701536655426
271
+ ],
272
+ "max": [
273
+ 0.41691166162490845,
274
+ 0.25864794850349426,
275
+ 0.21218234300613403,
276
+ 3.122201919555664,
277
+ 1.8618112802505493,
278
+ 6.280478477478027,
279
+ 1.0
280
+ ],
281
+ "min": [
282
+ -0.4007510244846344,
283
+ -0.13874775171279907,
284
+ -0.22553899884223938,
285
+ -3.2010786533355713,
286
+ -1.8618112802505493,
287
+ -6.279075622558594,
288
+ 0.0
289
+ ],
290
+ "q01": [
291
+ -0.02872725307941437,
292
+ -0.04170349963009357,
293
+ -0.026093858778476715,
294
+ -0.08092105075716972,
295
+ -0.09288699507713317,
296
+ -0.20718276381492615,
297
+ 0.0
298
+ ],
299
+ "q99": [
300
+ 0.028309678435325586,
301
+ 0.040855254605412394,
302
+ 0.040161586627364146,
303
+ 0.08192047759890528,
304
+ 0.07792850524187081,
305
+ 0.20382574498653397,
306
+ 1.0
307
+ ],
308
+ "mask": [
309
+ true,
310
+ true,
311
+ true,
312
+ true,
313
+ true,
314
+ true,
315
+ false
316
+ ]
317
+ },
318
+ "proprio": {
319
+ "mean": [
320
+ 0.0,
321
+ 0.0,
322
+ 0.0,
323
+ 0.0,
324
+ 0.0,
325
+ 0.0,
326
+ 0.0
327
+ ],
328
+ "std": [
329
+ 0.0,
330
+ 0.0,
331
+ 0.0,
332
+ 0.0,
333
+ 0.0,
334
+ 0.0,
335
+ 0.0
336
+ ],
337
+ "max": [
338
+ 0.0,
339
+ 0.0,
340
+ 0.0,
341
+ 0.0,
342
+ 0.0,
343
+ 0.0,
344
+ 0.0
345
+ ],
346
+ "min": [
347
+ 0.0,
348
+ 0.0,
349
+ 0.0,
350
+ 0.0,
351
+ 0.0,
352
+ 0.0,
353
+ 0.0
354
+ ],
355
+ "q01": [
356
+ 0.0,
357
+ 0.0,
358
+ 0.0,
359
+ 0.0,
360
+ 0.0,
361
+ 0.0,
362
+ 0.0
363
+ ],
364
+ "q99": [
365
+ 0.0,
366
+ 0.0,
367
+ 0.0,
368
+ 0.0,
369
+ 0.0,
370
+ 0.0,
371
+ 0.0
372
+ ]
373
+ },
374
+ "num_transitions": 2135463,
375
+ "num_trajectories": 60064
376
+ },
377
+ "taco_play": {
378
+ "action": {
379
+ "mean": [
380
+ -0.003845922416076064,
381
+ 0.009671456180512905,
382
+ 0.012780580669641495,
383
+ -0.005403771996498108,
384
+ -0.009606587700545788,
385
+ -0.002480733208358288,
386
+ 0.4263913035392761
387
+ ],
388
+ "std": [
389
+ 0.23254038393497467,
390
+ 0.36298269033432007,
391
+ 0.28692901134490967,
392
+ 0.2617705166339874,
393
+ 0.2438892275094986,
394
+ 0.5216503143310547,
395
+ 0.4946896731853485
396
+ ],
397
+ "max": [
398
+ 1.4915844202041626,
399
+ 2.1842432022094727,
400
+ 2.6836395263671875,
401
+ 5.035226821899414,
402
+ 2.665864944458008,
403
+ 4.250768661499023,
404
+ 1.0
405
+ ],
406
+ "min": [
407
+ -4.242457866668701,
408
+ -3.192805051803589,
409
+ -1.3371467590332031,
410
+ -4.202683448791504,
411
+ -2.6722638607025146,
412
+ -3.3467135429382324,
413
+ 0.0
414
+ ],
415
+ "q01": [
416
+ -0.7106140398979186,
417
+ -1.056944659948349,
418
+ -0.5878450274467468,
419
+ -0.7682853937149048,
420
+ -0.7180147767066956,
421
+ -1.5527938604354858,
422
+ 0.0
423
+ ],
424
+ "q99": [
425
+ 0.6482916426658629,
426
+ 1.0051310062408447,
427
+ 0.9480248689651489,
428
+ 0.6926478147506714,
429
+ 0.6351067513227462,
430
+ 1.628010264635086,
431
+ 1.0
432
+ ],
433
+ "mask": [
434
+ true,
435
+ true,
436
+ true,
437
+ true,
438
+ true,
439
+ true,
440
+ false
441
+ ]
442
+ },
443
+ "proprio": {
444
+ "mean": [
445
+ 0.0,
446
+ 0.0,
447
+ 0.0,
448
+ 0.0,
449
+ 0.0,
450
+ 0.0,
451
+ 0.0
452
+ ],
453
+ "std": [
454
+ 0.0,
455
+ 0.0,
456
+ 0.0,
457
+ 0.0,
458
+ 0.0,
459
+ 0.0,
460
+ 0.0
461
+ ],
462
+ "max": [
463
+ 0.0,
464
+ 0.0,
465
+ 0.0,
466
+ 0.0,
467
+ 0.0,
468
+ 0.0,
469
+ 0.0
470
+ ],
471
+ "min": [
472
+ 0.0,
473
+ 0.0,
474
+ 0.0,
475
+ 0.0,
476
+ 0.0,
477
+ 0.0,
478
+ 0.0
479
+ ],
480
+ "q01": [
481
+ 0.0,
482
+ 0.0,
483
+ 0.0,
484
+ 0.0,
485
+ 0.0,
486
+ 0.0,
487
+ 0.0
488
+ ],
489
+ "q99": [
490
+ 0.0,
491
+ 0.0,
492
+ 0.0,
493
+ 0.0,
494
+ 0.0,
495
+ 0.0,
496
+ 0.0
497
+ ]
498
+ },
499
+ "num_transitions": 237798,
500
+ "num_trajectories": 3603
501
+ },
502
+ "jaco_play": {
503
+ "action": {
504
+ "mean": [
505
+ 0.0009658430935814977,
506
+ -0.00580078037455678,
507
+ -0.00395062193274498,
508
+ 0.0,
509
+ 0.0,
510
+ 0.0,
511
+ 0.34934908151626587
512
+ ],
513
+ "std": [
514
+ 0.12235074490308762,
515
+ 0.09678777307271957,
516
+ 0.11155334860086441,
517
+ 0.0,
518
+ 0.0,
519
+ 0.0,
520
+ 0.4768252968788147
521
+ ],
522
+ "max": [
523
+ 0.20000000298023224,
524
+ 0.20000000298023224,
525
+ 0.20000000298023224,
526
+ 0.0,
527
+ 0.0,
528
+ 0.0,
529
+ 1.0
530
+ ],
531
+ "min": [
532
+ -0.20000000298023224,
533
+ -0.20000000298023224,
534
+ -0.20000000298023224,
535
+ 0.0,
536
+ 0.0,
537
+ 0.0,
538
+ 0.0
539
+ ],
540
+ "q01": [
541
+ -0.20000000298023224,
542
+ -0.20000000298023224,
543
+ -0.20000000298023224,
544
+ 0.0,
545
+ 0.0,
546
+ 0.0,
547
+ 0.0
548
+ ],
549
+ "q99": [
550
+ 0.20000000298023224,
551
+ 0.20000000298023224,
552
+ 0.20000000298023224,
553
+ 0.0,
554
+ 0.0,
555
+ 0.0,
556
+ 1.0
557
+ ],
558
+ "mask": [
559
+ true,
560
+ true,
561
+ true,
562
+ true,
563
+ true,
564
+ true,
565
+ false
566
+ ]
567
+ },
568
+ "proprio": {
569
+ "mean": [
570
+ 0.0,
571
+ 0.0,
572
+ 0.0,
573
+ 0.0,
574
+ 0.0,
575
+ 0.0,
576
+ 0.0
577
+ ],
578
+ "std": [
579
+ 0.0,
580
+ 0.0,
581
+ 0.0,
582
+ 0.0,
583
+ 0.0,
584
+ 0.0,
585
+ 0.0
586
+ ],
587
+ "max": [
588
+ 0.0,
589
+ 0.0,
590
+ 0.0,
591
+ 0.0,
592
+ 0.0,
593
+ 0.0,
594
+ 0.0
595
+ ],
596
+ "min": [
597
+ 0.0,
598
+ 0.0,
599
+ 0.0,
600
+ 0.0,
601
+ 0.0,
602
+ 0.0,
603
+ 0.0
604
+ ],
605
+ "q01": [
606
+ 0.0,
607
+ 0.0,
608
+ 0.0,
609
+ 0.0,
610
+ 0.0,
611
+ 0.0,
612
+ 0.0
613
+ ],
614
+ "q99": [
615
+ 0.0,
616
+ 0.0,
617
+ 0.0,
618
+ 0.0,
619
+ 0.0,
620
+ 0.0,
621
+ 0.0
622
+ ]
623
+ },
624
+ "num_transitions": 77965,
625
+ "num_trajectories": 1085
626
+ },
627
+ "berkeley_cable_routing": {
628
+ "action": {
629
+ "mean": [
630
+ -0.07139874249696732,
631
+ 0.023609008640050888,
632
+ 0.10241943597793579,
633
+ 0.0,
634
+ 0.0,
635
+ 0.049671024084091187,
636
+ 0.0
637
+ ],
638
+ "std": [
639
+ 0.1815500408411026,
640
+ 0.1810990273952484,
641
+ 0.21220779418945312,
642
+ 0.0,
643
+ 0.0,
644
+ 0.3475511968135834,
645
+ 0.0
646
+ ],
647
+ "max": [
648
+ 0.9633283019065857,
649
+ 1.0,
650
+ 1.0,
651
+ 0.0,
652
+ 0.0,
653
+ 1.0,
654
+ 0.0
655
+ ],
656
+ "min": [
657
+ -0.9809081554412842,
658
+ -0.9554349184036255,
659
+ -0.9994775056838989,
660
+ 0.0,
661
+ 0.0,
662
+ -1.0,
663
+ 0.0
664
+ ],
665
+ "q01": [
666
+ -0.5534318816661835,
667
+ -0.4797285574674606,
668
+ -0.5314934802055359,
669
+ 0.0,
670
+ 0.0,
671
+ -0.8855219376087189,
672
+ 0.0
673
+ ],
674
+ "q99": [
675
+ 0.42652835428714786,
676
+ 0.5000944086909298,
677
+ 0.639823433756829,
678
+ 0.0,
679
+ 0.0,
680
+ 0.984243879914284,
681
+ 0.0
682
+ ],
683
+ "mask": [
684
+ true,
685
+ true,
686
+ true,
687
+ true,
688
+ true,
689
+ true,
690
+ false
691
+ ]
692
+ },
693
+ "proprio": {
694
+ "mean": [
695
+ 0.0,
696
+ 0.0,
697
+ 0.0,
698
+ 0.0,
699
+ 0.0,
700
+ 0.0,
701
+ 0.0
702
+ ],
703
+ "std": [
704
+ 0.0,
705
+ 0.0,
706
+ 0.0,
707
+ 0.0,
708
+ 0.0,
709
+ 0.0,
710
+ 0.0
711
+ ],
712
+ "max": [
713
+ 0.0,
714
+ 0.0,
715
+ 0.0,
716
+ 0.0,
717
+ 0.0,
718
+ 0.0,
719
+ 0.0
720
+ ],
721
+ "min": [
722
+ 0.0,
723
+ 0.0,
724
+ 0.0,
725
+ 0.0,
726
+ 0.0,
727
+ 0.0,
728
+ 0.0
729
+ ],
730
+ "q01": [
731
+ 0.0,
732
+ 0.0,
733
+ 0.0,
734
+ 0.0,
735
+ 0.0,
736
+ 0.0,
737
+ 0.0
738
+ ],
739
+ "q99": [
740
+ 0.0,
741
+ 0.0,
742
+ 0.0,
743
+ 0.0,
744
+ 0.0,
745
+ 0.0,
746
+ 0.0
747
+ ]
748
+ },
749
+ "num_transitions": 42328,
750
+ "num_trajectories": 1647
751
+ },
752
+ "roboturk": {
753
+ "action": {
754
+ "mean": [
755
+ 0.0014448732836171985,
756
+ -0.0015945249469950795,
757
+ -0.0011753785656765103,
758
+ 0.0023012510500848293,
759
+ -0.0009382463176734746,
760
+ -0.00011485807772260159,
761
+ 0.5746025443077087
762
+ ],
763
+ "std": [
764
+ 0.04935386776924133,
765
+ 0.0635455846786499,
766
+ 0.061164740473032,
767
+ 0.09553450345993042,
768
+ 0.08420111238956451,
769
+ 0.06517903506755829,
770
+ 0.49452081322669983
771
+ ],
772
+ "max": [
773
+ 0.39124172925949097,
774
+ 0.4601028263568878,
775
+ 0.4870833456516266,
776
+ 1.816888689994812,
777
+ 1.8240282535552979,
778
+ 1.4824820756912231,
779
+ 1.0
780
+ ],
781
+ "min": [
782
+ -0.6546999216079712,
783
+ -0.6365841031074524,
784
+ -0.4217723608016968,
785
+ -1.6695482730865479,
786
+ -1.8023357391357422,
787
+ -1.4630827903747559,
788
+ 0.0
789
+ ],
790
+ "q01": [
791
+ -0.1342635464668274,
792
+ -0.19996687173843383,
793
+ -0.1482972100377083,
794
+ -0.20720748245716095,
795
+ -0.09676413893699647,
796
+ -0.18075634717941286,
797
+ 0.0
798
+ ],
799
+ "q99": [
800
+ 0.14956976801157001,
801
+ 0.1805950567126275,
802
+ 0.18841815620660796,
803
+ 0.21615413755178453,
804
+ 0.09457383215427405,
805
+ 0.18543301910162005,
806
+ 1.0
807
+ ],
808
+ "mask": [
809
+ true,
810
+ true,
811
+ true,
812
+ true,
813
+ true,
814
+ true,
815
+ false
816
+ ]
817
+ },
818
+ "proprio": {
819
+ "mean": [
820
+ 0.0,
821
+ 0.0,
822
+ 0.0,
823
+ 0.0,
824
+ 0.0,
825
+ 0.0,
826
+ 0.0
827
+ ],
828
+ "std": [
829
+ 0.0,
830
+ 0.0,
831
+ 0.0,
832
+ 0.0,
833
+ 0.0,
834
+ 0.0,
835
+ 0.0
836
+ ],
837
+ "max": [
838
+ 0.0,
839
+ 0.0,
840
+ 0.0,
841
+ 0.0,
842
+ 0.0,
843
+ 0.0,
844
+ 0.0
845
+ ],
846
+ "min": [
847
+ 0.0,
848
+ 0.0,
849
+ 0.0,
850
+ 0.0,
851
+ 0.0,
852
+ 0.0,
853
+ 0.0
854
+ ],
855
+ "q01": [
856
+ 0.0,
857
+ 0.0,
858
+ 0.0,
859
+ 0.0,
860
+ 0.0,
861
+ 0.0,
862
+ 0.0
863
+ ],
864
+ "q99": [
865
+ 0.0,
866
+ 0.0,
867
+ 0.0,
868
+ 0.0,
869
+ 0.0,
870
+ 0.0,
871
+ 0.0
872
+ ]
873
+ },
874
+ "num_transitions": 187507,
875
+ "num_trajectories": 1995
876
+ },
877
+ "viola": {
878
+ "action": {
879
+ "mean": [
880
+ 0.04761844128370285,
881
+ -0.029204415157437325,
882
+ 0.05586736649274826,
883
+ -0.002618510741740465,
884
+ 0.006867344491183758,
885
+ -0.01682133786380291,
886
+ 0.7323777675628662
887
+ ],
888
+ "std": [
889
+ 0.39157867431640625,
890
+ 0.4076525568962097,
891
+ 0.40077948570251465,
892
+ 0.10023996233940125,
893
+ 0.0844319611787796,
894
+ 0.10375042259693146,
895
+ 0.44260647892951965
896
+ ],
897
+ "max": [
898
+ 1.0,
899
+ 1.0,
900
+ 1.0,
901
+ 0.375,
902
+ 0.36321428418159485,
903
+ 0.375,
904
+ 1.0
905
+ ],
906
+ "min": [
907
+ -1.0,
908
+ -1.0,
909
+ -1.0,
910
+ -0.375,
911
+ -0.375,
912
+ -0.375,
913
+ 0.0
914
+ ],
915
+ "q01": [
916
+ -0.9628571271896362,
917
+ -1.0,
918
+ -1.0,
919
+ -0.26249998807907104,
920
+ -0.21321429312229156,
921
+ -0.3385714292526245,
922
+ 0.0
923
+ ],
924
+ "q99": [
925
+ 0.9114285707473755,
926
+ 0.868571400642395,
927
+ 1.0,
928
+ 0.2817857265472412,
929
+ 0.2239285707473755,
930
+ 0.3557142913341522,
931
+ 1.0
932
+ ],
933
+ "mask": [
934
+ true,
935
+ true,
936
+ true,
937
+ true,
938
+ true,
939
+ true,
940
+ false
941
+ ]
942
+ },
943
+ "proprio": {
944
+ "mean": [
945
+ 0.0,
946
+ 0.0,
947
+ 0.0,
948
+ 0.0,
949
+ 0.0,
950
+ 0.0,
951
+ 0.0
952
+ ],
953
+ "std": [
954
+ 0.0,
955
+ 0.0,
956
+ 0.0,
957
+ 0.0,
958
+ 0.0,
959
+ 0.0,
960
+ 0.0
961
+ ],
962
+ "max": [
963
+ 0.0,
964
+ 0.0,
965
+ 0.0,
966
+ 0.0,
967
+ 0.0,
968
+ 0.0,
969
+ 0.0
970
+ ],
971
+ "min": [
972
+ 0.0,
973
+ 0.0,
974
+ 0.0,
975
+ 0.0,
976
+ 0.0,
977
+ 0.0,
978
+ 0.0
979
+ ],
980
+ "q01": [
981
+ 0.0,
982
+ 0.0,
983
+ 0.0,
984
+ 0.0,
985
+ 0.0,
986
+ 0.0,
987
+ 0.0
988
+ ],
989
+ "q99": [
990
+ 0.0,
991
+ 0.0,
992
+ 0.0,
993
+ 0.0,
994
+ 0.0,
995
+ 0.0,
996
+ 0.0
997
+ ]
998
+ },
999
+ "num_transitions": 76324,
1000
+ "num_trajectories": 150
1001
+ },
1002
+ "berkeley_autolab_ur5": {
1003
+ "action": {
1004
+ "mean": [
1005
+ 0.0005683620693162084,
1006
+ 0.001217700308188796,
1007
+ -0.0005296372692100704,
1008
+ 0.00021029810886830091,
1009
+ 6.0695128922816366e-05,
1010
+ 0.001204986940138042,
1011
+ 0.6298308372497559
1012
+ ],
1013
+ "std": [
1014
+ 0.0115329809486866,
1015
+ 0.007990492507815361,
1016
+ 0.009577835910022259,
1017
+ 0.009432995691895485,
1018
+ 0.016427582129836082,
1019
+ 0.011053967289626598,
1020
+ 0.48267969489097595
1021
+ ],
1022
+ "max": [
1023
+ 0.019999999552965164,
1024
+ 0.019999999552965164,
1025
+ 0.019999999552965164,
1026
+ 0.06666667014360428,
1027
+ 0.06666667014360428,
1028
+ 0.06666667014360428,
1029
+ 1.0
1030
+ ],
1031
+ "min": [
1032
+ -0.019999999552965164,
1033
+ -0.019999999552965164,
1034
+ -0.019999999552965164,
1035
+ -0.06666667014360428,
1036
+ -0.06666667014360428,
1037
+ -0.06666667014360428,
1038
+ 0.0
1039
+ ],
1040
+ "q01": [
1041
+ -0.019999999552965164,
1042
+ -0.019999999552965164,
1043
+ -0.019999999552965164,
1044
+ -0.02628571353852749,
1045
+ -0.06666667014360428,
1046
+ -0.03847619146108627,
1047
+ 0.0
1048
+ ],
1049
+ "q99": [
1050
+ 0.019999999552965164,
1051
+ 0.019999999552965164,
1052
+ 0.019999999552965164,
1053
+ 0.031809523701667786,
1054
+ 0.06666667014360428,
1055
+ 0.036571428179740906,
1056
+ 1.0
1057
+ ],
1058
+ "mask": [
1059
+ true,
1060
+ true,
1061
+ true,
1062
+ true,
1063
+ true,
1064
+ true,
1065
+ false
1066
+ ]
1067
+ },
1068
+ "proprio": {
1069
+ "mean": [
1070
+ 0.0,
1071
+ 0.0,
1072
+ 0.0,
1073
+ 0.0,
1074
+ 0.0,
1075
+ 0.0,
1076
+ 0.0
1077
+ ],
1078
+ "std": [
1079
+ 0.0,
1080
+ 0.0,
1081
+ 0.0,
1082
+ 0.0,
1083
+ 0.0,
1084
+ 0.0,
1085
+ 0.0
1086
+ ],
1087
+ "max": [
1088
+ 0.0,
1089
+ 0.0,
1090
+ 0.0,
1091
+ 0.0,
1092
+ 0.0,
1093
+ 0.0,
1094
+ 0.0
1095
+ ],
1096
+ "min": [
1097
+ 0.0,
1098
+ 0.0,
1099
+ 0.0,
1100
+ 0.0,
1101
+ 0.0,
1102
+ 0.0,
1103
+ 0.0
1104
+ ],
1105
+ "q01": [
1106
+ 0.0,
1107
+ 0.0,
1108
+ 0.0,
1109
+ 0.0,
1110
+ 0.0,
1111
+ 0.0,
1112
+ 0.0
1113
+ ],
1114
+ "q99": [
1115
+ 0.0,
1116
+ 0.0,
1117
+ 0.0,
1118
+ 0.0,
1119
+ 0.0,
1120
+ 0.0,
1121
+ 0.0
1122
+ ]
1123
+ },
1124
+ "num_transitions": 97939,
1125
+ "num_trajectories": 1000
1126
+ },
1127
+ "toto": {
1128
+ "action": {
1129
+ "mean": [
1130
+ 0.38542115688323975,
1131
+ 0.007769413758069277,
1132
+ 0.3632740378379822,
1133
+ -0.6652036905288696,
1134
+ 0.1890396922826767,
1135
+ 0.03298724442720413,
1136
+ 0.0
1137
+ ],
1138
+ "std": [
1139
+ 0.12211652100086212,
1140
+ 0.19378550350666046,
1141
+ 0.10178236663341522,
1142
+ 0.5725259184837341,
1143
+ 0.29884573817253113,
1144
+ 0.3259911835193634,
1145
+ 0.0
1146
+ ],
1147
+ "max": [
1148
+ 0.6839867234230042,
1149
+ 0.4454185664653778,
1150
+ 0.7984078526496887,
1151
+ 2.120781660079956,
1152
+ 1.371164321899414,
1153
+ 1.4118704795837402,
1154
+ 0.0
1155
+ ],
1156
+ "min": [
1157
+ 0.09922284632921219,
1158
+ -0.5180193781852722,
1159
+ 0.13791072368621826,
1160
+ -2.635117530822754,
1161
+ -1.0734480619430542,
1162
+ -1.9282547235488892,
1163
+ 0.0
1164
+ ],
1165
+ "q01": [
1166
+ 0.1756722891330719,
1167
+ -0.3077590811252594,
1168
+ 0.235383919775486,
1169
+ -2.0908505964279174,
1170
+ -0.6191593289375306,
1171
+ -0.7488683319091797,
1172
+ 0.0
1173
+ ],
1174
+ "q99": [
1175
+ 0.6136963081359863,
1176
+ 0.33704194784164443,
1177
+ 0.6681221985816956,
1178
+ 0.7422861719131538,
1179
+ 0.7955395007133507,
1180
+ 0.740464625358582,
1181
+ 0.0
1182
+ ],
1183
+ "mask": [
1184
+ true,
1185
+ true,
1186
+ true,
1187
+ true,
1188
+ true,
1189
+ true,
1190
+ false
1191
+ ]
1192
+ },
1193
+ "proprio": {
1194
+ "mean": [
1195
+ 0.0,
1196
+ 0.0,
1197
+ 0.0,
1198
+ 0.0,
1199
+ 0.0,
1200
+ 0.0,
1201
+ 0.0
1202
+ ],
1203
+ "std": [
1204
+ 0.0,
1205
+ 0.0,
1206
+ 0.0,
1207
+ 0.0,
1208
+ 0.0,
1209
+ 0.0,
1210
+ 0.0
1211
+ ],
1212
+ "max": [
1213
+ 0.0,
1214
+ 0.0,
1215
+ 0.0,
1216
+ 0.0,
1217
+ 0.0,
1218
+ 0.0,
1219
+ 0.0
1220
+ ],
1221
+ "min": [
1222
+ 0.0,
1223
+ 0.0,
1224
+ 0.0,
1225
+ 0.0,
1226
+ 0.0,
1227
+ 0.0,
1228
+ 0.0
1229
+ ],
1230
+ "q01": [
1231
+ 0.0,
1232
+ 0.0,
1233
+ 0.0,
1234
+ 0.0,
1235
+ 0.0,
1236
+ 0.0,
1237
+ 0.0
1238
+ ],
1239
+ "q99": [
1240
+ 0.0,
1241
+ 0.0,
1242
+ 0.0,
1243
+ 0.0,
1244
+ 0.0,
1245
+ 0.0,
1246
+ 0.0
1247
+ ]
1248
+ },
1249
+ "num_transitions": 325699,
1250
+ "num_trajectories": 1003
1251
+ },
1252
+ "stanford_hydra_dataset_converted_externally_to_rlds": {
1253
+ "action": {
1254
+ "mean": [
1255
+ 0.0007790001109242439,
1256
+ 0.00013707754260394722,
1257
+ -0.0002548607881180942,
1258
+ 0.0012903271708637476,
1259
+ -0.004751681815832853,
1260
+ 0.002692886395379901,
1261
+ 0.48855218291282654
1262
+ ],
1263
+ "std": [
1264
+ 0.008022161200642586,
1265
+ 0.009131459519267082,
1266
+ 0.009574338793754578,
1267
+ 0.04122216999530792,
1268
+ 0.0384303517639637,
1269
+ 0.04606688767671585,
1270
+ 0.49976691603660583
1271
+ ],
1272
+ "max": [
1273
+ 0.02499854564666748,
1274
+ 0.02499903365969658,
1275
+ 0.024999922141432762,
1276
+ 0.24974457919597626,
1277
+ 0.24997030198574066,
1278
+ 0.24999946355819702,
1279
+ 1.0
1280
+ ],
1281
+ "min": [
1282
+ -0.024999044835567474,
1283
+ -0.024999700486660004,
1284
+ -0.02499929815530777,
1285
+ -0.24993225932121277,
1286
+ -0.2499666064977646,
1287
+ -0.2499932497739792,
1288
+ 0.0
1289
+ ],
1290
+ "q01": [
1291
+ -0.019992006458342076,
1292
+ -0.02415412735193968,
1293
+ -0.022941758055239916,
1294
+ -0.11085530579090118,
1295
+ -0.12024572037160397,
1296
+ -0.13314770206809043,
1297
+ 0.0
1298
+ ],
1299
+ "q99": [
1300
+ 0.022886231057345868,
1301
+ 0.022358838934451335,
1302
+ 0.02410089675337076,
1303
+ 0.12370114490389822,
1304
+ 0.11323311634361738,
1305
+ 0.18474749639630164,
1306
+ 1.0
1307
+ ],
1308
+ "mask": [
1309
+ true,
1310
+ true,
1311
+ true,
1312
+ true,
1313
+ true,
1314
+ true,
1315
+ false
1316
+ ]
1317
+ },
1318
+ "proprio": {
1319
+ "mean": [
1320
+ 0.0,
1321
+ 0.0,
1322
+ 0.0,
1323
+ 0.0,
1324
+ 0.0,
1325
+ 0.0,
1326
+ 0.0
1327
+ ],
1328
+ "std": [
1329
+ 0.0,
1330
+ 0.0,
1331
+ 0.0,
1332
+ 0.0,
1333
+ 0.0,
1334
+ 0.0,
1335
+ 0.0
1336
+ ],
1337
+ "max": [
1338
+ 0.0,
1339
+ 0.0,
1340
+ 0.0,
1341
+ 0.0,
1342
+ 0.0,
1343
+ 0.0,
1344
+ 0.0
1345
+ ],
1346
+ "min": [
1347
+ 0.0,
1348
+ 0.0,
1349
+ 0.0,
1350
+ 0.0,
1351
+ 0.0,
1352
+ 0.0,
1353
+ 0.0
1354
+ ],
1355
+ "q01": [
1356
+ 0.0,
1357
+ 0.0,
1358
+ 0.0,
1359
+ 0.0,
1360
+ 0.0,
1361
+ 0.0,
1362
+ 0.0
1363
+ ],
1364
+ "q99": [
1365
+ 0.0,
1366
+ 0.0,
1367
+ 0.0,
1368
+ 0.0,
1369
+ 0.0,
1370
+ 0.0,
1371
+ 0.0
1372
+ ]
1373
+ },
1374
+ "num_transitions": 358234,
1375
+ "num_trajectories": 570
1376
+ },
1377
+ "austin_buds_dataset_converted_externally_to_rlds": {
1378
+ "action": {
1379
+ "mean": [
1380
+ -0.07678354531526566,
1381
+ 0.0036849044263362885,
1382
+ 0.05644911900162697,
1383
+ 0.0,
1384
+ 0.0,
1385
+ 0.0,
1386
+ 0.3510494828224182
1387
+ ],
1388
+ "std": [
1389
+ 0.6367740631103516,
1390
+ 0.37889179587364197,
1391
+ 0.47796326875686646,
1392
+ 0.0,
1393
+ 0.0,
1394
+ 0.0,
1395
+ 0.47721168398857117
1396
+ ],
1397
+ "max": [
1398
+ 1.0,
1399
+ 1.0,
1400
+ 1.0,
1401
+ 0.0,
1402
+ 0.0,
1403
+ 0.0,
1404
+ 1.0
1405
+ ],
1406
+ "min": [
1407
+ -1.0,
1408
+ -1.0,
1409
+ -1.0,
1410
+ 0.0,
1411
+ 0.0,
1412
+ 0.0,
1413
+ 0.0
1414
+ ],
1415
+ "q01": [
1416
+ -1.0,
1417
+ -0.9599999785423279,
1418
+ -0.8714285492897034,
1419
+ 0.0,
1420
+ 0.0,
1421
+ 0.0,
1422
+ 0.0
1423
+ ],
1424
+ "q99": [
1425
+ 1.0,
1426
+ 0.8600000143051147,
1427
+ 1.0,
1428
+ 0.0,
1429
+ 0.0,
1430
+ 0.0,
1431
+ 1.0
1432
+ ],
1433
+ "mask": [
1434
+ true,
1435
+ true,
1436
+ true,
1437
+ true,
1438
+ true,
1439
+ true,
1440
+ false
1441
+ ]
1442
+ },
1443
+ "proprio": {
1444
+ "mean": [
1445
+ 0.0,
1446
+ 0.0,
1447
+ 0.0,
1448
+ 0.0,
1449
+ 0.0,
1450
+ 0.0,
1451
+ 0.0
1452
+ ],
1453
+ "std": [
1454
+ 0.0,
1455
+ 0.0,
1456
+ 0.0,
1457
+ 0.0,
1458
+ 0.0,
1459
+ 0.0,
1460
+ 0.0
1461
+ ],
1462
+ "max": [
1463
+ 0.0,
1464
+ 0.0,
1465
+ 0.0,
1466
+ 0.0,
1467
+ 0.0,
1468
+ 0.0,
1469
+ 0.0
1470
+ ],
1471
+ "min": [
1472
+ 0.0,
1473
+ 0.0,
1474
+ 0.0,
1475
+ 0.0,
1476
+ 0.0,
1477
+ 0.0,
1478
+ 0.0
1479
+ ],
1480
+ "q01": [
1481
+ 0.0,
1482
+ 0.0,
1483
+ 0.0,
1484
+ 0.0,
1485
+ 0.0,
1486
+ 0.0,
1487
+ 0.0
1488
+ ],
1489
+ "q99": [
1490
+ 0.0,
1491
+ 0.0,
1492
+ 0.0,
1493
+ 0.0,
1494
+ 0.0,
1495
+ 0.0,
1496
+ 0.0
1497
+ ]
1498
+ },
1499
+ "num_transitions": 34112,
1500
+ "num_trajectories": 50
1501
+ },
1502
+ "nyu_franka_play_dataset_converted_externally_to_rlds": {
1503
+ "action": {
1504
+ "mean": [
1505
+ 0.001021989737637341,
1506
+ -0.00012002651783404872,
1507
+ 0.00032894269679673016,
1508
+ 0.0015034361276775599,
1509
+ -0.002198522910475731,
1510
+ -0.001663230243138969,
1511
+ 0.7230083346366882
1512
+ ],
1513
+ "std": [
1514
+ 0.01327415369451046,
1515
+ 0.013215910643339157,
1516
+ 0.012822109274566174,
1517
+ 0.2732451558113098,
1518
+ 0.057022541761398315,
1519
+ 0.039172880351543427,
1520
+ 0.44752755761146545
1521
+ ],
1522
+ "max": [
1523
+ 0.06424188613891602,
1524
+ 0.07027634978294373,
1525
+ 0.06129661202430725,
1526
+ 6.281067848205566,
1527
+ 0.1967729926109314,
1528
+ 0.26377415657043457,
1529
+ 1.0
1530
+ ],
1531
+ "min": [
1532
+ -0.05952230095863342,
1533
+ -0.07232445478439331,
1534
+ -0.06730806827545166,
1535
+ -6.278434753417969,
1536
+ -0.21479034423828125,
1537
+ -0.3627619743347168,
1538
+ 0.0
1539
+ ],
1540
+ "q01": [
1541
+ -0.03199600875377655,
1542
+ -0.032861671447753905,
1543
+ -0.03368805110454559,
1544
+ -0.12080862045288086,
1545
+ -0.12175218224525451,
1546
+ -0.11370223641395569,
1547
+ 0.0
1548
+ ],
1549
+ "q99": [
1550
+ 0.03101520001888276,
1551
+ 0.0373908892273903,
1552
+ 0.03646374464035038,
1553
+ 0.11764093399047852,
1554
+ 0.1258920183777809,
1555
+ 0.09366151213645942,
1556
+ 1.0
1557
+ ],
1558
+ "mask": [
1559
+ true,
1560
+ true,
1561
+ true,
1562
+ true,
1563
+ true,
1564
+ true,
1565
+ false
1566
+ ]
1567
+ },
1568
+ "proprio": {
1569
+ "mean": [
1570
+ 0.0,
1571
+ 0.0,
1572
+ 0.0,
1573
+ 0.0,
1574
+ 0.0,
1575
+ 0.0,
1576
+ 0.0
1577
+ ],
1578
+ "std": [
1579
+ 0.0,
1580
+ 0.0,
1581
+ 0.0,
1582
+ 0.0,
1583
+ 0.0,
1584
+ 0.0,
1585
+ 0.0
1586
+ ],
1587
+ "max": [
1588
+ 0.0,
1589
+ 0.0,
1590
+ 0.0,
1591
+ 0.0,
1592
+ 0.0,
1593
+ 0.0,
1594
+ 0.0
1595
+ ],
1596
+ "min": [
1597
+ 0.0,
1598
+ 0.0,
1599
+ 0.0,
1600
+ 0.0,
1601
+ 0.0,
1602
+ 0.0,
1603
+ 0.0
1604
+ ],
1605
+ "q01": [
1606
+ 0.0,
1607
+ 0.0,
1608
+ 0.0,
1609
+ 0.0,
1610
+ 0.0,
1611
+ 0.0,
1612
+ 0.0
1613
+ ],
1614
+ "q99": [
1615
+ 0.0,
1616
+ 0.0,
1617
+ 0.0,
1618
+ 0.0,
1619
+ 0.0,
1620
+ 0.0,
1621
+ 0.0
1622
+ ]
1623
+ },
1624
+ "num_transitions": 44875,
1625
+ "num_trajectories": 456
1626
+ },
1627
+ "furniture_bench_dataset_converted_externally_to_rlds": {
1628
+ "action": {
1629
+ "mean": [
1630
+ 0.00014610752987209707,
1631
+ 0.0010830952087417245,
1632
+ 0.0006224989192560315,
1633
+ -0.003303206292912364,
1634
+ -0.0026880695950239897,
1635
+ 0.018242603167891502,
1636
+ 0.48854944109916687
1637
+ ],
1638
+ "std": [
1639
+ 0.01610708422958851,
1640
+ 0.014891477301716805,
1641
+ 0.014014219865202904,
1642
+ 0.058274295181035995,
1643
+ 0.11417088657617569,
1644
+ 0.33479776978492737,
1645
+ 0.49991825222969055
1646
+ ],
1647
+ "max": [
1648
+ 0.10000000149011612,
1649
+ 0.10000000149011612,
1650
+ 0.10000000149011612,
1651
+ 0.8651833534240723,
1652
+ 1.0909736156463623,
1653
+ 2.863185405731201,
1654
+ 1.0
1655
+ ],
1656
+ "min": [
1657
+ -0.10495579987764359,
1658
+ -0.10939455777406693,
1659
+ -0.10000000149011612,
1660
+ -0.971906840801239,
1661
+ -1.0475432872772217,
1662
+ -3.06000018119812,
1663
+ 0.0
1664
+ ],
1665
+ "q01": [
1666
+ -0.053988199681043625,
1667
+ -0.05049169331789017,
1668
+ -0.032499241530895236,
1669
+ -0.1953887003660202,
1670
+ -0.41674559473991396,
1671
+ -0.8886768388748169,
1672
+ 0.0
1673
+ ],
1674
+ "q99": [
1675
+ 0.05414841488003723,
1676
+ 0.04965164884924884,
1677
+ 0.060055799782276154,
1678
+ 0.18231668293476103,
1679
+ 0.39867786407470646,
1680
+ 0.8772023963928218,
1681
+ 1.0
1682
+ ],
1683
+ "mask": [
1684
+ true,
1685
+ true,
1686
+ true,
1687
+ true,
1688
+ true,
1689
+ true,
1690
+ false
1691
+ ]
1692
+ },
1693
+ "proprio": {
1694
+ "mean": [
1695
+ 0.0,
1696
+ 0.0,
1697
+ 0.0,
1698
+ 0.0,
1699
+ 0.0,
1700
+ 0.0,
1701
+ 0.0
1702
+ ],
1703
+ "std": [
1704
+ 0.0,
1705
+ 0.0,
1706
+ 0.0,
1707
+ 0.0,
1708
+ 0.0,
1709
+ 0.0,
1710
+ 0.0
1711
+ ],
1712
+ "max": [
1713
+ 0.0,
1714
+ 0.0,
1715
+ 0.0,
1716
+ 0.0,
1717
+ 0.0,
1718
+ 0.0,
1719
+ 0.0
1720
+ ],
1721
+ "min": [
1722
+ 0.0,
1723
+ 0.0,
1724
+ 0.0,
1725
+ 0.0,
1726
+ 0.0,
1727
+ 0.0,
1728
+ 0.0
1729
+ ],
1730
+ "q01": [
1731
+ 0.0,
1732
+ 0.0,
1733
+ 0.0,
1734
+ 0.0,
1735
+ 0.0,
1736
+ 0.0,
1737
+ 0.0
1738
+ ],
1739
+ "q99": [
1740
+ 0.0,
1741
+ 0.0,
1742
+ 0.0,
1743
+ 0.0,
1744
+ 0.0,
1745
+ 0.0,
1746
+ 0.0
1747
+ ]
1748
+ },
1749
+ "num_transitions": 3948057,
1750
+ "num_trajectories": 5100
1751
+ },
1752
+ "ucsd_kitchen_dataset_converted_externally_to_rlds": {
1753
+ "action": {
1754
+ "mean": [
1755
+ 410.37567138671875,
1756
+ 116.9518814086914,
1757
+ 192.35032653808594,
1758
+ -121.22441864013672,
1759
+ -33.84893035888672,
1760
+ 50.016136169433594,
1761
+ 0.741813600063324
1762
+ ],
1763
+ "std": [
1764
+ 122.81494903564453,
1765
+ 108.8009033203125,
1766
+ 130.303466796875,
1767
+ 116.28205108642578,
1768
+ 27.621843338012695,
1769
+ 41.02094650268555,
1770
+ 0.43763357400894165
1771
+ ],
1772
+ "max": [
1773
+ 678.0,
1774
+ 400.0,
1775
+ 507.0,
1776
+ 180.00001525878906,
1777
+ 6.000013828277588,
1778
+ 116.99998474121094,
1779
+ 1.0
1780
+ ],
1781
+ "min": [
1782
+ 172.0,
1783
+ -166.0,
1784
+ -99.99999237060547,
1785
+ -180.00001525878906,
1786
+ -89.0,
1787
+ -96.00010681152344,
1788
+ 0.0
1789
+ ],
1790
+ "q01": [
1791
+ 200.00001052856445,
1792
+ -102.31004211425781,
1793
+ -94.99993370056153,
1794
+ -180.00001525878906,
1795
+ -88.00001525878906,
1796
+ -38.999977111816406,
1797
+ 0.0
1798
+ ],
1799
+ "q99": [
1800
+ 637.0,
1801
+ 368.30999999999995,
1802
+ 493.0,
1803
+ 180.00001525878906,
1804
+ 0.999983012676239,
1805
+ 105.00001525878906,
1806
+ 1.0
1807
+ ],
1808
+ "mask": [
1809
+ true,
1810
+ true,
1811
+ true,
1812
+ true,
1813
+ true,
1814
+ true,
1815
+ false
1816
+ ]
1817
+ },
1818
+ "proprio": {
1819
+ "mean": [
1820
+ 0.0,
1821
+ 0.0,
1822
+ 0.0,
1823
+ 0.0,
1824
+ 0.0,
1825
+ 0.0,
1826
+ 0.0
1827
+ ],
1828
+ "std": [
1829
+ 0.0,
1830
+ 0.0,
1831
+ 0.0,
1832
+ 0.0,
1833
+ 0.0,
1834
+ 0.0,
1835
+ 0.0
1836
+ ],
1837
+ "max": [
1838
+ 0.0,
1839
+ 0.0,
1840
+ 0.0,
1841
+ 0.0,
1842
+ 0.0,
1843
+ 0.0,
1844
+ 0.0
1845
+ ],
1846
+ "min": [
1847
+ 0.0,
1848
+ 0.0,
1849
+ 0.0,
1850
+ 0.0,
1851
+ 0.0,
1852
+ 0.0,
1853
+ 0.0
1854
+ ],
1855
+ "q01": [
1856
+ 0.0,
1857
+ 0.0,
1858
+ 0.0,
1859
+ 0.0,
1860
+ 0.0,
1861
+ 0.0,
1862
+ 0.0
1863
+ ],
1864
+ "q99": [
1865
+ 0.0,
1866
+ 0.0,
1867
+ 0.0,
1868
+ 0.0,
1869
+ 0.0,
1870
+ 0.0,
1871
+ 0.0
1872
+ ]
1873
+ },
1874
+ "num_transitions": 3970,
1875
+ "num_trajectories": 150
1876
+ },
1877
+ "austin_sailor_dataset_converted_externally_to_rlds": {
1878
+ "action": {
1879
+ "mean": [
1880
+ 0.011825348250567913,
1881
+ 0.006461074110120535,
1882
+ 0.06023626774549484,
1883
+ 0.0,
1884
+ 0.0,
1885
+ 0.0016465914668515325,
1886
+ 0.5260950326919556
1887
+ ],
1888
+ "std": [
1889
+ 0.46348899602890015,
1890
+ 0.41240179538726807,
1891
+ 0.411862850189209,
1892
+ 0.0,
1893
+ 0.0,
1894
+ 0.0578610822558403,
1895
+ 0.49894046783447266
1896
+ ],
1897
+ "max": [
1898
+ 1.0,
1899
+ 1.0,
1900
+ 1.0,
1901
+ 0.0,
1902
+ 0.0,
1903
+ 0.375,
1904
+ 1.0
1905
+ ],
1906
+ "min": [
1907
+ -1.0,
1908
+ -1.0,
1909
+ -1.0,
1910
+ 0.0,
1911
+ 0.0,
1912
+ -0.375,
1913
+ 0.0
1914
+ ],
1915
+ "q01": [
1916
+ -1.0,
1917
+ -0.9828571677207947,
1918
+ -0.6000000238418579,
1919
+ 0.0,
1920
+ 0.0,
1921
+ -0.17249999940395355,
1922
+ 0.0
1923
+ ],
1924
+ "q99": [
1925
+ 1.0,
1926
+ 0.9457142949104309,
1927
+ 1.0,
1928
+ 0.0,
1929
+ 0.0,
1930
+ 0.17892856895923615,
1931
+ 1.0
1932
+ ],
1933
+ "mask": [
1934
+ true,
1935
+ true,
1936
+ true,
1937
+ true,
1938
+ true,
1939
+ true,
1940
+ false
1941
+ ]
1942
+ },
1943
+ "proprio": {
1944
+ "mean": [
1945
+ 0.0,
1946
+ 0.0,
1947
+ 0.0,
1948
+ 0.0,
1949
+ 0.0,
1950
+ 0.0,
1951
+ 0.0
1952
+ ],
1953
+ "std": [
1954
+ 0.0,
1955
+ 0.0,
1956
+ 0.0,
1957
+ 0.0,
1958
+ 0.0,
1959
+ 0.0,
1960
+ 0.0
1961
+ ],
1962
+ "max": [
1963
+ 0.0,
1964
+ 0.0,
1965
+ 0.0,
1966
+ 0.0,
1967
+ 0.0,
1968
+ 0.0,
1969
+ 0.0
1970
+ ],
1971
+ "min": [
1972
+ 0.0,
1973
+ 0.0,
1974
+ 0.0,
1975
+ 0.0,
1976
+ 0.0,
1977
+ 0.0,
1978
+ 0.0
1979
+ ],
1980
+ "q01": [
1981
+ 0.0,
1982
+ 0.0,
1983
+ 0.0,
1984
+ 0.0,
1985
+ 0.0,
1986
+ 0.0,
1987
+ 0.0
1988
+ ],
1989
+ "q99": [
1990
+ 0.0,
1991
+ 0.0,
1992
+ 0.0,
1993
+ 0.0,
1994
+ 0.0,
1995
+ 0.0,
1996
+ 0.0
1997
+ ]
1998
+ },
1999
+ "num_transitions": 353094,
2000
+ "num_trajectories": 240
2001
+ },
2002
+ "austin_sirius_dataset_converted_externally_to_rlds": {
2003
+ "action": {
2004
+ "mean": [
2005
+ 0.07747682929039001,
2006
+ 0.03195561468601227,
2007
+ 0.04244732856750488,
2008
+ 0.0,
2009
+ 0.0,
2010
+ -0.01603456400334835,
2011
+ 0.43260177969932556
2012
+ ],
2013
+ "std": [
2014
+ 0.3906329572200775,
2015
+ 0.2998155355453491,
2016
+ 0.2782271206378937,
2017
+ 0.0,
2018
+ 0.0,
2019
+ 0.08120622485876083,
2020
+ 0.49528297781944275
2021
+ ],
2022
+ "max": [
2023
+ 1.0002285242080688,
2024
+ 0.960608720779419,
2025
+ 1.105179786682129,
2026
+ 0.0,
2027
+ 0.0,
2028
+ 0.341785728931427,
2029
+ 1.0
2030
+ ],
2031
+ "min": [
2032
+ -1.0183025598526,
2033
+ -0.9800000190734863,
2034
+ -0.9774575233459473,
2035
+ 0.0,
2036
+ 0.0,
2037
+ -0.34607142210006714,
2038
+ 0.0
2039
+ ],
2040
+ "q01": [
2041
+ -0.780905865430832,
2042
+ -0.5667179036140442,
2043
+ -0.5254343223571777,
2044
+ 0.0,
2045
+ 0.0,
2046
+ -0.28495091378688814,
2047
+ 0.0
2048
+ ],
2049
+ "q99": [
2050
+ 0.9569637751579284,
2051
+ 0.6971374487876891,
2052
+ 0.8124888157844541,
2053
+ 0.0,
2054
+ 0.0,
2055
+ 0.1971428543329239,
2056
+ 1.0
2057
+ ],
2058
+ "mask": [
2059
+ true,
2060
+ true,
2061
+ true,
2062
+ true,
2063
+ true,
2064
+ true,
2065
+ false
2066
+ ]
2067
+ },
2068
+ "proprio": {
2069
+ "mean": [
2070
+ 0.0,
2071
+ 0.0,
2072
+ 0.0,
2073
+ 0.0,
2074
+ 0.0,
2075
+ 0.0,
2076
+ 0.0
2077
+ ],
2078
+ "std": [
2079
+ 0.0,
2080
+ 0.0,
2081
+ 0.0,
2082
+ 0.0,
2083
+ 0.0,
2084
+ 0.0,
2085
+ 0.0
2086
+ ],
2087
+ "max": [
2088
+ 0.0,
2089
+ 0.0,
2090
+ 0.0,
2091
+ 0.0,
2092
+ 0.0,
2093
+ 0.0,
2094
+ 0.0
2095
+ ],
2096
+ "min": [
2097
+ 0.0,
2098
+ 0.0,
2099
+ 0.0,
2100
+ 0.0,
2101
+ 0.0,
2102
+ 0.0,
2103
+ 0.0
2104
+ ],
2105
+ "q01": [
2106
+ 0.0,
2107
+ 0.0,
2108
+ 0.0,
2109
+ 0.0,
2110
+ 0.0,
2111
+ 0.0,
2112
+ 0.0
2113
+ ],
2114
+ "q99": [
2115
+ 0.0,
2116
+ 0.0,
2117
+ 0.0,
2118
+ 0.0,
2119
+ 0.0,
2120
+ 0.0,
2121
+ 0.0
2122
+ ]
2123
+ },
2124
+ "num_transitions": 279939,
2125
+ "num_trajectories": 559
2126
+ },
2127
+ "dlr_edan_shared_control_converted_externally_to_rlds": {
2128
+ "action": {
2129
+ "mean": [
2130
+ 0.006647810339927673,
2131
+ -0.0007657372043468058,
2132
+ 0.006522852927446365,
2133
+ 0.0011679717572405934,
2134
+ -0.006395625416189432,
2135
+ -0.011902998201549053,
2136
+ 0.6985887289047241
2137
+ ],
2138
+ "std": [
2139
+ 0.021393608301877975,
2140
+ 0.01814231649041176,
2141
+ 0.03374375030398369,
2142
+ 0.01743541844189167,
2143
+ 0.03394376486539841,
2144
+ 0.04641875624656677,
2145
+ 0.4588589072227478
2146
+ ],
2147
+ "max": [
2148
+ 0.18991442024707794,
2149
+ 0.0739002525806427,
2150
+ 0.18064819276332855,
2151
+ 0.0866486132144928,
2152
+ 0.13464981317520142,
2153
+ 0.16910280287265778,
2154
+ 1.0
2155
+ ],
2156
+ "min": [
2157
+ -0.10054297000169754,
2158
+ -0.08427435159683228,
2159
+ -0.13533438742160797,
2160
+ -0.17556548118591309,
2161
+ -0.18485672771930695,
2162
+ -0.2680685818195343,
2163
+ 0.0
2164
+ ],
2165
+ "q01": [
2166
+ -0.02987122368067503,
2167
+ -0.06013262912631035,
2168
+ -0.08286409199237824,
2169
+ -0.05924444157630205,
2170
+ -0.15986866518855095,
2171
+ -0.15636983573436739,
2172
+ 0.0
2173
+ ],
2174
+ "q99": [
2175
+ 0.08832092039287087,
2176
+ 0.042126184627413736,
2177
+ 0.11311905644834042,
2178
+ 0.0643695573508739,
2179
+ 0.03941855944693088,
2180
+ 0.156646853685379,
2181
+ 1.0
2182
+ ],
2183
+ "mask": [
2184
+ true,
2185
+ true,
2186
+ true,
2187
+ true,
2188
+ true,
2189
+ true,
2190
+ false
2191
+ ]
2192
+ },
2193
+ "proprio": {
2194
+ "mean": [
2195
+ 0.0,
2196
+ 0.0,
2197
+ 0.0,
2198
+ 0.0,
2199
+ 0.0,
2200
+ 0.0,
2201
+ 0.0
2202
+ ],
2203
+ "std": [
2204
+ 0.0,
2205
+ 0.0,
2206
+ 0.0,
2207
+ 0.0,
2208
+ 0.0,
2209
+ 0.0,
2210
+ 0.0
2211
+ ],
2212
+ "max": [
2213
+ 0.0,
2214
+ 0.0,
2215
+ 0.0,
2216
+ 0.0,
2217
+ 0.0,
2218
+ 0.0,
2219
+ 0.0
2220
+ ],
2221
+ "min": [
2222
+ 0.0,
2223
+ 0.0,
2224
+ 0.0,
2225
+ 0.0,
2226
+ 0.0,
2227
+ 0.0,
2228
+ 0.0
2229
+ ],
2230
+ "q01": [
2231
+ 0.0,
2232
+ 0.0,
2233
+ 0.0,
2234
+ 0.0,
2235
+ 0.0,
2236
+ 0.0,
2237
+ 0.0
2238
+ ],
2239
+ "q99": [
2240
+ 0.0,
2241
+ 0.0,
2242
+ 0.0,
2243
+ 0.0,
2244
+ 0.0,
2245
+ 0.0,
2246
+ 0.0
2247
+ ]
2248
+ },
2249
+ "num_transitions": 8928,
2250
+ "num_trajectories": 104
2251
+ },
2252
+ "iamlab_cmu_pickup_insert_converted_externally_to_rlds": {
2253
+ "action": {
2254
+ "mean": [
2255
+ 0.5274372696876526,
2256
+ 0.02858201041817665,
2257
+ 0.18712575733661652,
2258
+ 1.2339589595794678,
2259
+ 0.03226623684167862,
2260
+ -1.4199490547180176,
2261
+ 0.5550631880760193
2262
+ ],
2263
+ "std": [
2264
+ 0.08108345419168472,
2265
+ 0.1116757020354271,
2266
+ 0.07747554779052734,
2267
+ 2.8737246990203857,
2268
+ 0.02774704433977604,
2269
+ 2.7678682804107666,
2270
+ 0.49695101380348206
2271
+ ],
2272
+ "max": [
2273
+ 0.6634981632232666,
2274
+ 0.23428471386432648,
2275
+ 0.4308285415172577,
2276
+ 3.1415927410125732,
2277
+ 0.13647015392780304,
2278
+ 3.141592502593994,
2279
+ 1.0
2280
+ ],
2281
+ "min": [
2282
+ 0.3071657121181488,
2283
+ -0.29754969477653503,
2284
+ 0.06578229367733002,
2285
+ -3.1415927410125732,
2286
+ -0.04584203287959099,
2287
+ -3.141592502593994,
2288
+ 0.0
2289
+ ],
2290
+ "q01": [
2291
+ 0.3148897051811218,
2292
+ -0.20317550599575043,
2293
+ 0.06785467118024827,
2294
+ -3.140952730178833,
2295
+ -0.029743434861302376,
2296
+ -3.141091251373291,
2297
+ 0.0
2298
+ ],
2299
+ "q99": [
2300
+ 0.6472805738449097,
2301
+ 0.20846802592277527,
2302
+ 0.36855655312538155,
2303
+ 3.1409926891326903,
2304
+ 0.11424950212240226,
2305
+ 3.1410969257354737,
2306
+ 1.0
2307
+ ],
2308
+ "mask": [
2309
+ true,
2310
+ true,
2311
+ true,
2312
+ true,
2313
+ true,
2314
+ true,
2315
+ false
2316
+ ]
2317
+ },
2318
+ "proprio": {
2319
+ "mean": [
2320
+ 0.0,
2321
+ 0.0,
2322
+ 0.0,
2323
+ 0.0,
2324
+ 0.0,
2325
+ 0.0,
2326
+ 0.0
2327
+ ],
2328
+ "std": [
2329
+ 0.0,
2330
+ 0.0,
2331
+ 0.0,
2332
+ 0.0,
2333
+ 0.0,
2334
+ 0.0,
2335
+ 0.0
2336
+ ],
2337
+ "max": [
2338
+ 0.0,
2339
+ 0.0,
2340
+ 0.0,
2341
+ 0.0,
2342
+ 0.0,
2343
+ 0.0,
2344
+ 0.0
2345
+ ],
2346
+ "min": [
2347
+ 0.0,
2348
+ 0.0,
2349
+ 0.0,
2350
+ 0.0,
2351
+ 0.0,
2352
+ 0.0,
2353
+ 0.0
2354
+ ],
2355
+ "q01": [
2356
+ 0.0,
2357
+ 0.0,
2358
+ 0.0,
2359
+ 0.0,
2360
+ 0.0,
2361
+ 0.0,
2362
+ 0.0
2363
+ ],
2364
+ "q99": [
2365
+ 0.0,
2366
+ 0.0,
2367
+ 0.0,
2368
+ 0.0,
2369
+ 0.0,
2370
+ 0.0,
2371
+ 0.0
2372
+ ]
2373
+ },
2374
+ "num_transitions": 146241,
2375
+ "num_trajectories": 631
2376
+ },
2377
+ "utaustin_mutex": {
2378
+ "action": {
2379
+ "mean": [
2380
+ 0.06176406890153885,
2381
+ -0.005005486309528351,
2382
+ 0.10216785222291946,
2383
+ -0.03314131125807762,
2384
+ 0.013895004987716675,
2385
+ -0.011317633092403412,
2386
+ 0.5038976669311523
2387
+ ],
2388
+ "std": [
2389
+ 0.1875014752149582,
2390
+ 0.4468473494052887,
2391
+ 0.3792876601219177,
2392
+ 0.14097853004932404,
2393
+ 0.06453701853752136,
2394
+ 0.11765272170305252,
2395
+ 0.501045286655426
2396
+ ],
2397
+ "max": [
2398
+ 1.0,
2399
+ 1.0,
2400
+ 1.0,
2401
+ 0.375,
2402
+ 0.375,
2403
+ 0.375,
2404
+ 1.0
2405
+ ],
2406
+ "min": [
2407
+ -1.0,
2408
+ -1.0,
2409
+ -1.0,
2410
+ -0.375,
2411
+ -0.375,
2412
+ -0.375,
2413
+ 0.0
2414
+ ],
2415
+ "q01": [
2416
+ -0.4285714328289032,
2417
+ -0.9800000190734863,
2418
+ -0.5571428537368774,
2419
+ -0.375,
2420
+ -0.15642857551574707,
2421
+ -0.335357129573822,
2422
+ 0.0
2423
+ ],
2424
+ "q99": [
2425
+ 0.5914285778999329,
2426
+ 0.9714285731315613,
2427
+ 1.0,
2428
+ 0.3278571367263794,
2429
+ 0.207857146859169,
2430
+ 0.25607141852378845,
2431
+ 1.0
2432
+ ],
2433
+ "mask": [
2434
+ true,
2435
+ true,
2436
+ true,
2437
+ true,
2438
+ true,
2439
+ true,
2440
+ false
2441
+ ]
2442
+ },
2443
+ "proprio": {
2444
+ "mean": [
2445
+ 0.0,
2446
+ 0.0,
2447
+ 0.0,
2448
+ 0.0,
2449
+ 0.0,
2450
+ 0.0,
2451
+ 0.0
2452
+ ],
2453
+ "std": [
2454
+ 0.0,
2455
+ 0.0,
2456
+ 0.0,
2457
+ 0.0,
2458
+ 0.0,
2459
+ 0.0,
2460
+ 0.0
2461
+ ],
2462
+ "max": [
2463
+ 0.0,
2464
+ 0.0,
2465
+ 0.0,
2466
+ 0.0,
2467
+ 0.0,
2468
+ 0.0,
2469
+ 0.0
2470
+ ],
2471
+ "min": [
2472
+ 0.0,
2473
+ 0.0,
2474
+ 0.0,
2475
+ 0.0,
2476
+ 0.0,
2477
+ 0.0,
2478
+ 0.0
2479
+ ],
2480
+ "q01": [
2481
+ 0.0,
2482
+ 0.0,
2483
+ 0.0,
2484
+ 0.0,
2485
+ 0.0,
2486
+ 0.0,
2487
+ 0.0
2488
+ ],
2489
+ "q99": [
2490
+ 0.0,
2491
+ 0.0,
2492
+ 0.0,
2493
+ 0.0,
2494
+ 0.0,
2495
+ 0.0,
2496
+ 0.0
2497
+ ]
2498
+ },
2499
+ "num_transitions": 361883,
2500
+ "num_trajectories": 1500
2501
+ },
2502
+ "berkeley_fanuc_manipulation": {
2503
+ "action": {
2504
+ "mean": [
2505
+ 0.0007744057802483439,
2506
+ -0.00031240080716088414,
2507
+ -0.0015001941937953234,
2508
+ -0.0007515158504247665,
2509
+ -0.00015832878125365824,
2510
+ 0.00014327642566058785,
2511
+ 0.699295699596405
2512
+ ],
2513
+ "std": [
2514
+ 0.0034070091787725687,
2515
+ 0.0049921851605176926,
2516
+ 0.005344334989786148,
2517
+ 0.00759894959628582,
2518
+ 0.004081866703927517,
2519
+ 0.008568956516683102,
2520
+ 0.4586937427520752
2521
+ ],
2522
+ "max": [
2523
+ 0.009999999776482582,
2524
+ 0.009999999776482582,
2525
+ 0.009999999776482582,
2526
+ 0.03490658476948738,
2527
+ 0.03490658476948738,
2528
+ 0.03490658476948738,
2529
+ 1.0
2530
+ ],
2531
+ "min": [
2532
+ -0.009999999776482582,
2533
+ -0.009999999776482582,
2534
+ -0.009999999776482582,
2535
+ -0.03490658476948738,
2536
+ -0.03490658476948738,
2537
+ -0.03490658476948738,
2538
+ 0.0
2539
+ ],
2540
+ "q01": [
2541
+ -0.009999999776482582,
2542
+ -0.009999999776482582,
2543
+ -0.009999999776482582,
2544
+ -0.03490658476948738,
2545
+ 0.0,
2546
+ -0.03490658476948738,
2547
+ 0.0
2548
+ ],
2549
+ "q99": [
2550
+ 0.009999999776482582,
2551
+ 0.009999999776482582,
2552
+ 0.009999999776482582,
2553
+ 0.03490658476948738,
2554
+ 0.0,
2555
+ 0.03490658476948738,
2556
+ 1.0
2557
+ ],
2558
+ "mask": [
2559
+ true,
2560
+ true,
2561
+ true,
2562
+ true,
2563
+ true,
2564
+ true,
2565
+ false
2566
+ ]
2567
+ },
2568
+ "proprio": {
2569
+ "mean": [
2570
+ 0.0,
2571
+ 0.0,
2572
+ 0.0,
2573
+ 0.0,
2574
+ 0.0,
2575
+ 0.0,
2576
+ 0.0
2577
+ ],
2578
+ "std": [
2579
+ 0.0,
2580
+ 0.0,
2581
+ 0.0,
2582
+ 0.0,
2583
+ 0.0,
2584
+ 0.0,
2585
+ 0.0
2586
+ ],
2587
+ "max": [
2588
+ 0.0,
2589
+ 0.0,
2590
+ 0.0,
2591
+ 0.0,
2592
+ 0.0,
2593
+ 0.0,
2594
+ 0.0
2595
+ ],
2596
+ "min": [
2597
+ 0.0,
2598
+ 0.0,
2599
+ 0.0,
2600
+ 0.0,
2601
+ 0.0,
2602
+ 0.0,
2603
+ 0.0
2604
+ ],
2605
+ "q01": [
2606
+ 0.0,
2607
+ 0.0,
2608
+ 0.0,
2609
+ 0.0,
2610
+ 0.0,
2611
+ 0.0,
2612
+ 0.0
2613
+ ],
2614
+ "q99": [
2615
+ 0.0,
2616
+ 0.0,
2617
+ 0.0,
2618
+ 0.0,
2619
+ 0.0,
2620
+ 0.0,
2621
+ 0.0
2622
+ ]
2623
+ },
2624
+ "num_transitions": 62613,
2625
+ "num_trajectories": 415
2626
+ },
2627
+ "cmu_stretch": {
2628
+ "action": {
2629
+ "mean": [
2630
+ 0.00036304505192674696,
2631
+ 0.0,
2632
+ 0.0016466958913952112,
2633
+ 0.0,
2634
+ 0.0,
2635
+ 0.0,
2636
+ 0.3987048268318176
2637
+ ],
2638
+ "std": [
2639
+ 0.004081828519701958,
2640
+ 0.0,
2641
+ 0.0037743328139185905,
2642
+ 0.0,
2643
+ 0.0,
2644
+ 0.0,
2645
+ 0.48963725566864014
2646
+ ],
2647
+ "max": [
2648
+ 0.02338407188653946,
2649
+ 0.0,
2650
+ 0.023404927924275398,
2651
+ 0.0,
2652
+ 0.0,
2653
+ 0.0,
2654
+ 1.0
2655
+ ],
2656
+ "min": [
2657
+ -0.019353797659277916,
2658
+ 0.0,
2659
+ -0.02019215188920498,
2660
+ 0.0,
2661
+ 0.0,
2662
+ 0.0,
2663
+ 0.0
2664
+ ],
2665
+ "q01": [
2666
+ -0.011175686959177256,
2667
+ 0.0,
2668
+ -0.0032206363626755773,
2669
+ 0.0,
2670
+ 0.0,
2671
+ 0.0,
2672
+ 0.0
2673
+ ],
2674
+ "q99": [
2675
+ 0.014501785952597848,
2676
+ 0.0,
2677
+ 0.015056106168776728,
2678
+ 0.0,
2679
+ 0.0,
2680
+ 0.0,
2681
+ 1.0
2682
+ ],
2683
+ "mask": [
2684
+ true,
2685
+ true,
2686
+ true,
2687
+ true,
2688
+ true,
2689
+ true,
2690
+ false
2691
+ ]
2692
+ },
2693
+ "proprio": {
2694
+ "mean": [
2695
+ 0.0,
2696
+ 0.0,
2697
+ 0.0,
2698
+ 0.0,
2699
+ 0.0,
2700
+ 0.0,
2701
+ 0.0
2702
+ ],
2703
+ "std": [
2704
+ 0.0,
2705
+ 0.0,
2706
+ 0.0,
2707
+ 0.0,
2708
+ 0.0,
2709
+ 0.0,
2710
+ 0.0
2711
+ ],
2712
+ "max": [
2713
+ 0.0,
2714
+ 0.0,
2715
+ 0.0,
2716
+ 0.0,
2717
+ 0.0,
2718
+ 0.0,
2719
+ 0.0
2720
+ ],
2721
+ "min": [
2722
+ 0.0,
2723
+ 0.0,
2724
+ 0.0,
2725
+ 0.0,
2726
+ 0.0,
2727
+ 0.0,
2728
+ 0.0
2729
+ ],
2730
+ "q01": [
2731
+ 0.0,
2732
+ 0.0,
2733
+ 0.0,
2734
+ 0.0,
2735
+ 0.0,
2736
+ 0.0,
2737
+ 0.0
2738
+ ],
2739
+ "q99": [
2740
+ 0.0,
2741
+ 0.0,
2742
+ 0.0,
2743
+ 0.0,
2744
+ 0.0,
2745
+ 0.0,
2746
+ 0.0
2747
+ ]
2748
+ },
2749
+ "num_transitions": 25016,
2750
+ "num_trajectories": 135
2751
+ },
2752
+ "bc_z": {
2753
+ "action": {
2754
+ "mean": [
2755
+ -0.009958467446267605,
2756
+ 0.0008958321413956583,
2757
+ 0.004995597992092371,
2758
+ 0.00029755113064311445,
2759
+ -0.008735382929444313,
2760
+ -0.030693737789988518,
2761
+ 0.8344562649726868
2762
+ ],
2763
+ "std": [
2764
+ 0.03053455986082554,
2765
+ 0.0231423731893301,
2766
+ 0.020641816779971123,
2767
+ 0.04155943542718887,
2768
+ 0.046427831053733826,
2769
+ 0.0769818127155304,
2770
+ 0.3610210120677948
2771
+ ],
2772
+ "max": [
2773
+ 0.2165454924106598,
2774
+ 0.1251407265663147,
2775
+ 0.10772687941789627,
2776
+ 0.33544227480888367,
2777
+ 0.28117990493774414,
2778
+ 0.40614867210388184,
2779
+ 1.0
2780
+ ],
2781
+ "min": [
2782
+ -0.1677047461271286,
2783
+ -0.14630407094955444,
2784
+ -0.10066790133714676,
2785
+ -0.29421567916870117,
2786
+ -0.32101404666900635,
2787
+ -0.4635624885559082,
2788
+ 0.0
2789
+ ],
2790
+ "q01": [
2791
+ -0.09220654994249344,
2792
+ -0.06456145539879798,
2793
+ -0.049121275544166565,
2794
+ -0.11594625547528267,
2795
+ -0.14152548640966414,
2796
+ -0.2251061636209488,
2797
+ 0.0
2798
+ ],
2799
+ "q99": [
2800
+ 0.07628866866230968,
2801
+ 0.058019736707210584,
2802
+ 0.052540797740221024,
2803
+ 0.11740604028105736,
2804
+ 0.11703975558280955,
2805
+ 0.16729306846857078,
2806
+ 1.0
2807
+ ],
2808
+ "mask": [
2809
+ true,
2810
+ true,
2811
+ true,
2812
+ true,
2813
+ true,
2814
+ true,
2815
+ false
2816
+ ]
2817
+ },
2818
+ "proprio": {
2819
+ "mean": [
2820
+ 0.0,
2821
+ 0.0,
2822
+ 0.0,
2823
+ 0.0,
2824
+ 0.0,
2825
+ 0.0,
2826
+ 0.0
2827
+ ],
2828
+ "std": [
2829
+ 0.0,
2830
+ 0.0,
2831
+ 0.0,
2832
+ 0.0,
2833
+ 0.0,
2834
+ 0.0,
2835
+ 0.0
2836
+ ],
2837
+ "max": [
2838
+ 0.0,
2839
+ 0.0,
2840
+ 0.0,
2841
+ 0.0,
2842
+ 0.0,
2843
+ 0.0,
2844
+ 0.0
2845
+ ],
2846
+ "min": [
2847
+ 0.0,
2848
+ 0.0,
2849
+ 0.0,
2850
+ 0.0,
2851
+ 0.0,
2852
+ 0.0,
2853
+ 0.0
2854
+ ],
2855
+ "q01": [
2856
+ 0.0,
2857
+ 0.0,
2858
+ 0.0,
2859
+ 0.0,
2860
+ 0.0,
2861
+ 0.0,
2862
+ 0.0
2863
+ ],
2864
+ "q99": [
2865
+ 0.0,
2866
+ 0.0,
2867
+ 0.0,
2868
+ 0.0,
2869
+ 0.0,
2870
+ 0.0,
2871
+ 0.0
2872
+ ]
2873
+ },
2874
+ "num_transitions": 6015535,
2875
+ "num_trajectories": 43264
2876
+ },
2877
+ "fmb_dataset": {
2878
+ "action": {
2879
+ "mean": [
2880
+ 0.059029702097177505,
2881
+ -0.06476633995771408,
2882
+ -0.09787475317716599,
2883
+ 0.004325388930737972,
2884
+ 0.00028963794466108084,
2885
+ -0.04457257315516472,
2886
+ 0.7336440086364746
2887
+ ],
2888
+ "std": [
2889
+ 0.28809213638305664,
2890
+ 0.2820415794849396,
2891
+ 0.4626740515232086,
2892
+ 0.3266514539718628,
2893
+ 0.10842999070882797,
2894
+ 0.3440099358558655,
2895
+ 0.4435282051563263
2896
+ ],
2897
+ "max": [
2898
+ 1.399999976158142,
2899
+ 1.0,
2900
+ 1.399999976158142,
2901
+ 1.0,
2902
+ 1.0,
2903
+ 1.0,
2904
+ 1.0
2905
+ ],
2906
+ "min": [
2907
+ -1.399999976158142,
2908
+ -1.399999976158142,
2909
+ -1.0,
2910
+ -1.0,
2911
+ -1.0,
2912
+ -1.0,
2913
+ 0.0
2914
+ ],
2915
+ "q01": [
2916
+ -0.8257142901420593,
2917
+ -1.399999976158142,
2918
+ -1.0,
2919
+ -1.0,
2920
+ -0.3028571307659149,
2921
+ -1.0,
2922
+ 0.0
2923
+ ],
2924
+ "q99": [
2925
+ 1.0,
2926
+ 0.5257142782211304,
2927
+ 1.0,
2928
+ 1.0,
2929
+ 0.3400000035762787,
2930
+ 1.0,
2931
+ 1.0
2932
+ ],
2933
+ "mask": [
2934
+ true,
2935
+ true,
2936
+ true,
2937
+ true,
2938
+ true,
2939
+ true,
2940
+ false
2941
+ ]
2942
+ },
2943
+ "proprio": {
2944
+ "mean": [
2945
+ 0.0,
2946
+ 0.0,
2947
+ 0.0,
2948
+ 0.0,
2949
+ 0.0,
2950
+ 0.0,
2951
+ 0.0
2952
+ ],
2953
+ "std": [
2954
+ 0.0,
2955
+ 0.0,
2956
+ 0.0,
2957
+ 0.0,
2958
+ 0.0,
2959
+ 0.0,
2960
+ 0.0
2961
+ ],
2962
+ "max": [
2963
+ 0.0,
2964
+ 0.0,
2965
+ 0.0,
2966
+ 0.0,
2967
+ 0.0,
2968
+ 0.0,
2969
+ 0.0
2970
+ ],
2971
+ "min": [
2972
+ 0.0,
2973
+ 0.0,
2974
+ 0.0,
2975
+ 0.0,
2976
+ 0.0,
2977
+ 0.0,
2978
+ 0.0
2979
+ ],
2980
+ "q01": [
2981
+ 0.0,
2982
+ 0.0,
2983
+ 0.0,
2984
+ 0.0,
2985
+ 0.0,
2986
+ 0.0,
2987
+ 0.0
2988
+ ],
2989
+ "q99": [
2990
+ 0.0,
2991
+ 0.0,
2992
+ 0.0,
2993
+ 0.0,
2994
+ 0.0,
2995
+ 0.0,
2996
+ 0.0
2997
+ ]
2998
+ },
2999
+ "num_transitions": 1137459,
3000
+ "num_trajectories": 8612
3001
+ },
3002
+ "dobbe": {
3003
+ "action": {
3004
+ "mean": [
3005
+ -0.0001120665911003016,
3006
+ 0.0011229600058868527,
3007
+ -0.00010194431524723768,
3008
+ -7.371398532995954e-05,
3009
+ -0.00067531579406932,
3010
+ -5.6643435527803376e-05,
3011
+ 0.6318281888961792
3012
+ ],
3013
+ "std": [
3014
+ 0.04264938458800316,
3015
+ 0.04428559169173241,
3016
+ 0.12224084138870239,
3017
+ 0.005388413090258837,
3018
+ 0.011246449314057827,
3019
+ 0.006287882570177317,
3020
+ 0.39732322096824646
3021
+ ],
3022
+ "max": [
3023
+ 38.590423583984375,
3024
+ 17.932697296142578,
3025
+ 4.843764305114746,
3026
+ 1.4372116327285767,
3027
+ 0.4340403974056244,
3028
+ 1.2057193517684937,
3029
+ 0.9998947381973267
3030
+ ],
3031
+ "min": [
3032
+ -5.700923442840576,
3033
+ -21.605947494506836,
3034
+ -123.72489929199219,
3035
+ -1.7229845523834229,
3036
+ -0.4998578727245331,
3037
+ -0.8867913484573364,
3038
+ 1.4196479014572105e-06
3039
+ ],
3040
+ "q01": [
3041
+ -0.01119564864784479,
3042
+ -0.014266146533191203,
3043
+ -0.0071747214533388615,
3044
+ -0.009444301575422287,
3045
+ -0.03990109823644161,
3046
+ -0.017422311007976532,
3047
+ 4.003279136668425e-05
3048
+ ],
3049
+ "q99": [
3050
+ 0.01015154086053368,
3051
+ 0.017181577533483497,
3052
+ 0.007216989761218411,
3053
+ 0.010380979906767595,
3054
+ 0.03556173853576176,
3055
+ 0.018032474815845446,
3056
+ 0.9982578039169312
3057
+ ],
3058
+ "mask": [
3059
+ true,
3060
+ true,
3061
+ true,
3062
+ true,
3063
+ true,
3064
+ true,
3065
+ false
3066
+ ]
3067
+ },
3068
+ "proprio": {
3069
+ "mean": [
3070
+ 0.0,
3071
+ 0.0,
3072
+ 0.0,
3073
+ 0.0,
3074
+ 0.0,
3075
+ 0.0,
3076
+ 0.0
3077
+ ],
3078
+ "std": [
3079
+ 0.0,
3080
+ 0.0,
3081
+ 0.0,
3082
+ 0.0,
3083
+ 0.0,
3084
+ 0.0,
3085
+ 0.0
3086
+ ],
3087
+ "max": [
3088
+ 0.0,
3089
+ 0.0,
3090
+ 0.0,
3091
+ 0.0,
3092
+ 0.0,
3093
+ 0.0,
3094
+ 0.0
3095
+ ],
3096
+ "min": [
3097
+ 0.0,
3098
+ 0.0,
3099
+ 0.0,
3100
+ 0.0,
3101
+ 0.0,
3102
+ 0.0,
3103
+ 0.0
3104
+ ],
3105
+ "q01": [
3106
+ 0.0,
3107
+ 0.0,
3108
+ 0.0,
3109
+ 0.0,
3110
+ 0.0,
3111
+ 0.0,
3112
+ 0.0
3113
+ ],
3114
+ "q99": [
3115
+ 0.0,
3116
+ 0.0,
3117
+ 0.0,
3118
+ 0.0,
3119
+ 0.0,
3120
+ 0.0,
3121
+ 0.0
3122
+ ]
3123
+ },
3124
+ "num_transitions": 1139911,
3125
+ "num_trajectories": 5208
3126
+ }
3127
+ }
openvla-7b-prismatic/run-metrics.jsonl ADDED
@@ -0,0 +1 @@
 
 
1
+ {"hparams": {"data_root_dir": "/scr/user/data", "hf_token": ".hf_token", "pretrained_checkpoint": "", "resume_epoch": null, "resume_step": null, "run_id": "prism-dinosiglip-224px+mx-oxe-magic-soup-plus+n8+b32+x7", "run_id_note": null, "run_root_dir": "./runs", "save_interval": 2500, "seed": 7, "stage": "vla-full-train", "trackers": ["jsonl", "wandb"], "vla": {"base_vlm": "prism-dinosiglip-224px+7b", "data_mix": "oxe_magic_soup_plus_minus", "enable_gradient_checkpointing": true, "enable_mixed_precision_training": true, "epochs": 1000, "expected_world_size": 64, "freeze_vision_backbone": false, "global_batch_size": 2048, "learning_rate": 2e-05, "lr_scheduler_type": "constant", "max_grad_norm": 1.0, "max_steps": null, "per_device_batch_size": 32, "reduce_in_full_precision": true, "shuffle_buffer_size": 256000, "train_strategy": "fsdp-full-shard", "type": "prism-dinosiglip-224px+mx-oxe-magic-soup-plus", "vla_id": "prism-dinosiglip-224px+mx-oxe-magic-soup-plus", "warmup_ratio": 0.0, "weight_decay": 0.0}, "wandb_entity": "", "wandb_project": ""}, "run_id": "prism-dinosiglip-224px+mx-oxe-magic-soup-plus+n8+b32+x7"}
openvla-7b/.gitattributes ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tar filter=lfs diff=lfs merge=lfs -text
29
+ *.tflite filter=lfs diff=lfs merge=lfs -text
30
+ *.tgz filter=lfs diff=lfs merge=lfs -text
31
+ *.wasm filter=lfs diff=lfs merge=lfs -text
32
+ *.xz filter=lfs diff=lfs merge=lfs -text
33
+ *.zip filter=lfs diff=lfs merge=lfs -text
34
+ *.zst filter=lfs diff=lfs merge=lfs -text
35
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
openvla-7b/README.md ADDED
@@ -0,0 +1,99 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ library_name: transformers
3
+ tags:
4
+ - robotics
5
+ - vla
6
+ - image-text-to-text
7
+ - multimodal
8
+ - pretraining
9
+ license: mit
10
+ language:
11
+ - en
12
+ pipeline_tag: image-text-to-text
13
+ ---
14
+
15
+ # OpenVLA 7B
16
+
17
+ OpenVLA 7B (`openvla-7b`) is an open vision-language-action model trained on 970K robot manipulation episodes from the [Open X-Embodiment](https://robotics-transformer-x.github.io/) dataset.
18
+ The model takes language instructions and camera images as input and generates robot actions. It supports controlling multiple robots out-of-the-box, and can be quickly adapted for new robot domains via (parameter-efficient) fine-tuning.
19
+
20
+ All OpenVLA checkpoints, as well as our [training codebase](https://github.com/openvla/openvla) are released under an MIT License.
21
+
22
+ For full details, please read [our paper](https://arxiv.org/abs/2406.09246) and see [our project page](https://openvla.github.io/).
23
+
24
+ ## Model Summary
25
+
26
+ - **Developed by:** The OpenVLA team consisting of researchers from Stanford, UC Berkeley, Google Deepmind, and the Toyota Research Institute.
27
+ - **Model type:** Vision-language-action (language, image => robot actions)
28
+ - **Language(s) (NLP):** en
29
+ - **License:** MIT
30
+ - **Finetuned from:** [`prism-dinosiglip-224px`](https://github.com/TRI-ML/prismatic-vlms), a VLM trained from:
31
+ + **Vision Backbone**: DINOv2 ViT-L/14 and SigLIP ViT-So400M/14
32
+ + **Language Model**: Llama-2
33
+ - **Pretraining Dataset:** [Open X-Embodiment](https://robotics-transformer-x.github.io/) -- specific component datasets can be found [here](https://github.com/openvla/openvla).
34
+ - **Repository:** [https://github.com/openvla/openvla](https://github.com/openvla/openvla)
35
+ - **Paper:** [OpenVLA: An Open-Source Vision-Language-Action Model](https://arxiv.org/abs/2406.09246)
36
+ - **Project Page & Videos:** [https://openvla.github.io/](https://openvla.github.io/)
37
+
38
+ ## Uses
39
+
40
+ OpenVLA models take a language instruction and a camera image of a robot workspace as input, and predict (normalized) robot actions consisting of 7-DoF end-effector deltas
41
+ of the form (x, y, z, roll, pitch, yaw, gripper). To execute on an actual robot platform, actions need to be *un-normalized* subject to statistics computed on a per-robot,
42
+ per-dataset basis. See [our repository](https://github.com/openvla/openvla) for more information.
43
+
44
+ OpenVLA models can be used zero-shot to control robots for specific combinations of embodiments and domains seen in the Open-X pretraining mixture (e.g., for
45
+ [BridgeV2 environments with a Widow-X robot](https://rail-berkeley.github.io/bridgedata/)). They can also be efficiently *fine-tuned* for new tasks and robot setups
46
+ given minimal demonstration data; [see here](https://github.com/openvla/openvla/blob/main/scripts/finetune.py).
47
+
48
+ **Out-of-Scope:** OpenVLA models do not zero-shot generalize to new (unseen) robot embodiments, or setups that are not represented in the pretraining mix; in these cases,
49
+ we suggest collecting a dataset of demonstrations on the desired setup, and fine-tuning OpenVLA models instead.
50
+
51
+ ## Getting Started
52
+
53
+ OpenVLA 7B can be used to control multiple robots for domains represented in the pretraining mixture out-of-the-box. For example,
54
+ here is an example for loading `openvla-7b` for zero-shot instruction following in the [BridgeV2 environments] with a Widow-X robot:
55
+
56
+ ```python
57
+ # Install minimal dependencies (`torch`, `transformers`, `timm`, `tokenizers`, ...)
58
+ # > pip install -r https://raw.githubusercontent.com/openvla/openvla/main/requirements-min.txt
59
+ from transformers import AutoModelForVision2Seq, AutoProcessor
60
+ from PIL import Image
61
+
62
+ import torch
63
+
64
+ # Load Processor & VLA
65
+ processor = AutoProcessor.from_pretrained("openvla/openvla-7b", trust_remote_code=True)
66
+ vla = AutoModelForVision2Seq.from_pretrained(
67
+ "openvla/openvla-7b",
68
+ attn_implementation="flash_attention_2", # [Optional] Requires `flash_attn`
69
+ torch_dtype=torch.bfloat16,
70
+ low_cpu_mem_usage=True,
71
+ trust_remote_code=True
72
+ ).to("cuda:0")
73
+
74
+ # Grab image input & format prompt
75
+ image: Image.Image = get_from_camera(...)
76
+ prompt = "In: What action should the robot take to {<INSTRUCTION>}?\nOut:"
77
+
78
+ # Predict Action (7-DoF; un-normalize for BridgeV2)
79
+ inputs = processor(prompt, image).to("cuda:0", dtype=torch.bfloat16)
80
+ action = vla.predict_action(**inputs, unnorm_key="bridge_orig", do_sample=False)
81
+
82
+ # Execute...
83
+ robot.act(action, ...)
84
+ ```
85
+
86
+ For more examples, including scripts for fine-tuning OpenVLA models on your own robot demonstration datasets, see [our training repository](https://github.com/openvla/openvla).
87
+
88
+ ## Citation
89
+
90
+ **BibTeX:**
91
+
92
+ ```bibtex
93
+ @article{kim24openvla,
94
+ title={OpenVLA: An Open-Source Vision-Language-Action Model},
95
+ author={{Moo Jin} Kim and Karl Pertsch and Siddharth Karamcheti and Ted Xiao and Ashwin Balakrishna and Suraj Nair and Rafael Rafailov and Ethan Foster and Grace Lam and Pannag Sanketi and Quan Vuong and Thomas Kollar and Benjamin Burchfiel and Russ Tedrake and Dorsa Sadigh and Sergey Levine and Percy Liang and Chelsea Finn},
96
+ journal = {arXiv preprint arXiv:2406.09246},
97
+ year={2024}
98
+ }
99
+ ```
openvla-7b/added_tokens.json ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ {
2
+ "<PAD>": 32000
3
+ }
openvla-7b/configuration_prismatic.py ADDED
@@ -0,0 +1,140 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ configuration_prismatic.py
3
+
4
+ HuggingFace-style configuration definition for Prismatic VLMs, inheriting from `transformers.PretrainedConfig`.
5
+ Default configuration specifies `siglip-224px+7b`.
6
+ """
7
+
8
+ from typing import Any, Dict, List, Optional
9
+
10
+ from transformers import PretrainedConfig
11
+ from transformers.models.auto import CONFIG_MAPPING
12
+
13
+ # === Utilities for Mapping Prismatic names to HF names ===
14
+ # fmt: off
15
+ VISION_BACKBONE_TO_RESOLUTION: Dict[str, List[int]] = {
16
+ "clip-vit-l": [224], "siglip-vit-so400m": [224], "dinov2-vit-l": [224], "in1k-vit-l": [224],
17
+
18
+ "clip-vit-l-336px": [336],
19
+ "siglip-vit-so400m-384px": [384],
20
+
21
+ "dinoclip-vit-l-336px": [336, 336],
22
+ "dinosiglip-vit-so-224px": [224, 224],
23
+ "dinosiglip-vit-so-384px": [384, 384],
24
+ }
25
+ VISION_BACKBONE_TO_TIMM_ID: Dict[str, List[str]] = {
26
+ "clip-vit-l": ["vit_large_patch14_clip_224.openai"],
27
+ "clip-vit-l-336px": ["vit_large_patch14_clip_336.openai"],
28
+
29
+ "dinov2-vit-l": ["vit_large_patch14_reg4_dinov2.lvd142m"],
30
+ "in1k-vit-l": ["vit_large_patch16_224.augreg_in21k_ft_in1k"],
31
+
32
+ "siglip-vit-so400m": ["vit_so400m_patch14_siglip_224"],
33
+ "siglip-vit-so400m-384px": ["vit_so400m_patch14_siglip_384"],
34
+
35
+ "dinoclip-vit-l-336px": ["vit_large_patch14_reg4_dinov2.lvd142m", "vit_large_patch14_clip_336.openai"],
36
+ "dinosiglip-vit-so-224px": ["vit_large_patch14_reg4_dinov2.lvd142m", "vit_so400m_patch14_siglip_224"],
37
+ "dinosiglip-vit-so-384px": ["vit_large_patch14_reg4_dinov2.lvd142m", "vit_so400m_patch14_siglip_384"],
38
+ }
39
+ TIMM_OVERRIDE_ACT_LAYER: Dict[str, List[Optional[str]]] = {
40
+ "clip-vit-l": ["quick_gelu"], "clip-vit-l-336px": ["quick_gelu"],
41
+ "dinov2-vit-l": [None], "in1k-vit-l": [None],
42
+ "siglip-vit-so400m": [None], "siglip-vit-so400m-384px": [None],
43
+ "dinoclip-vit-l-336px": [None, "quick_gelu"],
44
+ "dinosiglip-vit-so-224px": [None, None], "dinosiglip-vit-so-384px": [None, None]
45
+ }
46
+
47
+ LLM_BACKBONE_TO_HF_PATH = {
48
+ "llama2-7b-pure": "meta-llama/Llama-2-7b-hf", "llama2-13b-pure": "meta-llama/Llama-2-13b-hf",
49
+ "llama2-7b-chat": "meta-llama/Llama-2-7b-chat-hf", "llama2-13b-chat": "meta-llama/Llama-2-13b-chat-hf",
50
+
51
+ "vicuna-v15-7b": "lmsys/vicuna-7b-v1.5", "vicuna-v15-13b": "lmsys/vicuna-13b-v1.5",
52
+
53
+ "mistral-v0.1-7b-pure": "mistralai/Mistral-7B-v0.1",
54
+ "mistral-v0.1-7b-instruct": "mistralai/Mistral-7B-Instruct-v0.1",
55
+
56
+ "phi-2-3b": "microsoft/phi-2",
57
+ }
58
+ LLM_BACKBONE_TO_HF_METACLASS = {
59
+ "llama2-7b-pure": "llama", "llama2-13b-pure": "llama", "llama2-7b-chat": "llama", "llama2-13b-chat": "llama",
60
+ "vicuna-v15-7b": "llama", "vicuna-v15-13b": "llama",
61
+
62
+ "mistral-v0.1-7b-pure": "mistral", "mistral-v0.1-7b-instruct": "mistral",
63
+
64
+ "phi-2-3b": "phi",
65
+ }
66
+
67
+ VALID_VISION_BACKBONES = set(VISION_BACKBONE_TO_RESOLUTION.keys())
68
+ VALID_LLM_BACKBONES = set(LLM_BACKBONE_TO_HF_PATH)
69
+ # fmt: on
70
+
71
+
72
+ class PrismaticConfig(PretrainedConfig):
73
+ model_type: str = "prismatic"
74
+ is_composition: bool = False
75
+
76
+ def __init__(
77
+ self,
78
+ vision_backbone_id: str = "siglip-vit-so400m",
79
+ llm_backbone_id: str = "vicuna-v15-7b",
80
+ arch_specifier: str = "no-align+gelu-mlp",
81
+ use_fused_vision_backbone: Optional[bool] = None,
82
+ image_resize_strategy: str = "letterbox",
83
+ text_config: Optional[Dict[str, Any]] = None,
84
+ llm_max_length: int = 2048,
85
+ pad_token_id: int = 32000,
86
+ pad_to_multiple_of: int = 64,
87
+ output_projector_states: bool = False,
88
+ **kwargs: str,
89
+ ) -> None:
90
+ if vision_backbone_id not in VALID_VISION_BACKBONES:
91
+ raise ValueError(f"Vision backbone `{vision_backbone_id}` not in {VALID_VISION_BACKBONES = }")
92
+
93
+ if llm_backbone_id not in VALID_LLM_BACKBONES:
94
+ raise ValueError(f"LLM backbone `{llm_backbone_id}` not in {VALID_LLM_BACKBONES = }")
95
+
96
+ # Set Prismatic Configuration Fields
97
+ self.vision_backbone_id = vision_backbone_id
98
+ self.llm_backbone_id = llm_backbone_id
99
+ self.arch_specifier = arch_specifier
100
+ self.output_projector_states = output_projector_states
101
+
102
+ # [Contract] All vision backbone parameters are lists =>> supports fused backbones with different preprocessing
103
+ self.use_fused_vision_backbone = (
104
+ use_fused_vision_backbone
105
+ if use_fused_vision_backbone is not None
106
+ else any(self.vision_backbone_id.startswith(v) for v in ["dinoclip", "dinosiglip"])
107
+ )
108
+
109
+ self.timm_model_ids = VISION_BACKBONE_TO_TIMM_ID[self.vision_backbone_id]
110
+ self.timm_override_act_layers = TIMM_OVERRIDE_ACT_LAYER[self.vision_backbone_id]
111
+ self.image_sizes = VISION_BACKBONE_TO_RESOLUTION[self.vision_backbone_id]
112
+ self.image_resize_strategy = image_resize_strategy
113
+
114
+ self.hf_llm_id = LLM_BACKBONE_TO_HF_PATH[self.llm_backbone_id]
115
+ self.llm_max_length = llm_max_length
116
+ self.pad_token_id, self.pad_to_multiple_of = pad_token_id, pad_to_multiple_of
117
+
118
+ # [IMPORTANT] HF Utilities actually look for a `text_config` field... we need to use that specific naming!
119
+ self.text_config = (
120
+ CONFIG_MAPPING[LLM_BACKBONE_TO_HF_METACLASS[self.llm_backbone_id]](**text_config)
121
+ if text_config is not None
122
+ else CONFIG_MAPPING[LLM_BACKBONE_TO_HF_METACLASS[self.llm_backbone_id]]()
123
+ )
124
+
125
+ # Dispatch **kwargs to super() =>> note that `pad_token_id` collides, so we pass it in here as well...
126
+ super().__init__(pad_token_id=pad_token_id, **kwargs)
127
+
128
+
129
+ class OpenVLAConfig(PrismaticConfig):
130
+ model_type: str = "openvla"
131
+
132
+ def __init__(
133
+ self,
134
+ norm_stats: Optional[Dict[str, Dict[str, Dict[str, Dict[str, List[float]]]]]] = None,
135
+ n_action_bins: int = 256,
136
+ **kwargs: str,
137
+ ) -> None:
138
+ self.norm_stats, self.n_action_bins = norm_stats, n_action_bins
139
+
140
+ super().__init__(**kwargs)
openvla-7b/generation_config.json ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ {
2
+ "_from_model_config": true,
3
+ "bos_token_id": 1,
4
+ "eos_token_id": 2,
5
+ "pad_token_id": 32000,
6
+ "transformers_version": "4.40.1"
7
+ }
openvla-7b/model-00001-of-00003.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:10d8636256018712c5e5c823d12e22b5797f99bb721bd123bf6bf2379892be85
3
+ size 6948961960
openvla-7b/model-00002-of-00003.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:2050b14f21d48904d269f48d5a980fecea87cd7b36641d9b0f015e72d1fe216a
3
+ size 6971232040
openvla-7b/model-00003-of-00003.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ea65305a1577f36f721965bf84c8caec0a948ce7ce84d754701637376c531fef
3
+ size 1162406824
openvla-7b/model.safetensors.index.json ADDED
@@ -0,0 +1,989 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "metadata": {
3
+ "total_size": 15082474368
4
+ },
5
+ "weight_map": {
6
+ "language_model.lm_head.weight": "model-00003-of-00003.safetensors",
7
+ "language_model.model.embed_tokens.weight": "model-00001-of-00003.safetensors",
8
+ "language_model.model.layers.0.input_layernorm.weight": "model-00001-of-00003.safetensors",
9
+ "language_model.model.layers.0.mlp.down_proj.weight": "model-00001-of-00003.safetensors",
10
+ "language_model.model.layers.0.mlp.gate_proj.weight": "model-00001-of-00003.safetensors",
11
+ "language_model.model.layers.0.mlp.up_proj.weight": "model-00001-of-00003.safetensors",
12
+ "language_model.model.layers.0.post_attention_layernorm.weight": "model-00001-of-00003.safetensors",
13
+ "language_model.model.layers.0.self_attn.k_proj.weight": "model-00001-of-00003.safetensors",
14
+ "language_model.model.layers.0.self_attn.o_proj.weight": "model-00001-of-00003.safetensors",
15
+ "language_model.model.layers.0.self_attn.q_proj.weight": "model-00001-of-00003.safetensors",
16
+ "language_model.model.layers.0.self_attn.v_proj.weight": "model-00001-of-00003.safetensors",
17
+ "language_model.model.layers.1.input_layernorm.weight": "model-00001-of-00003.safetensors",
18
+ "language_model.model.layers.1.mlp.down_proj.weight": "model-00001-of-00003.safetensors",
19
+ "language_model.model.layers.1.mlp.gate_proj.weight": "model-00001-of-00003.safetensors",
20
+ "language_model.model.layers.1.mlp.up_proj.weight": "model-00001-of-00003.safetensors",
21
+ "language_model.model.layers.1.post_attention_layernorm.weight": "model-00001-of-00003.safetensors",
22
+ "language_model.model.layers.1.self_attn.k_proj.weight": "model-00001-of-00003.safetensors",
23
+ "language_model.model.layers.1.self_attn.o_proj.weight": "model-00001-of-00003.safetensors",
24
+ "language_model.model.layers.1.self_attn.q_proj.weight": "model-00001-of-00003.safetensors",
25
+ "language_model.model.layers.1.self_attn.v_proj.weight": "model-00001-of-00003.safetensors",
26
+ "language_model.model.layers.10.input_layernorm.weight": "model-00001-of-00003.safetensors",
27
+ "language_model.model.layers.10.mlp.down_proj.weight": "model-00001-of-00003.safetensors",
28
+ "language_model.model.layers.10.mlp.gate_proj.weight": "model-00001-of-00003.safetensors",
29
+ "language_model.model.layers.10.mlp.up_proj.weight": "model-00001-of-00003.safetensors",
30
+ "language_model.model.layers.10.post_attention_layernorm.weight": "model-00001-of-00003.safetensors",
31
+ "language_model.model.layers.10.self_attn.k_proj.weight": "model-00001-of-00003.safetensors",
32
+ "language_model.model.layers.10.self_attn.o_proj.weight": "model-00001-of-00003.safetensors",
33
+ "language_model.model.layers.10.self_attn.q_proj.weight": "model-00001-of-00003.safetensors",
34
+ "language_model.model.layers.10.self_attn.v_proj.weight": "model-00001-of-00003.safetensors",
35
+ "language_model.model.layers.11.input_layernorm.weight": "model-00001-of-00003.safetensors",
36
+ "language_model.model.layers.11.mlp.down_proj.weight": "model-00001-of-00003.safetensors",
37
+ "language_model.model.layers.11.mlp.gate_proj.weight": "model-00001-of-00003.safetensors",
38
+ "language_model.model.layers.11.mlp.up_proj.weight": "model-00001-of-00003.safetensors",
39
+ "language_model.model.layers.11.post_attention_layernorm.weight": "model-00001-of-00003.safetensors",
40
+ "language_model.model.layers.11.self_attn.k_proj.weight": "model-00001-of-00003.safetensors",
41
+ "language_model.model.layers.11.self_attn.o_proj.weight": "model-00001-of-00003.safetensors",
42
+ "language_model.model.layers.11.self_attn.q_proj.weight": "model-00001-of-00003.safetensors",
43
+ "language_model.model.layers.11.self_attn.v_proj.weight": "model-00001-of-00003.safetensors",
44
+ "language_model.model.layers.12.input_layernorm.weight": "model-00002-of-00003.safetensors",
45
+ "language_model.model.layers.12.mlp.down_proj.weight": "model-00002-of-00003.safetensors",
46
+ "language_model.model.layers.12.mlp.gate_proj.weight": "model-00001-of-00003.safetensors",
47
+ "language_model.model.layers.12.mlp.up_proj.weight": "model-00002-of-00003.safetensors",
48
+ "language_model.model.layers.12.post_attention_layernorm.weight": "model-00002-of-00003.safetensors",
49
+ "language_model.model.layers.12.self_attn.k_proj.weight": "model-00001-of-00003.safetensors",
50
+ "language_model.model.layers.12.self_attn.o_proj.weight": "model-00001-of-00003.safetensors",
51
+ "language_model.model.layers.12.self_attn.q_proj.weight": "model-00001-of-00003.safetensors",
52
+ "language_model.model.layers.12.self_attn.v_proj.weight": "model-00001-of-00003.safetensors",
53
+ "language_model.model.layers.13.input_layernorm.weight": "model-00002-of-00003.safetensors",
54
+ "language_model.model.layers.13.mlp.down_proj.weight": "model-00002-of-00003.safetensors",
55
+ "language_model.model.layers.13.mlp.gate_proj.weight": "model-00002-of-00003.safetensors",
56
+ "language_model.model.layers.13.mlp.up_proj.weight": "model-00002-of-00003.safetensors",
57
+ "language_model.model.layers.13.post_attention_layernorm.weight": "model-00002-of-00003.safetensors",
58
+ "language_model.model.layers.13.self_attn.k_proj.weight": "model-00002-of-00003.safetensors",
59
+ "language_model.model.layers.13.self_attn.o_proj.weight": "model-00002-of-00003.safetensors",
60
+ "language_model.model.layers.13.self_attn.q_proj.weight": "model-00002-of-00003.safetensors",
61
+ "language_model.model.layers.13.self_attn.v_proj.weight": "model-00002-of-00003.safetensors",
62
+ "language_model.model.layers.14.input_layernorm.weight": "model-00002-of-00003.safetensors",
63
+ "language_model.model.layers.14.mlp.down_proj.weight": "model-00002-of-00003.safetensors",
64
+ "language_model.model.layers.14.mlp.gate_proj.weight": "model-00002-of-00003.safetensors",
65
+ "language_model.model.layers.14.mlp.up_proj.weight": "model-00002-of-00003.safetensors",
66
+ "language_model.model.layers.14.post_attention_layernorm.weight": "model-00002-of-00003.safetensors",
67
+ "language_model.model.layers.14.self_attn.k_proj.weight": "model-00002-of-00003.safetensors",
68
+ "language_model.model.layers.14.self_attn.o_proj.weight": "model-00002-of-00003.safetensors",
69
+ "language_model.model.layers.14.self_attn.q_proj.weight": "model-00002-of-00003.safetensors",
70
+ "language_model.model.layers.14.self_attn.v_proj.weight": "model-00002-of-00003.safetensors",
71
+ "language_model.model.layers.15.input_layernorm.weight": "model-00002-of-00003.safetensors",
72
+ "language_model.model.layers.15.mlp.down_proj.weight": "model-00002-of-00003.safetensors",
73
+ "language_model.model.layers.15.mlp.gate_proj.weight": "model-00002-of-00003.safetensors",
74
+ "language_model.model.layers.15.mlp.up_proj.weight": "model-00002-of-00003.safetensors",
75
+ "language_model.model.layers.15.post_attention_layernorm.weight": "model-00002-of-00003.safetensors",
76
+ "language_model.model.layers.15.self_attn.k_proj.weight": "model-00002-of-00003.safetensors",
77
+ "language_model.model.layers.15.self_attn.o_proj.weight": "model-00002-of-00003.safetensors",
78
+ "language_model.model.layers.15.self_attn.q_proj.weight": "model-00002-of-00003.safetensors",
79
+ "language_model.model.layers.15.self_attn.v_proj.weight": "model-00002-of-00003.safetensors",
80
+ "language_model.model.layers.16.input_layernorm.weight": "model-00002-of-00003.safetensors",
81
+ "language_model.model.layers.16.mlp.down_proj.weight": "model-00002-of-00003.safetensors",
82
+ "language_model.model.layers.16.mlp.gate_proj.weight": "model-00002-of-00003.safetensors",
83
+ "language_model.model.layers.16.mlp.up_proj.weight": "model-00002-of-00003.safetensors",
84
+ "language_model.model.layers.16.post_attention_layernorm.weight": "model-00002-of-00003.safetensors",
85
+ "language_model.model.layers.16.self_attn.k_proj.weight": "model-00002-of-00003.safetensors",
86
+ "language_model.model.layers.16.self_attn.o_proj.weight": "model-00002-of-00003.safetensors",
87
+ "language_model.model.layers.16.self_attn.q_proj.weight": "model-00002-of-00003.safetensors",
88
+ "language_model.model.layers.16.self_attn.v_proj.weight": "model-00002-of-00003.safetensors",
89
+ "language_model.model.layers.17.input_layernorm.weight": "model-00002-of-00003.safetensors",
90
+ "language_model.model.layers.17.mlp.down_proj.weight": "model-00002-of-00003.safetensors",
91
+ "language_model.model.layers.17.mlp.gate_proj.weight": "model-00002-of-00003.safetensors",
92
+ "language_model.model.layers.17.mlp.up_proj.weight": "model-00002-of-00003.safetensors",
93
+ "language_model.model.layers.17.post_attention_layernorm.weight": "model-00002-of-00003.safetensors",
94
+ "language_model.model.layers.17.self_attn.k_proj.weight": "model-00002-of-00003.safetensors",
95
+ "language_model.model.layers.17.self_attn.o_proj.weight": "model-00002-of-00003.safetensors",
96
+ "language_model.model.layers.17.self_attn.q_proj.weight": "model-00002-of-00003.safetensors",
97
+ "language_model.model.layers.17.self_attn.v_proj.weight": "model-00002-of-00003.safetensors",
98
+ "language_model.model.layers.18.input_layernorm.weight": "model-00002-of-00003.safetensors",
99
+ "language_model.model.layers.18.mlp.down_proj.weight": "model-00002-of-00003.safetensors",
100
+ "language_model.model.layers.18.mlp.gate_proj.weight": "model-00002-of-00003.safetensors",
101
+ "language_model.model.layers.18.mlp.up_proj.weight": "model-00002-of-00003.safetensors",
102
+ "language_model.model.layers.18.post_attention_layernorm.weight": "model-00002-of-00003.safetensors",
103
+ "language_model.model.layers.18.self_attn.k_proj.weight": "model-00002-of-00003.safetensors",
104
+ "language_model.model.layers.18.self_attn.o_proj.weight": "model-00002-of-00003.safetensors",
105
+ "language_model.model.layers.18.self_attn.q_proj.weight": "model-00002-of-00003.safetensors",
106
+ "language_model.model.layers.18.self_attn.v_proj.weight": "model-00002-of-00003.safetensors",
107
+ "language_model.model.layers.19.input_layernorm.weight": "model-00002-of-00003.safetensors",
108
+ "language_model.model.layers.19.mlp.down_proj.weight": "model-00002-of-00003.safetensors",
109
+ "language_model.model.layers.19.mlp.gate_proj.weight": "model-00002-of-00003.safetensors",
110
+ "language_model.model.layers.19.mlp.up_proj.weight": "model-00002-of-00003.safetensors",
111
+ "language_model.model.layers.19.post_attention_layernorm.weight": "model-00002-of-00003.safetensors",
112
+ "language_model.model.layers.19.self_attn.k_proj.weight": "model-00002-of-00003.safetensors",
113
+ "language_model.model.layers.19.self_attn.o_proj.weight": "model-00002-of-00003.safetensors",
114
+ "language_model.model.layers.19.self_attn.q_proj.weight": "model-00002-of-00003.safetensors",
115
+ "language_model.model.layers.19.self_attn.v_proj.weight": "model-00002-of-00003.safetensors",
116
+ "language_model.model.layers.2.input_layernorm.weight": "model-00001-of-00003.safetensors",
117
+ "language_model.model.layers.2.mlp.down_proj.weight": "model-00001-of-00003.safetensors",
118
+ "language_model.model.layers.2.mlp.gate_proj.weight": "model-00001-of-00003.safetensors",
119
+ "language_model.model.layers.2.mlp.up_proj.weight": "model-00001-of-00003.safetensors",
120
+ "language_model.model.layers.2.post_attention_layernorm.weight": "model-00001-of-00003.safetensors",
121
+ "language_model.model.layers.2.self_attn.k_proj.weight": "model-00001-of-00003.safetensors",
122
+ "language_model.model.layers.2.self_attn.o_proj.weight": "model-00001-of-00003.safetensors",
123
+ "language_model.model.layers.2.self_attn.q_proj.weight": "model-00001-of-00003.safetensors",
124
+ "language_model.model.layers.2.self_attn.v_proj.weight": "model-00001-of-00003.safetensors",
125
+ "language_model.model.layers.20.input_layernorm.weight": "model-00002-of-00003.safetensors",
126
+ "language_model.model.layers.20.mlp.down_proj.weight": "model-00002-of-00003.safetensors",
127
+ "language_model.model.layers.20.mlp.gate_proj.weight": "model-00002-of-00003.safetensors",
128
+ "language_model.model.layers.20.mlp.up_proj.weight": "model-00002-of-00003.safetensors",
129
+ "language_model.model.layers.20.post_attention_layernorm.weight": "model-00002-of-00003.safetensors",
130
+ "language_model.model.layers.20.self_attn.k_proj.weight": "model-00002-of-00003.safetensors",
131
+ "language_model.model.layers.20.self_attn.o_proj.weight": "model-00002-of-00003.safetensors",
132
+ "language_model.model.layers.20.self_attn.q_proj.weight": "model-00002-of-00003.safetensors",
133
+ "language_model.model.layers.20.self_attn.v_proj.weight": "model-00002-of-00003.safetensors",
134
+ "language_model.model.layers.21.input_layernorm.weight": "model-00002-of-00003.safetensors",
135
+ "language_model.model.layers.21.mlp.down_proj.weight": "model-00002-of-00003.safetensors",
136
+ "language_model.model.layers.21.mlp.gate_proj.weight": "model-00002-of-00003.safetensors",
137
+ "language_model.model.layers.21.mlp.up_proj.weight": "model-00002-of-00003.safetensors",
138
+ "language_model.model.layers.21.post_attention_layernorm.weight": "model-00002-of-00003.safetensors",
139
+ "language_model.model.layers.21.self_attn.k_proj.weight": "model-00002-of-00003.safetensors",
140
+ "language_model.model.layers.21.self_attn.o_proj.weight": "model-00002-of-00003.safetensors",
141
+ "language_model.model.layers.21.self_attn.q_proj.weight": "model-00002-of-00003.safetensors",
142
+ "language_model.model.layers.21.self_attn.v_proj.weight": "model-00002-of-00003.safetensors",
143
+ "language_model.model.layers.22.input_layernorm.weight": "model-00002-of-00003.safetensors",
144
+ "language_model.model.layers.22.mlp.down_proj.weight": "model-00002-of-00003.safetensors",
145
+ "language_model.model.layers.22.mlp.gate_proj.weight": "model-00002-of-00003.safetensors",
146
+ "language_model.model.layers.22.mlp.up_proj.weight": "model-00002-of-00003.safetensors",
147
+ "language_model.model.layers.22.post_attention_layernorm.weight": "model-00002-of-00003.safetensors",
148
+ "language_model.model.layers.22.self_attn.k_proj.weight": "model-00002-of-00003.safetensors",
149
+ "language_model.model.layers.22.self_attn.o_proj.weight": "model-00002-of-00003.safetensors",
150
+ "language_model.model.layers.22.self_attn.q_proj.weight": "model-00002-of-00003.safetensors",
151
+ "language_model.model.layers.22.self_attn.v_proj.weight": "model-00002-of-00003.safetensors",
152
+ "language_model.model.layers.23.input_layernorm.weight": "model-00002-of-00003.safetensors",
153
+ "language_model.model.layers.23.mlp.down_proj.weight": "model-00002-of-00003.safetensors",
154
+ "language_model.model.layers.23.mlp.gate_proj.weight": "model-00002-of-00003.safetensors",
155
+ "language_model.model.layers.23.mlp.up_proj.weight": "model-00002-of-00003.safetensors",
156
+ "language_model.model.layers.23.post_attention_layernorm.weight": "model-00002-of-00003.safetensors",
157
+ "language_model.model.layers.23.self_attn.k_proj.weight": "model-00002-of-00003.safetensors",
158
+ "language_model.model.layers.23.self_attn.o_proj.weight": "model-00002-of-00003.safetensors",
159
+ "language_model.model.layers.23.self_attn.q_proj.weight": "model-00002-of-00003.safetensors",
160
+ "language_model.model.layers.23.self_attn.v_proj.weight": "model-00002-of-00003.safetensors",
161
+ "language_model.model.layers.24.input_layernorm.weight": "model-00002-of-00003.safetensors",
162
+ "language_model.model.layers.24.mlp.down_proj.weight": "model-00002-of-00003.safetensors",
163
+ "language_model.model.layers.24.mlp.gate_proj.weight": "model-00002-of-00003.safetensors",
164
+ "language_model.model.layers.24.mlp.up_proj.weight": "model-00002-of-00003.safetensors",
165
+ "language_model.model.layers.24.post_attention_layernorm.weight": "model-00002-of-00003.safetensors",
166
+ "language_model.model.layers.24.self_attn.k_proj.weight": "model-00002-of-00003.safetensors",
167
+ "language_model.model.layers.24.self_attn.o_proj.weight": "model-00002-of-00003.safetensors",
168
+ "language_model.model.layers.24.self_attn.q_proj.weight": "model-00002-of-00003.safetensors",
169
+ "language_model.model.layers.24.self_attn.v_proj.weight": "model-00002-of-00003.safetensors",
170
+ "language_model.model.layers.25.input_layernorm.weight": "model-00002-of-00003.safetensors",
171
+ "language_model.model.layers.25.mlp.down_proj.weight": "model-00002-of-00003.safetensors",
172
+ "language_model.model.layers.25.mlp.gate_proj.weight": "model-00002-of-00003.safetensors",
173
+ "language_model.model.layers.25.mlp.up_proj.weight": "model-00002-of-00003.safetensors",
174
+ "language_model.model.layers.25.post_attention_layernorm.weight": "model-00002-of-00003.safetensors",
175
+ "language_model.model.layers.25.self_attn.k_proj.weight": "model-00002-of-00003.safetensors",
176
+ "language_model.model.layers.25.self_attn.o_proj.weight": "model-00002-of-00003.safetensors",
177
+ "language_model.model.layers.25.self_attn.q_proj.weight": "model-00002-of-00003.safetensors",
178
+ "language_model.model.layers.25.self_attn.v_proj.weight": "model-00002-of-00003.safetensors",
179
+ "language_model.model.layers.26.input_layernorm.weight": "model-00002-of-00003.safetensors",
180
+ "language_model.model.layers.26.mlp.down_proj.weight": "model-00002-of-00003.safetensors",
181
+ "language_model.model.layers.26.mlp.gate_proj.weight": "model-00002-of-00003.safetensors",
182
+ "language_model.model.layers.26.mlp.up_proj.weight": "model-00002-of-00003.safetensors",
183
+ "language_model.model.layers.26.post_attention_layernorm.weight": "model-00002-of-00003.safetensors",
184
+ "language_model.model.layers.26.self_attn.k_proj.weight": "model-00002-of-00003.safetensors",
185
+ "language_model.model.layers.26.self_attn.o_proj.weight": "model-00002-of-00003.safetensors",
186
+ "language_model.model.layers.26.self_attn.q_proj.weight": "model-00002-of-00003.safetensors",
187
+ "language_model.model.layers.26.self_attn.v_proj.weight": "model-00002-of-00003.safetensors",
188
+ "language_model.model.layers.27.input_layernorm.weight": "model-00002-of-00003.safetensors",
189
+ "language_model.model.layers.27.mlp.down_proj.weight": "model-00002-of-00003.safetensors",
190
+ "language_model.model.layers.27.mlp.gate_proj.weight": "model-00002-of-00003.safetensors",
191
+ "language_model.model.layers.27.mlp.up_proj.weight": "model-00002-of-00003.safetensors",
192
+ "language_model.model.layers.27.post_attention_layernorm.weight": "model-00002-of-00003.safetensors",
193
+ "language_model.model.layers.27.self_attn.k_proj.weight": "model-00002-of-00003.safetensors",
194
+ "language_model.model.layers.27.self_attn.o_proj.weight": "model-00002-of-00003.safetensors",
195
+ "language_model.model.layers.27.self_attn.q_proj.weight": "model-00002-of-00003.safetensors",
196
+ "language_model.model.layers.27.self_attn.v_proj.weight": "model-00002-of-00003.safetensors",
197
+ "language_model.model.layers.28.input_layernorm.weight": "model-00002-of-00003.safetensors",
198
+ "language_model.model.layers.28.mlp.down_proj.weight": "model-00002-of-00003.safetensors",
199
+ "language_model.model.layers.28.mlp.gate_proj.weight": "model-00002-of-00003.safetensors",
200
+ "language_model.model.layers.28.mlp.up_proj.weight": "model-00002-of-00003.safetensors",
201
+ "language_model.model.layers.28.post_attention_layernorm.weight": "model-00002-of-00003.safetensors",
202
+ "language_model.model.layers.28.self_attn.k_proj.weight": "model-00002-of-00003.safetensors",
203
+ "language_model.model.layers.28.self_attn.o_proj.weight": "model-00002-of-00003.safetensors",
204
+ "language_model.model.layers.28.self_attn.q_proj.weight": "model-00002-of-00003.safetensors",
205
+ "language_model.model.layers.28.self_attn.v_proj.weight": "model-00002-of-00003.safetensors",
206
+ "language_model.model.layers.29.input_layernorm.weight": "model-00003-of-00003.safetensors",
207
+ "language_model.model.layers.29.mlp.down_proj.weight": "model-00003-of-00003.safetensors",
208
+ "language_model.model.layers.29.mlp.gate_proj.weight": "model-00002-of-00003.safetensors",
209
+ "language_model.model.layers.29.mlp.up_proj.weight": "model-00002-of-00003.safetensors",
210
+ "language_model.model.layers.29.post_attention_layernorm.weight": "model-00003-of-00003.safetensors",
211
+ "language_model.model.layers.29.self_attn.k_proj.weight": "model-00002-of-00003.safetensors",
212
+ "language_model.model.layers.29.self_attn.o_proj.weight": "model-00002-of-00003.safetensors",
213
+ "language_model.model.layers.29.self_attn.q_proj.weight": "model-00002-of-00003.safetensors",
214
+ "language_model.model.layers.29.self_attn.v_proj.weight": "model-00002-of-00003.safetensors",
215
+ "language_model.model.layers.3.input_layernorm.weight": "model-00001-of-00003.safetensors",
216
+ "language_model.model.layers.3.mlp.down_proj.weight": "model-00001-of-00003.safetensors",
217
+ "language_model.model.layers.3.mlp.gate_proj.weight": "model-00001-of-00003.safetensors",
218
+ "language_model.model.layers.3.mlp.up_proj.weight": "model-00001-of-00003.safetensors",
219
+ "language_model.model.layers.3.post_attention_layernorm.weight": "model-00001-of-00003.safetensors",
220
+ "language_model.model.layers.3.self_attn.k_proj.weight": "model-00001-of-00003.safetensors",
221
+ "language_model.model.layers.3.self_attn.o_proj.weight": "model-00001-of-00003.safetensors",
222
+ "language_model.model.layers.3.self_attn.q_proj.weight": "model-00001-of-00003.safetensors",
223
+ "language_model.model.layers.3.self_attn.v_proj.weight": "model-00001-of-00003.safetensors",
224
+ "language_model.model.layers.30.input_layernorm.weight": "model-00003-of-00003.safetensors",
225
+ "language_model.model.layers.30.mlp.down_proj.weight": "model-00003-of-00003.safetensors",
226
+ "language_model.model.layers.30.mlp.gate_proj.weight": "model-00003-of-00003.safetensors",
227
+ "language_model.model.layers.30.mlp.up_proj.weight": "model-00003-of-00003.safetensors",
228
+ "language_model.model.layers.30.post_attention_layernorm.weight": "model-00003-of-00003.safetensors",
229
+ "language_model.model.layers.30.self_attn.k_proj.weight": "model-00003-of-00003.safetensors",
230
+ "language_model.model.layers.30.self_attn.o_proj.weight": "model-00003-of-00003.safetensors",
231
+ "language_model.model.layers.30.self_attn.q_proj.weight": "model-00003-of-00003.safetensors",
232
+ "language_model.model.layers.30.self_attn.v_proj.weight": "model-00003-of-00003.safetensors",
233
+ "language_model.model.layers.31.input_layernorm.weight": "model-00003-of-00003.safetensors",
234
+ "language_model.model.layers.31.mlp.down_proj.weight": "model-00003-of-00003.safetensors",
235
+ "language_model.model.layers.31.mlp.gate_proj.weight": "model-00003-of-00003.safetensors",
236
+ "language_model.model.layers.31.mlp.up_proj.weight": "model-00003-of-00003.safetensors",
237
+ "language_model.model.layers.31.post_attention_layernorm.weight": "model-00003-of-00003.safetensors",
238
+ "language_model.model.layers.31.self_attn.k_proj.weight": "model-00003-of-00003.safetensors",
239
+ "language_model.model.layers.31.self_attn.o_proj.weight": "model-00003-of-00003.safetensors",
240
+ "language_model.model.layers.31.self_attn.q_proj.weight": "model-00003-of-00003.safetensors",
241
+ "language_model.model.layers.31.self_attn.v_proj.weight": "model-00003-of-00003.safetensors",
242
+ "language_model.model.layers.4.input_layernorm.weight": "model-00001-of-00003.safetensors",
243
+ "language_model.model.layers.4.mlp.down_proj.weight": "model-00001-of-00003.safetensors",
244
+ "language_model.model.layers.4.mlp.gate_proj.weight": "model-00001-of-00003.safetensors",
245
+ "language_model.model.layers.4.mlp.up_proj.weight": "model-00001-of-00003.safetensors",
246
+ "language_model.model.layers.4.post_attention_layernorm.weight": "model-00001-of-00003.safetensors",
247
+ "language_model.model.layers.4.self_attn.k_proj.weight": "model-00001-of-00003.safetensors",
248
+ "language_model.model.layers.4.self_attn.o_proj.weight": "model-00001-of-00003.safetensors",
249
+ "language_model.model.layers.4.self_attn.q_proj.weight": "model-00001-of-00003.safetensors",
250
+ "language_model.model.layers.4.self_attn.v_proj.weight": "model-00001-of-00003.safetensors",
251
+ "language_model.model.layers.5.input_layernorm.weight": "model-00001-of-00003.safetensors",
252
+ "language_model.model.layers.5.mlp.down_proj.weight": "model-00001-of-00003.safetensors",
253
+ "language_model.model.layers.5.mlp.gate_proj.weight": "model-00001-of-00003.safetensors",
254
+ "language_model.model.layers.5.mlp.up_proj.weight": "model-00001-of-00003.safetensors",
255
+ "language_model.model.layers.5.post_attention_layernorm.weight": "model-00001-of-00003.safetensors",
256
+ "language_model.model.layers.5.self_attn.k_proj.weight": "model-00001-of-00003.safetensors",
257
+ "language_model.model.layers.5.self_attn.o_proj.weight": "model-00001-of-00003.safetensors",
258
+ "language_model.model.layers.5.self_attn.q_proj.weight": "model-00001-of-00003.safetensors",
259
+ "language_model.model.layers.5.self_attn.v_proj.weight": "model-00001-of-00003.safetensors",
260
+ "language_model.model.layers.6.input_layernorm.weight": "model-00001-of-00003.safetensors",
261
+ "language_model.model.layers.6.mlp.down_proj.weight": "model-00001-of-00003.safetensors",
262
+ "language_model.model.layers.6.mlp.gate_proj.weight": "model-00001-of-00003.safetensors",
263
+ "language_model.model.layers.6.mlp.up_proj.weight": "model-00001-of-00003.safetensors",
264
+ "language_model.model.layers.6.post_attention_layernorm.weight": "model-00001-of-00003.safetensors",
265
+ "language_model.model.layers.6.self_attn.k_proj.weight": "model-00001-of-00003.safetensors",
266
+ "language_model.model.layers.6.self_attn.o_proj.weight": "model-00001-of-00003.safetensors",
267
+ "language_model.model.layers.6.self_attn.q_proj.weight": "model-00001-of-00003.safetensors",
268
+ "language_model.model.layers.6.self_attn.v_proj.weight": "model-00001-of-00003.safetensors",
269
+ "language_model.model.layers.7.input_layernorm.weight": "model-00001-of-00003.safetensors",
270
+ "language_model.model.layers.7.mlp.down_proj.weight": "model-00001-of-00003.safetensors",
271
+ "language_model.model.layers.7.mlp.gate_proj.weight": "model-00001-of-00003.safetensors",
272
+ "language_model.model.layers.7.mlp.up_proj.weight": "model-00001-of-00003.safetensors",
273
+ "language_model.model.layers.7.post_attention_layernorm.weight": "model-00001-of-00003.safetensors",
274
+ "language_model.model.layers.7.self_attn.k_proj.weight": "model-00001-of-00003.safetensors",
275
+ "language_model.model.layers.7.self_attn.o_proj.weight": "model-00001-of-00003.safetensors",
276
+ "language_model.model.layers.7.self_attn.q_proj.weight": "model-00001-of-00003.safetensors",
277
+ "language_model.model.layers.7.self_attn.v_proj.weight": "model-00001-of-00003.safetensors",
278
+ "language_model.model.layers.8.input_layernorm.weight": "model-00001-of-00003.safetensors",
279
+ "language_model.model.layers.8.mlp.down_proj.weight": "model-00001-of-00003.safetensors",
280
+ "language_model.model.layers.8.mlp.gate_proj.weight": "model-00001-of-00003.safetensors",
281
+ "language_model.model.layers.8.mlp.up_proj.weight": "model-00001-of-00003.safetensors",
282
+ "language_model.model.layers.8.post_attention_layernorm.weight": "model-00001-of-00003.safetensors",
283
+ "language_model.model.layers.8.self_attn.k_proj.weight": "model-00001-of-00003.safetensors",
284
+ "language_model.model.layers.8.self_attn.o_proj.weight": "model-00001-of-00003.safetensors",
285
+ "language_model.model.layers.8.self_attn.q_proj.weight": "model-00001-of-00003.safetensors",
286
+ "language_model.model.layers.8.self_attn.v_proj.weight": "model-00001-of-00003.safetensors",
287
+ "language_model.model.layers.9.input_layernorm.weight": "model-00001-of-00003.safetensors",
288
+ "language_model.model.layers.9.mlp.down_proj.weight": "model-00001-of-00003.safetensors",
289
+ "language_model.model.layers.9.mlp.gate_proj.weight": "model-00001-of-00003.safetensors",
290
+ "language_model.model.layers.9.mlp.up_proj.weight": "model-00001-of-00003.safetensors",
291
+ "language_model.model.layers.9.post_attention_layernorm.weight": "model-00001-of-00003.safetensors",
292
+ "language_model.model.layers.9.self_attn.k_proj.weight": "model-00001-of-00003.safetensors",
293
+ "language_model.model.layers.9.self_attn.o_proj.weight": "model-00001-of-00003.safetensors",
294
+ "language_model.model.layers.9.self_attn.q_proj.weight": "model-00001-of-00003.safetensors",
295
+ "language_model.model.layers.9.self_attn.v_proj.weight": "model-00001-of-00003.safetensors",
296
+ "language_model.model.norm.weight": "model-00003-of-00003.safetensors",
297
+ "projector.fc1.bias": "model-00001-of-00003.safetensors",
298
+ "projector.fc1.weight": "model-00001-of-00003.safetensors",
299
+ "projector.fc2.bias": "model-00001-of-00003.safetensors",
300
+ "projector.fc2.weight": "model-00001-of-00003.safetensors",
301
+ "projector.fc3.bias": "model-00001-of-00003.safetensors",
302
+ "projector.fc3.weight": "model-00001-of-00003.safetensors",
303
+ "vision_backbone.featurizer.blocks.0.attn.proj.bias": "model-00001-of-00003.safetensors",
304
+ "vision_backbone.featurizer.blocks.0.attn.proj.weight": "model-00001-of-00003.safetensors",
305
+ "vision_backbone.featurizer.blocks.0.attn.qkv.bias": "model-00001-of-00003.safetensors",
306
+ "vision_backbone.featurizer.blocks.0.attn.qkv.weight": "model-00001-of-00003.safetensors",
307
+ "vision_backbone.featurizer.blocks.0.ls1.scale_factor": "model-00001-of-00003.safetensors",
308
+ "vision_backbone.featurizer.blocks.0.ls2.scale_factor": "model-00001-of-00003.safetensors",
309
+ "vision_backbone.featurizer.blocks.0.mlp.fc1.bias": "model-00001-of-00003.safetensors",
310
+ "vision_backbone.featurizer.blocks.0.mlp.fc1.weight": "model-00001-of-00003.safetensors",
311
+ "vision_backbone.featurizer.blocks.0.mlp.fc2.bias": "model-00001-of-00003.safetensors",
312
+ "vision_backbone.featurizer.blocks.0.mlp.fc2.weight": "model-00001-of-00003.safetensors",
313
+ "vision_backbone.featurizer.blocks.0.norm1.bias": "model-00001-of-00003.safetensors",
314
+ "vision_backbone.featurizer.blocks.0.norm1.weight": "model-00001-of-00003.safetensors",
315
+ "vision_backbone.featurizer.blocks.0.norm2.bias": "model-00001-of-00003.safetensors",
316
+ "vision_backbone.featurizer.blocks.0.norm2.weight": "model-00001-of-00003.safetensors",
317
+ "vision_backbone.featurizer.blocks.1.attn.proj.bias": "model-00001-of-00003.safetensors",
318
+ "vision_backbone.featurizer.blocks.1.attn.proj.weight": "model-00001-of-00003.safetensors",
319
+ "vision_backbone.featurizer.blocks.1.attn.qkv.bias": "model-00001-of-00003.safetensors",
320
+ "vision_backbone.featurizer.blocks.1.attn.qkv.weight": "model-00001-of-00003.safetensors",
321
+ "vision_backbone.featurizer.blocks.1.ls1.scale_factor": "model-00001-of-00003.safetensors",
322
+ "vision_backbone.featurizer.blocks.1.ls2.scale_factor": "model-00001-of-00003.safetensors",
323
+ "vision_backbone.featurizer.blocks.1.mlp.fc1.bias": "model-00001-of-00003.safetensors",
324
+ "vision_backbone.featurizer.blocks.1.mlp.fc1.weight": "model-00001-of-00003.safetensors",
325
+ "vision_backbone.featurizer.blocks.1.mlp.fc2.bias": "model-00001-of-00003.safetensors",
326
+ "vision_backbone.featurizer.blocks.1.mlp.fc2.weight": "model-00001-of-00003.safetensors",
327
+ "vision_backbone.featurizer.blocks.1.norm1.bias": "model-00001-of-00003.safetensors",
328
+ "vision_backbone.featurizer.blocks.1.norm1.weight": "model-00001-of-00003.safetensors",
329
+ "vision_backbone.featurizer.blocks.1.norm2.bias": "model-00001-of-00003.safetensors",
330
+ "vision_backbone.featurizer.blocks.1.norm2.weight": "model-00001-of-00003.safetensors",
331
+ "vision_backbone.featurizer.blocks.10.attn.proj.bias": "model-00001-of-00003.safetensors",
332
+ "vision_backbone.featurizer.blocks.10.attn.proj.weight": "model-00001-of-00003.safetensors",
333
+ "vision_backbone.featurizer.blocks.10.attn.qkv.bias": "model-00001-of-00003.safetensors",
334
+ "vision_backbone.featurizer.blocks.10.attn.qkv.weight": "model-00001-of-00003.safetensors",
335
+ "vision_backbone.featurizer.blocks.10.ls1.scale_factor": "model-00001-of-00003.safetensors",
336
+ "vision_backbone.featurizer.blocks.10.ls2.scale_factor": "model-00001-of-00003.safetensors",
337
+ "vision_backbone.featurizer.blocks.10.mlp.fc1.bias": "model-00001-of-00003.safetensors",
338
+ "vision_backbone.featurizer.blocks.10.mlp.fc1.weight": "model-00001-of-00003.safetensors",
339
+ "vision_backbone.featurizer.blocks.10.mlp.fc2.bias": "model-00001-of-00003.safetensors",
340
+ "vision_backbone.featurizer.blocks.10.mlp.fc2.weight": "model-00001-of-00003.safetensors",
341
+ "vision_backbone.featurizer.blocks.10.norm1.bias": "model-00001-of-00003.safetensors",
342
+ "vision_backbone.featurizer.blocks.10.norm1.weight": "model-00001-of-00003.safetensors",
343
+ "vision_backbone.featurizer.blocks.10.norm2.bias": "model-00001-of-00003.safetensors",
344
+ "vision_backbone.featurizer.blocks.10.norm2.weight": "model-00001-of-00003.safetensors",
345
+ "vision_backbone.featurizer.blocks.11.attn.proj.bias": "model-00001-of-00003.safetensors",
346
+ "vision_backbone.featurizer.blocks.11.attn.proj.weight": "model-00001-of-00003.safetensors",
347
+ "vision_backbone.featurizer.blocks.11.attn.qkv.bias": "model-00001-of-00003.safetensors",
348
+ "vision_backbone.featurizer.blocks.11.attn.qkv.weight": "model-00001-of-00003.safetensors",
349
+ "vision_backbone.featurizer.blocks.11.ls1.scale_factor": "model-00001-of-00003.safetensors",
350
+ "vision_backbone.featurizer.blocks.11.ls2.scale_factor": "model-00001-of-00003.safetensors",
351
+ "vision_backbone.featurizer.blocks.11.mlp.fc1.bias": "model-00001-of-00003.safetensors",
352
+ "vision_backbone.featurizer.blocks.11.mlp.fc1.weight": "model-00001-of-00003.safetensors",
353
+ "vision_backbone.featurizer.blocks.11.mlp.fc2.bias": "model-00001-of-00003.safetensors",
354
+ "vision_backbone.featurizer.blocks.11.mlp.fc2.weight": "model-00001-of-00003.safetensors",
355
+ "vision_backbone.featurizer.blocks.11.norm1.bias": "model-00001-of-00003.safetensors",
356
+ "vision_backbone.featurizer.blocks.11.norm1.weight": "model-00001-of-00003.safetensors",
357
+ "vision_backbone.featurizer.blocks.11.norm2.bias": "model-00001-of-00003.safetensors",
358
+ "vision_backbone.featurizer.blocks.11.norm2.weight": "model-00001-of-00003.safetensors",
359
+ "vision_backbone.featurizer.blocks.12.attn.proj.bias": "model-00001-of-00003.safetensors",
360
+ "vision_backbone.featurizer.blocks.12.attn.proj.weight": "model-00001-of-00003.safetensors",
361
+ "vision_backbone.featurizer.blocks.12.attn.qkv.bias": "model-00001-of-00003.safetensors",
362
+ "vision_backbone.featurizer.blocks.12.attn.qkv.weight": "model-00001-of-00003.safetensors",
363
+ "vision_backbone.featurizer.blocks.12.ls1.scale_factor": "model-00001-of-00003.safetensors",
364
+ "vision_backbone.featurizer.blocks.12.ls2.scale_factor": "model-00001-of-00003.safetensors",
365
+ "vision_backbone.featurizer.blocks.12.mlp.fc1.bias": "model-00001-of-00003.safetensors",
366
+ "vision_backbone.featurizer.blocks.12.mlp.fc1.weight": "model-00001-of-00003.safetensors",
367
+ "vision_backbone.featurizer.blocks.12.mlp.fc2.bias": "model-00001-of-00003.safetensors",
368
+ "vision_backbone.featurizer.blocks.12.mlp.fc2.weight": "model-00001-of-00003.safetensors",
369
+ "vision_backbone.featurizer.blocks.12.norm1.bias": "model-00001-of-00003.safetensors",
370
+ "vision_backbone.featurizer.blocks.12.norm1.weight": "model-00001-of-00003.safetensors",
371
+ "vision_backbone.featurizer.blocks.12.norm2.bias": "model-00001-of-00003.safetensors",
372
+ "vision_backbone.featurizer.blocks.12.norm2.weight": "model-00001-of-00003.safetensors",
373
+ "vision_backbone.featurizer.blocks.13.attn.proj.bias": "model-00001-of-00003.safetensors",
374
+ "vision_backbone.featurizer.blocks.13.attn.proj.weight": "model-00001-of-00003.safetensors",
375
+ "vision_backbone.featurizer.blocks.13.attn.qkv.bias": "model-00001-of-00003.safetensors",
376
+ "vision_backbone.featurizer.blocks.13.attn.qkv.weight": "model-00001-of-00003.safetensors",
377
+ "vision_backbone.featurizer.blocks.13.ls1.scale_factor": "model-00001-of-00003.safetensors",
378
+ "vision_backbone.featurizer.blocks.13.ls2.scale_factor": "model-00001-of-00003.safetensors",
379
+ "vision_backbone.featurizer.blocks.13.mlp.fc1.bias": "model-00001-of-00003.safetensors",
380
+ "vision_backbone.featurizer.blocks.13.mlp.fc1.weight": "model-00001-of-00003.safetensors",
381
+ "vision_backbone.featurizer.blocks.13.mlp.fc2.bias": "model-00001-of-00003.safetensors",
382
+ "vision_backbone.featurizer.blocks.13.mlp.fc2.weight": "model-00001-of-00003.safetensors",
383
+ "vision_backbone.featurizer.blocks.13.norm1.bias": "model-00001-of-00003.safetensors",
384
+ "vision_backbone.featurizer.blocks.13.norm1.weight": "model-00001-of-00003.safetensors",
385
+ "vision_backbone.featurizer.blocks.13.norm2.bias": "model-00001-of-00003.safetensors",
386
+ "vision_backbone.featurizer.blocks.13.norm2.weight": "model-00001-of-00003.safetensors",
387
+ "vision_backbone.featurizer.blocks.14.attn.proj.bias": "model-00001-of-00003.safetensors",
388
+ "vision_backbone.featurizer.blocks.14.attn.proj.weight": "model-00001-of-00003.safetensors",
389
+ "vision_backbone.featurizer.blocks.14.attn.qkv.bias": "model-00001-of-00003.safetensors",
390
+ "vision_backbone.featurizer.blocks.14.attn.qkv.weight": "model-00001-of-00003.safetensors",
391
+ "vision_backbone.featurizer.blocks.14.ls1.scale_factor": "model-00001-of-00003.safetensors",
392
+ "vision_backbone.featurizer.blocks.14.ls2.scale_factor": "model-00001-of-00003.safetensors",
393
+ "vision_backbone.featurizer.blocks.14.mlp.fc1.bias": "model-00001-of-00003.safetensors",
394
+ "vision_backbone.featurizer.blocks.14.mlp.fc1.weight": "model-00001-of-00003.safetensors",
395
+ "vision_backbone.featurizer.blocks.14.mlp.fc2.bias": "model-00001-of-00003.safetensors",
396
+ "vision_backbone.featurizer.blocks.14.mlp.fc2.weight": "model-00001-of-00003.safetensors",
397
+ "vision_backbone.featurizer.blocks.14.norm1.bias": "model-00001-of-00003.safetensors",
398
+ "vision_backbone.featurizer.blocks.14.norm1.weight": "model-00001-of-00003.safetensors",
399
+ "vision_backbone.featurizer.blocks.14.norm2.bias": "model-00001-of-00003.safetensors",
400
+ "vision_backbone.featurizer.blocks.14.norm2.weight": "model-00001-of-00003.safetensors",
401
+ "vision_backbone.featurizer.blocks.15.attn.proj.bias": "model-00001-of-00003.safetensors",
402
+ "vision_backbone.featurizer.blocks.15.attn.proj.weight": "model-00001-of-00003.safetensors",
403
+ "vision_backbone.featurizer.blocks.15.attn.qkv.bias": "model-00001-of-00003.safetensors",
404
+ "vision_backbone.featurizer.blocks.15.attn.qkv.weight": "model-00001-of-00003.safetensors",
405
+ "vision_backbone.featurizer.blocks.15.ls1.scale_factor": "model-00001-of-00003.safetensors",
406
+ "vision_backbone.featurizer.blocks.15.ls2.scale_factor": "model-00001-of-00003.safetensors",
407
+ "vision_backbone.featurizer.blocks.15.mlp.fc1.bias": "model-00001-of-00003.safetensors",
408
+ "vision_backbone.featurizer.blocks.15.mlp.fc1.weight": "model-00001-of-00003.safetensors",
409
+ "vision_backbone.featurizer.blocks.15.mlp.fc2.bias": "model-00001-of-00003.safetensors",
410
+ "vision_backbone.featurizer.blocks.15.mlp.fc2.weight": "model-00001-of-00003.safetensors",
411
+ "vision_backbone.featurizer.blocks.15.norm1.bias": "model-00001-of-00003.safetensors",
412
+ "vision_backbone.featurizer.blocks.15.norm1.weight": "model-00001-of-00003.safetensors",
413
+ "vision_backbone.featurizer.blocks.15.norm2.bias": "model-00001-of-00003.safetensors",
414
+ "vision_backbone.featurizer.blocks.15.norm2.weight": "model-00001-of-00003.safetensors",
415
+ "vision_backbone.featurizer.blocks.16.attn.proj.bias": "model-00001-of-00003.safetensors",
416
+ "vision_backbone.featurizer.blocks.16.attn.proj.weight": "model-00001-of-00003.safetensors",
417
+ "vision_backbone.featurizer.blocks.16.attn.qkv.bias": "model-00001-of-00003.safetensors",
418
+ "vision_backbone.featurizer.blocks.16.attn.qkv.weight": "model-00001-of-00003.safetensors",
419
+ "vision_backbone.featurizer.blocks.16.ls1.scale_factor": "model-00001-of-00003.safetensors",
420
+ "vision_backbone.featurizer.blocks.16.ls2.scale_factor": "model-00001-of-00003.safetensors",
421
+ "vision_backbone.featurizer.blocks.16.mlp.fc1.bias": "model-00001-of-00003.safetensors",
422
+ "vision_backbone.featurizer.blocks.16.mlp.fc1.weight": "model-00001-of-00003.safetensors",
423
+ "vision_backbone.featurizer.blocks.16.mlp.fc2.bias": "model-00001-of-00003.safetensors",
424
+ "vision_backbone.featurizer.blocks.16.mlp.fc2.weight": "model-00001-of-00003.safetensors",
425
+ "vision_backbone.featurizer.blocks.16.norm1.bias": "model-00001-of-00003.safetensors",
426
+ "vision_backbone.featurizer.blocks.16.norm1.weight": "model-00001-of-00003.safetensors",
427
+ "vision_backbone.featurizer.blocks.16.norm2.bias": "model-00001-of-00003.safetensors",
428
+ "vision_backbone.featurizer.blocks.16.norm2.weight": "model-00001-of-00003.safetensors",
429
+ "vision_backbone.featurizer.blocks.17.attn.proj.bias": "model-00001-of-00003.safetensors",
430
+ "vision_backbone.featurizer.blocks.17.attn.proj.weight": "model-00001-of-00003.safetensors",
431
+ "vision_backbone.featurizer.blocks.17.attn.qkv.bias": "model-00001-of-00003.safetensors",
432
+ "vision_backbone.featurizer.blocks.17.attn.qkv.weight": "model-00001-of-00003.safetensors",
433
+ "vision_backbone.featurizer.blocks.17.ls1.scale_factor": "model-00001-of-00003.safetensors",
434
+ "vision_backbone.featurizer.blocks.17.ls2.scale_factor": "model-00001-of-00003.safetensors",
435
+ "vision_backbone.featurizer.blocks.17.mlp.fc1.bias": "model-00001-of-00003.safetensors",
436
+ "vision_backbone.featurizer.blocks.17.mlp.fc1.weight": "model-00001-of-00003.safetensors",
437
+ "vision_backbone.featurizer.blocks.17.mlp.fc2.bias": "model-00001-of-00003.safetensors",
438
+ "vision_backbone.featurizer.blocks.17.mlp.fc2.weight": "model-00001-of-00003.safetensors",
439
+ "vision_backbone.featurizer.blocks.17.norm1.bias": "model-00001-of-00003.safetensors",
440
+ "vision_backbone.featurizer.blocks.17.norm1.weight": "model-00001-of-00003.safetensors",
441
+ "vision_backbone.featurizer.blocks.17.norm2.bias": "model-00001-of-00003.safetensors",
442
+ "vision_backbone.featurizer.blocks.17.norm2.weight": "model-00001-of-00003.safetensors",
443
+ "vision_backbone.featurizer.blocks.18.attn.proj.bias": "model-00001-of-00003.safetensors",
444
+ "vision_backbone.featurizer.blocks.18.attn.proj.weight": "model-00001-of-00003.safetensors",
445
+ "vision_backbone.featurizer.blocks.18.attn.qkv.bias": "model-00001-of-00003.safetensors",
446
+ "vision_backbone.featurizer.blocks.18.attn.qkv.weight": "model-00001-of-00003.safetensors",
447
+ "vision_backbone.featurizer.blocks.18.ls1.scale_factor": "model-00001-of-00003.safetensors",
448
+ "vision_backbone.featurizer.blocks.18.ls2.scale_factor": "model-00001-of-00003.safetensors",
449
+ "vision_backbone.featurizer.blocks.18.mlp.fc1.bias": "model-00001-of-00003.safetensors",
450
+ "vision_backbone.featurizer.blocks.18.mlp.fc1.weight": "model-00001-of-00003.safetensors",
451
+ "vision_backbone.featurizer.blocks.18.mlp.fc2.bias": "model-00001-of-00003.safetensors",
452
+ "vision_backbone.featurizer.blocks.18.mlp.fc2.weight": "model-00001-of-00003.safetensors",
453
+ "vision_backbone.featurizer.blocks.18.norm1.bias": "model-00001-of-00003.safetensors",
454
+ "vision_backbone.featurizer.blocks.18.norm1.weight": "model-00001-of-00003.safetensors",
455
+ "vision_backbone.featurizer.blocks.18.norm2.bias": "model-00001-of-00003.safetensors",
456
+ "vision_backbone.featurizer.blocks.18.norm2.weight": "model-00001-of-00003.safetensors",
457
+ "vision_backbone.featurizer.blocks.19.attn.proj.bias": "model-00001-of-00003.safetensors",
458
+ "vision_backbone.featurizer.blocks.19.attn.proj.weight": "model-00001-of-00003.safetensors",
459
+ "vision_backbone.featurizer.blocks.19.attn.qkv.bias": "model-00001-of-00003.safetensors",
460
+ "vision_backbone.featurizer.blocks.19.attn.qkv.weight": "model-00001-of-00003.safetensors",
461
+ "vision_backbone.featurizer.blocks.19.ls1.scale_factor": "model-00001-of-00003.safetensors",
462
+ "vision_backbone.featurizer.blocks.19.ls2.scale_factor": "model-00001-of-00003.safetensors",
463
+ "vision_backbone.featurizer.blocks.19.mlp.fc1.bias": "model-00001-of-00003.safetensors",
464
+ "vision_backbone.featurizer.blocks.19.mlp.fc1.weight": "model-00001-of-00003.safetensors",
465
+ "vision_backbone.featurizer.blocks.19.mlp.fc2.bias": "model-00001-of-00003.safetensors",
466
+ "vision_backbone.featurizer.blocks.19.mlp.fc2.weight": "model-00001-of-00003.safetensors",
467
+ "vision_backbone.featurizer.blocks.19.norm1.bias": "model-00001-of-00003.safetensors",
468
+ "vision_backbone.featurizer.blocks.19.norm1.weight": "model-00001-of-00003.safetensors",
469
+ "vision_backbone.featurizer.blocks.19.norm2.bias": "model-00001-of-00003.safetensors",
470
+ "vision_backbone.featurizer.blocks.19.norm2.weight": "model-00001-of-00003.safetensors",
471
+ "vision_backbone.featurizer.blocks.2.attn.proj.bias": "model-00001-of-00003.safetensors",
472
+ "vision_backbone.featurizer.blocks.2.attn.proj.weight": "model-00001-of-00003.safetensors",
473
+ "vision_backbone.featurizer.blocks.2.attn.qkv.bias": "model-00001-of-00003.safetensors",
474
+ "vision_backbone.featurizer.blocks.2.attn.qkv.weight": "model-00001-of-00003.safetensors",
475
+ "vision_backbone.featurizer.blocks.2.ls1.scale_factor": "model-00001-of-00003.safetensors",
476
+ "vision_backbone.featurizer.blocks.2.ls2.scale_factor": "model-00001-of-00003.safetensors",
477
+ "vision_backbone.featurizer.blocks.2.mlp.fc1.bias": "model-00001-of-00003.safetensors",
478
+ "vision_backbone.featurizer.blocks.2.mlp.fc1.weight": "model-00001-of-00003.safetensors",
479
+ "vision_backbone.featurizer.blocks.2.mlp.fc2.bias": "model-00001-of-00003.safetensors",
480
+ "vision_backbone.featurizer.blocks.2.mlp.fc2.weight": "model-00001-of-00003.safetensors",
481
+ "vision_backbone.featurizer.blocks.2.norm1.bias": "model-00001-of-00003.safetensors",
482
+ "vision_backbone.featurizer.blocks.2.norm1.weight": "model-00001-of-00003.safetensors",
483
+ "vision_backbone.featurizer.blocks.2.norm2.bias": "model-00001-of-00003.safetensors",
484
+ "vision_backbone.featurizer.blocks.2.norm2.weight": "model-00001-of-00003.safetensors",
485
+ "vision_backbone.featurizer.blocks.20.attn.proj.bias": "model-00001-of-00003.safetensors",
486
+ "vision_backbone.featurizer.blocks.20.attn.proj.weight": "model-00001-of-00003.safetensors",
487
+ "vision_backbone.featurizer.blocks.20.attn.qkv.bias": "model-00001-of-00003.safetensors",
488
+ "vision_backbone.featurizer.blocks.20.attn.qkv.weight": "model-00001-of-00003.safetensors",
489
+ "vision_backbone.featurizer.blocks.20.ls1.scale_factor": "model-00001-of-00003.safetensors",
490
+ "vision_backbone.featurizer.blocks.20.ls2.scale_factor": "model-00001-of-00003.safetensors",
491
+ "vision_backbone.featurizer.blocks.20.mlp.fc1.bias": "model-00001-of-00003.safetensors",
492
+ "vision_backbone.featurizer.blocks.20.mlp.fc1.weight": "model-00001-of-00003.safetensors",
493
+ "vision_backbone.featurizer.blocks.20.mlp.fc2.bias": "model-00001-of-00003.safetensors",
494
+ "vision_backbone.featurizer.blocks.20.mlp.fc2.weight": "model-00001-of-00003.safetensors",
495
+ "vision_backbone.featurizer.blocks.20.norm1.bias": "model-00001-of-00003.safetensors",
496
+ "vision_backbone.featurizer.blocks.20.norm1.weight": "model-00001-of-00003.safetensors",
497
+ "vision_backbone.featurizer.blocks.20.norm2.bias": "model-00001-of-00003.safetensors",
498
+ "vision_backbone.featurizer.blocks.20.norm2.weight": "model-00001-of-00003.safetensors",
499
+ "vision_backbone.featurizer.blocks.21.attn.proj.bias": "model-00001-of-00003.safetensors",
500
+ "vision_backbone.featurizer.blocks.21.attn.proj.weight": "model-00001-of-00003.safetensors",
501
+ "vision_backbone.featurizer.blocks.21.attn.qkv.bias": "model-00001-of-00003.safetensors",
502
+ "vision_backbone.featurizer.blocks.21.attn.qkv.weight": "model-00001-of-00003.safetensors",
503
+ "vision_backbone.featurizer.blocks.21.ls1.scale_factor": "model-00001-of-00003.safetensors",
504
+ "vision_backbone.featurizer.blocks.21.ls2.scale_factor": "model-00001-of-00003.safetensors",
505
+ "vision_backbone.featurizer.blocks.21.mlp.fc1.bias": "model-00001-of-00003.safetensors",
506
+ "vision_backbone.featurizer.blocks.21.mlp.fc1.weight": "model-00001-of-00003.safetensors",
507
+ "vision_backbone.featurizer.blocks.21.mlp.fc2.bias": "model-00001-of-00003.safetensors",
508
+ "vision_backbone.featurizer.blocks.21.mlp.fc2.weight": "model-00001-of-00003.safetensors",
509
+ "vision_backbone.featurizer.blocks.21.norm1.bias": "model-00001-of-00003.safetensors",
510
+ "vision_backbone.featurizer.blocks.21.norm1.weight": "model-00001-of-00003.safetensors",
511
+ "vision_backbone.featurizer.blocks.21.norm2.bias": "model-00001-of-00003.safetensors",
512
+ "vision_backbone.featurizer.blocks.21.norm2.weight": "model-00001-of-00003.safetensors",
513
+ "vision_backbone.featurizer.blocks.22.attn.proj.bias": "model-00001-of-00003.safetensors",
514
+ "vision_backbone.featurizer.blocks.22.attn.proj.weight": "model-00001-of-00003.safetensors",
515
+ "vision_backbone.featurizer.blocks.22.attn.qkv.bias": "model-00001-of-00003.safetensors",
516
+ "vision_backbone.featurizer.blocks.22.attn.qkv.weight": "model-00001-of-00003.safetensors",
517
+ "vision_backbone.featurizer.blocks.22.ls1.scale_factor": "model-00001-of-00003.safetensors",
518
+ "vision_backbone.featurizer.blocks.22.ls2.scale_factor": "model-00001-of-00003.safetensors",
519
+ "vision_backbone.featurizer.blocks.22.mlp.fc1.bias": "model-00001-of-00003.safetensors",
520
+ "vision_backbone.featurizer.blocks.22.mlp.fc1.weight": "model-00001-of-00003.safetensors",
521
+ "vision_backbone.featurizer.blocks.22.mlp.fc2.bias": "model-00001-of-00003.safetensors",
522
+ "vision_backbone.featurizer.blocks.22.mlp.fc2.weight": "model-00001-of-00003.safetensors",
523
+ "vision_backbone.featurizer.blocks.22.norm1.bias": "model-00001-of-00003.safetensors",
524
+ "vision_backbone.featurizer.blocks.22.norm1.weight": "model-00001-of-00003.safetensors",
525
+ "vision_backbone.featurizer.blocks.22.norm2.bias": "model-00001-of-00003.safetensors",
526
+ "vision_backbone.featurizer.blocks.22.norm2.weight": "model-00001-of-00003.safetensors",
527
+ "vision_backbone.featurizer.blocks.23.attn.proj.bias": "model-00001-of-00003.safetensors",
528
+ "vision_backbone.featurizer.blocks.23.attn.proj.weight": "model-00001-of-00003.safetensors",
529
+ "vision_backbone.featurizer.blocks.23.attn.qkv.bias": "model-00001-of-00003.safetensors",
530
+ "vision_backbone.featurizer.blocks.23.attn.qkv.weight": "model-00001-of-00003.safetensors",
531
+ "vision_backbone.featurizer.blocks.23.ls1.scale_factor": "model-00001-of-00003.safetensors",
532
+ "vision_backbone.featurizer.blocks.23.ls2.scale_factor": "model-00001-of-00003.safetensors",
533
+ "vision_backbone.featurizer.blocks.23.mlp.fc1.bias": "model-00001-of-00003.safetensors",
534
+ "vision_backbone.featurizer.blocks.23.mlp.fc1.weight": "model-00001-of-00003.safetensors",
535
+ "vision_backbone.featurizer.blocks.23.mlp.fc2.bias": "model-00001-of-00003.safetensors",
536
+ "vision_backbone.featurizer.blocks.23.mlp.fc2.weight": "model-00001-of-00003.safetensors",
537
+ "vision_backbone.featurizer.blocks.23.norm1.bias": "model-00001-of-00003.safetensors",
538
+ "vision_backbone.featurizer.blocks.23.norm1.weight": "model-00001-of-00003.safetensors",
539
+ "vision_backbone.featurizer.blocks.23.norm2.bias": "model-00001-of-00003.safetensors",
540
+ "vision_backbone.featurizer.blocks.23.norm2.weight": "model-00001-of-00003.safetensors",
541
+ "vision_backbone.featurizer.blocks.3.attn.proj.bias": "model-00001-of-00003.safetensors",
542
+ "vision_backbone.featurizer.blocks.3.attn.proj.weight": "model-00001-of-00003.safetensors",
543
+ "vision_backbone.featurizer.blocks.3.attn.qkv.bias": "model-00001-of-00003.safetensors",
544
+ "vision_backbone.featurizer.blocks.3.attn.qkv.weight": "model-00001-of-00003.safetensors",
545
+ "vision_backbone.featurizer.blocks.3.ls1.scale_factor": "model-00001-of-00003.safetensors",
546
+ "vision_backbone.featurizer.blocks.3.ls2.scale_factor": "model-00001-of-00003.safetensors",
547
+ "vision_backbone.featurizer.blocks.3.mlp.fc1.bias": "model-00001-of-00003.safetensors",
548
+ "vision_backbone.featurizer.blocks.3.mlp.fc1.weight": "model-00001-of-00003.safetensors",
549
+ "vision_backbone.featurizer.blocks.3.mlp.fc2.bias": "model-00001-of-00003.safetensors",
550
+ "vision_backbone.featurizer.blocks.3.mlp.fc2.weight": "model-00001-of-00003.safetensors",
551
+ "vision_backbone.featurizer.blocks.3.norm1.bias": "model-00001-of-00003.safetensors",
552
+ "vision_backbone.featurizer.blocks.3.norm1.weight": "model-00001-of-00003.safetensors",
553
+ "vision_backbone.featurizer.blocks.3.norm2.bias": "model-00001-of-00003.safetensors",
554
+ "vision_backbone.featurizer.blocks.3.norm2.weight": "model-00001-of-00003.safetensors",
555
+ "vision_backbone.featurizer.blocks.4.attn.proj.bias": "model-00001-of-00003.safetensors",
556
+ "vision_backbone.featurizer.blocks.4.attn.proj.weight": "model-00001-of-00003.safetensors",
557
+ "vision_backbone.featurizer.blocks.4.attn.qkv.bias": "model-00001-of-00003.safetensors",
558
+ "vision_backbone.featurizer.blocks.4.attn.qkv.weight": "model-00001-of-00003.safetensors",
559
+ "vision_backbone.featurizer.blocks.4.ls1.scale_factor": "model-00001-of-00003.safetensors",
560
+ "vision_backbone.featurizer.blocks.4.ls2.scale_factor": "model-00001-of-00003.safetensors",
561
+ "vision_backbone.featurizer.blocks.4.mlp.fc1.bias": "model-00001-of-00003.safetensors",
562
+ "vision_backbone.featurizer.blocks.4.mlp.fc1.weight": "model-00001-of-00003.safetensors",
563
+ "vision_backbone.featurizer.blocks.4.mlp.fc2.bias": "model-00001-of-00003.safetensors",
564
+ "vision_backbone.featurizer.blocks.4.mlp.fc2.weight": "model-00001-of-00003.safetensors",
565
+ "vision_backbone.featurizer.blocks.4.norm1.bias": "model-00001-of-00003.safetensors",
566
+ "vision_backbone.featurizer.blocks.4.norm1.weight": "model-00001-of-00003.safetensors",
567
+ "vision_backbone.featurizer.blocks.4.norm2.bias": "model-00001-of-00003.safetensors",
568
+ "vision_backbone.featurizer.blocks.4.norm2.weight": "model-00001-of-00003.safetensors",
569
+ "vision_backbone.featurizer.blocks.5.attn.proj.bias": "model-00001-of-00003.safetensors",
570
+ "vision_backbone.featurizer.blocks.5.attn.proj.weight": "model-00001-of-00003.safetensors",
571
+ "vision_backbone.featurizer.blocks.5.attn.qkv.bias": "model-00001-of-00003.safetensors",
572
+ "vision_backbone.featurizer.blocks.5.attn.qkv.weight": "model-00001-of-00003.safetensors",
573
+ "vision_backbone.featurizer.blocks.5.ls1.scale_factor": "model-00001-of-00003.safetensors",
574
+ "vision_backbone.featurizer.blocks.5.ls2.scale_factor": "model-00001-of-00003.safetensors",
575
+ "vision_backbone.featurizer.blocks.5.mlp.fc1.bias": "model-00001-of-00003.safetensors",
576
+ "vision_backbone.featurizer.blocks.5.mlp.fc1.weight": "model-00001-of-00003.safetensors",
577
+ "vision_backbone.featurizer.blocks.5.mlp.fc2.bias": "model-00001-of-00003.safetensors",
578
+ "vision_backbone.featurizer.blocks.5.mlp.fc2.weight": "model-00001-of-00003.safetensors",
579
+ "vision_backbone.featurizer.blocks.5.norm1.bias": "model-00001-of-00003.safetensors",
580
+ "vision_backbone.featurizer.blocks.5.norm1.weight": "model-00001-of-00003.safetensors",
581
+ "vision_backbone.featurizer.blocks.5.norm2.bias": "model-00001-of-00003.safetensors",
582
+ "vision_backbone.featurizer.blocks.5.norm2.weight": "model-00001-of-00003.safetensors",
583
+ "vision_backbone.featurizer.blocks.6.attn.proj.bias": "model-00001-of-00003.safetensors",
584
+ "vision_backbone.featurizer.blocks.6.attn.proj.weight": "model-00001-of-00003.safetensors",
585
+ "vision_backbone.featurizer.blocks.6.attn.qkv.bias": "model-00001-of-00003.safetensors",
586
+ "vision_backbone.featurizer.blocks.6.attn.qkv.weight": "model-00001-of-00003.safetensors",
587
+ "vision_backbone.featurizer.blocks.6.ls1.scale_factor": "model-00001-of-00003.safetensors",
588
+ "vision_backbone.featurizer.blocks.6.ls2.scale_factor": "model-00001-of-00003.safetensors",
589
+ "vision_backbone.featurizer.blocks.6.mlp.fc1.bias": "model-00001-of-00003.safetensors",
590
+ "vision_backbone.featurizer.blocks.6.mlp.fc1.weight": "model-00001-of-00003.safetensors",
591
+ "vision_backbone.featurizer.blocks.6.mlp.fc2.bias": "model-00001-of-00003.safetensors",
592
+ "vision_backbone.featurizer.blocks.6.mlp.fc2.weight": "model-00001-of-00003.safetensors",
593
+ "vision_backbone.featurizer.blocks.6.norm1.bias": "model-00001-of-00003.safetensors",
594
+ "vision_backbone.featurizer.blocks.6.norm1.weight": "model-00001-of-00003.safetensors",
595
+ "vision_backbone.featurizer.blocks.6.norm2.bias": "model-00001-of-00003.safetensors",
596
+ "vision_backbone.featurizer.blocks.6.norm2.weight": "model-00001-of-00003.safetensors",
597
+ "vision_backbone.featurizer.blocks.7.attn.proj.bias": "model-00001-of-00003.safetensors",
598
+ "vision_backbone.featurizer.blocks.7.attn.proj.weight": "model-00001-of-00003.safetensors",
599
+ "vision_backbone.featurizer.blocks.7.attn.qkv.bias": "model-00001-of-00003.safetensors",
600
+ "vision_backbone.featurizer.blocks.7.attn.qkv.weight": "model-00001-of-00003.safetensors",
601
+ "vision_backbone.featurizer.blocks.7.ls1.scale_factor": "model-00001-of-00003.safetensors",
602
+ "vision_backbone.featurizer.blocks.7.ls2.scale_factor": "model-00001-of-00003.safetensors",
603
+ "vision_backbone.featurizer.blocks.7.mlp.fc1.bias": "model-00001-of-00003.safetensors",
604
+ "vision_backbone.featurizer.blocks.7.mlp.fc1.weight": "model-00001-of-00003.safetensors",
605
+ "vision_backbone.featurizer.blocks.7.mlp.fc2.bias": "model-00001-of-00003.safetensors",
606
+ "vision_backbone.featurizer.blocks.7.mlp.fc2.weight": "model-00001-of-00003.safetensors",
607
+ "vision_backbone.featurizer.blocks.7.norm1.bias": "model-00001-of-00003.safetensors",
608
+ "vision_backbone.featurizer.blocks.7.norm1.weight": "model-00001-of-00003.safetensors",
609
+ "vision_backbone.featurizer.blocks.7.norm2.bias": "model-00001-of-00003.safetensors",
610
+ "vision_backbone.featurizer.blocks.7.norm2.weight": "model-00001-of-00003.safetensors",
611
+ "vision_backbone.featurizer.blocks.8.attn.proj.bias": "model-00001-of-00003.safetensors",
612
+ "vision_backbone.featurizer.blocks.8.attn.proj.weight": "model-00001-of-00003.safetensors",
613
+ "vision_backbone.featurizer.blocks.8.attn.qkv.bias": "model-00001-of-00003.safetensors",
614
+ "vision_backbone.featurizer.blocks.8.attn.qkv.weight": "model-00001-of-00003.safetensors",
615
+ "vision_backbone.featurizer.blocks.8.ls1.scale_factor": "model-00001-of-00003.safetensors",
616
+ "vision_backbone.featurizer.blocks.8.ls2.scale_factor": "model-00001-of-00003.safetensors",
617
+ "vision_backbone.featurizer.blocks.8.mlp.fc1.bias": "model-00001-of-00003.safetensors",
618
+ "vision_backbone.featurizer.blocks.8.mlp.fc1.weight": "model-00001-of-00003.safetensors",
619
+ "vision_backbone.featurizer.blocks.8.mlp.fc2.bias": "model-00001-of-00003.safetensors",
620
+ "vision_backbone.featurizer.blocks.8.mlp.fc2.weight": "model-00001-of-00003.safetensors",
621
+ "vision_backbone.featurizer.blocks.8.norm1.bias": "model-00001-of-00003.safetensors",
622
+ "vision_backbone.featurizer.blocks.8.norm1.weight": "model-00001-of-00003.safetensors",
623
+ "vision_backbone.featurizer.blocks.8.norm2.bias": "model-00001-of-00003.safetensors",
624
+ "vision_backbone.featurizer.blocks.8.norm2.weight": "model-00001-of-00003.safetensors",
625
+ "vision_backbone.featurizer.blocks.9.attn.proj.bias": "model-00001-of-00003.safetensors",
626
+ "vision_backbone.featurizer.blocks.9.attn.proj.weight": "model-00001-of-00003.safetensors",
627
+ "vision_backbone.featurizer.blocks.9.attn.qkv.bias": "model-00001-of-00003.safetensors",
628
+ "vision_backbone.featurizer.blocks.9.attn.qkv.weight": "model-00001-of-00003.safetensors",
629
+ "vision_backbone.featurizer.blocks.9.ls1.scale_factor": "model-00001-of-00003.safetensors",
630
+ "vision_backbone.featurizer.blocks.9.ls2.scale_factor": "model-00001-of-00003.safetensors",
631
+ "vision_backbone.featurizer.blocks.9.mlp.fc1.bias": "model-00001-of-00003.safetensors",
632
+ "vision_backbone.featurizer.blocks.9.mlp.fc1.weight": "model-00001-of-00003.safetensors",
633
+ "vision_backbone.featurizer.blocks.9.mlp.fc2.bias": "model-00001-of-00003.safetensors",
634
+ "vision_backbone.featurizer.blocks.9.mlp.fc2.weight": "model-00001-of-00003.safetensors",
635
+ "vision_backbone.featurizer.blocks.9.norm1.bias": "model-00001-of-00003.safetensors",
636
+ "vision_backbone.featurizer.blocks.9.norm1.weight": "model-00001-of-00003.safetensors",
637
+ "vision_backbone.featurizer.blocks.9.norm2.bias": "model-00001-of-00003.safetensors",
638
+ "vision_backbone.featurizer.blocks.9.norm2.weight": "model-00001-of-00003.safetensors",
639
+ "vision_backbone.featurizer.cls_token": "model-00001-of-00003.safetensors",
640
+ "vision_backbone.featurizer.norm.bias": "model-00001-of-00003.safetensors",
641
+ "vision_backbone.featurizer.norm.weight": "model-00001-of-00003.safetensors",
642
+ "vision_backbone.featurizer.patch_embed.proj.bias": "model-00001-of-00003.safetensors",
643
+ "vision_backbone.featurizer.patch_embed.proj.weight": "model-00001-of-00003.safetensors",
644
+ "vision_backbone.featurizer.pos_embed": "model-00001-of-00003.safetensors",
645
+ "vision_backbone.featurizer.reg_token": "model-00001-of-00003.safetensors",
646
+ "vision_backbone.fused_featurizer.attn_pool.kv.bias": "model-00001-of-00003.safetensors",
647
+ "vision_backbone.fused_featurizer.attn_pool.kv.weight": "model-00001-of-00003.safetensors",
648
+ "vision_backbone.fused_featurizer.attn_pool.latent": "model-00001-of-00003.safetensors",
649
+ "vision_backbone.fused_featurizer.attn_pool.mlp.fc1.bias": "model-00001-of-00003.safetensors",
650
+ "vision_backbone.fused_featurizer.attn_pool.mlp.fc1.weight": "model-00001-of-00003.safetensors",
651
+ "vision_backbone.fused_featurizer.attn_pool.mlp.fc2.bias": "model-00001-of-00003.safetensors",
652
+ "vision_backbone.fused_featurizer.attn_pool.mlp.fc2.weight": "model-00001-of-00003.safetensors",
653
+ "vision_backbone.fused_featurizer.attn_pool.norm.bias": "model-00001-of-00003.safetensors",
654
+ "vision_backbone.fused_featurizer.attn_pool.norm.weight": "model-00001-of-00003.safetensors",
655
+ "vision_backbone.fused_featurizer.attn_pool.proj.bias": "model-00001-of-00003.safetensors",
656
+ "vision_backbone.fused_featurizer.attn_pool.proj.weight": "model-00001-of-00003.safetensors",
657
+ "vision_backbone.fused_featurizer.attn_pool.q.bias": "model-00001-of-00003.safetensors",
658
+ "vision_backbone.fused_featurizer.attn_pool.q.weight": "model-00001-of-00003.safetensors",
659
+ "vision_backbone.fused_featurizer.blocks.0.attn.proj.bias": "model-00001-of-00003.safetensors",
660
+ "vision_backbone.fused_featurizer.blocks.0.attn.proj.weight": "model-00001-of-00003.safetensors",
661
+ "vision_backbone.fused_featurizer.blocks.0.attn.qkv.bias": "model-00001-of-00003.safetensors",
662
+ "vision_backbone.fused_featurizer.blocks.0.attn.qkv.weight": "model-00001-of-00003.safetensors",
663
+ "vision_backbone.fused_featurizer.blocks.0.mlp.fc1.bias": "model-00001-of-00003.safetensors",
664
+ "vision_backbone.fused_featurizer.blocks.0.mlp.fc1.weight": "model-00001-of-00003.safetensors",
665
+ "vision_backbone.fused_featurizer.blocks.0.mlp.fc2.bias": "model-00001-of-00003.safetensors",
666
+ "vision_backbone.fused_featurizer.blocks.0.mlp.fc2.weight": "model-00001-of-00003.safetensors",
667
+ "vision_backbone.fused_featurizer.blocks.0.norm1.bias": "model-00001-of-00003.safetensors",
668
+ "vision_backbone.fused_featurizer.blocks.0.norm1.weight": "model-00001-of-00003.safetensors",
669
+ "vision_backbone.fused_featurizer.blocks.0.norm2.bias": "model-00001-of-00003.safetensors",
670
+ "vision_backbone.fused_featurizer.blocks.0.norm2.weight": "model-00001-of-00003.safetensors",
671
+ "vision_backbone.fused_featurizer.blocks.1.attn.proj.bias": "model-00001-of-00003.safetensors",
672
+ "vision_backbone.fused_featurizer.blocks.1.attn.proj.weight": "model-00001-of-00003.safetensors",
673
+ "vision_backbone.fused_featurizer.blocks.1.attn.qkv.bias": "model-00001-of-00003.safetensors",
674
+ "vision_backbone.fused_featurizer.blocks.1.attn.qkv.weight": "model-00001-of-00003.safetensors",
675
+ "vision_backbone.fused_featurizer.blocks.1.mlp.fc1.bias": "model-00001-of-00003.safetensors",
676
+ "vision_backbone.fused_featurizer.blocks.1.mlp.fc1.weight": "model-00001-of-00003.safetensors",
677
+ "vision_backbone.fused_featurizer.blocks.1.mlp.fc2.bias": "model-00001-of-00003.safetensors",
678
+ "vision_backbone.fused_featurizer.blocks.1.mlp.fc2.weight": "model-00001-of-00003.safetensors",
679
+ "vision_backbone.fused_featurizer.blocks.1.norm1.bias": "model-00001-of-00003.safetensors",
680
+ "vision_backbone.fused_featurizer.blocks.1.norm1.weight": "model-00001-of-00003.safetensors",
681
+ "vision_backbone.fused_featurizer.blocks.1.norm2.bias": "model-00001-of-00003.safetensors",
682
+ "vision_backbone.fused_featurizer.blocks.1.norm2.weight": "model-00001-of-00003.safetensors",
683
+ "vision_backbone.fused_featurizer.blocks.10.attn.proj.bias": "model-00001-of-00003.safetensors",
684
+ "vision_backbone.fused_featurizer.blocks.10.attn.proj.weight": "model-00001-of-00003.safetensors",
685
+ "vision_backbone.fused_featurizer.blocks.10.attn.qkv.bias": "model-00001-of-00003.safetensors",
686
+ "vision_backbone.fused_featurizer.blocks.10.attn.qkv.weight": "model-00001-of-00003.safetensors",
687
+ "vision_backbone.fused_featurizer.blocks.10.mlp.fc1.bias": "model-00001-of-00003.safetensors",
688
+ "vision_backbone.fused_featurizer.blocks.10.mlp.fc1.weight": "model-00001-of-00003.safetensors",
689
+ "vision_backbone.fused_featurizer.blocks.10.mlp.fc2.bias": "model-00001-of-00003.safetensors",
690
+ "vision_backbone.fused_featurizer.blocks.10.mlp.fc2.weight": "model-00001-of-00003.safetensors",
691
+ "vision_backbone.fused_featurizer.blocks.10.norm1.bias": "model-00001-of-00003.safetensors",
692
+ "vision_backbone.fused_featurizer.blocks.10.norm1.weight": "model-00001-of-00003.safetensors",
693
+ "vision_backbone.fused_featurizer.blocks.10.norm2.bias": "model-00001-of-00003.safetensors",
694
+ "vision_backbone.fused_featurizer.blocks.10.norm2.weight": "model-00001-of-00003.safetensors",
695
+ "vision_backbone.fused_featurizer.blocks.11.attn.proj.bias": "model-00001-of-00003.safetensors",
696
+ "vision_backbone.fused_featurizer.blocks.11.attn.proj.weight": "model-00001-of-00003.safetensors",
697
+ "vision_backbone.fused_featurizer.blocks.11.attn.qkv.bias": "model-00001-of-00003.safetensors",
698
+ "vision_backbone.fused_featurizer.blocks.11.attn.qkv.weight": "model-00001-of-00003.safetensors",
699
+ "vision_backbone.fused_featurizer.blocks.11.mlp.fc1.bias": "model-00001-of-00003.safetensors",
700
+ "vision_backbone.fused_featurizer.blocks.11.mlp.fc1.weight": "model-00001-of-00003.safetensors",
701
+ "vision_backbone.fused_featurizer.blocks.11.mlp.fc2.bias": "model-00001-of-00003.safetensors",
702
+ "vision_backbone.fused_featurizer.blocks.11.mlp.fc2.weight": "model-00001-of-00003.safetensors",
703
+ "vision_backbone.fused_featurizer.blocks.11.norm1.bias": "model-00001-of-00003.safetensors",
704
+ "vision_backbone.fused_featurizer.blocks.11.norm1.weight": "model-00001-of-00003.safetensors",
705
+ "vision_backbone.fused_featurizer.blocks.11.norm2.bias": "model-00001-of-00003.safetensors",
706
+ "vision_backbone.fused_featurizer.blocks.11.norm2.weight": "model-00001-of-00003.safetensors",
707
+ "vision_backbone.fused_featurizer.blocks.12.attn.proj.bias": "model-00001-of-00003.safetensors",
708
+ "vision_backbone.fused_featurizer.blocks.12.attn.proj.weight": "model-00001-of-00003.safetensors",
709
+ "vision_backbone.fused_featurizer.blocks.12.attn.qkv.bias": "model-00001-of-00003.safetensors",
710
+ "vision_backbone.fused_featurizer.blocks.12.attn.qkv.weight": "model-00001-of-00003.safetensors",
711
+ "vision_backbone.fused_featurizer.blocks.12.mlp.fc1.bias": "model-00001-of-00003.safetensors",
712
+ "vision_backbone.fused_featurizer.blocks.12.mlp.fc1.weight": "model-00001-of-00003.safetensors",
713
+ "vision_backbone.fused_featurizer.blocks.12.mlp.fc2.bias": "model-00001-of-00003.safetensors",
714
+ "vision_backbone.fused_featurizer.blocks.12.mlp.fc2.weight": "model-00001-of-00003.safetensors",
715
+ "vision_backbone.fused_featurizer.blocks.12.norm1.bias": "model-00001-of-00003.safetensors",
716
+ "vision_backbone.fused_featurizer.blocks.12.norm1.weight": "model-00001-of-00003.safetensors",
717
+ "vision_backbone.fused_featurizer.blocks.12.norm2.bias": "model-00001-of-00003.safetensors",
718
+ "vision_backbone.fused_featurizer.blocks.12.norm2.weight": "model-00001-of-00003.safetensors",
719
+ "vision_backbone.fused_featurizer.blocks.13.attn.proj.bias": "model-00001-of-00003.safetensors",
720
+ "vision_backbone.fused_featurizer.blocks.13.attn.proj.weight": "model-00001-of-00003.safetensors",
721
+ "vision_backbone.fused_featurizer.blocks.13.attn.qkv.bias": "model-00001-of-00003.safetensors",
722
+ "vision_backbone.fused_featurizer.blocks.13.attn.qkv.weight": "model-00001-of-00003.safetensors",
723
+ "vision_backbone.fused_featurizer.blocks.13.mlp.fc1.bias": "model-00001-of-00003.safetensors",
724
+ "vision_backbone.fused_featurizer.blocks.13.mlp.fc1.weight": "model-00001-of-00003.safetensors",
725
+ "vision_backbone.fused_featurizer.blocks.13.mlp.fc2.bias": "model-00001-of-00003.safetensors",
726
+ "vision_backbone.fused_featurizer.blocks.13.mlp.fc2.weight": "model-00001-of-00003.safetensors",
727
+ "vision_backbone.fused_featurizer.blocks.13.norm1.bias": "model-00001-of-00003.safetensors",
728
+ "vision_backbone.fused_featurizer.blocks.13.norm1.weight": "model-00001-of-00003.safetensors",
729
+ "vision_backbone.fused_featurizer.blocks.13.norm2.bias": "model-00001-of-00003.safetensors",
730
+ "vision_backbone.fused_featurizer.blocks.13.norm2.weight": "model-00001-of-00003.safetensors",
731
+ "vision_backbone.fused_featurizer.blocks.14.attn.proj.bias": "model-00001-of-00003.safetensors",
732
+ "vision_backbone.fused_featurizer.blocks.14.attn.proj.weight": "model-00001-of-00003.safetensors",
733
+ "vision_backbone.fused_featurizer.blocks.14.attn.qkv.bias": "model-00001-of-00003.safetensors",
734
+ "vision_backbone.fused_featurizer.blocks.14.attn.qkv.weight": "model-00001-of-00003.safetensors",
735
+ "vision_backbone.fused_featurizer.blocks.14.mlp.fc1.bias": "model-00001-of-00003.safetensors",
736
+ "vision_backbone.fused_featurizer.blocks.14.mlp.fc1.weight": "model-00001-of-00003.safetensors",
737
+ "vision_backbone.fused_featurizer.blocks.14.mlp.fc2.bias": "model-00001-of-00003.safetensors",
738
+ "vision_backbone.fused_featurizer.blocks.14.mlp.fc2.weight": "model-00001-of-00003.safetensors",
739
+ "vision_backbone.fused_featurizer.blocks.14.norm1.bias": "model-00001-of-00003.safetensors",
740
+ "vision_backbone.fused_featurizer.blocks.14.norm1.weight": "model-00001-of-00003.safetensors",
741
+ "vision_backbone.fused_featurizer.blocks.14.norm2.bias": "model-00001-of-00003.safetensors",
742
+ "vision_backbone.fused_featurizer.blocks.14.norm2.weight": "model-00001-of-00003.safetensors",
743
+ "vision_backbone.fused_featurizer.blocks.15.attn.proj.bias": "model-00001-of-00003.safetensors",
744
+ "vision_backbone.fused_featurizer.blocks.15.attn.proj.weight": "model-00001-of-00003.safetensors",
745
+ "vision_backbone.fused_featurizer.blocks.15.attn.qkv.bias": "model-00001-of-00003.safetensors",
746
+ "vision_backbone.fused_featurizer.blocks.15.attn.qkv.weight": "model-00001-of-00003.safetensors",
747
+ "vision_backbone.fused_featurizer.blocks.15.mlp.fc1.bias": "model-00001-of-00003.safetensors",
748
+ "vision_backbone.fused_featurizer.blocks.15.mlp.fc1.weight": "model-00001-of-00003.safetensors",
749
+ "vision_backbone.fused_featurizer.blocks.15.mlp.fc2.bias": "model-00001-of-00003.safetensors",
750
+ "vision_backbone.fused_featurizer.blocks.15.mlp.fc2.weight": "model-00001-of-00003.safetensors",
751
+ "vision_backbone.fused_featurizer.blocks.15.norm1.bias": "model-00001-of-00003.safetensors",
752
+ "vision_backbone.fused_featurizer.blocks.15.norm1.weight": "model-00001-of-00003.safetensors",
753
+ "vision_backbone.fused_featurizer.blocks.15.norm2.bias": "model-00001-of-00003.safetensors",
754
+ "vision_backbone.fused_featurizer.blocks.15.norm2.weight": "model-00001-of-00003.safetensors",
755
+ "vision_backbone.fused_featurizer.blocks.16.attn.proj.bias": "model-00001-of-00003.safetensors",
756
+ "vision_backbone.fused_featurizer.blocks.16.attn.proj.weight": "model-00001-of-00003.safetensors",
757
+ "vision_backbone.fused_featurizer.blocks.16.attn.qkv.bias": "model-00001-of-00003.safetensors",
758
+ "vision_backbone.fused_featurizer.blocks.16.attn.qkv.weight": "model-00001-of-00003.safetensors",
759
+ "vision_backbone.fused_featurizer.blocks.16.mlp.fc1.bias": "model-00001-of-00003.safetensors",
760
+ "vision_backbone.fused_featurizer.blocks.16.mlp.fc1.weight": "model-00001-of-00003.safetensors",
761
+ "vision_backbone.fused_featurizer.blocks.16.mlp.fc2.bias": "model-00001-of-00003.safetensors",
762
+ "vision_backbone.fused_featurizer.blocks.16.mlp.fc2.weight": "model-00001-of-00003.safetensors",
763
+ "vision_backbone.fused_featurizer.blocks.16.norm1.bias": "model-00001-of-00003.safetensors",
764
+ "vision_backbone.fused_featurizer.blocks.16.norm1.weight": "model-00001-of-00003.safetensors",
765
+ "vision_backbone.fused_featurizer.blocks.16.norm2.bias": "model-00001-of-00003.safetensors",
766
+ "vision_backbone.fused_featurizer.blocks.16.norm2.weight": "model-00001-of-00003.safetensors",
767
+ "vision_backbone.fused_featurizer.blocks.17.attn.proj.bias": "model-00001-of-00003.safetensors",
768
+ "vision_backbone.fused_featurizer.blocks.17.attn.proj.weight": "model-00001-of-00003.safetensors",
769
+ "vision_backbone.fused_featurizer.blocks.17.attn.qkv.bias": "model-00001-of-00003.safetensors",
770
+ "vision_backbone.fused_featurizer.blocks.17.attn.qkv.weight": "model-00001-of-00003.safetensors",
771
+ "vision_backbone.fused_featurizer.blocks.17.mlp.fc1.bias": "model-00001-of-00003.safetensors",
772
+ "vision_backbone.fused_featurizer.blocks.17.mlp.fc1.weight": "model-00001-of-00003.safetensors",
773
+ "vision_backbone.fused_featurizer.blocks.17.mlp.fc2.bias": "model-00001-of-00003.safetensors",
774
+ "vision_backbone.fused_featurizer.blocks.17.mlp.fc2.weight": "model-00001-of-00003.safetensors",
775
+ "vision_backbone.fused_featurizer.blocks.17.norm1.bias": "model-00001-of-00003.safetensors",
776
+ "vision_backbone.fused_featurizer.blocks.17.norm1.weight": "model-00001-of-00003.safetensors",
777
+ "vision_backbone.fused_featurizer.blocks.17.norm2.bias": "model-00001-of-00003.safetensors",
778
+ "vision_backbone.fused_featurizer.blocks.17.norm2.weight": "model-00001-of-00003.safetensors",
779
+ "vision_backbone.fused_featurizer.blocks.18.attn.proj.bias": "model-00001-of-00003.safetensors",
780
+ "vision_backbone.fused_featurizer.blocks.18.attn.proj.weight": "model-00001-of-00003.safetensors",
781
+ "vision_backbone.fused_featurizer.blocks.18.attn.qkv.bias": "model-00001-of-00003.safetensors",
782
+ "vision_backbone.fused_featurizer.blocks.18.attn.qkv.weight": "model-00001-of-00003.safetensors",
783
+ "vision_backbone.fused_featurizer.blocks.18.mlp.fc1.bias": "model-00001-of-00003.safetensors",
784
+ "vision_backbone.fused_featurizer.blocks.18.mlp.fc1.weight": "model-00001-of-00003.safetensors",
785
+ "vision_backbone.fused_featurizer.blocks.18.mlp.fc2.bias": "model-00001-of-00003.safetensors",
786
+ "vision_backbone.fused_featurizer.blocks.18.mlp.fc2.weight": "model-00001-of-00003.safetensors",
787
+ "vision_backbone.fused_featurizer.blocks.18.norm1.bias": "model-00001-of-00003.safetensors",
788
+ "vision_backbone.fused_featurizer.blocks.18.norm1.weight": "model-00001-of-00003.safetensors",
789
+ "vision_backbone.fused_featurizer.blocks.18.norm2.bias": "model-00001-of-00003.safetensors",
790
+ "vision_backbone.fused_featurizer.blocks.18.norm2.weight": "model-00001-of-00003.safetensors",
791
+ "vision_backbone.fused_featurizer.blocks.19.attn.proj.bias": "model-00001-of-00003.safetensors",
792
+ "vision_backbone.fused_featurizer.blocks.19.attn.proj.weight": "model-00001-of-00003.safetensors",
793
+ "vision_backbone.fused_featurizer.blocks.19.attn.qkv.bias": "model-00001-of-00003.safetensors",
794
+ "vision_backbone.fused_featurizer.blocks.19.attn.qkv.weight": "model-00001-of-00003.safetensors",
795
+ "vision_backbone.fused_featurizer.blocks.19.mlp.fc1.bias": "model-00001-of-00003.safetensors",
796
+ "vision_backbone.fused_featurizer.blocks.19.mlp.fc1.weight": "model-00001-of-00003.safetensors",
797
+ "vision_backbone.fused_featurizer.blocks.19.mlp.fc2.bias": "model-00001-of-00003.safetensors",
798
+ "vision_backbone.fused_featurizer.blocks.19.mlp.fc2.weight": "model-00001-of-00003.safetensors",
799
+ "vision_backbone.fused_featurizer.blocks.19.norm1.bias": "model-00001-of-00003.safetensors",
800
+ "vision_backbone.fused_featurizer.blocks.19.norm1.weight": "model-00001-of-00003.safetensors",
801
+ "vision_backbone.fused_featurizer.blocks.19.norm2.bias": "model-00001-of-00003.safetensors",
802
+ "vision_backbone.fused_featurizer.blocks.19.norm2.weight": "model-00001-of-00003.safetensors",
803
+ "vision_backbone.fused_featurizer.blocks.2.attn.proj.bias": "model-00001-of-00003.safetensors",
804
+ "vision_backbone.fused_featurizer.blocks.2.attn.proj.weight": "model-00001-of-00003.safetensors",
805
+ "vision_backbone.fused_featurizer.blocks.2.attn.qkv.bias": "model-00001-of-00003.safetensors",
806
+ "vision_backbone.fused_featurizer.blocks.2.attn.qkv.weight": "model-00001-of-00003.safetensors",
807
+ "vision_backbone.fused_featurizer.blocks.2.mlp.fc1.bias": "model-00001-of-00003.safetensors",
808
+ "vision_backbone.fused_featurizer.blocks.2.mlp.fc1.weight": "model-00001-of-00003.safetensors",
809
+ "vision_backbone.fused_featurizer.blocks.2.mlp.fc2.bias": "model-00001-of-00003.safetensors",
810
+ "vision_backbone.fused_featurizer.blocks.2.mlp.fc2.weight": "model-00001-of-00003.safetensors",
811
+ "vision_backbone.fused_featurizer.blocks.2.norm1.bias": "model-00001-of-00003.safetensors",
812
+ "vision_backbone.fused_featurizer.blocks.2.norm1.weight": "model-00001-of-00003.safetensors",
813
+ "vision_backbone.fused_featurizer.blocks.2.norm2.bias": "model-00001-of-00003.safetensors",
814
+ "vision_backbone.fused_featurizer.blocks.2.norm2.weight": "model-00001-of-00003.safetensors",
815
+ "vision_backbone.fused_featurizer.blocks.20.attn.proj.bias": "model-00001-of-00003.safetensors",
816
+ "vision_backbone.fused_featurizer.blocks.20.attn.proj.weight": "model-00001-of-00003.safetensors",
817
+ "vision_backbone.fused_featurizer.blocks.20.attn.qkv.bias": "model-00001-of-00003.safetensors",
818
+ "vision_backbone.fused_featurizer.blocks.20.attn.qkv.weight": "model-00001-of-00003.safetensors",
819
+ "vision_backbone.fused_featurizer.blocks.20.mlp.fc1.bias": "model-00001-of-00003.safetensors",
820
+ "vision_backbone.fused_featurizer.blocks.20.mlp.fc1.weight": "model-00001-of-00003.safetensors",
821
+ "vision_backbone.fused_featurizer.blocks.20.mlp.fc2.bias": "model-00001-of-00003.safetensors",
822
+ "vision_backbone.fused_featurizer.blocks.20.mlp.fc2.weight": "model-00001-of-00003.safetensors",
823
+ "vision_backbone.fused_featurizer.blocks.20.norm1.bias": "model-00001-of-00003.safetensors",
824
+ "vision_backbone.fused_featurizer.blocks.20.norm1.weight": "model-00001-of-00003.safetensors",
825
+ "vision_backbone.fused_featurizer.blocks.20.norm2.bias": "model-00001-of-00003.safetensors",
826
+ "vision_backbone.fused_featurizer.blocks.20.norm2.weight": "model-00001-of-00003.safetensors",
827
+ "vision_backbone.fused_featurizer.blocks.21.attn.proj.bias": "model-00001-of-00003.safetensors",
828
+ "vision_backbone.fused_featurizer.blocks.21.attn.proj.weight": "model-00001-of-00003.safetensors",
829
+ "vision_backbone.fused_featurizer.blocks.21.attn.qkv.bias": "model-00001-of-00003.safetensors",
830
+ "vision_backbone.fused_featurizer.blocks.21.attn.qkv.weight": "model-00001-of-00003.safetensors",
831
+ "vision_backbone.fused_featurizer.blocks.21.mlp.fc1.bias": "model-00001-of-00003.safetensors",
832
+ "vision_backbone.fused_featurizer.blocks.21.mlp.fc1.weight": "model-00001-of-00003.safetensors",
833
+ "vision_backbone.fused_featurizer.blocks.21.mlp.fc2.bias": "model-00001-of-00003.safetensors",
834
+ "vision_backbone.fused_featurizer.blocks.21.mlp.fc2.weight": "model-00001-of-00003.safetensors",
835
+ "vision_backbone.fused_featurizer.blocks.21.norm1.bias": "model-00001-of-00003.safetensors",
836
+ "vision_backbone.fused_featurizer.blocks.21.norm1.weight": "model-00001-of-00003.safetensors",
837
+ "vision_backbone.fused_featurizer.blocks.21.norm2.bias": "model-00001-of-00003.safetensors",
838
+ "vision_backbone.fused_featurizer.blocks.21.norm2.weight": "model-00001-of-00003.safetensors",
839
+ "vision_backbone.fused_featurizer.blocks.22.attn.proj.bias": "model-00001-of-00003.safetensors",
840
+ "vision_backbone.fused_featurizer.blocks.22.attn.proj.weight": "model-00001-of-00003.safetensors",
841
+ "vision_backbone.fused_featurizer.blocks.22.attn.qkv.bias": "model-00001-of-00003.safetensors",
842
+ "vision_backbone.fused_featurizer.blocks.22.attn.qkv.weight": "model-00001-of-00003.safetensors",
843
+ "vision_backbone.fused_featurizer.blocks.22.mlp.fc1.bias": "model-00001-of-00003.safetensors",
844
+ "vision_backbone.fused_featurizer.blocks.22.mlp.fc1.weight": "model-00001-of-00003.safetensors",
845
+ "vision_backbone.fused_featurizer.blocks.22.mlp.fc2.bias": "model-00001-of-00003.safetensors",
846
+ "vision_backbone.fused_featurizer.blocks.22.mlp.fc2.weight": "model-00001-of-00003.safetensors",
847
+ "vision_backbone.fused_featurizer.blocks.22.norm1.bias": "model-00001-of-00003.safetensors",
848
+ "vision_backbone.fused_featurizer.blocks.22.norm1.weight": "model-00001-of-00003.safetensors",
849
+ "vision_backbone.fused_featurizer.blocks.22.norm2.bias": "model-00001-of-00003.safetensors",
850
+ "vision_backbone.fused_featurizer.blocks.22.norm2.weight": "model-00001-of-00003.safetensors",
851
+ "vision_backbone.fused_featurizer.blocks.23.attn.proj.bias": "model-00001-of-00003.safetensors",
852
+ "vision_backbone.fused_featurizer.blocks.23.attn.proj.weight": "model-00001-of-00003.safetensors",
853
+ "vision_backbone.fused_featurizer.blocks.23.attn.qkv.bias": "model-00001-of-00003.safetensors",
854
+ "vision_backbone.fused_featurizer.blocks.23.attn.qkv.weight": "model-00001-of-00003.safetensors",
855
+ "vision_backbone.fused_featurizer.blocks.23.mlp.fc1.bias": "model-00001-of-00003.safetensors",
856
+ "vision_backbone.fused_featurizer.blocks.23.mlp.fc1.weight": "model-00001-of-00003.safetensors",
857
+ "vision_backbone.fused_featurizer.blocks.23.mlp.fc2.bias": "model-00001-of-00003.safetensors",
858
+ "vision_backbone.fused_featurizer.blocks.23.mlp.fc2.weight": "model-00001-of-00003.safetensors",
859
+ "vision_backbone.fused_featurizer.blocks.23.norm1.bias": "model-00001-of-00003.safetensors",
860
+ "vision_backbone.fused_featurizer.blocks.23.norm1.weight": "model-00001-of-00003.safetensors",
861
+ "vision_backbone.fused_featurizer.blocks.23.norm2.bias": "model-00001-of-00003.safetensors",
862
+ "vision_backbone.fused_featurizer.blocks.23.norm2.weight": "model-00001-of-00003.safetensors",
863
+ "vision_backbone.fused_featurizer.blocks.24.attn.proj.bias": "model-00001-of-00003.safetensors",
864
+ "vision_backbone.fused_featurizer.blocks.24.attn.proj.weight": "model-00001-of-00003.safetensors",
865
+ "vision_backbone.fused_featurizer.blocks.24.attn.qkv.bias": "model-00001-of-00003.safetensors",
866
+ "vision_backbone.fused_featurizer.blocks.24.attn.qkv.weight": "model-00001-of-00003.safetensors",
867
+ "vision_backbone.fused_featurizer.blocks.24.mlp.fc1.bias": "model-00001-of-00003.safetensors",
868
+ "vision_backbone.fused_featurizer.blocks.24.mlp.fc1.weight": "model-00001-of-00003.safetensors",
869
+ "vision_backbone.fused_featurizer.blocks.24.mlp.fc2.bias": "model-00001-of-00003.safetensors",
870
+ "vision_backbone.fused_featurizer.blocks.24.mlp.fc2.weight": "model-00001-of-00003.safetensors",
871
+ "vision_backbone.fused_featurizer.blocks.24.norm1.bias": "model-00001-of-00003.safetensors",
872
+ "vision_backbone.fused_featurizer.blocks.24.norm1.weight": "model-00001-of-00003.safetensors",
873
+ "vision_backbone.fused_featurizer.blocks.24.norm2.bias": "model-00001-of-00003.safetensors",
874
+ "vision_backbone.fused_featurizer.blocks.24.norm2.weight": "model-00001-of-00003.safetensors",
875
+ "vision_backbone.fused_featurizer.blocks.25.attn.proj.bias": "model-00001-of-00003.safetensors",
876
+ "vision_backbone.fused_featurizer.blocks.25.attn.proj.weight": "model-00001-of-00003.safetensors",
877
+ "vision_backbone.fused_featurizer.blocks.25.attn.qkv.bias": "model-00001-of-00003.safetensors",
878
+ "vision_backbone.fused_featurizer.blocks.25.attn.qkv.weight": "model-00001-of-00003.safetensors",
879
+ "vision_backbone.fused_featurizer.blocks.25.mlp.fc1.bias": "model-00001-of-00003.safetensors",
880
+ "vision_backbone.fused_featurizer.blocks.25.mlp.fc1.weight": "model-00001-of-00003.safetensors",
881
+ "vision_backbone.fused_featurizer.blocks.25.mlp.fc2.bias": "model-00001-of-00003.safetensors",
882
+ "vision_backbone.fused_featurizer.blocks.25.mlp.fc2.weight": "model-00001-of-00003.safetensors",
883
+ "vision_backbone.fused_featurizer.blocks.25.norm1.bias": "model-00001-of-00003.safetensors",
884
+ "vision_backbone.fused_featurizer.blocks.25.norm1.weight": "model-00001-of-00003.safetensors",
885
+ "vision_backbone.fused_featurizer.blocks.25.norm2.bias": "model-00001-of-00003.safetensors",
886
+ "vision_backbone.fused_featurizer.blocks.25.norm2.weight": "model-00001-of-00003.safetensors",
887
+ "vision_backbone.fused_featurizer.blocks.26.attn.proj.bias": "model-00001-of-00003.safetensors",
888
+ "vision_backbone.fused_featurizer.blocks.26.attn.proj.weight": "model-00001-of-00003.safetensors",
889
+ "vision_backbone.fused_featurizer.blocks.26.attn.qkv.bias": "model-00001-of-00003.safetensors",
890
+ "vision_backbone.fused_featurizer.blocks.26.attn.qkv.weight": "model-00001-of-00003.safetensors",
891
+ "vision_backbone.fused_featurizer.blocks.26.mlp.fc1.bias": "model-00001-of-00003.safetensors",
892
+ "vision_backbone.fused_featurizer.blocks.26.mlp.fc1.weight": "model-00001-of-00003.safetensors",
893
+ "vision_backbone.fused_featurizer.blocks.26.mlp.fc2.bias": "model-00001-of-00003.safetensors",
894
+ "vision_backbone.fused_featurizer.blocks.26.mlp.fc2.weight": "model-00001-of-00003.safetensors",
895
+ "vision_backbone.fused_featurizer.blocks.26.norm1.bias": "model-00001-of-00003.safetensors",
896
+ "vision_backbone.fused_featurizer.blocks.26.norm1.weight": "model-00001-of-00003.safetensors",
897
+ "vision_backbone.fused_featurizer.blocks.26.norm2.bias": "model-00001-of-00003.safetensors",
898
+ "vision_backbone.fused_featurizer.blocks.26.norm2.weight": "model-00001-of-00003.safetensors",
899
+ "vision_backbone.fused_featurizer.blocks.3.attn.proj.bias": "model-00001-of-00003.safetensors",
900
+ "vision_backbone.fused_featurizer.blocks.3.attn.proj.weight": "model-00001-of-00003.safetensors",
901
+ "vision_backbone.fused_featurizer.blocks.3.attn.qkv.bias": "model-00001-of-00003.safetensors",
902
+ "vision_backbone.fused_featurizer.blocks.3.attn.qkv.weight": "model-00001-of-00003.safetensors",
903
+ "vision_backbone.fused_featurizer.blocks.3.mlp.fc1.bias": "model-00001-of-00003.safetensors",
904
+ "vision_backbone.fused_featurizer.blocks.3.mlp.fc1.weight": "model-00001-of-00003.safetensors",
905
+ "vision_backbone.fused_featurizer.blocks.3.mlp.fc2.bias": "model-00001-of-00003.safetensors",
906
+ "vision_backbone.fused_featurizer.blocks.3.mlp.fc2.weight": "model-00001-of-00003.safetensors",
907
+ "vision_backbone.fused_featurizer.blocks.3.norm1.bias": "model-00001-of-00003.safetensors",
908
+ "vision_backbone.fused_featurizer.blocks.3.norm1.weight": "model-00001-of-00003.safetensors",
909
+ "vision_backbone.fused_featurizer.blocks.3.norm2.bias": "model-00001-of-00003.safetensors",
910
+ "vision_backbone.fused_featurizer.blocks.3.norm2.weight": "model-00001-of-00003.safetensors",
911
+ "vision_backbone.fused_featurizer.blocks.4.attn.proj.bias": "model-00001-of-00003.safetensors",
912
+ "vision_backbone.fused_featurizer.blocks.4.attn.proj.weight": "model-00001-of-00003.safetensors",
913
+ "vision_backbone.fused_featurizer.blocks.4.attn.qkv.bias": "model-00001-of-00003.safetensors",
914
+ "vision_backbone.fused_featurizer.blocks.4.attn.qkv.weight": "model-00001-of-00003.safetensors",
915
+ "vision_backbone.fused_featurizer.blocks.4.mlp.fc1.bias": "model-00001-of-00003.safetensors",
916
+ "vision_backbone.fused_featurizer.blocks.4.mlp.fc1.weight": "model-00001-of-00003.safetensors",
917
+ "vision_backbone.fused_featurizer.blocks.4.mlp.fc2.bias": "model-00001-of-00003.safetensors",
918
+ "vision_backbone.fused_featurizer.blocks.4.mlp.fc2.weight": "model-00001-of-00003.safetensors",
919
+ "vision_backbone.fused_featurizer.blocks.4.norm1.bias": "model-00001-of-00003.safetensors",
920
+ "vision_backbone.fused_featurizer.blocks.4.norm1.weight": "model-00001-of-00003.safetensors",
921
+ "vision_backbone.fused_featurizer.blocks.4.norm2.bias": "model-00001-of-00003.safetensors",
922
+ "vision_backbone.fused_featurizer.blocks.4.norm2.weight": "model-00001-of-00003.safetensors",
923
+ "vision_backbone.fused_featurizer.blocks.5.attn.proj.bias": "model-00001-of-00003.safetensors",
924
+ "vision_backbone.fused_featurizer.blocks.5.attn.proj.weight": "model-00001-of-00003.safetensors",
925
+ "vision_backbone.fused_featurizer.blocks.5.attn.qkv.bias": "model-00001-of-00003.safetensors",
926
+ "vision_backbone.fused_featurizer.blocks.5.attn.qkv.weight": "model-00001-of-00003.safetensors",
927
+ "vision_backbone.fused_featurizer.blocks.5.mlp.fc1.bias": "model-00001-of-00003.safetensors",
928
+ "vision_backbone.fused_featurizer.blocks.5.mlp.fc1.weight": "model-00001-of-00003.safetensors",
929
+ "vision_backbone.fused_featurizer.blocks.5.mlp.fc2.bias": "model-00001-of-00003.safetensors",
930
+ "vision_backbone.fused_featurizer.blocks.5.mlp.fc2.weight": "model-00001-of-00003.safetensors",
931
+ "vision_backbone.fused_featurizer.blocks.5.norm1.bias": "model-00001-of-00003.safetensors",
932
+ "vision_backbone.fused_featurizer.blocks.5.norm1.weight": "model-00001-of-00003.safetensors",
933
+ "vision_backbone.fused_featurizer.blocks.5.norm2.bias": "model-00001-of-00003.safetensors",
934
+ "vision_backbone.fused_featurizer.blocks.5.norm2.weight": "model-00001-of-00003.safetensors",
935
+ "vision_backbone.fused_featurizer.blocks.6.attn.proj.bias": "model-00001-of-00003.safetensors",
936
+ "vision_backbone.fused_featurizer.blocks.6.attn.proj.weight": "model-00001-of-00003.safetensors",
937
+ "vision_backbone.fused_featurizer.blocks.6.attn.qkv.bias": "model-00001-of-00003.safetensors",
938
+ "vision_backbone.fused_featurizer.blocks.6.attn.qkv.weight": "model-00001-of-00003.safetensors",
939
+ "vision_backbone.fused_featurizer.blocks.6.mlp.fc1.bias": "model-00001-of-00003.safetensors",
940
+ "vision_backbone.fused_featurizer.blocks.6.mlp.fc1.weight": "model-00001-of-00003.safetensors",
941
+ "vision_backbone.fused_featurizer.blocks.6.mlp.fc2.bias": "model-00001-of-00003.safetensors",
942
+ "vision_backbone.fused_featurizer.blocks.6.mlp.fc2.weight": "model-00001-of-00003.safetensors",
943
+ "vision_backbone.fused_featurizer.blocks.6.norm1.bias": "model-00001-of-00003.safetensors",
944
+ "vision_backbone.fused_featurizer.blocks.6.norm1.weight": "model-00001-of-00003.safetensors",
945
+ "vision_backbone.fused_featurizer.blocks.6.norm2.bias": "model-00001-of-00003.safetensors",
946
+ "vision_backbone.fused_featurizer.blocks.6.norm2.weight": "model-00001-of-00003.safetensors",
947
+ "vision_backbone.fused_featurizer.blocks.7.attn.proj.bias": "model-00001-of-00003.safetensors",
948
+ "vision_backbone.fused_featurizer.blocks.7.attn.proj.weight": "model-00001-of-00003.safetensors",
949
+ "vision_backbone.fused_featurizer.blocks.7.attn.qkv.bias": "model-00001-of-00003.safetensors",
950
+ "vision_backbone.fused_featurizer.blocks.7.attn.qkv.weight": "model-00001-of-00003.safetensors",
951
+ "vision_backbone.fused_featurizer.blocks.7.mlp.fc1.bias": "model-00001-of-00003.safetensors",
952
+ "vision_backbone.fused_featurizer.blocks.7.mlp.fc1.weight": "model-00001-of-00003.safetensors",
953
+ "vision_backbone.fused_featurizer.blocks.7.mlp.fc2.bias": "model-00001-of-00003.safetensors",
954
+ "vision_backbone.fused_featurizer.blocks.7.mlp.fc2.weight": "model-00001-of-00003.safetensors",
955
+ "vision_backbone.fused_featurizer.blocks.7.norm1.bias": "model-00001-of-00003.safetensors",
956
+ "vision_backbone.fused_featurizer.blocks.7.norm1.weight": "model-00001-of-00003.safetensors",
957
+ "vision_backbone.fused_featurizer.blocks.7.norm2.bias": "model-00001-of-00003.safetensors",
958
+ "vision_backbone.fused_featurizer.blocks.7.norm2.weight": "model-00001-of-00003.safetensors",
959
+ "vision_backbone.fused_featurizer.blocks.8.attn.proj.bias": "model-00001-of-00003.safetensors",
960
+ "vision_backbone.fused_featurizer.blocks.8.attn.proj.weight": "model-00001-of-00003.safetensors",
961
+ "vision_backbone.fused_featurizer.blocks.8.attn.qkv.bias": "model-00001-of-00003.safetensors",
962
+ "vision_backbone.fused_featurizer.blocks.8.attn.qkv.weight": "model-00001-of-00003.safetensors",
963
+ "vision_backbone.fused_featurizer.blocks.8.mlp.fc1.bias": "model-00001-of-00003.safetensors",
964
+ "vision_backbone.fused_featurizer.blocks.8.mlp.fc1.weight": "model-00001-of-00003.safetensors",
965
+ "vision_backbone.fused_featurizer.blocks.8.mlp.fc2.bias": "model-00001-of-00003.safetensors",
966
+ "vision_backbone.fused_featurizer.blocks.8.mlp.fc2.weight": "model-00001-of-00003.safetensors",
967
+ "vision_backbone.fused_featurizer.blocks.8.norm1.bias": "model-00001-of-00003.safetensors",
968
+ "vision_backbone.fused_featurizer.blocks.8.norm1.weight": "model-00001-of-00003.safetensors",
969
+ "vision_backbone.fused_featurizer.blocks.8.norm2.bias": "model-00001-of-00003.safetensors",
970
+ "vision_backbone.fused_featurizer.blocks.8.norm2.weight": "model-00001-of-00003.safetensors",
971
+ "vision_backbone.fused_featurizer.blocks.9.attn.proj.bias": "model-00001-of-00003.safetensors",
972
+ "vision_backbone.fused_featurizer.blocks.9.attn.proj.weight": "model-00001-of-00003.safetensors",
973
+ "vision_backbone.fused_featurizer.blocks.9.attn.qkv.bias": "model-00001-of-00003.safetensors",
974
+ "vision_backbone.fused_featurizer.blocks.9.attn.qkv.weight": "model-00001-of-00003.safetensors",
975
+ "vision_backbone.fused_featurizer.blocks.9.mlp.fc1.bias": "model-00001-of-00003.safetensors",
976
+ "vision_backbone.fused_featurizer.blocks.9.mlp.fc1.weight": "model-00001-of-00003.safetensors",
977
+ "vision_backbone.fused_featurizer.blocks.9.mlp.fc2.bias": "model-00001-of-00003.safetensors",
978
+ "vision_backbone.fused_featurizer.blocks.9.mlp.fc2.weight": "model-00001-of-00003.safetensors",
979
+ "vision_backbone.fused_featurizer.blocks.9.norm1.bias": "model-00001-of-00003.safetensors",
980
+ "vision_backbone.fused_featurizer.blocks.9.norm1.weight": "model-00001-of-00003.safetensors",
981
+ "vision_backbone.fused_featurizer.blocks.9.norm2.bias": "model-00001-of-00003.safetensors",
982
+ "vision_backbone.fused_featurizer.blocks.9.norm2.weight": "model-00001-of-00003.safetensors",
983
+ "vision_backbone.fused_featurizer.norm.bias": "model-00001-of-00003.safetensors",
984
+ "vision_backbone.fused_featurizer.norm.weight": "model-00001-of-00003.safetensors",
985
+ "vision_backbone.fused_featurizer.patch_embed.proj.bias": "model-00001-of-00003.safetensors",
986
+ "vision_backbone.fused_featurizer.patch_embed.proj.weight": "model-00001-of-00003.safetensors",
987
+ "vision_backbone.fused_featurizer.pos_embed": "model-00001-of-00003.safetensors"
988
+ }
989
+ }
openvla-7b/modeling_prismatic.py ADDED
@@ -0,0 +1,1086 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ modeling_prismatic.py
3
+
4
+ Core HuggingFace-style PrismaticPreTrainedModel and PrismaticForConditionalGeneration class definitions.
5
+ Inherits from the default `transformers.PretrainedModel`. Meant to be standalone and self-contained,
6
+ but exactly replicate the logic in `prismatic.models.vlms.prismatic.py`.
7
+ """
8
+
9
+ import logging
10
+ from dataclasses import dataclass
11
+ from functools import partial
12
+ from typing import Any, Callable, ClassVar, Dict, List, Optional, Tuple, Union
13
+
14
+ import numpy as np
15
+ import timm
16
+ import tokenizers
17
+ import torch
18
+ import torch.nn as nn
19
+ import transformers
20
+ from timm.models.vision_transformer import LayerScale
21
+ from transformers import AutoModelForCausalLM, PretrainedConfig, PreTrainedModel
22
+ from transformers.modeling_outputs import ModelOutput
23
+
24
+ from prismatic.training.train_utils import (
25
+ get_current_action_mask,
26
+ get_next_actions_mask,
27
+ )
28
+ from prismatic.vla.constants import (
29
+ ACTION_DIM,
30
+ ACTION_PROPRIO_NORMALIZATION_TYPE,
31
+ ACTION_TOKEN_BEGIN_IDX,
32
+ IGNORE_INDEX,
33
+ NUM_ACTIONS_CHUNK,
34
+ STOP_INDEX,
35
+ NormalizationType,
36
+ )
37
+
38
+ from .configuration_prismatic import OpenVLAConfig, PrismaticConfig
39
+
40
+ # Set up logger
41
+ logger = logging.getLogger(__name__)
42
+
43
+
44
+ # === Utility Functions for Monkey-Patching ===
45
+ def unpack_tuple(fn: Callable[[Any], Tuple[Any]]) -> Callable[[Any], Any]:
46
+ def wrapper(*args: Any, **kwargs: Any) -> Any:
47
+ result = fn(*args, **kwargs)
48
+ return result[0] if isinstance(result, tuple) else result
49
+
50
+ return wrapper
51
+
52
+
53
+ # HF Transformers overwrites parameters with names containing `gamma`; we're going to patch VisionBackbone.LayerScale.
54
+ # =>> TIMM :: https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/vision_transformer.py#L109
55
+ # =>> Transformers :: https://github.com/huggingface/transformers/blob/main/src/transformers/modeling_utils.py#L3960
56
+ def _ls_new_forward(self, x: torch.Tensor) -> torch.Tensor:
57
+ return x.mul_(self.scale_factor) if self.inplace else x * self.scale_factor
58
+
59
+
60
+ def ls_apply_patch(ls_module: LayerScale):
61
+ ls_module.scale_factor = nn.Parameter(ls_module.gamma.clone())
62
+ ls_module.forward = _ls_new_forward.__get__(ls_module, LayerScale)
63
+ del ls_module.gamma
64
+
65
+
66
+ # === Prismatic Vision Backbone (nn.Module) Definitions (w/ Fused Backbone Support) ===
67
+ class PrismaticVisionBackbone(nn.Module):
68
+ """
69
+ Vision backbone for Prismatic models that handles image feature extraction.
70
+
71
+ Supports both single backbone (e.g., SigLIP) and fused backbone (e.g., SigLIP + DINOv2) configurations.
72
+ For fused backbones, features from both models are concatenated along the feature dimension.
73
+ """
74
+
75
+ def __init__(
76
+ self,
77
+ use_fused_vision_backbone: bool,
78
+ image_sizes: List[int],
79
+ timm_model_ids: List[str],
80
+ timm_override_act_layers: List[Optional[str]],
81
+ ) -> None:
82
+ """
83
+ Initialize the vision backbone.
84
+
85
+ Args:
86
+ use_fused_vision_backbone: Whether to use two backbones and fuse their features
87
+ image_sizes: List of image sizes for each backbone
88
+ timm_model_ids: List of TIMM model IDs to use for each backbone
89
+ timm_override_act_layers: List of activation layer overrides for each backbone
90
+ """
91
+ super().__init__()
92
+ self.use_fused_vision_backbone = use_fused_vision_backbone
93
+ self.num_images_in_input = 1 # Default value, can be overridden later
94
+
95
+ # Validate number of (fused) vision backbones
96
+ if len(timm_model_ids) > 2:
97
+ raise ValueError("Prismatic models only support up to 2 (fused) vision backbones!")
98
+
99
+ # Create primary featurizer
100
+ self.featurizer = self._create_featurizer(
101
+ model_id=timm_model_ids[0], img_size=image_sizes[0], act_layer=timm_override_act_layers[0]
102
+ )
103
+ self.embed_dim = self.featurizer.embed_dim
104
+
105
+ # Create secondary featurizer if using fused backbone
106
+ if self.use_fused_vision_backbone:
107
+ self.fused_featurizer = self._create_featurizer(
108
+ model_id=timm_model_ids[1], img_size=image_sizes[1], act_layer=timm_override_act_layers[1]
109
+ )
110
+ self.embed_dim += self.fused_featurizer.embed_dim
111
+
112
+ # Patch LayerScale modules for HF compatibility
113
+ self._patch_layer_scales()
114
+
115
+ def _create_featurizer(self, model_id: str, img_size: int, act_layer: Optional[str]) -> nn.Module:
116
+ """
117
+ Create a TIMM-based featurizer model with appropriate configurations.
118
+
119
+ Args:
120
+ model_id: The TIMM model ID to load
121
+ img_size: Input image size for the model
122
+ act_layer: Override for the activation layer type
123
+
124
+ Returns:
125
+ A configured featurizer model
126
+ """
127
+ featurizer = timm.create_model(
128
+ model_id,
129
+ pretrained=False,
130
+ num_classes=0,
131
+ img_size=img_size,
132
+ act_layer=act_layer,
133
+ )
134
+
135
+ # Monkey-patch the forward function to extract the second-to-last layer features
136
+ num_blocks = len(featurizer.blocks)
137
+ featurizer.forward = unpack_tuple(partial(featurizer.get_intermediate_layers, n={num_blocks - 2}))
138
+
139
+ return featurizer
140
+
141
+ def _patch_layer_scales(self) -> None:
142
+ """
143
+ Patch all LayerScale modules to be compatible with HF's parameter naming.
144
+
145
+ HF Transformers overwrites parameters with names containing 'gamma',
146
+ so we need to rename and modify the forward method.
147
+ """
148
+ # Patch primary featurizer
149
+ for module in self.featurizer.modules():
150
+ if isinstance(module, LayerScale):
151
+ ls_apply_patch(module)
152
+
153
+ # Patch secondary featurizer if it exists
154
+ if self.use_fused_vision_backbone:
155
+ for module in self.fused_featurizer.modules():
156
+ if isinstance(module, LayerScale):
157
+ ls_apply_patch(module)
158
+
159
+ def get_num_patches(self) -> int:
160
+ """
161
+ Returns the number of vision patches output by the vision backbone.
162
+
163
+ Returns:
164
+ Number of patches per image
165
+ """
166
+ return self.featurizer.patch_embed.num_patches
167
+
168
+ def get_num_images_in_input(self) -> int:
169
+ """
170
+ Returns the number of input images for the vision backbone.
171
+
172
+ Returns:
173
+ Number of images expected in the input
174
+ """
175
+ return self.num_images_in_input
176
+
177
+ def set_num_images_in_input(self, num_images_in_input: int) -> None:
178
+ """
179
+ Sets the number of input images for the vision backbone.
180
+
181
+ Args:
182
+ num_images_in_input: Number of images to expect in the input
183
+ """
184
+ self.num_images_in_input = num_images_in_input
185
+
186
+ def forward(self, pixel_values: torch.Tensor) -> torch.Tensor:
187
+ """
188
+ Implements the forward pass for the vision backbone.
189
+
190
+ If `self.use_fused_vision_backbone == True`, uses both SigLIP and DINOv2 transformers to extract visual features
191
+ (otherwise uses SigLIP only). Allows multi-image inputs (but only for fused vision backbone).
192
+
193
+ Args:
194
+ pixel_values (torch.Tensor): Pixels for input image(s), (B, C, H, W).
195
+ """
196
+ if self.num_images_in_input == 1:
197
+ if not self.use_fused_vision_backbone:
198
+ return self.featurizer(pixel_values)
199
+
200
+ # Split `pixel_values :: [bsz, 2 * 3, resolution, resolution]` =>> featurize =>> channel stack
201
+ img, img_fused = torch.split(pixel_values, [3, 3], dim=1)
202
+ patches, patches_fused = self.featurizer(img), self.fused_featurizer(img_fused)
203
+
204
+ return torch.cat([patches, patches_fused], dim=2)
205
+
206
+ else:
207
+ assert self.use_fused_vision_backbone, "Multi-image inputs require using fused backbone!"
208
+
209
+ # Split `pixel_values` into individual images (each with 6 channels: 3 for SigLIP + 3 for DINOv2)
210
+ images = torch.split(pixel_values, [6] * self.num_images_in_input, dim=1)
211
+
212
+ # Process each image and collect patches
213
+ all_patches = []
214
+ for img in images:
215
+ # Split each image further into two stacks of channels (each with 3 channels)
216
+ img_regular, img_fused = torch.split(img, [3, 3], dim=1)
217
+
218
+ # Get patches from both SigLIP and DINOv2 vision transformers
219
+ patches = self.featurizer(img_regular)
220
+ patches_fused = self.fused_featurizer(img_fused)
221
+
222
+ # Concatenate SigLIP and DINOv2 patches along the hidden dimension
223
+ combined_patches = torch.cat([patches, patches_fused], dim=2)
224
+ all_patches.append(combined_patches)
225
+
226
+ # Concatenate all patches along the patch dimension
227
+ return torch.cat(all_patches, dim=1)
228
+
229
+
230
+ # === Prismatic Projector (nn.Module) Definitions ===
231
+ class PrismaticProjector(nn.Module):
232
+ def __init__(self, use_fused_vision_backbone: bool, vision_dim: int, llm_dim: int) -> None:
233
+ super().__init__()
234
+ self.use_fused_vision_backbone = use_fused_vision_backbone
235
+ self.vision_dim, self.llm_dim = vision_dim, llm_dim
236
+
237
+ # Switch on `use_fused_vision_backbone` =>> use slightly different MLPs and projection factors!
238
+ if not self.use_fused_vision_backbone:
239
+ self.fc1 = nn.Linear(self.vision_dim, self.llm_dim, bias=True)
240
+ self.fc2 = nn.Linear(self.llm_dim, self.llm_dim, bias=True)
241
+ self.act_fn1 = nn.GELU()
242
+ else:
243
+ initial_projection_dim = 4 * vision_dim
244
+ self.fc1 = nn.Linear(self.vision_dim, initial_projection_dim, bias=True)
245
+ self.fc2 = nn.Linear(initial_projection_dim, self.llm_dim, bias=True)
246
+ self.fc3 = nn.Linear(self.llm_dim, self.llm_dim, bias=True)
247
+ self.act_fn1 = nn.GELU()
248
+ self.act_fn2 = nn.GELU()
249
+
250
+ def forward(self, img_patches: torch.Tensor) -> torch.Tensor:
251
+ if not self.use_fused_vision_backbone:
252
+ projected_features = self.fc1(img_patches)
253
+ projected_features = self.act_fn1(projected_features)
254
+ projected_features = self.fc2(projected_features)
255
+ else:
256
+ projected_features = self.fc1(img_patches)
257
+ projected_features = self.act_fn1(projected_features)
258
+ projected_features = self.fc2(projected_features)
259
+ projected_features = self.act_fn2(projected_features)
260
+ projected_features = self.fc3(projected_features)
261
+
262
+ return projected_features
263
+
264
+
265
+ # === Main HF Class Definitions ===
266
+ @dataclass
267
+ class PrismaticCausalLMOutputWithPast(ModelOutput):
268
+ """Base class for Prismatic casual (visually-conditioned) language model outputs; also exposes visual features."""
269
+
270
+ loss: Optional[torch.FloatTensor] = None
271
+ logits: torch.FloatTensor = None
272
+ past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
273
+ hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
274
+ attentions: Optional[Tuple[torch.FloatTensor]] = None
275
+
276
+ # Additions for VLMs
277
+ projector_features: Optional[torch.FloatTensor] = None
278
+
279
+
280
+ class PrismaticPreTrainedModel(PreTrainedModel):
281
+ config_class: PretrainedConfig = PrismaticConfig
282
+ base_model_prefix: str = "model"
283
+ supports_gradient_checkpointing: bool = True
284
+
285
+ _no_split_modules: ClassVar[List[str]] = ["PrismaticProjector"]
286
+ _skip_keys_device_placement: str = "past_key_values"
287
+ _supports_flash_attn_2: bool = True
288
+
289
+ def _init_weights(self, module: nn.Module) -> None:
290
+ # Important :: this HF ported version is *not* meant for training from scratch; only inference and fine-tuning!
291
+ # => As such, this init_weights code is not correct; if training VLMs from scratch, use the main codebase at
292
+ # https://github.com/TRI-ML/prismatic-vlms
293
+ std = (
294
+ self.config.initializer_range
295
+ if hasattr(self.config, "initializer_range")
296
+ else self.config.text_config.initializer_range
297
+ )
298
+
299
+ if hasattr(module, "class_embedding"):
300
+ module.class_embedding.data.normal_(mean=0.0, std=std)
301
+
302
+ if isinstance(module, (nn.Linear, nn.Conv2d)):
303
+ module.weight.data.normal_(mean=0.0, std=std)
304
+ if module.bias is not None:
305
+ module.bias.data.zero_()
306
+ elif isinstance(module, nn.Embedding):
307
+ module.weight.data.normal_(mean=0.0, std=std)
308
+ if module.padding_idx is not None:
309
+ module.weight.data[module.padding_idx].zero_()
310
+
311
+ @property
312
+ def _supports_sdpa(self) -> bool:
313
+ """Check LLM supports SDPA Attention"""
314
+ return self.language_model._supports_sdpa
315
+
316
+
317
+ class PrismaticForConditionalGeneration(PrismaticPreTrainedModel):
318
+ def __init__(self, config: PrismaticConfig) -> None:
319
+ super().__init__(config)
320
+
321
+ # [Validation] Lightweight Validate on `config` Fields + Dependency Versions
322
+ if config.use_fused_vision_backbone is None:
323
+ raise ValueError("Missing config field `use_fused_vision_backbone`")
324
+
325
+ if timm.__version__ not in {"0.9.10", "0.9.11", "0.9.12", "0.9.16"}:
326
+ raise NotImplementedError(
327
+ "TIMM Version must be >= 0.9.10 and < 1.0.0 (breaking); please raise a GitHub Issue "
328
+ "if you urgently need support for latest TIMM versions."
329
+ )
330
+
331
+ if (transformers.__version__ != "4.40.1") or (tokenizers.__version__ != "0.19.1"):
332
+ logger.warning(
333
+ f"Expected `transformers==4.40.1` and `tokenizers==0.19.1` but got "
334
+ f"`transformers=={transformers.__version__}` and `tokenizers=={tokenizers.__version__}`; "
335
+ f"there might be inference-time regressions due to dependency changes. If in doubt, please"
336
+ f"use the above versions."
337
+ )
338
+
339
+ # Instantiate PrismaticVisionBackbone (w/ Potential Fused Backbone)
340
+ self.vision_backbone = PrismaticVisionBackbone(
341
+ config.use_fused_vision_backbone, config.image_sizes, config.timm_model_ids, config.timm_override_act_layers
342
+ )
343
+
344
+ # Create Multimodal Projector
345
+ self.projector = PrismaticProjector(
346
+ config.use_fused_vision_backbone,
347
+ vision_dim=self.vision_backbone.embed_dim,
348
+ llm_dim=config.text_config.hidden_size,
349
+ )
350
+
351
+ # Instantiate LLM Backbone
352
+ self.language_model = AutoModelForCausalLM.from_config(
353
+ config.text_config, attn_implementation=config._attn_implementation
354
+ )
355
+ self.vocab_size = config.text_config.vocab_size
356
+ self.pad_token_id = config.pad_token_id
357
+ self.llm_dim = config.text_config.hidden_size
358
+
359
+ # HF Boilerplate =>> initializes weights via `_init_weights()` and sets gradient checkpointing
360
+ self.post_init()
361
+
362
+ # === `PreTrainedModel` Boilerplate ===
363
+ def get_input_embeddings(self) -> nn.Module:
364
+ return self.language_model.get_input_embeddings()
365
+
366
+ def set_input_embeddings(self, value: nn.Module) -> None:
367
+ self.language_model.set_input_embeddings(value)
368
+
369
+ def get_output_embeddings(self) -> nn.Module:
370
+ return self.language_model.get_output_embeddings()
371
+
372
+ def set_output_embeddings(self, new_embeddings: nn.Module) -> None:
373
+ self.language_model.set_output_embeddings(new_embeddings)
374
+
375
+ def get_decoder(self) -> nn.Module:
376
+ return self.language_model.get_decoder()
377
+
378
+ def set_decoder(self, decoder: nn.Module) -> None:
379
+ self.language_model.set_decoder(decoder)
380
+
381
+ def tie_weights(self) -> None:
382
+ self.language_model.tie_weights() # Note: `Llama-2` and `Mistral` don't tie weights (no-op)
383
+
384
+ def resize_token_embeddings(
385
+ self, new_num_tokens: Optional[int] = None, pad_to_multiple_of: Optional[int] = None
386
+ ) -> nn.Embedding:
387
+ updated_embeddings = self.language_model.resize_token_embeddings(new_num_tokens, pad_to_multiple_of)
388
+
389
+ # Update config/instance variables
390
+ self.config.text_config.vocab_size = updated_embeddings.num_embeddings
391
+ self.vocab_size = updated_embeddings.num_embeddings
392
+
393
+ return updated_embeddings
394
+
395
+ def _replace_input_embeddings(self, input_embeddings, all_actions_mask, noisy_action_features):
396
+ """
397
+ Replace embeddings in input_embeddings at positions where all_actions_mask is True
398
+ with embeddings from noisy_action_features, using vectorized operations.
399
+
400
+ Args:
401
+ input_embeddings: Tensor of shape (B, S, D)
402
+ all_actions_mask: Boolean tensor of shape (B, S)
403
+ noisy_action_features: Tensor of shape (B, K, D) where K is the number of True values in mask per sample
404
+
405
+ Returns:
406
+ Modified input_embeddings tensor
407
+ """
408
+ # Clone input to avoid modifying the original tensor
409
+ new_input_embeddings = input_embeddings.clone()
410
+
411
+ # Create a tensor with the same shape of input_embeddings to hold the noisy action features
412
+ repositioned_noisy_action_features = torch.zeros_like(input_embeddings)
413
+
414
+ # Create batch indices for splicing
415
+ batch_indices = torch.arange(input_embeddings.shape[0], device=input_embeddings.device)
416
+ batch_indices = batch_indices.unsqueeze(1).expand(-1, noisy_action_features.shape[1])
417
+
418
+ # Get indices where mask is True for each sample
419
+ masked_indices = torch.stack([torch.where(mask)[0] for mask in all_actions_mask])
420
+
421
+ # Move the noisy action features into their correct positions
422
+ repositioned_noisy_action_features[batch_indices, masked_indices] = noisy_action_features
423
+
424
+ # Combine original input embeddings and noisy action embeddings using the mask
425
+ new_input_embeddings = torch.where(
426
+ all_actions_mask.unsqueeze(-1), repositioned_noisy_action_features, new_input_embeddings
427
+ )
428
+
429
+ return new_input_embeddings
430
+
431
+ def _process_action_masks(self, labels):
432
+ """Helper to get action masks from labels"""
433
+ current_action_mask = get_current_action_mask(labels)
434
+ next_actions_mask = get_next_actions_mask(labels)
435
+ all_actions_mask = current_action_mask | next_actions_mask # (B, seq_len)
436
+ return all_actions_mask
437
+
438
+ def _process_vision_features(self, pixel_values, language_embeddings=None, use_film=False):
439
+ """Process vision features with optional FiLM conditioning"""
440
+ if use_film:
441
+ # FiLM: Infuse language inputs into visual features
442
+ patch_features = self.vision_backbone(pixel_values, language_embeddings) # (bsz, 256 * num_images, D)
443
+ else:
444
+ patch_features = self.vision_backbone(pixel_values) # (bsz, 256 * num_images, D)
445
+
446
+ # Project patch embeddings into language embedding space
447
+ return self.projector(patch_features)
448
+
449
+ def _process_proprio_features(self, projected_patch_embeddings, proprio, proprio_projector):
450
+ """Process proprioceptive features and append to vision features"""
451
+ if proprio_projector is not None and proprio is not None:
452
+ # projected_patch_embeddings: (bsz, num_patches * num_images, llm_dim)
453
+ # proprio: (bsz, proprio_dim) or (propro_dim,)
454
+ proprio = proprio.reshape(projected_patch_embeddings.shape[0], -1) # (bsz, proprio_dim)
455
+ proprio_features = proprio_projector(proprio) # (bsz, llm_dim)
456
+ proprio_features = proprio_features.unsqueeze(dim=1) # (bsz, 1, llm_dim)
457
+ # For simplicity, just append proprio token to the end of projected vision patch tokens
458
+ return torch.cat((projected_patch_embeddings, proprio_features), dim=1)
459
+ return projected_patch_embeddings
460
+
461
+ def _build_multimodal_attention(self, input_embeddings, projected_patch_embeddings, attention_mask):
462
+ """Build multimodal embeddings and attention mask"""
463
+ # Update attention mask
464
+ projected_patch_attention_mask = None
465
+ if attention_mask is not None:
466
+ projected_patch_attention_mask = torch.full(
467
+ (projected_patch_embeddings.shape[0], projected_patch_embeddings.shape[1]),
468
+ fill_value=True,
469
+ dtype=attention_mask.dtype,
470
+ device=attention_mask.device,
471
+ )
472
+
473
+ # Build multimodal embeddings & attention mask; insert embeddings after <BOS> token (1:)
474
+ multimodal_embeddings = torch.cat(
475
+ [input_embeddings[:, :1, :], projected_patch_embeddings, input_embeddings[:, 1:, :]], dim=1
476
+ )
477
+
478
+ multimodal_attention_mask = None
479
+ if attention_mask is not None:
480
+ multimodal_attention_mask = torch.cat(
481
+ [attention_mask[:, :1], projected_patch_attention_mask, attention_mask[:, 1:]], dim=1
482
+ )
483
+
484
+ return multimodal_embeddings, multimodal_attention_mask
485
+
486
+ def _build_multimodal_labels(self, labels, projected_patch_embeddings):
487
+ """Build multimodal labels with IGNORE_INDEX for patch embeddings"""
488
+ if labels is not None:
489
+ projected_patch_labels = torch.full(
490
+ (projected_patch_embeddings.shape[0], projected_patch_embeddings.shape[1]),
491
+ fill_value=IGNORE_INDEX,
492
+ dtype=labels.dtype,
493
+ device=labels.device,
494
+ )
495
+ return torch.cat([labels[:, :1], projected_patch_labels, labels[:, 1:]], dim=1)
496
+ return None
497
+
498
+ # === Core Prismatic VLM `forward()` Logic ===
499
+ def forward(
500
+ self,
501
+ input_ids: Optional[torch.LongTensor] = None,
502
+ attention_mask: Optional[torch.Tensor] = None,
503
+ pixel_values: Optional[torch.FloatTensor] = None,
504
+ labels: Optional[torch.LongTensor] = None,
505
+ inputs_embeds: Optional[torch.FloatTensor] = None,
506
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
507
+ use_cache: Optional[bool] = None,
508
+ output_attentions: Optional[bool] = None,
509
+ output_hidden_states: Optional[bool] = None,
510
+ output_projector_features: Optional[bool] = None,
511
+ return_dict: Optional[bool] = None,
512
+ proprio=None,
513
+ proprio_projector=None,
514
+ noisy_actions=None,
515
+ noisy_action_projector=None,
516
+ diffusion_timestep_embeddings=None,
517
+ use_film: bool = False,
518
+ ) -> Union[Tuple, PrismaticCausalLMOutputWithPast]:
519
+ """Run a forward pass through the VLM, returning a PrismaticCausalLMOutputWithPast instance."""
520
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
521
+ output_hidden_states = (
522
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
523
+ )
524
+ output_projector_features = output_projector_features if output_projector_features is not None else False
525
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
526
+
527
+ # Respect `use_cache` only if not training (even if `gradient_checkpointing` is off)
528
+ use_cache = use_cache and not self.training
529
+
530
+ # Instantiate Placeholder for Projector Features
531
+ projected_patch_embeddings = None
532
+
533
+ # === Handle Generation with Cache (`input_ids.shape[1] == 1`) =>> requires `past_keys_values` ===
534
+ if input_ids.shape[1] == 1:
535
+ assert input_ids.shape[0] == 1, "Generation is only currently supported for batch size of 1!"
536
+ assert past_key_values is not None, "You must provide `past_key_values` during cached generation!"
537
+ assert labels is None, "Unexpected key `labels` provided during cached generation!"
538
+
539
+ language_model_output = self.language_model(
540
+ input_ids=input_ids,
541
+ attention_mask=None,
542
+ position_ids=None,
543
+ past_key_values=past_key_values,
544
+ inputs_embeds=None,
545
+ labels=None,
546
+ use_cache=use_cache,
547
+ output_attentions=output_attentions,
548
+ output_hidden_states=output_hidden_states,
549
+ return_dict=return_dict,
550
+ )
551
+
552
+ # === Handle Unimodal Forward ===
553
+ elif pixel_values is None:
554
+ assert (input_ids is not None) and (inputs_embeds is None), "Missing `input_ids` in language-only forward!"
555
+ assert past_key_values is None, "Unexpected key `past_key_values` provided during language-only forward!"
556
+
557
+ language_model_output = self.language_model(
558
+ input_ids=input_ids,
559
+ attention_mask=attention_mask,
560
+ position_ids=None,
561
+ past_key_values=None,
562
+ inputs_embeds=None,
563
+ labels=labels,
564
+ use_cache=use_cache,
565
+ output_attentions=output_attentions,
566
+ output_hidden_states=output_hidden_states,
567
+ return_dict=return_dict,
568
+ )
569
+
570
+ # === Handle Multimodal Forward ===
571
+ elif (input_ids.shape[0] == pixel_values.shape[0]) or (inputs_embeds.shape[0] == pixel_values.shape[0]):
572
+ assert past_key_values is None, "Unexpected key `past_key_values` provided during multimodal forward!"
573
+
574
+ # Get input embeddings (from language model embeddings)
575
+ input_embeddings = self.get_input_embeddings()(input_ids) # (B, seq_len, D)
576
+
577
+ # Extract action masks
578
+ all_actions_mask = self._process_action_masks(labels)
579
+
580
+ # Extract the language portion of the input embeddings (i.e. remove the action tokens portion)
581
+ language_embeddings = input_embeddings[~all_actions_mask].reshape(
582
+ input_embeddings.shape[0], -1, input_embeddings.shape[2]
583
+ ) # (B, lang_seq_len, llm_dim)
584
+
585
+ # Get visual features
586
+ projected_patch_embeddings = self._process_vision_features(pixel_values, language_embeddings, use_film)
587
+
588
+ # Add proprioceptive state if provided
589
+ projected_patch_embeddings = self._process_proprio_features(
590
+ projected_patch_embeddings, proprio, proprio_projector
591
+ )
592
+
593
+ # [Diffusion] Add diffusion timestep embedding if provided
594
+ if diffusion_timestep_embeddings is not None:
595
+ # For simplicity, just append diffusion timestep embedding to the end of projected vision patch tokens
596
+ projected_patch_embeddings = torch.cat(
597
+ (projected_patch_embeddings, diffusion_timestep_embeddings), dim=1
598
+ )
599
+
600
+ # Process action embeddings
601
+ if noisy_actions is not None:
602
+ # Get mask corresponding to all action tokens
603
+ all_actions_mask = self._process_action_masks(labels)
604
+
605
+ # Reshape noisy actions into individual action tokens
606
+ # noisy_actions: (B, chunk_len, action_dim) -> (B, chunk_len * action_dim, 1)
607
+ B = noisy_actions.shape[0]
608
+ noisy_actions = noisy_actions.reshape(B, -1).unsqueeze(-1)
609
+
610
+ # Project noisy action tokens into language model embedding space
611
+ noisy_action_features = noisy_action_projector(noisy_actions) # (B, chunk_len * action_dim, llm_dim)
612
+
613
+ # Replace embeddings of the action tokens with noisy action embeddings
614
+ input_embeddings = self._replace_input_embeddings(
615
+ input_embeddings, all_actions_mask, noisy_action_features
616
+ )
617
+ else:
618
+ # Replace the embeddings of the action tokens with zeros
619
+ # (Later on, the positional embeddings will be added to them)
620
+ all_actions_mask = all_actions_mask.unsqueeze(-1) # (B, seq_len, 1)
621
+ input_embeddings = input_embeddings * ~all_actions_mask
622
+
623
+ # Build multimodal embeddings & attention mask
624
+ multimodal_embeddings, multimodal_attention_mask = self._build_multimodal_attention(
625
+ input_embeddings, projected_patch_embeddings, attention_mask
626
+ )
627
+
628
+ # Build labels for multimodal sequence if needed
629
+ multimodal_labels = self._build_multimodal_labels(labels, projected_patch_embeddings)
630
+
631
+ # Dispatch to language model
632
+ language_model_output = self.language_model(
633
+ input_ids=None,
634
+ attention_mask=multimodal_attention_mask,
635
+ position_ids=None,
636
+ past_key_values=None,
637
+ inputs_embeds=multimodal_embeddings,
638
+ labels=multimodal_labels,
639
+ use_cache=use_cache,
640
+ output_attentions=output_attentions,
641
+ output_hidden_states=output_hidden_states,
642
+ return_dict=return_dict,
643
+ )
644
+ import pdb; pdb.set_trace()
645
+
646
+ # === Otherwise =>> Assume Invalid! ===
647
+ elif (input_ids.shape[0] != pixel_values.shape[0]) or (inputs_embeds.shape[0] != pixel_values.shape[0]):
648
+ raise ValueError("Non-homogenous batch of (text, image) input -- forward() does not support mixed batches!")
649
+
650
+ else:
651
+ raise ValueError(
652
+ "Invalid PrismaticForConditionalGeneration `forward()` call with provided arguments:\n"
653
+ f"=> `input_ids` = {input_ids is not None}\n"
654
+ f"=> `attention_mask` = {attention_mask is not None}\n"
655
+ f"=> `pixel_values` = {pixel_values is not None}\n"
656
+ f"=> `labels` = {labels is not None}\n"
657
+ f"=> `input_embeds` = {inputs_embeds is not None}\n"
658
+ f"=> `past_key_values` = {past_key_values is not None}\n"
659
+ f"=> `use_cache` = {use_cache}"
660
+ )
661
+
662
+ # Unpack `language_model_output` and return PrismaticCausalLMOutputWithPast (or tuple if not `return_dict`)
663
+ if not return_dict:
664
+ if output_projector_features and (projected_patch_embeddings is not None):
665
+ return *language_model_output, projected_patch_embeddings
666
+
667
+ return language_model_output
668
+
669
+ return PrismaticCausalLMOutputWithPast(
670
+ loss=language_model_output.loss,
671
+ logits=language_model_output.logits,
672
+ past_key_values=language_model_output.past_key_values,
673
+ hidden_states=language_model_output.hidden_states,
674
+ attentions=language_model_output.attentions,
675
+ projector_features=projected_patch_embeddings,
676
+ )
677
+
678
+ # === GenerationMixin Methods ===
679
+ def prepare_inputs_for_generation(
680
+ self,
681
+ input_ids: Optional[torch.Tensor] = None,
682
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
683
+ inputs_embeds: Optional[torch.FloatTensor] = None,
684
+ pixel_values: Optional[torch.FloatTensor] = None,
685
+ attention_mask: Optional[torch.Tensor] = None,
686
+ **kwargs: str,
687
+ ) -> Dict[str, torch.Tensor]:
688
+ """Borrowed from `LlamaForCausalLM` and simplified for batch size = 1; mirrors original PrismaticVLM logic."""
689
+ if ((input_ids is not None) and (input_ids.shape[0] > 1)) or (
690
+ (inputs_embeds is not None) and (inputs_embeds.shape[0] > 1)
691
+ ):
692
+ raise ValueError("Generation with batch size > 1 is not currently supported!")
693
+
694
+ # Handle `past_key_values` (cache) =>> assume `input_ids` just has unprocessed tokens
695
+ if past_key_values is not None:
696
+ input_ids = input_ids[:, -1:]
697
+
698
+ # If `input_embeds` are passed, we only want to use them in the 1st generation step
699
+ if inputs_embeds is not None and past_key_values is None:
700
+ model_inputs = {"input_embeds": inputs_embeds}
701
+ else:
702
+ model_inputs = {"input_ids": input_ids}
703
+
704
+ # Make sure `pixel_values` are preserved in `model_inputs`
705
+ model_inputs.update(
706
+ {
707
+ "attention_mask": attention_mask,
708
+ "pixel_values": pixel_values,
709
+ "past_key_values": past_key_values,
710
+ "use_cache": kwargs.get("use_cache"),
711
+ }
712
+ )
713
+
714
+ return model_inputs
715
+
716
+ # Defer to Language Model (all handle this differently, with different return types)
717
+ def _reorder_cache(self, *args, **kwargs) -> Any:
718
+ return self.language_model._reorder_cache(*args, **kwargs)
719
+
720
+
721
+ class OpenVLAForActionPrediction(PrismaticForConditionalGeneration):
722
+ config_class: PretrainedConfig = OpenVLAConfig
723
+
724
+ def __init__(self, config: OpenVLAConfig) -> None:
725
+ super().__init__(config)
726
+ self.norm_stats = config.norm_stats
727
+
728
+ # Compute action bins
729
+ self.bins = np.linspace(-1, 1, config.n_action_bins)
730
+ self.bin_centers = (self.bins[:-1] + self.bins[1:]) / 2.0
731
+
732
+ # Compute vocab size for de-tokenization -- revert added "multiple of"
733
+ self.vocab_size = self.config.text_config.vocab_size - self.config.pad_to_multiple_of
734
+
735
+ def _prepare_input_for_action_prediction(self, input_ids, attention_mask):
736
+ """Prepares input for action prediction by adding necessary tokens"""
737
+ # Add (ACTION_DIM * NUM_ACTIONS_CHUNK) placeholder tokens to input_ids to simulate action tokens
738
+ placeholder_action_token_ids = (
739
+ torch.ones((input_ids.shape[0], ACTION_DIM * NUM_ACTIONS_CHUNK)).to(input_ids.device).to(input_ids.dtype)
740
+ )
741
+ input_ids = torch.cat([input_ids, placeholder_action_token_ids], dim=-1)
742
+
743
+ # Add stop token to sequence (needed in non-causal bi-directional self-attention, as it appears at train time)
744
+ stop_token_id = torch.ones((input_ids.shape[0], 1)).to(input_ids.device).to(input_ids.dtype) * STOP_INDEX
745
+ input_ids = torch.cat([input_ids, stop_token_id], dim=-1)
746
+
747
+ # Extend the attention mask to fit the new shape of input
748
+ # Note: Only batch size == 1 supported right now
749
+ mask_extension = (
750
+ torch.ones((attention_mask.shape[0], input_ids.shape[-1] - attention_mask.shape[-1]))
751
+ .to(attention_mask.device)
752
+ .to(attention_mask.dtype)
753
+ )
754
+ attention_mask = torch.cat([attention_mask, mask_extension], dim=-1)
755
+
756
+ return input_ids, attention_mask
757
+
758
+ def _prepare_labels_for_action_prediction(self, labels, input_ids):
759
+ """Creates labels tensor for action prediction if not provided"""
760
+ # Extend labels tensor with fake action labels
761
+ ARBITRARY_ACTION_TOKEN_IDX = ACTION_TOKEN_BEGIN_IDX + 1
762
+ labels_extension = (
763
+ torch.ones((labels.shape[0], input_ids.shape[-1] - labels.shape[-1])).to(labels.device).to(labels.dtype)
764
+ * ARBITRARY_ACTION_TOKEN_IDX
765
+ )
766
+ labels = torch.cat([labels, labels_extension], dim=-1)
767
+
768
+ # Replace last label token with stop token
769
+ labels[:, -1] = STOP_INDEX
770
+
771
+ return labels
772
+
773
+ def _unnormalize_actions(self, normalized_actions, unnorm_key=None):
774
+ """Unnormalize actions using dataset statistics"""
775
+ action_norm_stats = self.get_action_stats(unnorm_key)
776
+
777
+ if ACTION_PROPRIO_NORMALIZATION_TYPE == NormalizationType.BOUNDS:
778
+ mask = action_norm_stats.get("mask", np.ones_like(action_norm_stats["min"], dtype=bool))
779
+ action_high, action_low = np.array(action_norm_stats["max"]), np.array(action_norm_stats["min"])
780
+ elif ACTION_PROPRIO_NORMALIZATION_TYPE == NormalizationType.BOUNDS_Q99:
781
+ mask = action_norm_stats.get("mask", np.ones_like(action_norm_stats["q01"], dtype=bool))
782
+ action_high, action_low = np.array(action_norm_stats["q99"]), np.array(action_norm_stats["q01"])
783
+ else:
784
+ raise ValueError("Unsupported action/proprio normalization type detected!")
785
+
786
+ actions = np.where(
787
+ mask,
788
+ 0.5 * (normalized_actions + 1) * (action_high - action_low + 1e-8) + action_low,
789
+ normalized_actions,
790
+ )
791
+
792
+ return actions
793
+
794
+ def _run_diffusion_prediction(
795
+ self,
796
+ input_embeddings,
797
+ all_actions_mask,
798
+ noise,
799
+ action_head,
800
+ projected_patch_embeddings,
801
+ labels,
802
+ attention_mask,
803
+ NUM_PATCHES,
804
+ NUM_PROMPT_TOKENS,
805
+ noisy_action_projector,
806
+ ):
807
+ """Run diffusion-based action prediction"""
808
+ # Clone embedding for reuse in each timestep
809
+ orig_projected_patch_embeddings = projected_patch_embeddings.clone()
810
+ curr_noisy_actions = noise
811
+
812
+ # Reverse diffusion: Iteratively denoise to generate action prediction
813
+ for t in action_head.noise_scheduler.timesteps:
814
+ # Get diffusion model's noise prediction (conditioned on VLA latent embedding, current noisy action
815
+ # embedding, and diffusion timestep embedding)
816
+ timesteps = torch.Tensor([t]).to(labels.device)
817
+ diffusion_timestep_embeddings = (
818
+ action_head.time_encoder(timesteps).to(curr_noisy_actions.dtype).to(curr_noisy_actions.device)
819
+ ) # (B, llm_dim)
820
+ diffusion_timestep_embeddings = diffusion_timestep_embeddings.unsqueeze(1) # (B, 1, llm_dim)
821
+
822
+ # [Diffusion] Replace the embeddings of the action tokens with noisy actions
823
+ # (Later on, the positional embeddings will be added to them)
824
+
825
+ # For simplicity, append diffusion timestep embedding to the end of projected vision tokens
826
+ projected_patch_embeddings = torch.cat(
827
+ (orig_projected_patch_embeddings, diffusion_timestep_embeddings), dim=1
828
+ )
829
+
830
+ # Reshape and project noisy actions into language embedding space
831
+ B = curr_noisy_actions.shape[0]
832
+ orig_curr_noisy_actions_shape = curr_noisy_actions.shape
833
+ curr_noisy_actions = curr_noisy_actions.reshape(B, -1).unsqueeze(-1)
834
+ noisy_action_features = noisy_action_projector(curr_noisy_actions)
835
+ curr_noisy_actions = curr_noisy_actions.reshape(orig_curr_noisy_actions_shape)
836
+
837
+ # Replace action token embeddings with noisy action embeddings
838
+ input_embeddings = self._replace_input_embeddings(
839
+ input_embeddings.clone(), all_actions_mask, noisy_action_features
840
+ )
841
+
842
+ # Build multimodal embeddings and attention mask
843
+ multimodal_embeddings, multimodal_attention_mask = self._build_multimodal_attention(
844
+ input_embeddings, projected_patch_embeddings, attention_mask
845
+ )
846
+
847
+ # Forward pass through language model
848
+ language_model_output = self.language_model(
849
+ input_ids=None,
850
+ attention_mask=multimodal_attention_mask,
851
+ position_ids=None,
852
+ past_key_values=None,
853
+ inputs_embeds=multimodal_embeddings,
854
+ labels=None,
855
+ use_cache=None,
856
+ output_attentions=False,
857
+ output_hidden_states=True,
858
+ return_dict=True,
859
+ )
860
+
861
+ # Extract hidden states for action portion of response
862
+ last_hidden_states = language_model_output.hidden_states[-1] # (B, seq_len, D)
863
+ actions_hidden_states = last_hidden_states[
864
+ :,
865
+ NUM_PATCHES + NUM_PROMPT_TOKENS : NUM_PATCHES + NUM_PROMPT_TOKENS + ACTION_DIM * NUM_ACTIONS_CHUNK,
866
+ :,
867
+ ] # (B, act_chunk_len, D)
868
+
869
+ # Predict noise and update noisy actions: x_t -> x_{t-1}
870
+ noise_pred = action_head.predict_noise(actions_hidden_states)
871
+ curr_noisy_actions = action_head.noise_scheduler.step(noise_pred, t, curr_noisy_actions).prev_sample
872
+
873
+ curr_noisy_actions = curr_noisy_actions.reshape(NUM_ACTIONS_CHUNK, ACTION_DIM)
874
+
875
+ # Return final actions
876
+ return curr_noisy_actions.float().cpu().detach().numpy(), actions_hidden_states
877
+
878
+ def _regression_or_discrete_prediction(
879
+ self,
880
+ input_embeddings,
881
+ all_actions_mask,
882
+ projected_patch_embeddings,
883
+ attention_mask,
884
+ labels,
885
+ NUM_PATCHES,
886
+ NUM_PROMPT_TOKENS,
887
+ action_head=None,
888
+ ):
889
+ """Run L1 regression-based continuous action prediction or discrete action tokens prediction."""
890
+ # Zero out action token embeddings
891
+ all_actions_mask = all_actions_mask.unsqueeze(-1) # (B, seq_len, 1)
892
+ input_embeddings = input_embeddings * ~all_actions_mask
893
+
894
+ # Build multimodal embeddings and attention mask
895
+ multimodal_embeddings, multimodal_attention_mask = self._build_multimodal_attention(
896
+ input_embeddings, projected_patch_embeddings, attention_mask
897
+ )
898
+
899
+ # Forward pass through language model
900
+ language_model_output = self.language_model(
901
+ input_ids=None,
902
+ attention_mask=multimodal_attention_mask,
903
+ position_ids=None,
904
+ past_key_values=None,
905
+ inputs_embeds=multimodal_embeddings,
906
+ labels=None,
907
+ use_cache=None,
908
+ output_attentions=False,
909
+ output_hidden_states=True,
910
+ return_dict=True,
911
+ )
912
+
913
+ # Extract hidden states for action tokens
914
+ last_hidden_states = language_model_output.hidden_states[-1] # (B, seq_len, D)
915
+ actions_hidden_states = last_hidden_states[
916
+ :,
917
+ NUM_PATCHES + NUM_PROMPT_TOKENS : NUM_PATCHES + NUM_PROMPT_TOKENS + ACTION_DIM * NUM_ACTIONS_CHUNK,
918
+ :,
919
+ ] # (B, act_chunk_len, D)
920
+
921
+ # Handle different prediction methods
922
+ if action_head is not None:
923
+ # L1 regression prediction
924
+ normalized_actions = action_head.predict_action(actions_hidden_states)
925
+ normalized_actions = normalized_actions.reshape(NUM_ACTIONS_CHUNK, ACTION_DIM)
926
+ normalized_actions = normalized_actions.float().cpu().detach().numpy()
927
+ else:
928
+ # Discrete token-based prediction
929
+ predicted_action_token_ids = (
930
+ language_model_output.logits[
931
+ :,
932
+ NUM_PATCHES + NUM_PROMPT_TOKENS : NUM_PATCHES + NUM_PROMPT_TOKENS + ACTION_DIM * NUM_ACTIONS_CHUNK,
933
+ ]
934
+ .argmax(dim=2)
935
+ .cpu()
936
+ .numpy()
937
+ )
938
+ discretized_actions = self.vocab_size - predicted_action_token_ids
939
+ discretized_actions = np.clip(discretized_actions - 1, a_min=0, a_max=self.bin_centers.shape[0] - 1)
940
+ normalized_actions = self.bin_centers[discretized_actions]
941
+ normalized_actions = normalized_actions.reshape(NUM_ACTIONS_CHUNK, ACTION_DIM)
942
+
943
+ return normalized_actions, actions_hidden_states
944
+
945
+ def predict_action(
946
+ self,
947
+ input_ids: Optional[torch.LongTensor] = None,
948
+ unnorm_key: Optional[str] = None,
949
+ proprio=None,
950
+ proprio_projector=None,
951
+ action_head=None,
952
+ noisy_action_projector=None,
953
+ use_film: bool = False,
954
+ **kwargs: str,
955
+ ) -> np.ndarray:
956
+ """Predict actions from input sequence, with options for different prediction methods.
957
+
958
+ Args:
959
+ input_ids: Input token ids
960
+ unnorm_key: Key for unnormalization statistics
961
+ proprio: Proprioceptive features
962
+ proprio_projector: Projector for proprioceptive features
963
+ action_head: Optional head for L1 regression or diffusion-based prediction
964
+ noisy_action_projector: Projector for noisy actions in diffusion-based prediction
965
+ use_film: Whether to use FiLM conditioning
966
+ **kwargs: Additional arguments including pixel_values and attention_mask
967
+
968
+ Returns:
969
+ Tuple of (unnormalized_actions, action_hidden_states)
970
+ """
971
+ # If the special empty token ('') does not already appear after the colon (':') token in the prompt
972
+ # (after "OUT:" or "ASSISTANT:"), insert it to match the inputs seen at training time
973
+ if not torch.all(input_ids[:, -1] == 29871):
974
+ input_ids = torch.cat(
975
+ (input_ids, torch.unsqueeze(torch.Tensor([29871]).long(), dim=0).to(input_ids.device)), dim=1
976
+ )
977
+
978
+ pixel_values = kwargs["pixel_values"]
979
+ attention_mask = kwargs["attention_mask"]
980
+
981
+ # Create fake labels tensor (needed for action mask)
982
+ labels = input_ids.clone()
983
+ labels[:] = IGNORE_INDEX
984
+
985
+ # Get number of tokens in prompt (excluding the start token)
986
+ NUM_PROMPT_TOKENS = input_ids.shape[-1] - 1 # Subtract action tokens and stop token
987
+
988
+ # Prepare inputs by adding necessary tokens
989
+ input_ids, attention_mask = self._prepare_input_for_action_prediction(input_ids, attention_mask)
990
+
991
+ # Update labels tensor for action mask computation later
992
+ labels = self._prepare_labels_for_action_prediction(labels, input_ids)
993
+
994
+ # Get input embeddings and action masks
995
+ input_embeddings = self.get_input_embeddings()(input_ids)
996
+ all_actions_mask = self._process_action_masks(labels)
997
+
998
+ # Extract language embeddings
999
+ language_embeddings = input_embeddings[~all_actions_mask].reshape(
1000
+ input_embeddings.shape[0], -1, input_embeddings.shape[2]
1001
+ )
1002
+
1003
+ # Process vision features
1004
+ projected_patch_embeddings = self._process_vision_features(pixel_values, language_embeddings, use_film)
1005
+
1006
+ # Add proprioceptive features if provided
1007
+ use_proprio = proprio_projector is not None and proprio is not None
1008
+ if use_proprio:
1009
+ proprio = torch.Tensor(proprio).to(projected_patch_embeddings.device, dtype=projected_patch_embeddings.dtype)
1010
+ projected_patch_embeddings = self._process_proprio_features(
1011
+ projected_patch_embeddings, proprio, proprio_projector
1012
+ )
1013
+
1014
+ # Use diffusion if provided, otherwise use regression or discrete prediction
1015
+ use_diffusion = noisy_action_projector is not None and hasattr(action_head, "noise_scheduler")
1016
+
1017
+ # Calculate number of patches (including proprio token and/or diffusion timestep embedding if present)
1018
+ NUM_PATCHES = self.vision_backbone.get_num_patches() * self.vision_backbone.get_num_images_in_input()
1019
+ if use_proprio:
1020
+ NUM_PATCHES += 1
1021
+ if use_diffusion:
1022
+ NUM_PATCHES += 1
1023
+
1024
+ if use_diffusion:
1025
+ # Sample random noise with shape equal to output action, used as the starting state for reverse diffusion
1026
+ noise = torch.randn(
1027
+ size=(1, NUM_ACTIONS_CHUNK, ACTION_DIM), device=input_embeddings.device, dtype=input_embeddings.dtype
1028
+ )
1029
+
1030
+ # Run diffusion-based prediction
1031
+ normalized_actions, actions_hidden_states = self._run_diffusion_prediction(
1032
+ input_embeddings,
1033
+ all_actions_mask,
1034
+ noise,
1035
+ action_head,
1036
+ projected_patch_embeddings,
1037
+ labels,
1038
+ attention_mask,
1039
+ NUM_PATCHES,
1040
+ NUM_PROMPT_TOKENS,
1041
+ noisy_action_projector,
1042
+ )
1043
+ else:
1044
+ # Run regression or discrete token-based prediction
1045
+ normalized_actions, actions_hidden_states = self._regression_or_discrete_prediction(
1046
+ input_embeddings,
1047
+ all_actions_mask,
1048
+ projected_patch_embeddings,
1049
+ attention_mask,
1050
+ labels,
1051
+ NUM_PATCHES,
1052
+ NUM_PROMPT_TOKENS,
1053
+ action_head,
1054
+ )
1055
+
1056
+ # Unnormalize predicted actions
1057
+ actions = self._unnormalize_actions(normalized_actions, unnorm_key)
1058
+
1059
+ return actions, actions_hidden_states
1060
+
1061
+ @staticmethod
1062
+ def _check_unnorm_key(norm_stats: Dict[str, Dict[str, Any]], unnorm_key: Optional[str]) -> str:
1063
+ """Validate and resolve the unnormalization key for action statistics"""
1064
+ if unnorm_key is None:
1065
+ assert len(norm_stats) == 1, (
1066
+ f"Your model was trained on more than one dataset, "
1067
+ f"please pass a `unnorm_key` from the following options to choose the statistics "
1068
+ f"used for un-normalizing actions: {norm_stats.keys()}"
1069
+ )
1070
+ unnorm_key = next(iter(norm_stats.keys()))
1071
+
1072
+ assert unnorm_key in norm_stats, (
1073
+ f"The `unnorm_key` you chose is not in the set of available dataset statistics, "
1074
+ f"please choose from: {norm_stats.keys()}"
1075
+ )
1076
+ return unnorm_key
1077
+
1078
+ def get_action_dim(self, unnorm_key: Optional[str] = None) -> int:
1079
+ """Get the dimensionality of the policy's action space."""
1080
+ unnorm_key = self._check_unnorm_key(self.norm_stats, unnorm_key)
1081
+ return len(self.norm_stats[unnorm_key]["action"]["min"])
1082
+
1083
+ def get_action_stats(self, unnorm_key: Optional[str] = None) -> Dict[str, Any]:
1084
+ """Get all the logged statistics for the given dataset."""
1085
+ unnorm_key = self._check_unnorm_key(self.norm_stats, unnorm_key)
1086
+ return self.norm_stats[unnorm_key]["action"]
openvla-7b/modeling_prismatic_rl.py ADDED
@@ -0,0 +1,1344 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ modeling_prismatic.py
3
+ Core HuggingFace-style PrismaticPreTrainedModel and PrismaticForConditionalGeneration class definitions.
4
+ Inherits from the default `transformers.PretrainedModel`. Meant to be standalone and self-contained,
5
+ but exactly replicate the logic in `prismatic.models.vlms.prismatic.py`.
6
+ """
7
+
8
+ import logging
9
+ from dataclasses import dataclass
10
+ from functools import partial
11
+ from typing import Any, Callable, ClassVar, Dict, List, Optional, Tuple, Union
12
+
13
+ import numpy as np
14
+ import timm
15
+ import tokenizers
16
+ import torch
17
+ import torch.nn as nn
18
+ import transformers
19
+ from timm.models.vision_transformer import LayerScale
20
+ from transformers import AutoModelForCausalLM, PretrainedConfig, PreTrainedModel
21
+ from transformers.modeling_outputs import ModelOutput
22
+
23
+ from prismatic.training.train_utils import (
24
+ get_current_action_mask,
25
+ get_next_actions_mask,
26
+ )
27
+ from prismatic.vla.constants import (
28
+ ACTION_DIM,
29
+ ACTION_PROPRIO_NORMALIZATION_TYPE,
30
+ ACTION_TOKEN_BEGIN_IDX,
31
+ IGNORE_INDEX,
32
+ NUM_ACTIONS_CHUNK,
33
+ STOP_INDEX,
34
+ NormalizationType,
35
+ )
36
+
37
+ from .configuration_prismatic import OpenVLAConfig, PrismaticConfig
38
+
39
+ TRAINING_MODE = "SFT"
40
+
41
+ # Set up logger
42
+ logger = logging.getLogger(__name__)
43
+
44
+
45
+ # === Utility Functions for Monkey-Patching ===
46
+ def unpack_tuple(fn: Callable[[Any], Tuple[Any]]) -> Callable[[Any], Any]:
47
+ def wrapper(*args: Any, **kwargs: Any) -> Any:
48
+ result = fn(*args, **kwargs)
49
+ return result[0] if isinstance(result, tuple) else result
50
+
51
+ return wrapper
52
+
53
+
54
+ # HF Transformers overwrites parameters with names containing `gamma`; we're going to patch VisionBackbone.LayerScale.
55
+ # =>> TIMM :: https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/vision_transformer.py#L109
56
+ # =>> Transformers :: https://github.com/huggingface/transformers/blob/main/src/transformers/modeling_utils.py#L3960
57
+ def _ls_new_forward(self, x: torch.Tensor) -> torch.Tensor:
58
+ return x.mul_(self.scale_factor) if self.inplace else x * self.scale_factor
59
+
60
+
61
+ def ls_apply_patch(ls_module: LayerScale):
62
+ ls_module.scale_factor = nn.Parameter(ls_module.gamma.clone())
63
+ ls_module.forward = _ls_new_forward.__get__(ls_module, LayerScale)
64
+ del ls_module.gamma
65
+
66
+
67
+ # === Prismatic Vision Backbone (nn.Module) Definitions (w/ Fused Backbone Support) ===
68
+ class PrismaticVisionBackbone(nn.Module):
69
+ """
70
+ Vision backbone for Prismatic models that handles image feature extraction.
71
+ Supports both single backbone (e.g., SigLIP) and fused backbone (e.g., SigLIP + DINOv2) configurations.
72
+ For fused backbones, features from both models are concatenated along the feature dimension.
73
+ """
74
+
75
+ def __init__(
76
+ self,
77
+ use_fused_vision_backbone: bool,
78
+ image_sizes: List[int],
79
+ timm_model_ids: List[str],
80
+ timm_override_act_layers: List[Optional[str]],
81
+ ) -> None:
82
+ """
83
+ Initialize the vision backbone.
84
+ Args:
85
+ use_fused_vision_backbone: Whether to use two backbones and fuse their features
86
+ image_sizes: List of image sizes for each backbone
87
+ timm_model_ids: List of TIMM model IDs to use for each backbone
88
+ timm_override_act_layers: List of activation layer overrides for each backbone
89
+ """
90
+ super().__init__()
91
+ self.use_fused_vision_backbone = use_fused_vision_backbone
92
+ self.num_images_in_input = 1 # Default value, can be overridden later
93
+
94
+ # Validate number of (fused) vision backbones
95
+ if len(timm_model_ids) > 2:
96
+ raise ValueError("Prismatic models only support up to 2 (fused) vision backbones!")
97
+
98
+ # Create primary featurizer
99
+ self.featurizer = self._create_featurizer(
100
+ model_id=timm_model_ids[0], img_size=image_sizes[0], act_layer=timm_override_act_layers[0]
101
+ )
102
+ self.embed_dim = self.featurizer.embed_dim
103
+
104
+ # Create secondary featurizer if using fused backbone
105
+ if self.use_fused_vision_backbone:
106
+ self.fused_featurizer = self._create_featurizer(
107
+ model_id=timm_model_ids[1], img_size=image_sizes[1], act_layer=timm_override_act_layers[1]
108
+ )
109
+ self.embed_dim += self.fused_featurizer.embed_dim
110
+
111
+ # Patch LayerScale modules for HF compatibility
112
+ self._patch_layer_scales()
113
+
114
+ def _create_featurizer(self, model_id: str, img_size: int, act_layer: Optional[str]) -> nn.Module:
115
+ """
116
+ Create a TIMM-based featurizer model with appropriate configurations.
117
+ Args:
118
+ model_id: The TIMM model ID to load
119
+ img_size: Input image size for the model
120
+ act_layer: Override for the activation layer type
121
+ Returns:
122
+ A configured featurizer model
123
+ """
124
+ featurizer = timm.create_model(
125
+ model_id,
126
+ pretrained=False,
127
+ num_classes=0,
128
+ img_size=img_size,
129
+ act_layer=act_layer,
130
+ )
131
+
132
+ # Monkey-patch the forward function to extract the second-to-last layer features
133
+ num_blocks = len(featurizer.blocks)
134
+ featurizer.forward = unpack_tuple(partial(featurizer.get_intermediate_layers, n={num_blocks - 2}))
135
+
136
+ return featurizer
137
+
138
+ def _patch_layer_scales(self) -> None:
139
+ """
140
+ Patch all LayerScale modules to be compatible with HF's parameter naming.
141
+ HF Transformers overwrites parameters with names containing 'gamma',
142
+ so we need to rename and modify the forward method.
143
+ """
144
+ # Patch primary featurizer
145
+ for module in self.featurizer.modules():
146
+ if isinstance(module, LayerScale):
147
+ ls_apply_patch(module)
148
+
149
+ # Patch secondary featurizer if it exists
150
+ if self.use_fused_vision_backbone:
151
+ for module in self.fused_featurizer.modules():
152
+ if isinstance(module, LayerScale):
153
+ ls_apply_patch(module)
154
+
155
+ def get_num_patches(self) -> int:
156
+ """
157
+ Returns the number of vision patches output by the vision backbone.
158
+ Returns:
159
+ Number of patches per image
160
+ """
161
+ return self.featurizer.patch_embed.num_patches
162
+
163
+ def get_num_images_in_input(self) -> int:
164
+ """
165
+ Returns the number of input images for the vision backbone.
166
+ Returns:
167
+ Number of images expected in the input
168
+ """
169
+ return self.num_images_in_input
170
+
171
+ def set_num_images_in_input(self, num_images_in_input: int) -> None:
172
+ """
173
+ Sets the number of input images for the vision backbone.
174
+ Args:
175
+ num_images_in_input: Number of images to expect in the input
176
+ """
177
+ self.num_images_in_input = num_images_in_input
178
+
179
+ def forward(self, pixel_values: torch.Tensor) -> torch.Tensor:
180
+ """
181
+ Implements the forward pass for the vision backbone.
182
+ If `self.use_fused_vision_backbone == True`, uses both SigLIP and DINOv2 transformers to extract visual features
183
+ (otherwise uses SigLIP only). Allows multi-image inputs (but only for fused vision backbone).
184
+ Args:
185
+ pixel_values (torch.Tensor): Pixels for input image(s), (B, C, H, W).
186
+ """
187
+ if self.num_images_in_input == 1:
188
+ if not self.use_fused_vision_backbone:
189
+ return self.featurizer(pixel_values)
190
+
191
+ # Split `pixel_values :: [bsz, 2 * 3, resolution, resolution]` =>> featurize =>> channel stack
192
+ img, img_fused = torch.split(pixel_values, [3, 3], dim=1)
193
+ patches, patches_fused = self.featurizer(img), self.fused_featurizer(img_fused)
194
+
195
+ return torch.cat([patches, patches_fused], dim=2)
196
+
197
+ else:
198
+ assert self.use_fused_vision_backbone, "Multi-image inputs require using fused backbone!"
199
+
200
+ # Split `pixel_values` into individual images (each with 6 channels: 3 for SigLIP + 3 for DINOv2)
201
+ images = torch.split(pixel_values, [6] * self.num_images_in_input, dim=1)
202
+
203
+ # Process each image and collect patches
204
+ all_patches = []
205
+ for img in images:
206
+ # Split each image further into two stacks of channels (each with 3 channels)
207
+ img_regular, img_fused = torch.split(img, [3, 3], dim=1)
208
+
209
+ # Get patches from both SigLIP and DINOv2 vision transformers
210
+ patches = self.featurizer(img_regular)
211
+ patches_fused = self.fused_featurizer(img_fused)
212
+
213
+ # Concatenate SigLIP and DINOv2 patches along the hidden dimension
214
+ combined_patches = torch.cat([patches, patches_fused], dim=2)
215
+ all_patches.append(combined_patches)
216
+
217
+ # Concatenate all patches along the patch dimension
218
+ return torch.cat(all_patches, dim=1)
219
+
220
+
221
+ # === Prismatic Projector (nn.Module) Definitions ===
222
+ class PrismaticProjector(nn.Module):
223
+ def __init__(self, use_fused_vision_backbone: bool, vision_dim: int, llm_dim: int) -> None:
224
+ super().__init__()
225
+ self.use_fused_vision_backbone = use_fused_vision_backbone
226
+ self.vision_dim, self.llm_dim = vision_dim, llm_dim
227
+
228
+ # Switch on `use_fused_vision_backbone` =>> use slightly different MLPs and projection factors!
229
+ if not self.use_fused_vision_backbone:
230
+ self.fc1 = nn.Linear(self.vision_dim, self.llm_dim, bias=True)
231
+ self.fc2 = nn.Linear(self.llm_dim, self.llm_dim, bias=True)
232
+ self.act_fn1 = nn.GELU()
233
+ else:
234
+ initial_projection_dim = 4 * vision_dim
235
+ self.fc1 = nn.Linear(self.vision_dim, initial_projection_dim, bias=True)
236
+ self.fc2 = nn.Linear(initial_projection_dim, self.llm_dim, bias=True)
237
+ self.fc3 = nn.Linear(self.llm_dim, self.llm_dim, bias=True)
238
+ self.act_fn1 = nn.GELU()
239
+ self.act_fn2 = nn.GELU()
240
+
241
+ def forward(self, img_patches: torch.Tensor) -> torch.Tensor:
242
+ if not self.use_fused_vision_backbone:
243
+ projected_features = self.fc1(img_patches)
244
+ projected_features = self.act_fn1(projected_features)
245
+ projected_features = self.fc2(projected_features)
246
+ else:
247
+ projected_features = self.fc1(img_patches)
248
+ projected_features = self.act_fn1(projected_features)
249
+ projected_features = self.fc2(projected_features)
250
+ projected_features = self.act_fn2(projected_features)
251
+ projected_features = self.fc3(projected_features)
252
+
253
+ return projected_features
254
+
255
+
256
+ # === Main HF Class Definitions ===
257
+ @dataclass
258
+ class PrismaticCausalLMOutputWithPast(ModelOutput):
259
+ """Base class for Prismatic casual (visually-conditioned) language model outputs; also exposes visual features."""
260
+
261
+ loss: Optional[torch.FloatTensor] = None
262
+ logits: torch.FloatTensor = None
263
+ past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
264
+ hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
265
+ attentions: Optional[Tuple[torch.FloatTensor]] = None
266
+
267
+ # Additions for VLMs
268
+ projector_features: Optional[torch.FloatTensor] = None
269
+
270
+
271
+ class PrismaticPreTrainedModel(PreTrainedModel):
272
+ config_class: PretrainedConfig = PrismaticConfig
273
+ base_model_prefix: str = "model"
274
+ supports_gradient_checkpointing: bool = True
275
+
276
+ _no_split_modules: ClassVar[List[str]] = ["PrismaticProjector"]
277
+ _skip_keys_device_placement: str = "past_key_values"
278
+ _supports_flash_attn_2: bool = True
279
+
280
+ def _init_weights(self, module: nn.Module) -> None:
281
+ # Important :: this HF ported version is *not* meant for training from scratch; only inference and fine-tuning!
282
+ # => As such, this init_weights code is not correct; if training VLMs from scratch, use the main codebase at
283
+ # https://github.com/TRI-ML/prismatic-vlms
284
+ std = (
285
+ self.config.initializer_range
286
+ if hasattr(self.config, "initializer_range")
287
+ else self.config.text_config.initializer_range
288
+ )
289
+
290
+ if hasattr(module, "class_embedding"):
291
+ module.class_embedding.data.normal_(mean=0.0, std=std)
292
+
293
+ if isinstance(module, (nn.Linear, nn.Conv2d)):
294
+ module.weight.data.normal_(mean=0.0, std=std)
295
+ if module.bias is not None:
296
+ module.bias.data.zero_()
297
+ elif isinstance(module, nn.Embedding):
298
+ module.weight.data.normal_(mean=0.0, std=std)
299
+ if module.padding_idx is not None:
300
+ module.weight.data[module.padding_idx].zero_()
301
+
302
+ @property
303
+ def _supports_sdpa(self) -> bool:
304
+ """Check LLM supports SDPA Attention"""
305
+ return self.language_model._supports_sdpa
306
+
307
+
308
+ class PrismaticForConditionalGeneration(PrismaticPreTrainedModel):
309
+ def __init__(self, config: PrismaticConfig) -> None:
310
+ super().__init__(config)
311
+
312
+ # [Validation] Lightweight Validate on `config` Fields + Dependency Versions
313
+ if config.use_fused_vision_backbone is None:
314
+ raise ValueError("Missing config field `use_fused_vision_backbone`")
315
+
316
+ if timm.__version__ not in {"0.9.10", "0.9.11", "0.9.12", "0.9.16"}:
317
+ raise NotImplementedError(
318
+ "TIMM Version must be >= 0.9.10 and < 1.0.0 (breaking); please raise a GitHub Issue "
319
+ "if you urgently need support for latest TIMM versions."
320
+ )
321
+
322
+ if (transformers.__version__ != "4.40.1") or (tokenizers.__version__ != "0.19.1"):
323
+ logger.warning(
324
+ f"Expected `transformers==4.40.1` and `tokenizers==0.19.1` but got "
325
+ f"`transformers=={transformers.__version__}` and `tokenizers=={tokenizers.__version__}`; "
326
+ f"there might be inference-time regressions due to dependency changes. If in doubt, please"
327
+ f"use the above versions."
328
+ )
329
+
330
+ # Instantiate PrismaticVisionBackbone (w/ Potential Fused Backbone)
331
+ self.vision_backbone = PrismaticVisionBackbone(
332
+ config.use_fused_vision_backbone, config.image_sizes, config.timm_model_ids, config.timm_override_act_layers
333
+ )
334
+
335
+ # Create Multimodal Projector
336
+ self.projector = PrismaticProjector(
337
+ config.use_fused_vision_backbone,
338
+ vision_dim=self.vision_backbone.embed_dim,
339
+ llm_dim=config.text_config.hidden_size,
340
+ )
341
+
342
+ # Instantiate LLM Backbone
343
+ self.language_model = AutoModelForCausalLM.from_config(
344
+ config.text_config, attn_implementation=config._attn_implementation
345
+ )
346
+ self.vocab_size = config.text_config.vocab_size
347
+ self.pad_token_id = config.pad_token_id
348
+ self.llm_dim = config.text_config.hidden_size
349
+
350
+ # HF Boilerplate =>> initializes weights via `_init_weights()` and sets gradient checkpointing
351
+ self.post_init()
352
+
353
+ # === `PreTrainedModel` Boilerplate ===
354
+ def get_input_embeddings(self) -> nn.Module:
355
+ return self.language_model.get_input_embeddings()
356
+
357
+ def set_input_embeddings(self, value: nn.Module) -> None:
358
+ self.language_model.set_input_embeddings(value)
359
+
360
+ def get_output_embeddings(self) -> nn.Module:
361
+ return self.language_model.get_output_embeddings()
362
+
363
+ def set_output_embeddings(self, new_embeddings: nn.Module) -> None:
364
+ self.language_model.set_output_embeddings(new_embeddings)
365
+
366
+ def get_decoder(self) -> nn.Module:
367
+ return self.language_model.get_decoder()
368
+
369
+ def set_decoder(self, decoder: nn.Module) -> None:
370
+ self.language_model.set_decoder(decoder)
371
+
372
+ def tie_weights(self) -> None:
373
+ self.language_model.tie_weights() # Note: `Llama-2` and `Mistral` don't tie weights (no-op)
374
+
375
+ def resize_token_embeddings(
376
+ self, new_num_tokens: Optional[int] = None, pad_to_multiple_of: Optional[int] = None
377
+ ) -> nn.Embedding:
378
+ updated_embeddings = self.language_model.resize_token_embeddings(new_num_tokens, pad_to_multiple_of)
379
+
380
+ # Update config/instance variables
381
+ self.config.text_config.vocab_size = updated_embeddings.num_embeddings
382
+ self.vocab_size = updated_embeddings.num_embeddings
383
+
384
+ return updated_embeddings
385
+
386
+ def _replace_input_embeddings(self, input_embeddings, all_actions_mask, noisy_action_features):
387
+ """
388
+ Replace embeddings in input_embeddings at positions where all_actions_mask is True
389
+ with embeddings from noisy_action_features, using vectorized operations.
390
+ Args:
391
+ input_embeddings: Tensor of shape (B, S, D)
392
+ all_actions_mask: Boolean tensor of shape (B, S)
393
+ noisy_action_features: Tensor of shape (B, K, D) where K is the number of True values in mask per sample
394
+ Returns:
395
+ Modified input_embeddings tensor
396
+ """
397
+ # Clone input to avoid modifying the original tensor
398
+ new_input_embeddings = input_embeddings.clone()
399
+
400
+ # Create a tensor with the same shape of input_embeddings to hold the noisy action features
401
+ repositioned_noisy_action_features = torch.zeros_like(input_embeddings)
402
+
403
+ # Create batch indices for splicing
404
+ batch_indices = torch.arange(input_embeddings.shape[0], device=input_embeddings.device)
405
+ batch_indices = batch_indices.unsqueeze(1).expand(-1, noisy_action_features.shape[1])
406
+
407
+ # Get indices where mask is True for each sample
408
+ masked_indices = torch.stack([torch.where(mask)[0] for mask in all_actions_mask])
409
+
410
+ # Move the noisy action features into their correct positions
411
+ repositioned_noisy_action_features[batch_indices, masked_indices] = noisy_action_features
412
+
413
+ # Combine original input embeddings and noisy action embeddings using the mask
414
+ new_input_embeddings = torch.where(
415
+ all_actions_mask.unsqueeze(-1), repositioned_noisy_action_features, new_input_embeddings
416
+ )
417
+
418
+ return new_input_embeddings
419
+
420
+ def _process_action_masks(self, labels):
421
+ """Helper to get action masks from labels"""
422
+ current_action_mask = get_current_action_mask(labels)
423
+ next_actions_mask = get_next_actions_mask(labels)
424
+ all_actions_mask = current_action_mask | next_actions_mask # (B, seq_len)
425
+ return all_actions_mask
426
+
427
+ def _process_vision_features(self, pixel_values, language_embeddings=None, use_film=False):
428
+ """Process vision features with optional FiLM conditioning"""
429
+ if use_film:
430
+ # FiLM: Infuse language inputs into visual features
431
+ patch_features = self.vision_backbone(pixel_values, language_embeddings) # (bsz, 256 * num_images, D)
432
+ else:
433
+ patch_features = self.vision_backbone(pixel_values) # (bsz, 256 * num_images, D)
434
+
435
+ # Project patch embeddings into language embedding space
436
+ return self.projector(patch_features)
437
+
438
+ def _process_proprio_features(self, projected_patch_embeddings, proprio, proprio_projector):
439
+ """Process proprioceptive features and append to vision features"""
440
+ if proprio_projector is not None and proprio is not None:
441
+ # projected_patch_embeddings: (bsz, num_patches * num_images, llm_dim)
442
+ # proprio: (bsz, proprio_dim) or (propro_dim,)
443
+ proprio = proprio.reshape(projected_patch_embeddings.shape[0], -1) # (bsz, proprio_dim)
444
+ proprio_features = proprio_projector(proprio) # (bsz, llm_dim)
445
+ proprio_features = proprio_features.unsqueeze(dim=1) # (bsz, 1, llm_dim)
446
+ # For simplicity, just append proprio token to the end of projected vision patch tokens
447
+ return torch.cat((projected_patch_embeddings, proprio_features), dim=1)
448
+ return projected_patch_embeddings
449
+
450
+ def _build_multimodal_attention(self, input_embeddings, projected_patch_embeddings, attention_mask):
451
+ """Build multimodal embeddings and attention mask"""
452
+ # Update attention mask
453
+ projected_patch_attention_mask = None
454
+ if attention_mask is not None:
455
+ projected_patch_attention_mask = torch.full(
456
+ (projected_patch_embeddings.shape[0], projected_patch_embeddings.shape[1]),
457
+ fill_value=True,
458
+ dtype=attention_mask.dtype,
459
+ device=attention_mask.device,
460
+ )
461
+
462
+ # Build multimodal embeddings & attention mask; insert embeddings after <BOS> token (1:)
463
+ multimodal_embeddings = torch.cat(
464
+ [input_embeddings[:, :1, :], projected_patch_embeddings, input_embeddings[:, 1:, :]], dim=1
465
+ )
466
+
467
+ multimodal_attention_mask = None
468
+ if attention_mask is not None:
469
+ multimodal_attention_mask = torch.cat(
470
+ [attention_mask[:, :1], projected_patch_attention_mask, attention_mask[:, 1:]], dim=1
471
+ )
472
+
473
+ return multimodal_embeddings, multimodal_attention_mask
474
+
475
+ def _build_multimodal_labels(self, labels, projected_patch_embeddings):
476
+ """Build multimodal labels with IGNORE_INDEX for patch embeddings"""
477
+ if labels is not None:
478
+ projected_patch_labels = torch.full(
479
+ (projected_patch_embeddings.shape[0], projected_patch_embeddings.shape[1]),
480
+ fill_value=IGNORE_INDEX,
481
+ dtype=labels.dtype,
482
+ device=labels.device,
483
+ )
484
+ return torch.cat([labels[:, :1], projected_patch_labels, labels[:, 1:]], dim=1)
485
+ return None
486
+
487
+ # === GenerationMixin Methods ===
488
+ def prepare_inputs_for_generation(
489
+ self,
490
+ input_ids: Optional[torch.Tensor] = None,
491
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
492
+ inputs_embeds: Optional[torch.FloatTensor] = None,
493
+ pixel_values: Optional[torch.FloatTensor] = None,
494
+ attention_mask: Optional[torch.Tensor] = None,
495
+ **kwargs: str,
496
+ ) -> Dict[str, torch.Tensor]:
497
+ """Borrowed from `LlamaForCausalLM` and simplified for batch size = 1; mirrors original PrismaticVLM logic."""
498
+ if ((input_ids is not None) and (input_ids.shape[0] > 1)) or (
499
+ (inputs_embeds is not None) and (inputs_embeds.shape[0] > 1)
500
+ ):
501
+ raise ValueError("Generation with batch size > 1 is not currently supported!")
502
+
503
+ # Handle `past_key_values` (cache) =>> assume `input_ids` just has unprocessed tokens
504
+ if past_key_values is not None:
505
+ input_ids = input_ids[:, -1:]
506
+
507
+ # If `input_embeds` are passed, we only want to use them in the 1st generation step
508
+ if inputs_embeds is not None and past_key_values is None:
509
+ model_inputs = {"input_embeds": inputs_embeds}
510
+ else:
511
+ model_inputs = {"input_ids": input_ids}
512
+
513
+ # Make sure `pixel_values` are preserved in `model_inputs`
514
+ model_inputs.update(
515
+ {
516
+ "attention_mask": attention_mask,
517
+ "pixel_values": pixel_values,
518
+ "past_key_values": past_key_values,
519
+ "use_cache": kwargs.get("use_cache"),
520
+ }
521
+ )
522
+ return model_inputs
523
+
524
+ # Defer to Language Model (all handle this differently, with different return types)
525
+ def _reorder_cache(self, *args, **kwargs) -> Any:
526
+ return self.language_model._reorder_cache(*args, **kwargs)
527
+
528
+ def _prepare_input_for_action_prediction_verl(self, input_ids, attention_mask):
529
+ """Prepares input for action prediction by adding necessary tokens"""
530
+ # Add (ACTION_DIM * NUM_ACTIONS_CHUNK) placeholder tokens to input_ids to simulate action tokens
531
+ placeholder_action_token_ids = (
532
+ torch.ones((input_ids.shape[0], ACTION_DIM * NUM_ACTIONS_CHUNK)).to(input_ids.device).to(input_ids.dtype)
533
+ )
534
+ input_ids = torch.cat([input_ids, placeholder_action_token_ids], dim=-1)
535
+
536
+ # Add stop token to sequence (needed in non-causal bi-directional self-attention, as it appears at train time)
537
+ stop_token_id = torch.ones((input_ids.shape[0], 1)).to(input_ids.device).to(input_ids.dtype) * STOP_INDEX
538
+ input_ids = torch.cat([input_ids, stop_token_id], dim=-1)
539
+
540
+ # Extend the attention mask to fit the new shape of input
541
+ # Note: Only batch size == 1 supported right now
542
+ mask_extension = (
543
+ torch.ones((attention_mask.shape[0], input_ids.shape[-1] - attention_mask.shape[-1]))
544
+ .to(attention_mask.device)
545
+ .to(attention_mask.dtype)
546
+ )
547
+ attention_mask = torch.cat([attention_mask, mask_extension], dim=-1)
548
+
549
+ return input_ids, attention_mask
550
+
551
+ def _prepare_labels_for_action_prediction_verl(self, labels, input_ids):
552
+ """Creates labels tensor for action prediction if not provided"""
553
+ # Extend labels tensor with fake action labels
554
+ ARBITRARY_ACTION_TOKEN_IDX = ACTION_TOKEN_BEGIN_IDX + 1
555
+ labels_extension = (
556
+ torch.ones((labels.shape[0], input_ids.shape[-1] - labels.shape[-1])).to(labels.device).to(labels.dtype)
557
+ * ARBITRARY_ACTION_TOKEN_IDX
558
+ )
559
+ labels = torch.cat([labels, labels_extension], dim=-1)
560
+
561
+ # Replace last label token with stop token
562
+ labels[:, -1] = STOP_INDEX
563
+
564
+ return labels
565
+
566
+ def _verl_discrete_compute_logits(
567
+ self,
568
+ input_embeddings,
569
+ all_actions_mask,
570
+ projected_patch_embeddings,
571
+ attention_mask,
572
+ labels,
573
+ NUM_PATCHES,
574
+ NUM_PROMPT_TOKENS,
575
+ action_head=None,
576
+ ):#contintue!!!!!
577
+ """Run L1 regression-based continuous action prediction or discrete action tokens prediction."""
578
+ # Zero out action token embeddings
579
+ all_actions_mask = all_actions_mask.unsqueeze(-1) # (B, seq_len, 1)
580
+ input_embeddings = input_embeddings * ~all_actions_mask
581
+
582
+ # Build multimodal embeddings and attention mask
583
+ multimodal_embeddings, multimodal_attention_mask = self._build_multimodal_attention(
584
+ input_embeddings, projected_patch_embeddings, attention_mask
585
+ )
586
+
587
+ # Forward pass through language model
588
+ language_model_output = self.language_model(
589
+ input_ids=None,
590
+ attention_mask=multimodal_attention_mask,
591
+ position_ids=None,
592
+ past_key_values=None,
593
+ inputs_embeds=multimodal_embeddings,
594
+ labels=None,
595
+ use_cache=None,
596
+ output_attentions=False,
597
+ output_hidden_states=False,
598
+ return_dict=True,
599
+ )
600
+
601
+ compute_logits = language_model_output.logits[
602
+ :,
603
+ NUM_PATCHES + NUM_PROMPT_TOKENS : NUM_PATCHES + NUM_PROMPT_TOKENS + ACTION_DIM * NUM_ACTIONS_CHUNK,
604
+ ]
605
+
606
+ return compute_logits
607
+
608
+ def _forward_sft(
609
+ self,
610
+ input_ids: Optional[torch.LongTensor] = None,
611
+ pixel_values: Optional[torch.FloatTensor] = None,
612
+ attention_mask: Optional[torch.Tensor] = None,
613
+ labels: Optional[torch.LongTensor] = None,
614
+ proprio=None,
615
+ proprio_projector=None,
616
+ ) -> Union[Tuple, PrismaticCausalLMOutputWithPast]:
617
+ """Run a forward pass through the VLM, returning a PrismaticCausalLMOutputWithPast instance."""
618
+ # Instantiate Placeholder for Projector Features
619
+ projected_patch_embeddings = None
620
+
621
+ # === Handle Multimodal Forward ===
622
+ # Get input embeddings (from language model embeddings)
623
+ input_embeddings = self.get_input_embeddings()(input_ids) # (B, seq_len, D)
624
+
625
+ # Extract action masks
626
+ all_actions_mask = self._process_action_masks(labels)
627
+
628
+ # Extract the language portion of the input embeddings (i.e. remove the action tokens portion)
629
+ language_embeddings = input_embeddings[~all_actions_mask].reshape(
630
+ input_embeddings.shape[0], -1, input_embeddings.shape[2]
631
+ ) # (B, lang_seq_len, llm_dim)
632
+
633
+ # Get visual features
634
+ projected_patch_embeddings = self._process_vision_features(pixel_values, language_embeddings)
635
+
636
+ # Add proprioceptive state if provided
637
+ projected_patch_embeddings = self._process_proprio_features(
638
+ projected_patch_embeddings, proprio, proprio_projector
639
+ )
640
+
641
+ # Replace the embeddings of the action tokens with zeros
642
+ # (Later on, the positional embeddings will be added to them)
643
+ all_actions_mask = all_actions_mask.unsqueeze(-1) # (B, seq_len, 1)
644
+ input_embeddings = input_embeddings * ~all_actions_mask
645
+
646
+ # Build multimodal embeddings & attention mask
647
+ multimodal_embeddings, multimodal_attention_mask = self._build_multimodal_attention(
648
+ input_embeddings, projected_patch_embeddings, attention_mask
649
+ )
650
+
651
+ # Build labels for multimodal sequence if needed
652
+ multimodal_labels = self._build_multimodal_labels(labels, projected_patch_embeddings)
653
+
654
+ # Dispatch to language model
655
+ language_model_output = self.language_model(
656
+ input_ids=None,
657
+ attention_mask=multimodal_attention_mask,
658
+ position_ids=None,
659
+ past_key_values=None,
660
+ inputs_embeds=multimodal_embeddings,
661
+ labels=multimodal_labels,
662
+ use_cache=None,
663
+ output_attentions=False,
664
+ output_hidden_states=False,
665
+ return_dict=True,
666
+ )
667
+
668
+ return PrismaticCausalLMOutputWithPast(
669
+ loss=language_model_output.loss,
670
+ logits=language_model_output.logits,
671
+ past_key_values=language_model_output.past_key_values,
672
+ hidden_states=language_model_output.hidden_states,
673
+ attentions=language_model_output.attentions,
674
+ projector_features=projected_patch_embeddings,
675
+ )
676
+
677
+ def _forward_rl(
678
+ self,
679
+ input_ids: Optional[torch.LongTensor] = None,
680
+ pixel_values=None,
681
+ attention_mask=None,
682
+ labels=None,
683
+ proprio=None,
684
+ proprio_projector=None,
685
+ **kwargs: str,
686
+ ) :
687
+ """Predict actions from input sequence, with options for different prediction methods.
688
+ Args:
689
+ input_ids: Input token ids
690
+ unnorm_key: Key for unnormalization statistics
691
+ proprio: Proprioceptive features
692
+ proprio_projector: Projector for proprioceptive features
693
+ action_head: Optional head for L1 regression or diffusion-based prediction
694
+ noisy_action_projector: Projector for noisy actions in diffusion-based prediction
695
+ use_film: Whether to use FiLM conditioning
696
+ **kwargs: Additional arguments including pixel_values and attention_mask
697
+ Returns:
698
+ Tuple of (unnormalized_actions, action_hidden_states)
699
+ """
700
+ # Create fake labels tensor (needed for action mask)
701
+ labels = input_ids.clone()
702
+ labels[:] = IGNORE_INDEX
703
+
704
+ # # Get number of tokens in prompt (excluding the start token)
705
+ NUM_PROMPT_TOKENS = input_ids.shape[-1] - 1 # Subtract action tokens and stop token
706
+
707
+
708
+ # # Prepare inputs by adding necessary tokens
709
+ # #input_ids, attention_mask = self._prepare_input_for_action_prediction_verl(input_ids, attention_mask)
710
+
711
+ # #test
712
+ placeholder_action_token_ids = (
713
+ torch.ones((input_ids.shape[0], ACTION_DIM * NUM_ACTIONS_CHUNK)).to(input_ids.device).to(input_ids.dtype)
714
+ )
715
+ input_ids = torch.cat([input_ids, placeholder_action_token_ids], dim=-1)
716
+
717
+ # Add stop token to sequence (needed in non-causal bi-directional self-attention, as it appears at train time)
718
+ stop_token_id = torch.ones((input_ids.shape[0], 1)).to(input_ids.device).to(input_ids.dtype) * STOP_INDEX
719
+ input_ids = torch.cat([input_ids, stop_token_id], dim=-1)
720
+
721
+ # Extend the attention mask to fit the new shape of input
722
+ # Note: Only batch size == 1 supported right now
723
+ mask_extension = (
724
+ torch.ones((attention_mask.shape[0], input_ids.shape[-1] - attention_mask.shape[-1]))
725
+ .to(attention_mask.device)
726
+ .to(attention_mask.dtype)
727
+ )
728
+ attention_mask = torch.cat([attention_mask, mask_extension], dim=-1)
729
+
730
+ ARBITRARY_ACTION_TOKEN_IDX = ACTION_TOKEN_BEGIN_IDX + 1
731
+ labels_extension = (
732
+ torch.ones((labels.shape[0], input_ids.shape[-1] - labels.shape[-1])).to(labels.device).to(labels.dtype)
733
+ * ARBITRARY_ACTION_TOKEN_IDX
734
+ )
735
+ labels = torch.cat([labels, labels_extension], dim=-1)
736
+
737
+ # # Replace last label token with stop token
738
+ labels[:, -1] = STOP_INDEX
739
+
740
+
741
+ # Get input embeddings and action masks
742
+
743
+ #NUM_PROMPT_TOKENS = kwargs["num_prompt_tokens"]
744
+
745
+ input_embeddings = self.get_input_embeddings()(input_ids)
746
+
747
+ newline_positions = labels != IGNORE_INDEX
748
+
749
+ # Calculate cumulative sum to identify regions between newlines
750
+ cumsum = torch.cumsum(newline_positions, dim=1)
751
+
752
+ # Create the mask
753
+ mask = (1 <= cumsum) & (cumsum <= ACTION_DIM)
754
+
755
+ # Extract the action part only
756
+ action_tokens_only_mask = labels > ACTION_TOKEN_BEGIN_IDX
757
+ current_action_mask = action_tokens_only_mask * mask
758
+
759
+ #next_actions_mask = get_next_actions_mask(labels)
760
+ newline_positions = labels != IGNORE_INDEX
761
+
762
+ # Calculate cumulative sum to identify regions between newlines
763
+ cumsum = torch.cumsum(newline_positions, dim=1)
764
+
765
+ # Create the mask
766
+ mask = cumsum > ACTION_DIM
767
+
768
+ # Extract the action part only
769
+ action_tokens_only_mask = labels > ACTION_TOKEN_BEGIN_IDX
770
+ next_actions_mask = action_tokens_only_mask * mask
771
+
772
+ all_actions_mask = current_action_mask | next_actions_mask # (B, seq_len)
773
+
774
+ # Extract language embeddings
775
+ language_embeddings = input_embeddings[~all_actions_mask].reshape(
776
+ input_embeddings.shape[0], -1, input_embeddings.shape[2]
777
+ )
778
+
779
+ # Process vision features
780
+ #projected_patch_embeddings = self._process_vision_features(pixel_values, language_embeddings, use_film)
781
+ patch_features = self.vision_backbone(pixel_values) # (bsz, 256 * num_images, D)
782
+
783
+ projected_patch_embeddings = self.projector(patch_features)
784
+ #test end
785
+
786
+
787
+ # Add proprioceptive features if provided
788
+ use_proprio = proprio_projector is not None and proprio is not None
789
+ if use_proprio:
790
+ proprio = torch.Tensor(proprio).to(projected_patch_embeddings.device, dtype=projected_patch_embeddings.dtype)
791
+ projected_patch_embeddings = self._process_proprio_features(
792
+ projected_patch_embeddings, proprio, proprio_projector
793
+ )
794
+
795
+ # Calculate number of patches (including proprio token and/or diffusion timestep embedding if present)
796
+ NUM_PATCHES = self.vision_backbone.get_num_patches() * self.vision_backbone.get_num_images_in_input()
797
+ if use_proprio:
798
+ NUM_PATCHES += 1
799
+
800
+ all_actions_mask = all_actions_mask.unsqueeze(-1) # (B, seq_len, 1)
801
+ input_embeddings = input_embeddings * ~all_actions_mask
802
+ projected_patch_attention_mask = None
803
+ if attention_mask is not None:
804
+ projected_patch_attention_mask = torch.full(
805
+ (projected_patch_embeddings.shape[0], projected_patch_embeddings.shape[1]),
806
+ fill_value=True,
807
+ dtype=attention_mask.dtype,
808
+ device=attention_mask.device,
809
+ )
810
+
811
+ # Build multimodal embeddings & attention mask; insert embeddings after <BOS> token (1:)
812
+ multimodal_embeddings = torch.cat(
813
+ [input_embeddings[:, :1, :], projected_patch_embeddings, input_embeddings[:, 1:, :]], dim=1
814
+ )
815
+
816
+ multimodal_attention_mask = None
817
+ if attention_mask is not None:
818
+ multimodal_attention_mask = torch.cat(
819
+ [attention_mask[:, :1], projected_patch_attention_mask, attention_mask[:, 1:]], dim=1
820
+ )
821
+ # Forward pass through language model
822
+ language_model_output = self.language_model(
823
+ input_ids=None,
824
+ attention_mask=multimodal_attention_mask,
825
+ position_ids=None,
826
+ past_key_values=None,
827
+ inputs_embeds=multimodal_embeddings,
828
+ labels=None,
829
+ use_cache=None,
830
+ output_attentions=False,
831
+ output_hidden_states=False,
832
+ return_dict=True,
833
+ )
834
+
835
+
836
+ compute_logits = language_model_output.logits[
837
+ :,
838
+ NUM_PATCHES + NUM_PROMPT_TOKENS : NUM_PATCHES + NUM_PROMPT_TOKENS + ACTION_DIM * NUM_ACTIONS_CHUNK,
839
+ ]
840
+
841
+ #test end
842
+
843
+ return compute_logits
844
+
845
+
846
+ def forward(
847
+ self,
848
+ input_ids: Optional[torch.LongTensor] = None,
849
+ attention_mask: Optional[torch.Tensor] = None,
850
+ pixel_values: Optional[torch.FloatTensor] = None,
851
+ labels: Optional[torch.LongTensor] = None,
852
+ inputs_embeds: Optional[torch.FloatTensor] = None,
853
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
854
+ use_cache: Optional[bool] = None,
855
+ output_attentions: Optional[bool] = None,
856
+ output_hidden_states: Optional[bool] = None,
857
+ output_projector_features: Optional[bool] = None,
858
+ return_dict: Optional[bool] = None,
859
+ proprio=None,
860
+ proprio_projector=None,
861
+ noisy_actions=None,
862
+ noisy_action_projector=None,
863
+ diffusion_timestep_embeddings=None,
864
+ use_film: bool = False,
865
+ ) -> Union[Tuple, PrismaticCausalLMOutputWithPast]:
866
+ if TRAINING_MODE == "SFT":
867
+ return self._forward_sft(
868
+ input_ids=input_ids,
869
+ pixel_values=pixel_values,
870
+ attention_mask=attention_mask,
871
+ labels=labels,
872
+ proprio=proprio,
873
+ proprio_projector=proprio_projector,
874
+ )
875
+ elif TRAINING_MODE == "RL":
876
+ return self._forward_rl(
877
+ input_ids=input_ids,
878
+ pixel_values=pixel_values,
879
+ attention_mask=attention_mask,
880
+ labels=None,
881
+ proprio=proprio,
882
+ proprio_projector=proprio_projector,
883
+ )
884
+ else:
885
+ raise ValueError(f"Unsupported training mode: {TRAINING_MODE}. Supported modes are 'SFT' and 'RL'.")
886
+
887
+ class OpenVLAForActionPredictionForRL(PrismaticForConditionalGeneration):
888
+ config_class: PretrainedConfig = OpenVLAConfig
889
+
890
+ def __init__(self, config: OpenVLAConfig) -> None:
891
+ super().__init__(config)
892
+ self.norm_stats = config.norm_stats
893
+
894
+ # Compute action bins
895
+ self.bins = np.linspace(-1, 1, config.n_action_bins)
896
+ self.bin_centers = (self.bins[:-1] + self.bins[1:]) / 2.0
897
+
898
+ # Compute vocab size for de-tokenization -- revert added "multiple of"
899
+ self.vocab_size = self.config.text_config.vocab_size - self.config.pad_to_multiple_of
900
+
901
+ def _prepare_input_for_action_prediction(self, input_ids, attention_mask):
902
+ """Prepares input for action prediction by adding necessary tokens"""
903
+ # Add (ACTION_DIM * NUM_ACTIONS_CHUNK) placeholder tokens to input_ids to simulate action tokens
904
+ placeholder_action_token_ids = (
905
+ torch.ones((input_ids.shape[0], ACTION_DIM * NUM_ACTIONS_CHUNK)).to(input_ids.device).to(input_ids.dtype)
906
+ )
907
+ input_ids = torch.cat([input_ids, placeholder_action_token_ids], dim=-1)
908
+
909
+ # Add stop token to sequence (needed in non-causal bi-directional self-attention, as it appears at train time)
910
+ stop_token_id = torch.ones((input_ids.shape[0], 1)).to(input_ids.device).to(input_ids.dtype) * STOP_INDEX
911
+ input_ids = torch.cat([input_ids, stop_token_id], dim=-1)
912
+
913
+ # Extend the attention mask to fit the new shape of input
914
+ # Note: Only batch size == 1 supported right now
915
+ mask_extension = (
916
+ torch.ones((attention_mask.shape[0], input_ids.shape[-1] - attention_mask.shape[-1]))
917
+ .to(attention_mask.device)
918
+ .to(attention_mask.dtype)
919
+ )
920
+ attention_mask = torch.cat([attention_mask, mask_extension], dim=-1)
921
+
922
+ return input_ids, attention_mask
923
+
924
+ def _prepare_labels_for_action_prediction(self, labels, input_ids):
925
+ """Creates labels tensor for action prediction if not provided"""
926
+ # Extend labels tensor with fake action labels
927
+ ARBITRARY_ACTION_TOKEN_IDX = ACTION_TOKEN_BEGIN_IDX + 1
928
+ labels_extension = (
929
+ torch.ones((labels.shape[0], input_ids.shape[-1] - labels.shape[-1])).to(labels.device).to(labels.dtype)
930
+ * ARBITRARY_ACTION_TOKEN_IDX
931
+ )
932
+ labels = torch.cat([labels, labels_extension], dim=-1)
933
+
934
+ # Replace last label token with stop token
935
+ labels[:, -1] = STOP_INDEX
936
+
937
+ return labels
938
+
939
+ def _unnormalize_actions(self, normalized_actions, unnorm_key=None):
940
+ """Unnormalize actions using dataset statistics"""
941
+ action_norm_stats = self.get_action_stats(unnorm_key)
942
+
943
+ if ACTION_PROPRIO_NORMALIZATION_TYPE == NormalizationType.BOUNDS:
944
+ mask = action_norm_stats.get("mask", np.ones_like(action_norm_stats["min"], dtype=bool))
945
+ action_high, action_low = np.array(action_norm_stats["max"]), np.array(action_norm_stats["min"])
946
+ elif ACTION_PROPRIO_NORMALIZATION_TYPE == NormalizationType.BOUNDS_Q99:
947
+ mask = action_norm_stats.get("mask", np.ones_like(action_norm_stats["q01"], dtype=bool))
948
+ action_high, action_low = np.array(action_norm_stats["q99"]), np.array(action_norm_stats["q01"])
949
+ else:
950
+ raise ValueError("Unsupported action/proprio normalization type detected!")
951
+
952
+ actions = np.where(
953
+ mask,
954
+ 0.5 * (normalized_actions + 1) * (action_high - action_low + 1e-8) + action_low,
955
+ normalized_actions,
956
+ )
957
+
958
+ return actions
959
+ def _regression_or_discrete_prediction(
960
+ self,
961
+ input_embeddings,
962
+ all_actions_mask,
963
+ projected_patch_embeddings,
964
+ attention_mask,
965
+ labels,
966
+ NUM_PATCHES,
967
+ NUM_PROMPT_TOKENS,
968
+ action_head=None,
969
+ ):
970
+ """Run L1 regression-based continuous action prediction or discrete action tokens prediction."""
971
+ # Zero out action token embeddings
972
+ all_actions_mask = all_actions_mask.unsqueeze(-1) # (B, seq_len, 1)
973
+ input_embeddings = input_embeddings * ~all_actions_mask
974
+
975
+ # Build multimodal embeddings and attention mask
976
+ multimodal_embeddings, multimodal_attention_mask = self._build_multimodal_attention(
977
+ input_embeddings, projected_patch_embeddings, attention_mask
978
+ )
979
+
980
+ # Forward pass through language model
981
+ language_model_output = self.language_model(
982
+ input_ids=None,
983
+ attention_mask=multimodal_attention_mask,
984
+ position_ids=None,
985
+ past_key_values=None,
986
+ inputs_embeds=multimodal_embeddings,
987
+ labels=None,
988
+ use_cache=None,
989
+ output_attentions=False,
990
+ output_hidden_states=True,
991
+ return_dict=True,
992
+ )
993
+
994
+ # Extract hidden states for action tokens
995
+ last_hidden_states = language_model_output.hidden_states[-1] # (B, seq_len, D)
996
+ actions_hidden_states = last_hidden_states[
997
+ :,
998
+ NUM_PATCHES + NUM_PROMPT_TOKENS : NUM_PATCHES + NUM_PROMPT_TOKENS + ACTION_DIM * NUM_ACTIONS_CHUNK,
999
+ :,
1000
+ ] # (B, act_chunk_len, D)
1001
+
1002
+ # Handle different prediction methods
1003
+ if action_head is not None:
1004
+ # L1 regression prediction
1005
+ normalized_actions = action_head.predict_action(actions_hidden_states)
1006
+ normalized_actions = normalized_actions.reshape(NUM_ACTIONS_CHUNK, ACTION_DIM)
1007
+ normalized_actions = normalized_actions.float().cpu().detach().numpy()
1008
+ else:
1009
+ # Discrete token-based prediction
1010
+ predicted_action_token_ids = (
1011
+ language_model_output.logits[
1012
+ :,
1013
+ NUM_PATCHES + NUM_PROMPT_TOKENS : NUM_PATCHES + NUM_PROMPT_TOKENS + ACTION_DIM * NUM_ACTIONS_CHUNK,
1014
+ ]
1015
+ .argmax(dim=2)
1016
+ .cpu()
1017
+ .numpy()
1018
+ )
1019
+ discretized_actions = self.vocab_size - predicted_action_token_ids
1020
+ discretized_actions = np.clip(discretized_actions - 1, a_min=0, a_max=self.bin_centers.shape[0] - 1)
1021
+ normalized_actions = self.bin_centers[discretized_actions]
1022
+ normalized_actions = normalized_actions.reshape(NUM_ACTIONS_CHUNK, ACTION_DIM)
1023
+
1024
+ return normalized_actions, actions_hidden_states
1025
+
1026
+ def _verl_discrete_prediction(
1027
+ self,
1028
+ input_embeddings,
1029
+ all_actions_mask,
1030
+ projected_patch_embeddings,
1031
+ attention_mask,
1032
+ labels,
1033
+ NUM_PATCHES,
1034
+ NUM_PROMPT_TOKENS,
1035
+ action_head=None,
1036
+ do_sample=True,
1037
+ temperature=1,
1038
+ ):
1039
+ """Run L1 regression-based continuous action prediction or discrete action tokens prediction."""
1040
+ # Zero out action token embeddings
1041
+ all_actions_mask = all_actions_mask.unsqueeze(-1) # (B, seq_len, 1)
1042
+ input_embeddings = input_embeddings * ~all_actions_mask
1043
+
1044
+ # Build multimodal embeddings and attention mask
1045
+ multimodal_embeddings, multimodal_attention_mask = self._build_multimodal_attention(
1046
+ input_embeddings, projected_patch_embeddings, attention_mask
1047
+ )
1048
+
1049
+ # Forward pass through language model
1050
+ language_model_output = self.language_model(
1051
+ input_ids=None,
1052
+ attention_mask=multimodal_attention_mask,
1053
+ position_ids=None,
1054
+ past_key_values=None,
1055
+ inputs_embeds=multimodal_embeddings,
1056
+ labels=None,
1057
+ use_cache=None,
1058
+ output_attentions=False,
1059
+ output_hidden_states=False,
1060
+ return_dict=True,
1061
+ )
1062
+ NUM_PROMPT_TOKENS = NUM_PROMPT_TOKENS + NUM_PATCHES
1063
+ batch_size = language_model_output.logits.shape[0]
1064
+ device = language_model_output.logits.device
1065
+
1066
+ start_indices = NUM_PROMPT_TOKENS.unsqueeze(1) # [batch_size, 1]
1067
+ position_offsets = torch.arange(ACTION_DIM * NUM_ACTIONS_CHUNK, device=device).unsqueeze(0) # [1, seq_length]
1068
+ seq_indices = start_indices + position_offsets # [batch_size, ACTION_DIM*NUM_ACTIONS_CHUNK]
1069
+ if do_sample == False:
1070
+ #padding + only get last 256 token
1071
+ reponse_ids_logits = language_model_output.logits[
1072
+ torch.arange(batch_size, device=device).unsqueeze(-1),
1073
+ seq_indices,
1074
+ :
1075
+ ]
1076
+ start_index = self.vocab_size - 256
1077
+ response_last256 = reponse_ids_logits[..., -256-64:-64] # Shape: [batch_size, seq_len, 256]
1078
+ last256_argmax = response_last256.argmax(dim=-1) # Shape: [batch_size, seq_len]
1079
+ reponse_ids = last256_argmax + start_index # Shape: [batch_size, seq_len]
1080
+ #padding + only get last 256 token end
1081
+
1082
+ predicted_action_token_ids = reponse_ids.cpu().numpy()
1083
+
1084
+ else:
1085
+ assert temperature>0
1086
+
1087
+ action_logits = language_model_output.logits[
1088
+ torch.arange(batch_size, device=device).unsqueeze(-1),
1089
+ seq_indices,
1090
+ :
1091
+ ]
1092
+
1093
+ #padding + only get last 256 token
1094
+ action_logits_last256 = action_logits[..., -256-64:-64]
1095
+ scaled_logits = action_logits_last256 / temperature
1096
+ probs = torch.softmax(scaled_logits, dim=-1)
1097
+ assert probs.shape[-1] == 256
1098
+ probs_flat = probs.reshape(-1, probs.shape[-1])
1099
+ sampled_indices_flat = torch.multinomial(probs_flat, num_samples=1)
1100
+ original_ids_flat = sampled_indices_flat + (self.vocab_size - 256)
1101
+ reponse_ids = original_ids_flat.view(action_logits.shape[0], -1)
1102
+ #padding + only get last 256 token end
1103
+
1104
+ predicted_action_token_ids = reponse_ids.cpu().numpy()
1105
+
1106
+ discretized_actions = self.vocab_size - predicted_action_token_ids
1107
+ discretized_actions = np.clip(discretized_actions - 1, a_min=0, a_max=self.bin_centers.shape[0] - 1)
1108
+ normalized_actions = self.bin_centers[discretized_actions]
1109
+ #normalized_actions = normalized_actions.reshape(NUM_ACTIONS_CHUNK, ACTION_DIM)
1110
+ normalized_actions = normalized_actions.reshape(-1, ACTION_DIM)
1111
+
1112
+ return normalized_actions, reponse_ids
1113
+ #return normalized_actions, actions_hidden_states
1114
+
1115
+ def predict_action(
1116
+ self,
1117
+ input_ids: Optional[torch.LongTensor] = None,
1118
+ unnorm_key: Optional[str] = None,
1119
+ proprio=None,
1120
+ proprio_projector=None,
1121
+ action_head=None,
1122
+ noisy_action_projector=None,
1123
+ use_film: bool = False,
1124
+ **kwargs: str,
1125
+ ) -> np.ndarray:
1126
+ """Predict actions from input sequence, with options for different prediction methods.
1127
+ Args:
1128
+ input_ids: Input token ids
1129
+ unnorm_key: Key for unnormalization statistics
1130
+ proprio: Proprioceptive features
1131
+ proprio_projector: Projector for proprioceptive features
1132
+ action_head: Optional head for L1 regression or diffusion-based prediction
1133
+ noisy_action_projector: Projector for noisy actions in diffusion-based prediction
1134
+ use_film: Whether to use FiLM conditioning
1135
+ **kwargs: Additional arguments including pixel_values and attention_mask
1136
+ Returns:
1137
+ Tuple of (unnormalized_actions, action_hidden_states)
1138
+ """
1139
+ # If the special empty token ('') does not already appear after the colon (':') token in the prompt
1140
+ # (after "OUT:" or "ASSISTANT:"), insert it to match the inputs seen at training time
1141
+ if not torch.all(input_ids[:, -1] == 29871):
1142
+ input_ids = torch.cat(
1143
+ (input_ids, torch.unsqueeze(torch.Tensor([29871]).long(), dim=0).to(input_ids.device)), dim=1
1144
+ )
1145
+
1146
+ pixel_values = kwargs["pixel_values"]
1147
+ attention_mask = kwargs["attention_mask"]
1148
+
1149
+ # Create fake labels tensor (needed for action mask)
1150
+ labels = input_ids.clone()
1151
+ labels[:] = IGNORE_INDEX
1152
+
1153
+ # Get number of tokens in prompt (excluding the start token)
1154
+ NUM_PROMPT_TOKENS = input_ids.shape[-1] - 1 # Subtract action tokens and stop token
1155
+
1156
+ # Prepare inputs by adding necessary tokens
1157
+ input_ids, attention_mask = self._prepare_input_for_action_prediction(input_ids, attention_mask)
1158
+
1159
+ # Update labels tensor for action mask computation later
1160
+ labels = self._prepare_labels_for_action_prediction(labels, input_ids)
1161
+
1162
+ # Get input embeddings and action masks
1163
+ input_embeddings = self.get_input_embeddings()(input_ids)
1164
+ all_actions_mask = self._process_action_masks(labels)
1165
+
1166
+ # Extract language embeddings
1167
+ language_embeddings = input_embeddings[~all_actions_mask].reshape(
1168
+ input_embeddings.shape[0], -1, input_embeddings.shape[2]
1169
+ )
1170
+
1171
+ # Process vision features
1172
+ projected_patch_embeddings = self._process_vision_features(pixel_values, language_embeddings, use_film)
1173
+
1174
+ # Add proprioceptive features if provided
1175
+ use_proprio = proprio_projector is not None and proprio is not None
1176
+ if use_proprio:
1177
+ proprio = torch.Tensor(proprio).to(projected_patch_embeddings.device, dtype=projected_patch_embeddings.dtype)
1178
+ projected_patch_embeddings = self._process_proprio_features(
1179
+ projected_patch_embeddings, proprio, proprio_projector
1180
+ )
1181
+
1182
+ # Use diffusion if provided, otherwise use regression or discrete prediction
1183
+ use_diffusion = noisy_action_projector is not None and hasattr(action_head, "noise_scheduler")
1184
+
1185
+ # Calculate number of patches (including proprio token and/or diffusion timestep embedding if present)
1186
+ NUM_PATCHES = self.vision_backbone.get_num_patches() * self.vision_backbone.get_num_images_in_input()
1187
+ if use_proprio:
1188
+ NUM_PATCHES += 1
1189
+ if use_diffusion:
1190
+ raise ValueError
1191
+ else:
1192
+ # Run regression or discrete token-based prediction
1193
+ normalized_actions, actions_hidden_states = self._regression_or_discrete_prediction(
1194
+ input_embeddings,
1195
+ all_actions_mask,
1196
+ projected_patch_embeddings,
1197
+ attention_mask,
1198
+ labels,
1199
+ NUM_PATCHES,
1200
+ NUM_PROMPT_TOKENS,
1201
+ action_head,
1202
+ )
1203
+
1204
+ # Unnormalize predicted actions
1205
+ actions = self._unnormalize_actions(normalized_actions, unnorm_key)
1206
+
1207
+ return actions, actions_hidden_states
1208
+
1209
+ def generate_action_verl(
1210
+ self,
1211
+ input_ids: Optional[torch.LongTensor] = None,
1212
+ unnorm_key: Optional[str] = None,
1213
+ proprio=None,
1214
+ proprio_projector=None,
1215
+ action_head=None,
1216
+ noisy_action_projector=None,
1217
+ use_film: bool = False,
1218
+ **kwargs: str,
1219
+ ) -> np.ndarray:
1220
+ """Predict actions from input sequence, with options for different prediction methods.
1221
+ Args:
1222
+ input_ids: Input token ids
1223
+ unnorm_key: Key for unnormalization statistics
1224
+ proprio: Proprioceptive features
1225
+ proprio_projector: Projector for proprioceptive features
1226
+ action_head: Optional head for L1 regression or diffusion-based prediction
1227
+ noisy_action_projector: Projector for noisy actions in diffusion-based prediction
1228
+ use_film: Whether to use FiLM conditioning
1229
+ **kwargs: Additional arguments including pixel_values and attention_mask
1230
+ Returns:
1231
+ Tuple of (unnormalized_actions, action_hidden_states)
1232
+ """
1233
+
1234
+ pixel_values = kwargs["pixel_values"]
1235
+ attention_mask = kwargs["attention_mask"]
1236
+ do_sample = kwargs["do_sample"]
1237
+ temperature = kwargs["temperature"]
1238
+
1239
+ # Create fake labels tensor (needed for action mask)
1240
+ labels = input_ids.clone()
1241
+ labels[:] = IGNORE_INDEX
1242
+
1243
+ # Get number of tokens in prompt (excluding the start token)
1244
+ #NUM_PROMPT_TOKENS = input_ids.shape[-1] - 1 # Subtract action tokens and stop token
1245
+ padding_idx = kwargs["padding_idx"]
1246
+ num_prompt_tokens = input_ids.ne(padding_idx).sum(dim=1) - 1
1247
+
1248
+ # Prepare inputs by adding necessary tokens
1249
+ input_ids, attention_mask = self._prepare_input_for_action_prediction(input_ids, attention_mask)
1250
+
1251
+ # Update labels tensor for action mask computation later
1252
+ labels = self._prepare_labels_for_action_prediction(labels, input_ids)
1253
+
1254
+ #here to convert padding from before to last
1255
+ padding_mask = input_ids.ne(padding_idx)
1256
+ assert torch.all(padding_mask==attention_mask.ne(0))
1257
+ #print("in predict_action padding_mask:", padding_mask)
1258
+ padding_mask = padding_mask.int()
1259
+ sorted_indices = torch.argsort(padding_mask, dim=1, descending=True, stable=True)
1260
+ input_ids = torch.gather(input_ids, 1, sorted_indices)
1261
+ attention_mask = torch.gather(attention_mask, 1, sorted_indices)
1262
+ labels = torch.gather(labels, 1, sorted_indices)
1263
+ assert use_film==False
1264
+
1265
+ # Get input embeddings and action masks
1266
+ input_embeddings = self.get_input_embeddings()(input_ids)
1267
+ all_actions_mask = self._process_action_masks(labels)
1268
+
1269
+ # Extract language embeddings
1270
+ language_embeddings = input_embeddings[~all_actions_mask].reshape(
1271
+ input_embeddings.shape[0], -1, input_embeddings.shape[2]
1272
+ )
1273
+
1274
+ # Process vision features
1275
+ projected_patch_embeddings = self._process_vision_features(pixel_values, language_embeddings, use_film)
1276
+
1277
+ # Add proprioceptive features if provided
1278
+ use_proprio = proprio_projector is not None and proprio is not None
1279
+ if use_proprio:
1280
+ proprio = torch.Tensor(proprio).to(projected_patch_embeddings.device, dtype=projected_patch_embeddings.dtype)
1281
+ projected_patch_embeddings = self._process_proprio_features(
1282
+ projected_patch_embeddings, proprio, proprio_projector
1283
+ )
1284
+
1285
+ # Use diffusion if provided, otherwise use regression or discrete prediction
1286
+ use_diffusion = noisy_action_projector is not None and hasattr(action_head, "noise_scheduler")
1287
+
1288
+ # Calculate number of patches (including proprio token and/or diffusion timestep embedding if present)
1289
+ NUM_PATCHES = self.vision_backbone.get_num_patches() * self.vision_backbone.get_num_images_in_input()
1290
+ if use_proprio:
1291
+ NUM_PATCHES += 1
1292
+ if use_diffusion:
1293
+ NUM_PATCHES += 1
1294
+
1295
+ if use_diffusion:
1296
+ raise ValueError
1297
+ else:
1298
+ # Run regression or discrete token-based prediction
1299
+ normalized_actions, reponse_ids = self._verl_discrete_prediction(
1300
+ input_embeddings,
1301
+ all_actions_mask,
1302
+ projected_patch_embeddings,
1303
+ attention_mask,
1304
+ labels,
1305
+ NUM_PATCHES,
1306
+ num_prompt_tokens,
1307
+ action_head,
1308
+ do_sample=do_sample,
1309
+ temperature=temperature,
1310
+ )
1311
+
1312
+ # Unnormalize predicted actions
1313
+ actions = self._unnormalize_actions(normalized_actions, unnorm_key)
1314
+ #verl add!
1315
+ actions = actions.reshape(-1 ,NUM_ACTIONS_CHUNK, ACTION_DIM)
1316
+ #
1317
+ return actions, reponse_ids
1318
+
1319
+ @staticmethod
1320
+ def _check_unnorm_key(norm_stats: Dict[str, Dict[str, Any]], unnorm_key: Optional[str]) -> str:
1321
+ """Validate and resolve the unnormalization key for action statistics"""
1322
+ if unnorm_key is None:
1323
+ assert len(norm_stats) == 1, (
1324
+ f"Your model was trained on more than one dataset, "
1325
+ f"please pass a `unnorm_key` from the following options to choose the statistics "
1326
+ f"used for un-normalizing actions: {norm_stats.keys()}"
1327
+ )
1328
+ unnorm_key = next(iter(norm_stats.keys()))
1329
+
1330
+ assert unnorm_key in norm_stats, (
1331
+ f"The `unnorm_key` you chose is not in the set of available dataset statistics, "
1332
+ f"please choose from: {norm_stats.keys()}"
1333
+ )
1334
+ return unnorm_key
1335
+
1336
+ def get_action_dim(self, unnorm_key: Optional[str] = None) -> int:
1337
+ """Get the dimensionality of the policy's action space."""
1338
+ unnorm_key = self._check_unnorm_key(self.norm_stats, unnorm_key)
1339
+ return len(self.norm_stats[unnorm_key]["action"]["min"])
1340
+
1341
+ def get_action_stats(self, unnorm_key: Optional[str] = None) -> Dict[str, Any]:
1342
+ """Get all the logged statistics for the given dataset."""
1343
+ unnorm_key = self._check_unnorm_key(self.norm_stats, unnorm_key)
1344
+ return self.norm_stats[unnorm_key]["action"]
openvla-7b/preprocessor_config.json ADDED
@@ -0,0 +1,114 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "auto_map": {
3
+ "AutoImageProcessor": "processing_prismatic.PrismaticImageProcessor",
4
+ "AutoProcessor": "processing_prismatic.PrismaticProcessor"
5
+ },
6
+ "image_processor_type": "PrismaticImageProcessor",
7
+ "image_resize_strategy": "resize-naive",
8
+ "input_sizes": [
9
+ [
10
+ 3,
11
+ 224,
12
+ 224
13
+ ],
14
+ [
15
+ 3,
16
+ 224,
17
+ 224
18
+ ]
19
+ ],
20
+ "interpolations": [
21
+ "bicubic",
22
+ "bicubic"
23
+ ],
24
+ "means": [
25
+ [
26
+ 0.485,
27
+ 0.456,
28
+ 0.406
29
+ ],
30
+ [
31
+ 0.5,
32
+ 0.5,
33
+ 0.5
34
+ ]
35
+ ],
36
+ "processor_class": "PrismaticProcessor",
37
+ "stds": [
38
+ [
39
+ 0.229,
40
+ 0.224,
41
+ 0.225
42
+ ],
43
+ [
44
+ 0.5,
45
+ 0.5,
46
+ 0.5
47
+ ]
48
+ ],
49
+ "tvf_crop_params": [
50
+ {
51
+ "output_size": [
52
+ 224,
53
+ 224
54
+ ]
55
+ },
56
+ {
57
+ "output_size": [
58
+ 224,
59
+ 224
60
+ ]
61
+ }
62
+ ],
63
+ "tvf_do_letterbox": false,
64
+ "tvf_letterbox_fill": null,
65
+ "tvf_normalize_params": [
66
+ {
67
+ "inplace": false,
68
+ "mean": [
69
+ 0.484375,
70
+ 0.455078125,
71
+ 0.40625
72
+ ],
73
+ "std": [
74
+ 0.228515625,
75
+ 0.2236328125,
76
+ 0.224609375
77
+ ]
78
+ },
79
+ {
80
+ "inplace": false,
81
+ "mean": [
82
+ 0.5,
83
+ 0.5,
84
+ 0.5
85
+ ],
86
+ "std": [
87
+ 0.5,
88
+ 0.5,
89
+ 0.5
90
+ ]
91
+ }
92
+ ],
93
+ "tvf_resize_params": [
94
+ {
95
+ "antialias": true,
96
+ "interpolation": 3,
97
+ "max_size": null,
98
+ "size": [
99
+ 224,
100
+ 224
101
+ ]
102
+ },
103
+ {
104
+ "antialias": true,
105
+ "interpolation": 3,
106
+ "max_size": null,
107
+ "size": [
108
+ 224,
109
+ 224
110
+ ]
111
+ }
112
+ ],
113
+ "use_fused_vision_backbone": true
114
+ }
openvla-7b/processing_prismatic.py ADDED
@@ -0,0 +1,257 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ processing_prismatic.py
3
+
4
+ HuggingFace-style preprocessor definitions for Prismatic VLMs, inheriting from `ProcessorMixin`. Default configuration
5
+ specifies `siglip-224px+7b`.
6
+ """
7
+
8
+ from typing import Any, ClassVar, List, Optional, Tuple, Union
9
+
10
+ import timm.data
11
+ import torch
12
+ import torchvision.transforms.functional as TVF
13
+ from PIL import Image
14
+ from torchvision.transforms import CenterCrop, Compose, Normalize, Resize, ToTensor
15
+ from transformers import PreTrainedTokenizerBase
16
+ from transformers.image_processing_utils import BatchFeature, ImageProcessingMixin
17
+ from transformers.processing_utils import ProcessorMixin
18
+ from transformers.tokenization_utils import PaddingStrategy, PreTokenizedInput, TextInput, TruncationStrategy
19
+ from transformers.utils import TensorType
20
+
21
+
22
+ # === Image Processing ===
23
+ def letterbox_pad_transform(image: Image.Image, padding_fill_value: Tuple[int, int, int]) -> Image.Image:
24
+ """Given a PIL.Image, pad to square by adding a symmetric border around the height/width."""
25
+ (w, h), max_wh = image.size, max(image.size)
26
+ horizontal_pad, vertical_pad = int((max_wh - w) / 2), int((max_wh - h) / 2)
27
+ padding = (horizontal_pad, vertical_pad, horizontal_pad, vertical_pad)
28
+
29
+ return TVF.pad(image, padding, fill=padding_fill_value, padding_mode="constant")
30
+
31
+
32
+ class PrismaticImageProcessor(ImageProcessingMixin):
33
+ model_input_names: ClassVar[List[str]] = ["pixel_values"]
34
+
35
+ def __init__(
36
+ self,
37
+ use_fused_vision_backbone: bool = False,
38
+ image_resize_strategy: str = "letterbox",
39
+ input_sizes: Optional[List[Tuple[int, int, int]]] = None,
40
+ interpolations: Optional[List[str]] = None,
41
+ means: Optional[List[Tuple[float, float, float]]] = None,
42
+ stds: Optional[List[Tuple[float, float, float]]] = None,
43
+ **kwargs: str,
44
+ ) -> None:
45
+ """
46
+ Initialize a PrismaticImageProcessor as a wrapper around a torchvision transform; this transform will be
47
+ created by TIMM, and edited to follow our custom `image_resize_strategy` logic.
48
+
49
+ @param use_fused_vision_backbone: Boolean indicating single or fused (dual) vision backbone
50
+ @param image_resize_strategy: Prismatic image resize strategy in < resize-naive | resize-crop | letterbox >
51
+ @param input_size: [TIMM :: `data_cfg`] Input image size as tuple (channels, width, height)
52
+ @param interpolation: [TIMM :: `data_cfg`] Interpolation as string (default: "bicubic")
53
+ @param mean: [TIMM :: `data_cfg`] Normalization mean as float tuple (or two-tuple if `fused_backbone`)
54
+ @param std: [TIMM :: `data_cfg`] Normalization std as float tuple (or two-tuple if `fused_backbone`)
55
+ """
56
+ self.use_fused_vision_backbone = use_fused_vision_backbone
57
+ self.image_resize_strategy = image_resize_strategy
58
+
59
+ # Handle `None` default values
60
+ input_sizes = [(3, 224, 224)] if input_sizes is None else input_sizes
61
+ means = [(0.5, 0.5, 0.5)] if means is None else means
62
+ stds = [(0.5, 0.5, 0.5)] if stds is None else stds
63
+
64
+ # TIMM `data_cfg` Parameters
65
+ self.input_sizes, self.interpolations, self.means, self.stds = input_sizes, interpolations, means, stds
66
+
67
+ # Grab torchvision transforms via TIMM =>> need to parse for specific "functional" transform values!
68
+ self.tvf_resize_params, self.tvf_crop_params, self.tvf_normalize_params = [], [], []
69
+ self.tvf_do_letterbox, self.tvf_letterbox_fill = False, None
70
+
71
+ for idx in range(len(input_sizes)):
72
+ transform = timm.data.create_transform(
73
+ input_size=self.input_sizes[idx],
74
+ interpolation=self.interpolations[idx],
75
+ mean=self.means[idx],
76
+ std=self.stds[idx],
77
+ crop_pct=1.0, # Set to 1.0 to ignore cropping (initial Resize sets `input_size`)
78
+ crop_mode="center", # Default crop mode -- no-op when `crop_pct == 1.0`
79
+ is_training=False, # No image augmentations when loading the transform!
80
+ )
81
+
82
+ # [Validation] Ensure appropriate transform structure, expected sizes
83
+ if not (
84
+ isinstance(transform, Compose)
85
+ and (len(transform.transforms) == 4)
86
+ and isinstance(transform.transforms[0], Resize)
87
+ and isinstance(transform.transforms[1], CenterCrop)
88
+ and isinstance(transform.transforms[2], ToTensor)
89
+ and isinstance(transform.transforms[3], Normalize)
90
+ and (transform.transforms[0].size == self.input_sizes[idx][-1])
91
+ and (transform.transforms[1].size == self.input_sizes[idx][-2:])
92
+ ):
93
+ raise ValueError(f"Unexpected TIMM image transformation structure/sizes: `{transform}`")
94
+
95
+ # HF Image Processors *must* be JSON-serializable; as such, cannot have torchvision. as an attribute.
96
+ # => Instead, we're going to parse the transform and call "torchvision.transforms.functional" (`tvf`)
97
+ resize_t, crop_t, norm_t = transform.transforms[0], transform.transforms[1], transform.transforms[3]
98
+ self.tvf_resize_params.append(
99
+ {
100
+ "size": resize_t.size,
101
+ "interpolation": TVF.pil_modes_mapping[resize_t.interpolation],
102
+ "max_size": None,
103
+ "antialias": True,
104
+ }
105
+ )
106
+ self.tvf_crop_params.append({"output_size": crop_t.size})
107
+ self.tvf_normalize_params.append(
108
+ {
109
+ "mean": norm_t.mean.float().numpy().tolist(),
110
+ "std": norm_t.std.float().numpy().tolist(),
111
+ "inplace": False,
112
+ }
113
+ )
114
+ self.tvf_do_letterbox, self.tvf_letterbox_fill = False, None
115
+
116
+ # Handle Prismatic `image_resize_strategy`
117
+ if self.image_resize_strategy == "resize-naive":
118
+ self.tvf_resize_params[idx]["size"] = (resize_t.size, resize_t.size)
119
+ elif self.image_resize_strategy == "letterbox":
120
+ self.tvf_do_letterbox, self.tvf_letterbox_fill = True, tuple([int(x * 255) for x in self.means[idx]])
121
+ elif self.image_resize_strategy == "resize-crop":
122
+ pass
123
+ else:
124
+ raise ValueError(f"Image resize strategy `{self.image_resize_strategy}` is not supported!")
125
+
126
+ # Dispatch **kwargs to super()
127
+ super().__init__(**kwargs)
128
+
129
+ def apply_transform(self, img: Image.Image) -> torch.Tensor:
130
+ """Apply `functional` variant of TIMM's Transform = Compose([Resize -> CenterCrop -> ToTensor -> Normalize])"""
131
+ if self.tvf_do_letterbox:
132
+ img = letterbox_pad_transform(img, self.tvf_letterbox_fill)
133
+
134
+ # [Contract] Fused Backbones expect "channel-stacked" inputs; we'll unpack on the model side!
135
+ imgs_t = []
136
+ for idx in range(len(self.input_sizes)):
137
+ img_idx = TVF.resize(img, **self.tvf_resize_params[idx])
138
+ img_idx = TVF.center_crop(img_idx, **self.tvf_crop_params[idx])
139
+ img_idx_t = TVF.to_tensor(img_idx)
140
+ img_idx_t = TVF.normalize(img_idx_t, **self.tvf_normalize_params[idx])
141
+ imgs_t.append(img_idx_t)
142
+
143
+ # [Contract] `imgs_t` is a list of Tensors of shape [3, input_size, input_size]; stack along dim = 0
144
+ img_t = torch.vstack(imgs_t)
145
+
146
+ return img_t
147
+
148
+ def preprocess(
149
+ self,
150
+ images: Union[Image.Image, List[Image.Image]],
151
+ return_tensors: Optional[Union[str, TensorType]] = None,
152
+ **_: str,
153
+ ) -> BatchFeature:
154
+ """
155
+ Preprocess an image (or batch of images); note that unlike the `transformers :: BaseImageProcessor` we
156
+ explicitly only handle PIL.Image.Image instances for simplicity.
157
+
158
+ @param images: A (batch of) PIL.Image.Image instance(s) to preprocess.
159
+ @param return_tensors: BatchFeature default Tensor format (e.g., "pt" for torch); if None, returns np.ndarray
160
+
161
+ @return: Instance of `transformers :: BatchFeature` with a single key "pixel_values"
162
+ """
163
+ if not isinstance(images, list):
164
+ images = [images]
165
+
166
+ # Apply `self.img_transform` to each image (will return list of torch.Tensors); stack into "batched" Tensor
167
+ pixel_values = torch.stack([self.apply_transform(img.convert("RGB")) for img in images])
168
+
169
+ # Return BatchFeature =>> note that for compatibility, constructor expects Dict[str, np.ndarray], so we convert
170
+ return BatchFeature(data={"pixel_values": pixel_values.float().numpy()}, tensor_type=return_tensors)
171
+
172
+ def __call__(self, images: Union[Image.Image, List[Image.Image]], **kwargs) -> BatchFeature:
173
+ return self.preprocess(images, **kwargs)
174
+
175
+
176
+ # === PrismaticProcessor =>> Wraps both ImageProcessor and Tokenizer ===
177
+ # =>> https://github.com/huggingface/transformers/blob/main/src/transformers/models/llava/processing_llava.py
178
+ class PrismaticProcessor(ProcessorMixin):
179
+ attributes: ClassVar[List[str]] = ["image_processor", "tokenizer"]
180
+ image_processor_class: str = "AutoImageProcessor"
181
+ tokenizer_class: str = "AutoTokenizer"
182
+
183
+ def __init__(
184
+ self,
185
+ image_processor: Optional[ImageProcessingMixin] = None,
186
+ tokenizer: Optional[PreTrainedTokenizerBase] = None,
187
+ ) -> None:
188
+ super().__init__(image_processor, tokenizer)
189
+
190
+ def __call__(
191
+ self,
192
+ text: Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]],
193
+ images: Union[Image.Image, List[Image.Image]],
194
+ padding: Union[bool, str, PaddingStrategy] = False,
195
+ truncation: Optional[Union[bool, str, TruncationStrategy]] = None,
196
+ max_length: Optional[int] = None,
197
+ return_tensors: Optional[Union[str, TensorType]] = TensorType.PYTORCH,
198
+ ) -> BatchFeature:
199
+ """
200
+ Preprocess a given (batch) of text/images for a Prismatic VLM; forwards text to the underlying LLM's tokenizer,
201
+ forwards images to PrismaticImageProcessor.
202
+
203
+ @param text: The (batch) of text to encode; must be a string or list of strings.
204
+ @param images: A (batch of) PIL.Image.Image instance(s) to preprocess.
205
+ @param padding: Sequence padding strategy (if multiple specified) in < True = "longest" | "max_length" | False >
206
+ @param truncation: Truncation strategy for the output sequences; requires `max_length` to be specified
207
+ @param max_length: Maximum length (in tokens) to truncate
208
+ @param return_tensors: Type of return tensors (usually "pt" or TensorType.PYTORCH)
209
+
210
+ @return: BatchFeature with keys for `input_ids`, `attention_mask` and `pixel_values`.
211
+ """
212
+ pixel_values = self.image_processor(images, return_tensors=return_tensors)["pixel_values"]
213
+ text_inputs = self.tokenizer(
214
+ text, return_tensors=return_tensors, padding=padding, truncation=truncation, max_length=max_length
215
+ )
216
+
217
+ # [Validate] Need same number of images and text inputs!
218
+ if pixel_values.shape[0] != text_inputs.input_ids.shape[0]:
219
+ raise ValueError("Batch is malformed; expected same number of images and text inputs!")
220
+
221
+ return BatchFeature(data={**text_inputs, "pixel_values": pixel_values})
222
+
223
+ # === Tokenizer Dispatch Utilities =>> check `PreTrainedTokenizerBase` for documentation ===
224
+ def batch_decode(
225
+ self,
226
+ sequences: Union[List[int], List[List[int]], torch.Tensor, Any], # `Any` = np.ndarray | tf.Tensor
227
+ skip_special_tokens: bool = False,
228
+ clean_up_tokenization_spaces: Optional[bool] = None,
229
+ **kwargs: str,
230
+ ) -> List[str]:
231
+ return self.tokenizer.batch_decode(
232
+ sequences=sequences,
233
+ skip_special_tokens=skip_special_tokens,
234
+ clean_up_tokenization_spaces=clean_up_tokenization_spaces,
235
+ **kwargs,
236
+ )
237
+
238
+ def decode(
239
+ self,
240
+ token_ids: Union[int, List[int], torch.Tensor, Any], # `Any` = np.ndarray | tf.Tensor
241
+ skip_special_tokens: bool = False,
242
+ clean_up_tokenization_spaces: Optional[bool] = None,
243
+ **kwargs: str,
244
+ ) -> str:
245
+ return self.tokenizer.decode(
246
+ token_ids=token_ids,
247
+ skip_special_tokens=skip_special_tokens,
248
+ clean_up_tokenization_spaces=clean_up_tokenization_spaces,
249
+ **kwargs,
250
+ )
251
+
252
+ @property
253
+ def model_input_names(self) -> List[str]:
254
+ tokenizer_input_names = self.tokenizer.model_input_names
255
+ image_processor_input_names = self.image_processor.model_input_names
256
+
257
+ return list(dict.fromkeys(tokenizer_input_names + image_processor_input_names))
openvla-7b/processor_config.json ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ {
2
+ "auto_map": {
3
+ "AutoProcessor": "processing_prismatic.PrismaticProcessor"
4
+ },
5
+ "processor_class": "PrismaticProcessor"
6
+ }
openvla-7b/special_tokens_map.json ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "bos_token": {
3
+ "content": "<s>",
4
+ "lstrip": false,
5
+ "normalized": false,
6
+ "rstrip": false,
7
+ "single_word": false
8
+ },
9
+ "eos_token": {
10
+ "content": "</s>",
11
+ "lstrip": false,
12
+ "normalized": false,
13
+ "rstrip": false,
14
+ "single_word": false
15
+ },
16
+ "pad_token": {
17
+ "content": "<PAD>",
18
+ "lstrip": false,
19
+ "normalized": false,
20
+ "rstrip": false,
21
+ "single_word": false
22
+ },
23
+ "unk_token": {
24
+ "content": "<unk>",
25
+ "lstrip": false,
26
+ "normalized": false,
27
+ "rstrip": false,
28
+ "single_word": false
29
+ }
30
+ }
openvla-7b/tokenizer.json ADDED
The diff for this file is too large to render. See raw diff
 
openvla-7b/tokenizer.model ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:9e556afd44213b6bd1be2b850ebbbd98f5481437a8021afaf58ee7fb1818d347
3
+ size 499723
openvla-7b/tokenizer_config.json ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "add_bos_token": true,
3
+ "add_eos_token": false,
4
+ "added_tokens_decoder": {
5
+ "0": {
6
+ "content": "<unk>",
7
+ "lstrip": false,
8
+ "normalized": false,
9
+ "rstrip": false,
10
+ "single_word": false,
11
+ "special": true
12
+ },
13
+ "1": {
14
+ "content": "<s>",
15
+ "lstrip": false,
16
+ "normalized": false,
17
+ "rstrip": false,
18
+ "single_word": false,
19
+ "special": true
20
+ },
21
+ "2": {
22
+ "content": "</s>",
23
+ "lstrip": false,
24
+ "normalized": false,
25
+ "rstrip": false,
26
+ "single_word": false,
27
+ "special": true
28
+ },
29
+ "32000": {
30
+ "content": "<PAD>",
31
+ "lstrip": false,
32
+ "normalized": false,
33
+ "rstrip": false,
34
+ "single_word": false,
35
+ "special": true
36
+ }
37
+ },
38
+ "auto_map": {
39
+ "AutoProcessor": "processing_prismatic.PrismaticProcessor"
40
+ },
41
+ "bos_token": "<s>",
42
+ "clean_up_tokenization_spaces": false,
43
+ "eos_token": "</s>",
44
+ "legacy": false,
45
+ "model_max_length": 2048,
46
+ "pad_token": "<PAD>",
47
+ "padding_side": "right",
48
+ "processor_class": "PrismaticProcessor",
49
+ "sp_model_kwargs": {},
50
+ "tokenizer_class": "LlamaTokenizer",
51
+ "unk_token": "<unk>",
52
+ "use_default_system_prompt": false
53
+ }
prismatic/config.json ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "dataset": {
3
+ "align_stage_components": [
4
+ "download/llava-laion-cc-sbu-558k/chat.json",
5
+ "download/llava-laion-cc-sbu-558k"
6
+ ],
7
+ "dataset_id": "llava-lvis4v-lrv",
8
+ "dataset_root_dir": "/opt/ml/input/data/training/skaramcheti/datasets/prismatic-vlms",
9
+ "finetune_stage_components": [
10
+ "download/llava-v1.5-instruct/llava_v1_5_lvis4v_lrv_mix1231k.json",
11
+ "download/llava-v1.5-instruct"
12
+ ],
13
+ "type": "llava-lvis4v-lrv"
14
+ },
15
+ "hf_token": ".hf_token",
16
+ "model": {
17
+ "align_epochs": 1,
18
+ "align_global_batch_size": 256,
19
+ "align_learning_rate": 0.001,
20
+ "align_lr_scheduler_type": "linear-warmup+cosine-decay",
21
+ "align_max_grad_norm": 1.0,
22
+ "align_max_steps": null,
23
+ "align_per_device_batch_size": 16,
24
+ "align_train_strategy": "fsdp-shard-grad-op",
25
+ "align_warmup_ratio": 0.03,
26
+ "align_weight_decay": 0.0,
27
+ "arch_specifier": "no-align+fused-gelu-mlp",
28
+ "enable_gradient_checkpointing": true,
29
+ "enable_mixed_precision_training": true,
30
+ "finetune_epochs": 2,
31
+ "finetune_global_batch_size": 128,
32
+ "finetune_learning_rate": 2e-05,
33
+ "finetune_lr_scheduler_type": "linear-warmup+cosine-decay",
34
+ "finetune_max_grad_norm": 1.0,
35
+ "finetune_max_steps": null,
36
+ "finetune_per_device_batch_size": 16,
37
+ "finetune_train_strategy": "fsdp-full-shard",
38
+ "finetune_warmup_ratio": 0.03,
39
+ "finetune_weight_decay": 0.1,
40
+ "image_resize_strategy": "resize-naive",
41
+ "llm_backbone_id": "llama2-7b-pure",
42
+ "llm_max_length": 2048,
43
+ "model_id": "prism-dinosiglip-224px+7b",
44
+ "reduce_in_full_precision": false,
45
+ "type": "prism-dinosiglip-224px+7b",
46
+ "vision_backbone_id": "dinosiglip-vit-so-224px"
47
+ },
48
+ "pretrained_checkpoint": null,
49
+ "run_id": "llava-lvis4v-lrv+prism-dinosiglip-224px+7b+stage-finetune+x7",
50
+ "run_root_dir": "/opt/ml/input/data/training/x-prismatic-vlms/runs",
51
+ "seed": 7,
52
+ "stage": "finetune",
53
+ "trackers": [
54
+ "jsonl",
55
+ "wandb"
56
+ ],
57
+ "wandb_entity": "stanford-voltron",
58
+ "wandb_project": "onyx-vlms"
59
+ }
prismatic/config.yaml ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ dataset:
2
+ align_stage_components:
3
+ - download/llava-laion-cc-sbu-558k/chat.json
4
+ - download/llava-laion-cc-sbu-558k
5
+ dataset_id: llava-lvis4v-lrv
6
+ dataset_root_dir: /opt/ml/input/data/training/skaramcheti/datasets/prismatic-vlms
7
+ finetune_stage_components:
8
+ - download/llava-v1.5-instruct/llava_v1_5_lvis4v_lrv_mix1231k.json
9
+ - download/llava-v1.5-instruct
10
+ type: llava-lvis4v-lrv
11
+ hf_token: .hf_token
12
+ model:
13
+ align_epochs: 1
14
+ align_global_batch_size: 256
15
+ align_learning_rate: 0.001
16
+ align_lr_scheduler_type: linear-warmup+cosine-decay
17
+ align_max_grad_norm: 1.0
18
+ align_max_steps: null
19
+ align_per_device_batch_size: 16
20
+ align_train_strategy: fsdp-shard-grad-op
21
+ align_warmup_ratio: 0.03
22
+ align_weight_decay: 0.0
23
+ arch_specifier: no-align+fused-gelu-mlp
24
+ enable_gradient_checkpointing: true
25
+ enable_mixed_precision_training: true
26
+ finetune_epochs: 2
27
+ finetune_global_batch_size: 128
28
+ finetune_learning_rate: 2.0e-05
29
+ finetune_lr_scheduler_type: linear-warmup+cosine-decay
30
+ finetune_max_grad_norm: 1.0
31
+ finetune_max_steps: null
32
+ finetune_per_device_batch_size: 16
33
+ finetune_train_strategy: fsdp-full-shard
34
+ finetune_warmup_ratio: 0.03
35
+ finetune_weight_decay: 0.1
36
+ image_resize_strategy: resize-naive
37
+ llm_backbone_id: llama2-7b-pure
38
+ llm_max_length: 2048
39
+ model_id: prism-dinosiglip-224px+7b
40
+ reduce_in_full_precision: false
41
+ type: prism-dinosiglip-224px+7b
42
+ vision_backbone_id: dinosiglip-vit-so-224px
43
+ pretrained_checkpoint: null
44
+ run_id: llava-lvis4v-lrv+prism-dinosiglip-224px+7b+stage-finetune+x7
45
+ run_root_dir: /opt/ml/input/data/training/x-prismatic-vlms/runs
46
+ seed: 7
47
+ stage: finetune
48
+ trackers:
49
+ - jsonl
50
+ - wandb
51
+ wandb_entity: stanford-voltron
52
+ wandb_project: onyx-vlms
prismatic/latest-checkpoint.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b0281cb5c9e37d08e1a9bde840522518b02bad916b96058ee9bb99e0956e0159
3
+ size 13620752058
prismatic/run-metrics.jsonl ADDED
@@ -0,0 +1 @@
 
 
1
+ {"hparams": {"dataset": {"align_stage_components": ["download/llava-laion-cc-sbu-558k/chat.json", "download/llava-laion-cc-sbu-558k"], "dataset_id": "llava-lvis4v-lrv", "dataset_root_dir": "/opt/ml/input/data/training/skaramcheti/datasets/prismatic-vlms", "finetune_stage_components": ["download/llava-v1.5-instruct/llava_v1_5_lvis4v_lrv_mix1231k.json", "download/llava-v1.5-instruct"], "type": "llava-lvis4v-lrv"}, "hf_token": ".hf_token", "model": {"align_epochs": 1, "align_global_batch_size": 256, "align_learning_rate": 0.001, "align_lr_scheduler_type": "linear-warmup+cosine-decay", "align_max_grad_norm": 1.0, "align_max_steps": null, "align_per_device_batch_size": 16, "align_train_strategy": "fsdp-shard-grad-op", "align_warmup_ratio": 0.03, "align_weight_decay": 0.0, "arch_specifier": "no-align+fused-gelu-mlp", "enable_gradient_checkpointing": true, "enable_mixed_precision_training": true, "finetune_epochs": 2, "finetune_global_batch_size": 128, "finetune_learning_rate": 2e-05, "finetune_lr_scheduler_type": "linear-warmup+cosine-decay", "finetune_max_grad_norm": 1.0, "finetune_max_steps": null, "finetune_per_device_batch_size": 16, "finetune_train_strategy": "fsdp-full-shard", "finetune_warmup_ratio": 0.03, "finetune_weight_decay": 0.1, "image_resize_strategy": "resize-naive", "llm_backbone_id": "llama2-7b-pure", "llm_max_length": 2048, "model_id": "prism-dinosiglip-224px+7b", "reduce_in_full_precision": false, "type": "prism-dinosiglip-224px+7b", "vision_backbone_id": "dinosiglip-vit-so-224px"}, "pretrained_checkpoint": null, "run_id": "llava-lvis4v-lrv+prism-dinosiglip-224px+7b+stage-finetune+x7", "run_root_dir": "/opt/ml/input/data/training/x-prismatic-vlms/runs", "seed": 7, "stage": "finetune", "trackers": ["jsonl", "wandb"], "wandb_entity": "stanford-voltron", "wandb_project": "onyx-vlms"}, "run_id": "llava-lvis4v-lrv+prism-dinosiglip-224px+7b+stage-finetune+x7"}