juyil commited on
Commit
d76927c
·
verified ·
1 Parent(s): c8d1052

Initial model upload

Browse files
action_head--latest_checkpoint.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a1cf1fb5304e026b1e02584cde50ff53ec12279c9d5be2e2b4a657ce8028a478
3
+ size 101219314
added_tokens.json ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ {
2
+ "<ACT>": 32001,
3
+ "<PAD>": 32000
4
+ }
config.json ADDED
@@ -0,0 +1,304 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "norm_stats": {
3
+ "libero_10_no_noops": {
4
+ "action": {
5
+ "mean": [
6
+ 0.01820324920117855,
7
+ 0.05858374014496803,
8
+ -0.05592384561896324,
9
+ 0.004626928828656673,
10
+ 0.00289608770981431,
11
+ -0.007673131301999092,
12
+ 0.5457824468612671
13
+ ],
14
+ "std": [
15
+ 0.2825464606285095,
16
+ 0.35904666781425476,
17
+ 0.3673802614212036,
18
+ 0.03770702704787254,
19
+ 0.05429719388484955,
20
+ 0.08725254982709885,
21
+ 0.49815231561660767
22
+ ],
23
+ "max": [
24
+ 0.9375,
25
+ 0.9375,
26
+ 0.9375,
27
+ 0.30000001192092896,
28
+ 0.29357144236564636,
29
+ 0.375,
30
+ 1.0
31
+ ],
32
+ "min": [
33
+ -0.9375,
34
+ -0.9375,
35
+ -0.9375,
36
+ -0.23642857372760773,
37
+ -0.3053571283817291,
38
+ -0.3675000071525574,
39
+ 0.0
40
+ ],
41
+ "q01": [
42
+ -0.6348214149475098,
43
+ -0.7741071581840515,
44
+ -0.7633928656578064,
45
+ -0.09749999642372131,
46
+ -0.14819999992847435,
47
+ -0.2742857038974762,
48
+ 0.0
49
+ ],
50
+ "q99": [
51
+ 0.7714285850524902,
52
+ 0.8464285731315613,
53
+ 0.9375,
54
+ 0.13928571343421936,
55
+ 0.15964286029338837,
56
+ 0.3246428668498993,
57
+ 1.0
58
+ ],
59
+ "mask": [
60
+ true,
61
+ true,
62
+ true,
63
+ true,
64
+ true,
65
+ true,
66
+ false
67
+ ]
68
+ },
69
+ "proprio": {
70
+ "mean": [
71
+ -0.04190658777952194,
72
+ 0.03539430722594261,
73
+ 0.8257141709327698,
74
+ 2.908308267593384,
75
+ -0.5562185049057007,
76
+ -0.16649018228054047,
77
+ 0.028316624462604523,
78
+ -0.028561657294631004
79
+ ],
80
+ "std": [
81
+ 0.10743364691734314,
82
+ 0.14424669742584229,
83
+ 0.2572328448295593,
84
+ 0.3441362977027893,
85
+ 1.234421730041504,
86
+ 0.3579835891723633,
87
+ 0.013308707624673843,
88
+ 0.013174631632864475
89
+ ],
90
+ "max": [
91
+ 0.21031762659549713,
92
+ 0.39128610491752625,
93
+ 1.3332009315490723,
94
+ 3.6714255809783936,
95
+ 3.560650587081909,
96
+ 1.386339545249939,
97
+ 0.04160946607589722,
98
+ 0.0013633022317662835
99
+ ],
100
+ "min": [
101
+ -0.4828203022480011,
102
+ -0.3255046010017395,
103
+ 0.445506751537323,
104
+ 1.1321442127227783,
105
+ -3.641430377960205,
106
+ -1.842738389968872,
107
+ -0.0010040868073701859,
108
+ -0.04111652821302414
109
+ ],
110
+ "q01": [
111
+ -0.3899900782108307,
112
+ -0.2838300323486328,
113
+ 0.44795057058334353,
114
+ 1.8810229921340942,
115
+ -2.886677579879761,
116
+ -1.1599004411697387,
117
+ 0.002066459748893976,
118
+ -0.04001387819647789
119
+ ],
120
+ "q99": [
121
+ 0.1530261474847791,
122
+ 0.32915401458740223,
123
+ 1.2546923208236693,
124
+ 3.303542451858519,
125
+ 2.7496529006957933,
126
+ 0.6893712210655194,
127
+ 0.040048558115959164,
128
+ -0.0017598449345678235
129
+ ]
130
+ },
131
+ "num_transitions": 101469,
132
+ "num_trajectories": 379
133
+ }
134
+ },
135
+ "n_action_bins": 256,
136
+ "vision_backbone_id": "dinosiglip-vit-so-224px",
137
+ "llm_backbone_id": "llama2-7b-pure",
138
+ "arch_specifier": "no-align+fused-gelu-mlp",
139
+ "output_projector_states": false,
140
+ "use_fused_vision_backbone": true,
141
+ "timm_model_ids": [
142
+ "vit_large_patch14_reg4_dinov2.lvd142m",
143
+ "vit_so400m_patch14_siglip_224"
144
+ ],
145
+ "timm_override_act_layers": [
146
+ null,
147
+ null
148
+ ],
149
+ "image_sizes": [
150
+ 224,
151
+ 224
152
+ ],
153
+ "image_resize_strategy": "resize-naive",
154
+ "hf_llm_id": "meta-llama/Llama-2-7b-hf",
155
+ "llm_max_length": 2048,
156
+ "pad_token_id": 32000,
157
+ "pad_to_multiple_of": 64,
158
+ "text_config": {
159
+ "vocab_size": 32064,
160
+ "max_position_embeddings": 2048,
161
+ "hidden_size": 4096,
162
+ "intermediate_size": 11008,
163
+ "num_hidden_layers": 32,
164
+ "num_attention_heads": 32,
165
+ "num_key_value_heads": 32,
166
+ "hidden_act": "silu",
167
+ "initializer_range": 0.02,
168
+ "rms_norm_eps": 1e-06,
169
+ "pretraining_tp": 1,
170
+ "use_cache": true,
171
+ "rope_theta": 10000.0,
172
+ "rope_scaling": null,
173
+ "attention_bias": false,
174
+ "attention_dropout": 0.0,
175
+ "return_dict": true,
176
+ "output_hidden_states": false,
177
+ "output_attentions": false,
178
+ "torchscript": false,
179
+ "torch_dtype": "bfloat16",
180
+ "use_bfloat16": false,
181
+ "tf_legacy_loss": false,
182
+ "pruned_heads": {},
183
+ "tie_word_embeddings": false,
184
+ "chunk_size_feed_forward": 0,
185
+ "is_encoder_decoder": false,
186
+ "is_decoder": false,
187
+ "cross_attention_hidden_size": null,
188
+ "add_cross_attention": false,
189
+ "tie_encoder_decoder": false,
190
+ "max_length": 20,
191
+ "min_length": 0,
192
+ "do_sample": false,
193
+ "early_stopping": false,
194
+ "num_beams": 1,
195
+ "num_beam_groups": 1,
196
+ "diversity_penalty": 0.0,
197
+ "temperature": 1.0,
198
+ "top_k": 50,
199
+ "top_p": 1.0,
200
+ "typical_p": 1.0,
201
+ "repetition_penalty": 1.0,
202
+ "length_penalty": 1.0,
203
+ "no_repeat_ngram_size": 0,
204
+ "encoder_no_repeat_ngram_size": 0,
205
+ "bad_words_ids": null,
206
+ "num_return_sequences": 1,
207
+ "output_scores": false,
208
+ "return_dict_in_generate": false,
209
+ "forced_bos_token_id": null,
210
+ "forced_eos_token_id": null,
211
+ "remove_invalid_values": false,
212
+ "exponential_decay_length_penalty": null,
213
+ "suppress_tokens": null,
214
+ "begin_suppress_tokens": null,
215
+ "architectures": null,
216
+ "finetuning_task": null,
217
+ "id2label": {
218
+ "0": "LABEL_0",
219
+ "1": "LABEL_1"
220
+ },
221
+ "label2id": {
222
+ "LABEL_0": 0,
223
+ "LABEL_1": 1
224
+ },
225
+ "tokenizer_class": null,
226
+ "prefix": null,
227
+ "bos_token_id": 1,
228
+ "pad_token_id": 32000,
229
+ "eos_token_id": 2,
230
+ "sep_token_id": null,
231
+ "decoder_start_token_id": null,
232
+ "task_specific_params": null,
233
+ "problem_type": null,
234
+ "_name_or_path": "",
235
+ "model_type": "llama"
236
+ },
237
+ "return_dict": true,
238
+ "output_hidden_states": false,
239
+ "output_attentions": false,
240
+ "torchscript": false,
241
+ "torch_dtype": "bfloat16",
242
+ "use_bfloat16": false,
243
+ "tf_legacy_loss": false,
244
+ "pruned_heads": {},
245
+ "tie_word_embeddings": true,
246
+ "chunk_size_feed_forward": 0,
247
+ "is_encoder_decoder": false,
248
+ "is_decoder": false,
249
+ "cross_attention_hidden_size": null,
250
+ "add_cross_attention": false,
251
+ "tie_encoder_decoder": false,
252
+ "max_length": 20,
253
+ "min_length": 0,
254
+ "do_sample": false,
255
+ "early_stopping": false,
256
+ "num_beams": 1,
257
+ "num_beam_groups": 1,
258
+ "diversity_penalty": 0.0,
259
+ "temperature": 1.0,
260
+ "top_k": 50,
261
+ "top_p": 1.0,
262
+ "typical_p": 1.0,
263
+ "repetition_penalty": 1.0,
264
+ "length_penalty": 1.0,
265
+ "no_repeat_ngram_size": 0,
266
+ "encoder_no_repeat_ngram_size": 0,
267
+ "bad_words_ids": null,
268
+ "num_return_sequences": 1,
269
+ "output_scores": false,
270
+ "return_dict_in_generate": false,
271
+ "forced_bos_token_id": null,
272
+ "forced_eos_token_id": null,
273
+ "remove_invalid_values": false,
274
+ "exponential_decay_length_penalty": null,
275
+ "suppress_tokens": null,
276
+ "begin_suppress_tokens": null,
277
+ "architectures": [
278
+ "OpenVLAForActionPrediction"
279
+ ],
280
+ "finetuning_task": null,
281
+ "id2label": {
282
+ "0": "LABEL_0",
283
+ "1": "LABEL_1"
284
+ },
285
+ "label2id": {
286
+ "LABEL_0": 0,
287
+ "LABEL_1": 1
288
+ },
289
+ "tokenizer_class": null,
290
+ "prefix": null,
291
+ "bos_token_id": null,
292
+ "eos_token_id": null,
293
+ "sep_token_id": null,
294
+ "decoder_start_token_id": null,
295
+ "task_specific_params": null,
296
+ "problem_type": null,
297
+ "_name_or_path": "/home/user1/.cache/huggingface/hub/models--openvla--openvla-7b/snapshots/31f090d05236101ebfc381b61c674dd4746d4ce0",
298
+ "transformers_version": "4.40.1",
299
+ "auto_map": {
300
+ "AutoConfig": "configuration_prismatic.OpenVLAConfig",
301
+ "AutoModelForVision2Seq": "modeling_prismatic.OpenVLAForActionPrediction"
302
+ },
303
+ "model_type": "openvla"
304
+ }
config.json.back.20250403_191017 ADDED
@@ -0,0 +1,304 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "norm_stats": {
3
+ "libero_10_no_noops": {
4
+ "action": {
5
+ "mean": [
6
+ 0.01820324920117855,
7
+ 0.05858374014496803,
8
+ -0.05592384561896324,
9
+ 0.004626928828656673,
10
+ 0.00289608770981431,
11
+ -0.007673131301999092,
12
+ 0.5457824468612671
13
+ ],
14
+ "std": [
15
+ 0.2825464606285095,
16
+ 0.35904666781425476,
17
+ 0.3673802614212036,
18
+ 0.03770702704787254,
19
+ 0.05429719388484955,
20
+ 0.08725254982709885,
21
+ 0.49815231561660767
22
+ ],
23
+ "max": [
24
+ 0.9375,
25
+ 0.9375,
26
+ 0.9375,
27
+ 0.30000001192092896,
28
+ 0.29357144236564636,
29
+ 0.375,
30
+ 1.0
31
+ ],
32
+ "min": [
33
+ -0.9375,
34
+ -0.9375,
35
+ -0.9375,
36
+ -0.23642857372760773,
37
+ -0.3053571283817291,
38
+ -0.3675000071525574,
39
+ 0.0
40
+ ],
41
+ "q01": [
42
+ -0.6348214149475098,
43
+ -0.7741071581840515,
44
+ -0.7633928656578064,
45
+ -0.09749999642372131,
46
+ -0.14819999992847435,
47
+ -0.2742857038974762,
48
+ 0.0
49
+ ],
50
+ "q99": [
51
+ 0.7714285850524902,
52
+ 0.8464285731315613,
53
+ 0.9375,
54
+ 0.13928571343421936,
55
+ 0.15964286029338837,
56
+ 0.3246428668498993,
57
+ 1.0
58
+ ],
59
+ "mask": [
60
+ true,
61
+ true,
62
+ true,
63
+ true,
64
+ true,
65
+ true,
66
+ false
67
+ ]
68
+ },
69
+ "proprio": {
70
+ "mean": [
71
+ -0.04190658777952194,
72
+ 0.03539430722594261,
73
+ 0.8257141709327698,
74
+ 2.908308267593384,
75
+ -0.5562185049057007,
76
+ -0.16649018228054047,
77
+ 0.028316624462604523,
78
+ -0.028561657294631004
79
+ ],
80
+ "std": [
81
+ 0.10743364691734314,
82
+ 0.14424669742584229,
83
+ 0.2572328448295593,
84
+ 0.3441362977027893,
85
+ 1.234421730041504,
86
+ 0.3579835891723633,
87
+ 0.013308707624673843,
88
+ 0.013174631632864475
89
+ ],
90
+ "max": [
91
+ 0.21031762659549713,
92
+ 0.39128610491752625,
93
+ 1.3332009315490723,
94
+ 3.6714255809783936,
95
+ 3.560650587081909,
96
+ 1.386339545249939,
97
+ 0.04160946607589722,
98
+ 0.0013633022317662835
99
+ ],
100
+ "min": [
101
+ -0.4828203022480011,
102
+ -0.3255046010017395,
103
+ 0.445506751537323,
104
+ 1.1321442127227783,
105
+ -3.641430377960205,
106
+ -1.842738389968872,
107
+ -0.0010040868073701859,
108
+ -0.04111652821302414
109
+ ],
110
+ "q01": [
111
+ -0.3899900782108307,
112
+ -0.2838300323486328,
113
+ 0.44795057058334353,
114
+ 1.8810229921340942,
115
+ -2.886677579879761,
116
+ -1.1599004411697387,
117
+ 0.002066459748893976,
118
+ -0.04001387819647789
119
+ ],
120
+ "q99": [
121
+ 0.1530261474847791,
122
+ 0.32915401458740223,
123
+ 1.2546923208236693,
124
+ 3.303542451858519,
125
+ 2.7496529006957933,
126
+ 0.6893712210655194,
127
+ 0.040048558115959164,
128
+ -0.0017598449345678235
129
+ ]
130
+ },
131
+ "num_transitions": 101469,
132
+ "num_trajectories": 379
133
+ }
134
+ },
135
+ "n_action_bins": 256,
136
+ "vision_backbone_id": "dinosiglip-vit-so-224px",
137
+ "llm_backbone_id": "llama2-7b-pure",
138
+ "arch_specifier": "no-align+fused-gelu-mlp",
139
+ "output_projector_states": false,
140
+ "use_fused_vision_backbone": true,
141
+ "timm_model_ids": [
142
+ "vit_large_patch14_reg4_dinov2.lvd142m",
143
+ "vit_so400m_patch14_siglip_224"
144
+ ],
145
+ "timm_override_act_layers": [
146
+ null,
147
+ null
148
+ ],
149
+ "image_sizes": [
150
+ 224,
151
+ 224
152
+ ],
153
+ "image_resize_strategy": "resize-naive",
154
+ "hf_llm_id": "meta-llama/Llama-2-7b-hf",
155
+ "llm_max_length": 2048,
156
+ "pad_token_id": 32000,
157
+ "pad_to_multiple_of": 64,
158
+ "text_config": {
159
+ "vocab_size": 32064,
160
+ "max_position_embeddings": 2048,
161
+ "hidden_size": 4096,
162
+ "intermediate_size": 11008,
163
+ "num_hidden_layers": 32,
164
+ "num_attention_heads": 32,
165
+ "num_key_value_heads": 32,
166
+ "hidden_act": "silu",
167
+ "initializer_range": 0.02,
168
+ "rms_norm_eps": 1e-06,
169
+ "pretraining_tp": 1,
170
+ "use_cache": true,
171
+ "rope_theta": 10000.0,
172
+ "rope_scaling": null,
173
+ "attention_bias": false,
174
+ "attention_dropout": 0.0,
175
+ "return_dict": true,
176
+ "output_hidden_states": false,
177
+ "output_attentions": false,
178
+ "torchscript": false,
179
+ "torch_dtype": "bfloat16",
180
+ "use_bfloat16": false,
181
+ "tf_legacy_loss": false,
182
+ "pruned_heads": {},
183
+ "tie_word_embeddings": false,
184
+ "chunk_size_feed_forward": 0,
185
+ "is_encoder_decoder": false,
186
+ "is_decoder": false,
187
+ "cross_attention_hidden_size": null,
188
+ "add_cross_attention": false,
189
+ "tie_encoder_decoder": false,
190
+ "max_length": 20,
191
+ "min_length": 0,
192
+ "do_sample": false,
193
+ "early_stopping": false,
194
+ "num_beams": 1,
195
+ "num_beam_groups": 1,
196
+ "diversity_penalty": 0.0,
197
+ "temperature": 1.0,
198
+ "top_k": 50,
199
+ "top_p": 1.0,
200
+ "typical_p": 1.0,
201
+ "repetition_penalty": 1.0,
202
+ "length_penalty": 1.0,
203
+ "no_repeat_ngram_size": 0,
204
+ "encoder_no_repeat_ngram_size": 0,
205
+ "bad_words_ids": null,
206
+ "num_return_sequences": 1,
207
+ "output_scores": false,
208
+ "return_dict_in_generate": false,
209
+ "forced_bos_token_id": null,
210
+ "forced_eos_token_id": null,
211
+ "remove_invalid_values": false,
212
+ "exponential_decay_length_penalty": null,
213
+ "suppress_tokens": null,
214
+ "begin_suppress_tokens": null,
215
+ "architectures": null,
216
+ "finetuning_task": null,
217
+ "id2label": {
218
+ "0": "LABEL_0",
219
+ "1": "LABEL_1"
220
+ },
221
+ "label2id": {
222
+ "LABEL_0": 0,
223
+ "LABEL_1": 1
224
+ },
225
+ "tokenizer_class": null,
226
+ "prefix": null,
227
+ "bos_token_id": 1,
228
+ "pad_token_id": 32000,
229
+ "eos_token_id": 2,
230
+ "sep_token_id": null,
231
+ "decoder_start_token_id": null,
232
+ "task_specific_params": null,
233
+ "problem_type": null,
234
+ "_name_or_path": "",
235
+ "model_type": "llama"
236
+ },
237
+ "return_dict": true,
238
+ "output_hidden_states": false,
239
+ "output_attentions": false,
240
+ "torchscript": false,
241
+ "torch_dtype": "bfloat16",
242
+ "use_bfloat16": false,
243
+ "tf_legacy_loss": false,
244
+ "pruned_heads": {},
245
+ "tie_word_embeddings": true,
246
+ "chunk_size_feed_forward": 0,
247
+ "is_encoder_decoder": false,
248
+ "is_decoder": false,
249
+ "cross_attention_hidden_size": null,
250
+ "add_cross_attention": false,
251
+ "tie_encoder_decoder": false,
252
+ "max_length": 20,
253
+ "min_length": 0,
254
+ "do_sample": false,
255
+ "early_stopping": false,
256
+ "num_beams": 1,
257
+ "num_beam_groups": 1,
258
+ "diversity_penalty": 0.0,
259
+ "temperature": 1.0,
260
+ "top_k": 50,
261
+ "top_p": 1.0,
262
+ "typical_p": 1.0,
263
+ "repetition_penalty": 1.0,
264
+ "length_penalty": 1.0,
265
+ "no_repeat_ngram_size": 0,
266
+ "encoder_no_repeat_ngram_size": 0,
267
+ "bad_words_ids": null,
268
+ "num_return_sequences": 1,
269
+ "output_scores": false,
270
+ "return_dict_in_generate": false,
271
+ "forced_bos_token_id": null,
272
+ "forced_eos_token_id": null,
273
+ "remove_invalid_values": false,
274
+ "exponential_decay_length_penalty": null,
275
+ "suppress_tokens": null,
276
+ "begin_suppress_tokens": null,
277
+ "architectures": [
278
+ "OpenVLAForActionPrediction"
279
+ ],
280
+ "finetuning_task": null,
281
+ "id2label": {
282
+ "0": "LABEL_0",
283
+ "1": "LABEL_1"
284
+ },
285
+ "label2id": {
286
+ "LABEL_0": 0,
287
+ "LABEL_1": 1
288
+ },
289
+ "tokenizer_class": null,
290
+ "prefix": null,
291
+ "bos_token_id": null,
292
+ "eos_token_id": null,
293
+ "sep_token_id": null,
294
+ "decoder_start_token_id": null,
295
+ "task_specific_params": null,
296
+ "problem_type": null,
297
+ "_name_or_path": "/home/user1/.cache/huggingface/hub/models--openvla--openvla-7b/snapshots/31f090d05236101ebfc381b61c674dd4746d4ce0",
298
+ "transformers_version": "4.40.1",
299
+ "auto_map": {
300
+ "AutoConfig": "configuration_prismatic.OpenVLAConfig",
301
+ "AutoModelForVision2Seq": "modeling_prismatic.OpenVLAForActionPrediction"
302
+ },
303
+ "model_type": "openvla"
304
+ }
config.json.back.20250403_194905 ADDED
@@ -0,0 +1,304 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "norm_stats": {
3
+ "libero_10_no_noops": {
4
+ "action": {
5
+ "mean": [
6
+ 0.01820324920117855,
7
+ 0.05858374014496803,
8
+ -0.05592384561896324,
9
+ 0.004626928828656673,
10
+ 0.00289608770981431,
11
+ -0.007673131301999092,
12
+ 0.5457824468612671
13
+ ],
14
+ "std": [
15
+ 0.2825464606285095,
16
+ 0.35904666781425476,
17
+ 0.3673802614212036,
18
+ 0.03770702704787254,
19
+ 0.05429719388484955,
20
+ 0.08725254982709885,
21
+ 0.49815231561660767
22
+ ],
23
+ "max": [
24
+ 0.9375,
25
+ 0.9375,
26
+ 0.9375,
27
+ 0.30000001192092896,
28
+ 0.29357144236564636,
29
+ 0.375,
30
+ 1.0
31
+ ],
32
+ "min": [
33
+ -0.9375,
34
+ -0.9375,
35
+ -0.9375,
36
+ -0.23642857372760773,
37
+ -0.3053571283817291,
38
+ -0.3675000071525574,
39
+ 0.0
40
+ ],
41
+ "q01": [
42
+ -0.6348214149475098,
43
+ -0.7741071581840515,
44
+ -0.7633928656578064,
45
+ -0.09749999642372131,
46
+ -0.14819999992847435,
47
+ -0.2742857038974762,
48
+ 0.0
49
+ ],
50
+ "q99": [
51
+ 0.7714285850524902,
52
+ 0.8464285731315613,
53
+ 0.9375,
54
+ 0.13928571343421936,
55
+ 0.15964286029338837,
56
+ 0.3246428668498993,
57
+ 1.0
58
+ ],
59
+ "mask": [
60
+ true,
61
+ true,
62
+ true,
63
+ true,
64
+ true,
65
+ true,
66
+ false
67
+ ]
68
+ },
69
+ "proprio": {
70
+ "mean": [
71
+ -0.04190658777952194,
72
+ 0.03539430722594261,
73
+ 0.8257141709327698,
74
+ 2.908308267593384,
75
+ -0.5562185049057007,
76
+ -0.16649018228054047,
77
+ 0.028316624462604523,
78
+ -0.028561657294631004
79
+ ],
80
+ "std": [
81
+ 0.10743364691734314,
82
+ 0.14424669742584229,
83
+ 0.2572328448295593,
84
+ 0.3441362977027893,
85
+ 1.234421730041504,
86
+ 0.3579835891723633,
87
+ 0.013308707624673843,
88
+ 0.013174631632864475
89
+ ],
90
+ "max": [
91
+ 0.21031762659549713,
92
+ 0.39128610491752625,
93
+ 1.3332009315490723,
94
+ 3.6714255809783936,
95
+ 3.560650587081909,
96
+ 1.386339545249939,
97
+ 0.04160946607589722,
98
+ 0.0013633022317662835
99
+ ],
100
+ "min": [
101
+ -0.4828203022480011,
102
+ -0.3255046010017395,
103
+ 0.445506751537323,
104
+ 1.1321442127227783,
105
+ -3.641430377960205,
106
+ -1.842738389968872,
107
+ -0.0010040868073701859,
108
+ -0.04111652821302414
109
+ ],
110
+ "q01": [
111
+ -0.3899900782108307,
112
+ -0.2838300323486328,
113
+ 0.44795057058334353,
114
+ 1.8810229921340942,
115
+ -2.886677579879761,
116
+ -1.1599004411697387,
117
+ 0.002066459748893976,
118
+ -0.04001387819647789
119
+ ],
120
+ "q99": [
121
+ 0.1530261474847791,
122
+ 0.32915401458740223,
123
+ 1.2546923208236693,
124
+ 3.303542451858519,
125
+ 2.7496529006957933,
126
+ 0.6893712210655194,
127
+ 0.040048558115959164,
128
+ -0.0017598449345678235
129
+ ]
130
+ },
131
+ "num_transitions": 101469,
132
+ "num_trajectories": 379
133
+ }
134
+ },
135
+ "n_action_bins": 256,
136
+ "vision_backbone_id": "dinosiglip-vit-so-224px",
137
+ "llm_backbone_id": "llama2-7b-pure",
138
+ "arch_specifier": "no-align+fused-gelu-mlp",
139
+ "output_projector_states": false,
140
+ "use_fused_vision_backbone": true,
141
+ "timm_model_ids": [
142
+ "vit_large_patch14_reg4_dinov2.lvd142m",
143
+ "vit_so400m_patch14_siglip_224"
144
+ ],
145
+ "timm_override_act_layers": [
146
+ null,
147
+ null
148
+ ],
149
+ "image_sizes": [
150
+ 224,
151
+ 224
152
+ ],
153
+ "image_resize_strategy": "resize-naive",
154
+ "hf_llm_id": "meta-llama/Llama-2-7b-hf",
155
+ "llm_max_length": 2048,
156
+ "pad_token_id": 32000,
157
+ "pad_to_multiple_of": 64,
158
+ "text_config": {
159
+ "vocab_size": 32064,
160
+ "max_position_embeddings": 2048,
161
+ "hidden_size": 4096,
162
+ "intermediate_size": 11008,
163
+ "num_hidden_layers": 32,
164
+ "num_attention_heads": 32,
165
+ "num_key_value_heads": 32,
166
+ "hidden_act": "silu",
167
+ "initializer_range": 0.02,
168
+ "rms_norm_eps": 1e-06,
169
+ "pretraining_tp": 1,
170
+ "use_cache": true,
171
+ "rope_theta": 10000.0,
172
+ "rope_scaling": null,
173
+ "attention_bias": false,
174
+ "attention_dropout": 0.0,
175
+ "return_dict": true,
176
+ "output_hidden_states": false,
177
+ "output_attentions": false,
178
+ "torchscript": false,
179
+ "torch_dtype": "bfloat16",
180
+ "use_bfloat16": false,
181
+ "tf_legacy_loss": false,
182
+ "pruned_heads": {},
183
+ "tie_word_embeddings": false,
184
+ "chunk_size_feed_forward": 0,
185
+ "is_encoder_decoder": false,
186
+ "is_decoder": false,
187
+ "cross_attention_hidden_size": null,
188
+ "add_cross_attention": false,
189
+ "tie_encoder_decoder": false,
190
+ "max_length": 20,
191
+ "min_length": 0,
192
+ "do_sample": false,
193
+ "early_stopping": false,
194
+ "num_beams": 1,
195
+ "num_beam_groups": 1,
196
+ "diversity_penalty": 0.0,
197
+ "temperature": 1.0,
198
+ "top_k": 50,
199
+ "top_p": 1.0,
200
+ "typical_p": 1.0,
201
+ "repetition_penalty": 1.0,
202
+ "length_penalty": 1.0,
203
+ "no_repeat_ngram_size": 0,
204
+ "encoder_no_repeat_ngram_size": 0,
205
+ "bad_words_ids": null,
206
+ "num_return_sequences": 1,
207
+ "output_scores": false,
208
+ "return_dict_in_generate": false,
209
+ "forced_bos_token_id": null,
210
+ "forced_eos_token_id": null,
211
+ "remove_invalid_values": false,
212
+ "exponential_decay_length_penalty": null,
213
+ "suppress_tokens": null,
214
+ "begin_suppress_tokens": null,
215
+ "architectures": null,
216
+ "finetuning_task": null,
217
+ "id2label": {
218
+ "0": "LABEL_0",
219
+ "1": "LABEL_1"
220
+ },
221
+ "label2id": {
222
+ "LABEL_0": 0,
223
+ "LABEL_1": 1
224
+ },
225
+ "tokenizer_class": null,
226
+ "prefix": null,
227
+ "bos_token_id": 1,
228
+ "pad_token_id": 32000,
229
+ "eos_token_id": 2,
230
+ "sep_token_id": null,
231
+ "decoder_start_token_id": null,
232
+ "task_specific_params": null,
233
+ "problem_type": null,
234
+ "_name_or_path": "",
235
+ "model_type": "llama"
236
+ },
237
+ "return_dict": true,
238
+ "output_hidden_states": false,
239
+ "output_attentions": false,
240
+ "torchscript": false,
241
+ "torch_dtype": "bfloat16",
242
+ "use_bfloat16": false,
243
+ "tf_legacy_loss": false,
244
+ "pruned_heads": {},
245
+ "tie_word_embeddings": true,
246
+ "chunk_size_feed_forward": 0,
247
+ "is_encoder_decoder": false,
248
+ "is_decoder": false,
249
+ "cross_attention_hidden_size": null,
250
+ "add_cross_attention": false,
251
+ "tie_encoder_decoder": false,
252
+ "max_length": 20,
253
+ "min_length": 0,
254
+ "do_sample": false,
255
+ "early_stopping": false,
256
+ "num_beams": 1,
257
+ "num_beam_groups": 1,
258
+ "diversity_penalty": 0.0,
259
+ "temperature": 1.0,
260
+ "top_k": 50,
261
+ "top_p": 1.0,
262
+ "typical_p": 1.0,
263
+ "repetition_penalty": 1.0,
264
+ "length_penalty": 1.0,
265
+ "no_repeat_ngram_size": 0,
266
+ "encoder_no_repeat_ngram_size": 0,
267
+ "bad_words_ids": null,
268
+ "num_return_sequences": 1,
269
+ "output_scores": false,
270
+ "return_dict_in_generate": false,
271
+ "forced_bos_token_id": null,
272
+ "forced_eos_token_id": null,
273
+ "remove_invalid_values": false,
274
+ "exponential_decay_length_penalty": null,
275
+ "suppress_tokens": null,
276
+ "begin_suppress_tokens": null,
277
+ "architectures": [
278
+ "OpenVLAForActionPrediction"
279
+ ],
280
+ "finetuning_task": null,
281
+ "id2label": {
282
+ "0": "LABEL_0",
283
+ "1": "LABEL_1"
284
+ },
285
+ "label2id": {
286
+ "LABEL_0": 0,
287
+ "LABEL_1": 1
288
+ },
289
+ "tokenizer_class": null,
290
+ "prefix": null,
291
+ "bos_token_id": null,
292
+ "eos_token_id": null,
293
+ "sep_token_id": null,
294
+ "decoder_start_token_id": null,
295
+ "task_specific_params": null,
296
+ "problem_type": null,
297
+ "_name_or_path": "/home/user1/.cache/huggingface/hub/models--openvla--openvla-7b/snapshots/31f090d05236101ebfc381b61c674dd4746d4ce0",
298
+ "transformers_version": "4.40.1",
299
+ "auto_map": {
300
+ "AutoConfig": "configuration_prismatic.OpenVLAConfig",
301
+ "AutoModelForVision2Seq": "modeling_prismatic.OpenVLAForActionPrediction"
302
+ },
303
+ "model_type": "openvla"
304
+ }
config.json.back.20250404_145345 ADDED
@@ -0,0 +1,304 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "norm_stats": {
3
+ "libero_10_no_noops": {
4
+ "action": {
5
+ "mean": [
6
+ 0.01820324920117855,
7
+ 0.05858374014496803,
8
+ -0.05592384561896324,
9
+ 0.004626928828656673,
10
+ 0.00289608770981431,
11
+ -0.007673131301999092,
12
+ 0.5457824468612671
13
+ ],
14
+ "std": [
15
+ 0.2825464606285095,
16
+ 0.35904666781425476,
17
+ 0.3673802614212036,
18
+ 0.03770702704787254,
19
+ 0.05429719388484955,
20
+ 0.08725254982709885,
21
+ 0.49815231561660767
22
+ ],
23
+ "max": [
24
+ 0.9375,
25
+ 0.9375,
26
+ 0.9375,
27
+ 0.30000001192092896,
28
+ 0.29357144236564636,
29
+ 0.375,
30
+ 1.0
31
+ ],
32
+ "min": [
33
+ -0.9375,
34
+ -0.9375,
35
+ -0.9375,
36
+ -0.23642857372760773,
37
+ -0.3053571283817291,
38
+ -0.3675000071525574,
39
+ 0.0
40
+ ],
41
+ "q01": [
42
+ -0.6348214149475098,
43
+ -0.7741071581840515,
44
+ -0.7633928656578064,
45
+ -0.09749999642372131,
46
+ -0.14819999992847435,
47
+ -0.2742857038974762,
48
+ 0.0
49
+ ],
50
+ "q99": [
51
+ 0.7714285850524902,
52
+ 0.8464285731315613,
53
+ 0.9375,
54
+ 0.13928571343421936,
55
+ 0.15964286029338837,
56
+ 0.3246428668498993,
57
+ 1.0
58
+ ],
59
+ "mask": [
60
+ true,
61
+ true,
62
+ true,
63
+ true,
64
+ true,
65
+ true,
66
+ false
67
+ ]
68
+ },
69
+ "proprio": {
70
+ "mean": [
71
+ -0.04190658777952194,
72
+ 0.03539430722594261,
73
+ 0.8257141709327698,
74
+ 2.908308267593384,
75
+ -0.5562185049057007,
76
+ -0.16649018228054047,
77
+ 0.028316624462604523,
78
+ -0.028561657294631004
79
+ ],
80
+ "std": [
81
+ 0.10743364691734314,
82
+ 0.14424669742584229,
83
+ 0.2572328448295593,
84
+ 0.3441362977027893,
85
+ 1.234421730041504,
86
+ 0.3579835891723633,
87
+ 0.013308707624673843,
88
+ 0.013174631632864475
89
+ ],
90
+ "max": [
91
+ 0.21031762659549713,
92
+ 0.39128610491752625,
93
+ 1.3332009315490723,
94
+ 3.6714255809783936,
95
+ 3.560650587081909,
96
+ 1.386339545249939,
97
+ 0.04160946607589722,
98
+ 0.0013633022317662835
99
+ ],
100
+ "min": [
101
+ -0.4828203022480011,
102
+ -0.3255046010017395,
103
+ 0.445506751537323,
104
+ 1.1321442127227783,
105
+ -3.641430377960205,
106
+ -1.842738389968872,
107
+ -0.0010040868073701859,
108
+ -0.04111652821302414
109
+ ],
110
+ "q01": [
111
+ -0.3899900782108307,
112
+ -0.2838300323486328,
113
+ 0.44795057058334353,
114
+ 1.8810229921340942,
115
+ -2.886677579879761,
116
+ -1.1599004411697387,
117
+ 0.002066459748893976,
118
+ -0.04001387819647789
119
+ ],
120
+ "q99": [
121
+ 0.1530261474847791,
122
+ 0.32915401458740223,
123
+ 1.2546923208236693,
124
+ 3.303542451858519,
125
+ 2.7496529006957933,
126
+ 0.6893712210655194,
127
+ 0.040048558115959164,
128
+ -0.0017598449345678235
129
+ ]
130
+ },
131
+ "num_transitions": 101469,
132
+ "num_trajectories": 379
133
+ }
134
+ },
135
+ "n_action_bins": 256,
136
+ "vision_backbone_id": "dinosiglip-vit-so-224px",
137
+ "llm_backbone_id": "llama2-7b-pure",
138
+ "arch_specifier": "no-align+fused-gelu-mlp",
139
+ "output_projector_states": false,
140
+ "use_fused_vision_backbone": true,
141
+ "timm_model_ids": [
142
+ "vit_large_patch14_reg4_dinov2.lvd142m",
143
+ "vit_so400m_patch14_siglip_224"
144
+ ],
145
+ "timm_override_act_layers": [
146
+ null,
147
+ null
148
+ ],
149
+ "image_sizes": [
150
+ 224,
151
+ 224
152
+ ],
153
+ "image_resize_strategy": "resize-naive",
154
+ "hf_llm_id": "meta-llama/Llama-2-7b-hf",
155
+ "llm_max_length": 2048,
156
+ "pad_token_id": 32000,
157
+ "pad_to_multiple_of": 64,
158
+ "text_config": {
159
+ "vocab_size": 32064,
160
+ "max_position_embeddings": 2048,
161
+ "hidden_size": 4096,
162
+ "intermediate_size": 11008,
163
+ "num_hidden_layers": 32,
164
+ "num_attention_heads": 32,
165
+ "num_key_value_heads": 32,
166
+ "hidden_act": "silu",
167
+ "initializer_range": 0.02,
168
+ "rms_norm_eps": 1e-06,
169
+ "pretraining_tp": 1,
170
+ "use_cache": true,
171
+ "rope_theta": 10000.0,
172
+ "rope_scaling": null,
173
+ "attention_bias": false,
174
+ "attention_dropout": 0.0,
175
+ "return_dict": true,
176
+ "output_hidden_states": false,
177
+ "output_attentions": false,
178
+ "torchscript": false,
179
+ "torch_dtype": "bfloat16",
180
+ "use_bfloat16": false,
181
+ "tf_legacy_loss": false,
182
+ "pruned_heads": {},
183
+ "tie_word_embeddings": false,
184
+ "chunk_size_feed_forward": 0,
185
+ "is_encoder_decoder": false,
186
+ "is_decoder": false,
187
+ "cross_attention_hidden_size": null,
188
+ "add_cross_attention": false,
189
+ "tie_encoder_decoder": false,
190
+ "max_length": 20,
191
+ "min_length": 0,
192
+ "do_sample": false,
193
+ "early_stopping": false,
194
+ "num_beams": 1,
195
+ "num_beam_groups": 1,
196
+ "diversity_penalty": 0.0,
197
+ "temperature": 1.0,
198
+ "top_k": 50,
199
+ "top_p": 1.0,
200
+ "typical_p": 1.0,
201
+ "repetition_penalty": 1.0,
202
+ "length_penalty": 1.0,
203
+ "no_repeat_ngram_size": 0,
204
+ "encoder_no_repeat_ngram_size": 0,
205
+ "bad_words_ids": null,
206
+ "num_return_sequences": 1,
207
+ "output_scores": false,
208
+ "return_dict_in_generate": false,
209
+ "forced_bos_token_id": null,
210
+ "forced_eos_token_id": null,
211
+ "remove_invalid_values": false,
212
+ "exponential_decay_length_penalty": null,
213
+ "suppress_tokens": null,
214
+ "begin_suppress_tokens": null,
215
+ "architectures": null,
216
+ "finetuning_task": null,
217
+ "id2label": {
218
+ "0": "LABEL_0",
219
+ "1": "LABEL_1"
220
+ },
221
+ "label2id": {
222
+ "LABEL_0": 0,
223
+ "LABEL_1": 1
224
+ },
225
+ "tokenizer_class": null,
226
+ "prefix": null,
227
+ "bos_token_id": 1,
228
+ "pad_token_id": 32000,
229
+ "eos_token_id": 2,
230
+ "sep_token_id": null,
231
+ "decoder_start_token_id": null,
232
+ "task_specific_params": null,
233
+ "problem_type": null,
234
+ "_name_or_path": "",
235
+ "model_type": "llama"
236
+ },
237
+ "return_dict": true,
238
+ "output_hidden_states": false,
239
+ "output_attentions": false,
240
+ "torchscript": false,
241
+ "torch_dtype": "bfloat16",
242
+ "use_bfloat16": false,
243
+ "tf_legacy_loss": false,
244
+ "pruned_heads": {},
245
+ "tie_word_embeddings": true,
246
+ "chunk_size_feed_forward": 0,
247
+ "is_encoder_decoder": false,
248
+ "is_decoder": false,
249
+ "cross_attention_hidden_size": null,
250
+ "add_cross_attention": false,
251
+ "tie_encoder_decoder": false,
252
+ "max_length": 20,
253
+ "min_length": 0,
254
+ "do_sample": false,
255
+ "early_stopping": false,
256
+ "num_beams": 1,
257
+ "num_beam_groups": 1,
258
+ "diversity_penalty": 0.0,
259
+ "temperature": 1.0,
260
+ "top_k": 50,
261
+ "top_p": 1.0,
262
+ "typical_p": 1.0,
263
+ "repetition_penalty": 1.0,
264
+ "length_penalty": 1.0,
265
+ "no_repeat_ngram_size": 0,
266
+ "encoder_no_repeat_ngram_size": 0,
267
+ "bad_words_ids": null,
268
+ "num_return_sequences": 1,
269
+ "output_scores": false,
270
+ "return_dict_in_generate": false,
271
+ "forced_bos_token_id": null,
272
+ "forced_eos_token_id": null,
273
+ "remove_invalid_values": false,
274
+ "exponential_decay_length_penalty": null,
275
+ "suppress_tokens": null,
276
+ "begin_suppress_tokens": null,
277
+ "architectures": [
278
+ "OpenVLAForActionPrediction"
279
+ ],
280
+ "finetuning_task": null,
281
+ "id2label": {
282
+ "0": "LABEL_0",
283
+ "1": "LABEL_1"
284
+ },
285
+ "label2id": {
286
+ "LABEL_0": 0,
287
+ "LABEL_1": 1
288
+ },
289
+ "tokenizer_class": null,
290
+ "prefix": null,
291
+ "bos_token_id": null,
292
+ "eos_token_id": null,
293
+ "sep_token_id": null,
294
+ "decoder_start_token_id": null,
295
+ "task_specific_params": null,
296
+ "problem_type": null,
297
+ "_name_or_path": "/home/user1/.cache/huggingface/hub/models--openvla--openvla-7b/snapshots/31f090d05236101ebfc381b61c674dd4746d4ce0",
298
+ "transformers_version": "4.40.1",
299
+ "auto_map": {
300
+ "AutoConfig": "configuration_prismatic.OpenVLAConfig",
301
+ "AutoModelForVision2Seq": "modeling_prismatic.OpenVLAForActionPrediction"
302
+ },
303
+ "model_type": "openvla"
304
+ }
config.json.back.20250404_145429 ADDED
@@ -0,0 +1,304 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "norm_stats": {
3
+ "libero_10_no_noops": {
4
+ "action": {
5
+ "mean": [
6
+ 0.01820324920117855,
7
+ 0.05858374014496803,
8
+ -0.05592384561896324,
9
+ 0.004626928828656673,
10
+ 0.00289608770981431,
11
+ -0.007673131301999092,
12
+ 0.5457824468612671
13
+ ],
14
+ "std": [
15
+ 0.2825464606285095,
16
+ 0.35904666781425476,
17
+ 0.3673802614212036,
18
+ 0.03770702704787254,
19
+ 0.05429719388484955,
20
+ 0.08725254982709885,
21
+ 0.49815231561660767
22
+ ],
23
+ "max": [
24
+ 0.9375,
25
+ 0.9375,
26
+ 0.9375,
27
+ 0.30000001192092896,
28
+ 0.29357144236564636,
29
+ 0.375,
30
+ 1.0
31
+ ],
32
+ "min": [
33
+ -0.9375,
34
+ -0.9375,
35
+ -0.9375,
36
+ -0.23642857372760773,
37
+ -0.3053571283817291,
38
+ -0.3675000071525574,
39
+ 0.0
40
+ ],
41
+ "q01": [
42
+ -0.6348214149475098,
43
+ -0.7741071581840515,
44
+ -0.7633928656578064,
45
+ -0.09749999642372131,
46
+ -0.14819999992847435,
47
+ -0.2742857038974762,
48
+ 0.0
49
+ ],
50
+ "q99": [
51
+ 0.7714285850524902,
52
+ 0.8464285731315613,
53
+ 0.9375,
54
+ 0.13928571343421936,
55
+ 0.15964286029338837,
56
+ 0.3246428668498993,
57
+ 1.0
58
+ ],
59
+ "mask": [
60
+ true,
61
+ true,
62
+ true,
63
+ true,
64
+ true,
65
+ true,
66
+ false
67
+ ]
68
+ },
69
+ "proprio": {
70
+ "mean": [
71
+ -0.04190658777952194,
72
+ 0.03539430722594261,
73
+ 0.8257141709327698,
74
+ 2.908308267593384,
75
+ -0.5562185049057007,
76
+ -0.16649018228054047,
77
+ 0.028316624462604523,
78
+ -0.028561657294631004
79
+ ],
80
+ "std": [
81
+ 0.10743364691734314,
82
+ 0.14424669742584229,
83
+ 0.2572328448295593,
84
+ 0.3441362977027893,
85
+ 1.234421730041504,
86
+ 0.3579835891723633,
87
+ 0.013308707624673843,
88
+ 0.013174631632864475
89
+ ],
90
+ "max": [
91
+ 0.21031762659549713,
92
+ 0.39128610491752625,
93
+ 1.3332009315490723,
94
+ 3.6714255809783936,
95
+ 3.560650587081909,
96
+ 1.386339545249939,
97
+ 0.04160946607589722,
98
+ 0.0013633022317662835
99
+ ],
100
+ "min": [
101
+ -0.4828203022480011,
102
+ -0.3255046010017395,
103
+ 0.445506751537323,
104
+ 1.1321442127227783,
105
+ -3.641430377960205,
106
+ -1.842738389968872,
107
+ -0.0010040868073701859,
108
+ -0.04111652821302414
109
+ ],
110
+ "q01": [
111
+ -0.3899900782108307,
112
+ -0.2838300323486328,
113
+ 0.44795057058334353,
114
+ 1.8810229921340942,
115
+ -2.886677579879761,
116
+ -1.1599004411697387,
117
+ 0.002066459748893976,
118
+ -0.04001387819647789
119
+ ],
120
+ "q99": [
121
+ 0.1530261474847791,
122
+ 0.32915401458740223,
123
+ 1.2546923208236693,
124
+ 3.303542451858519,
125
+ 2.7496529006957933,
126
+ 0.6893712210655194,
127
+ 0.040048558115959164,
128
+ -0.0017598449345678235
129
+ ]
130
+ },
131
+ "num_transitions": 101469,
132
+ "num_trajectories": 379
133
+ }
134
+ },
135
+ "n_action_bins": 256,
136
+ "vision_backbone_id": "dinosiglip-vit-so-224px",
137
+ "llm_backbone_id": "llama2-7b-pure",
138
+ "arch_specifier": "no-align+fused-gelu-mlp",
139
+ "output_projector_states": false,
140
+ "use_fused_vision_backbone": true,
141
+ "timm_model_ids": [
142
+ "vit_large_patch14_reg4_dinov2.lvd142m",
143
+ "vit_so400m_patch14_siglip_224"
144
+ ],
145
+ "timm_override_act_layers": [
146
+ null,
147
+ null
148
+ ],
149
+ "image_sizes": [
150
+ 224,
151
+ 224
152
+ ],
153
+ "image_resize_strategy": "resize-naive",
154
+ "hf_llm_id": "meta-llama/Llama-2-7b-hf",
155
+ "llm_max_length": 2048,
156
+ "pad_token_id": 32000,
157
+ "pad_to_multiple_of": 64,
158
+ "text_config": {
159
+ "vocab_size": 32064,
160
+ "max_position_embeddings": 2048,
161
+ "hidden_size": 4096,
162
+ "intermediate_size": 11008,
163
+ "num_hidden_layers": 32,
164
+ "num_attention_heads": 32,
165
+ "num_key_value_heads": 32,
166
+ "hidden_act": "silu",
167
+ "initializer_range": 0.02,
168
+ "rms_norm_eps": 1e-06,
169
+ "pretraining_tp": 1,
170
+ "use_cache": true,
171
+ "rope_theta": 10000.0,
172
+ "rope_scaling": null,
173
+ "attention_bias": false,
174
+ "attention_dropout": 0.0,
175
+ "return_dict": true,
176
+ "output_hidden_states": false,
177
+ "output_attentions": false,
178
+ "torchscript": false,
179
+ "torch_dtype": "bfloat16",
180
+ "use_bfloat16": false,
181
+ "tf_legacy_loss": false,
182
+ "pruned_heads": {},
183
+ "tie_word_embeddings": false,
184
+ "chunk_size_feed_forward": 0,
185
+ "is_encoder_decoder": false,
186
+ "is_decoder": false,
187
+ "cross_attention_hidden_size": null,
188
+ "add_cross_attention": false,
189
+ "tie_encoder_decoder": false,
190
+ "max_length": 20,
191
+ "min_length": 0,
192
+ "do_sample": false,
193
+ "early_stopping": false,
194
+ "num_beams": 1,
195
+ "num_beam_groups": 1,
196
+ "diversity_penalty": 0.0,
197
+ "temperature": 1.0,
198
+ "top_k": 50,
199
+ "top_p": 1.0,
200
+ "typical_p": 1.0,
201
+ "repetition_penalty": 1.0,
202
+ "length_penalty": 1.0,
203
+ "no_repeat_ngram_size": 0,
204
+ "encoder_no_repeat_ngram_size": 0,
205
+ "bad_words_ids": null,
206
+ "num_return_sequences": 1,
207
+ "output_scores": false,
208
+ "return_dict_in_generate": false,
209
+ "forced_bos_token_id": null,
210
+ "forced_eos_token_id": null,
211
+ "remove_invalid_values": false,
212
+ "exponential_decay_length_penalty": null,
213
+ "suppress_tokens": null,
214
+ "begin_suppress_tokens": null,
215
+ "architectures": null,
216
+ "finetuning_task": null,
217
+ "id2label": {
218
+ "0": "LABEL_0",
219
+ "1": "LABEL_1"
220
+ },
221
+ "label2id": {
222
+ "LABEL_0": 0,
223
+ "LABEL_1": 1
224
+ },
225
+ "tokenizer_class": null,
226
+ "prefix": null,
227
+ "bos_token_id": 1,
228
+ "pad_token_id": 32000,
229
+ "eos_token_id": 2,
230
+ "sep_token_id": null,
231
+ "decoder_start_token_id": null,
232
+ "task_specific_params": null,
233
+ "problem_type": null,
234
+ "_name_or_path": "",
235
+ "model_type": "llama"
236
+ },
237
+ "return_dict": true,
238
+ "output_hidden_states": false,
239
+ "output_attentions": false,
240
+ "torchscript": false,
241
+ "torch_dtype": "bfloat16",
242
+ "use_bfloat16": false,
243
+ "tf_legacy_loss": false,
244
+ "pruned_heads": {},
245
+ "tie_word_embeddings": true,
246
+ "chunk_size_feed_forward": 0,
247
+ "is_encoder_decoder": false,
248
+ "is_decoder": false,
249
+ "cross_attention_hidden_size": null,
250
+ "add_cross_attention": false,
251
+ "tie_encoder_decoder": false,
252
+ "max_length": 20,
253
+ "min_length": 0,
254
+ "do_sample": false,
255
+ "early_stopping": false,
256
+ "num_beams": 1,
257
+ "num_beam_groups": 1,
258
+ "diversity_penalty": 0.0,
259
+ "temperature": 1.0,
260
+ "top_k": 50,
261
+ "top_p": 1.0,
262
+ "typical_p": 1.0,
263
+ "repetition_penalty": 1.0,
264
+ "length_penalty": 1.0,
265
+ "no_repeat_ngram_size": 0,
266
+ "encoder_no_repeat_ngram_size": 0,
267
+ "bad_words_ids": null,
268
+ "num_return_sequences": 1,
269
+ "output_scores": false,
270
+ "return_dict_in_generate": false,
271
+ "forced_bos_token_id": null,
272
+ "forced_eos_token_id": null,
273
+ "remove_invalid_values": false,
274
+ "exponential_decay_length_penalty": null,
275
+ "suppress_tokens": null,
276
+ "begin_suppress_tokens": null,
277
+ "architectures": [
278
+ "OpenVLAForActionPrediction"
279
+ ],
280
+ "finetuning_task": null,
281
+ "id2label": {
282
+ "0": "LABEL_0",
283
+ "1": "LABEL_1"
284
+ },
285
+ "label2id": {
286
+ "LABEL_0": 0,
287
+ "LABEL_1": 1
288
+ },
289
+ "tokenizer_class": null,
290
+ "prefix": null,
291
+ "bos_token_id": null,
292
+ "eos_token_id": null,
293
+ "sep_token_id": null,
294
+ "decoder_start_token_id": null,
295
+ "task_specific_params": null,
296
+ "problem_type": null,
297
+ "_name_or_path": "/home/user1/.cache/huggingface/hub/models--openvla--openvla-7b/snapshots/31f090d05236101ebfc381b61c674dd4746d4ce0",
298
+ "transformers_version": "4.40.1",
299
+ "auto_map": {
300
+ "AutoConfig": "configuration_prismatic.OpenVLAConfig",
301
+ "AutoModelForVision2Seq": "modeling_prismatic.OpenVLAForActionPrediction"
302
+ },
303
+ "model_type": "openvla"
304
+ }
config.json.back.20250404_150035 ADDED
@@ -0,0 +1,304 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "norm_stats": {
3
+ "libero_10_no_noops": {
4
+ "action": {
5
+ "mean": [
6
+ 0.01820324920117855,
7
+ 0.05858374014496803,
8
+ -0.05592384561896324,
9
+ 0.004626928828656673,
10
+ 0.00289608770981431,
11
+ -0.007673131301999092,
12
+ 0.5457824468612671
13
+ ],
14
+ "std": [
15
+ 0.2825464606285095,
16
+ 0.35904666781425476,
17
+ 0.3673802614212036,
18
+ 0.03770702704787254,
19
+ 0.05429719388484955,
20
+ 0.08725254982709885,
21
+ 0.49815231561660767
22
+ ],
23
+ "max": [
24
+ 0.9375,
25
+ 0.9375,
26
+ 0.9375,
27
+ 0.30000001192092896,
28
+ 0.29357144236564636,
29
+ 0.375,
30
+ 1.0
31
+ ],
32
+ "min": [
33
+ -0.9375,
34
+ -0.9375,
35
+ -0.9375,
36
+ -0.23642857372760773,
37
+ -0.3053571283817291,
38
+ -0.3675000071525574,
39
+ 0.0
40
+ ],
41
+ "q01": [
42
+ -0.6348214149475098,
43
+ -0.7741071581840515,
44
+ -0.7633928656578064,
45
+ -0.09749999642372131,
46
+ -0.14819999992847435,
47
+ -0.2742857038974762,
48
+ 0.0
49
+ ],
50
+ "q99": [
51
+ 0.7714285850524902,
52
+ 0.8464285731315613,
53
+ 0.9375,
54
+ 0.13928571343421936,
55
+ 0.15964286029338837,
56
+ 0.3246428668498993,
57
+ 1.0
58
+ ],
59
+ "mask": [
60
+ true,
61
+ true,
62
+ true,
63
+ true,
64
+ true,
65
+ true,
66
+ false
67
+ ]
68
+ },
69
+ "proprio": {
70
+ "mean": [
71
+ -0.04190658777952194,
72
+ 0.03539430722594261,
73
+ 0.8257141709327698,
74
+ 2.908308267593384,
75
+ -0.5562185049057007,
76
+ -0.16649018228054047,
77
+ 0.028316624462604523,
78
+ -0.028561657294631004
79
+ ],
80
+ "std": [
81
+ 0.10743364691734314,
82
+ 0.14424669742584229,
83
+ 0.2572328448295593,
84
+ 0.3441362977027893,
85
+ 1.234421730041504,
86
+ 0.3579835891723633,
87
+ 0.013308707624673843,
88
+ 0.013174631632864475
89
+ ],
90
+ "max": [
91
+ 0.21031762659549713,
92
+ 0.39128610491752625,
93
+ 1.3332009315490723,
94
+ 3.6714255809783936,
95
+ 3.560650587081909,
96
+ 1.386339545249939,
97
+ 0.04160946607589722,
98
+ 0.0013633022317662835
99
+ ],
100
+ "min": [
101
+ -0.4828203022480011,
102
+ -0.3255046010017395,
103
+ 0.445506751537323,
104
+ 1.1321442127227783,
105
+ -3.641430377960205,
106
+ -1.842738389968872,
107
+ -0.0010040868073701859,
108
+ -0.04111652821302414
109
+ ],
110
+ "q01": [
111
+ -0.3899900782108307,
112
+ -0.2838300323486328,
113
+ 0.44795057058334353,
114
+ 1.8810229921340942,
115
+ -2.886677579879761,
116
+ -1.1599004411697387,
117
+ 0.002066459748893976,
118
+ -0.04001387819647789
119
+ ],
120
+ "q99": [
121
+ 0.1530261474847791,
122
+ 0.32915401458740223,
123
+ 1.2546923208236693,
124
+ 3.303542451858519,
125
+ 2.7496529006957933,
126
+ 0.6893712210655194,
127
+ 0.040048558115959164,
128
+ -0.0017598449345678235
129
+ ]
130
+ },
131
+ "num_transitions": 101469,
132
+ "num_trajectories": 379
133
+ }
134
+ },
135
+ "n_action_bins": 256,
136
+ "vision_backbone_id": "dinosiglip-vit-so-224px",
137
+ "llm_backbone_id": "llama2-7b-pure",
138
+ "arch_specifier": "no-align+fused-gelu-mlp",
139
+ "output_projector_states": false,
140
+ "use_fused_vision_backbone": true,
141
+ "timm_model_ids": [
142
+ "vit_large_patch14_reg4_dinov2.lvd142m",
143
+ "vit_so400m_patch14_siglip_224"
144
+ ],
145
+ "timm_override_act_layers": [
146
+ null,
147
+ null
148
+ ],
149
+ "image_sizes": [
150
+ 224,
151
+ 224
152
+ ],
153
+ "image_resize_strategy": "resize-naive",
154
+ "hf_llm_id": "meta-llama/Llama-2-7b-hf",
155
+ "llm_max_length": 2048,
156
+ "pad_token_id": 32000,
157
+ "pad_to_multiple_of": 64,
158
+ "text_config": {
159
+ "vocab_size": 32064,
160
+ "max_position_embeddings": 2048,
161
+ "hidden_size": 4096,
162
+ "intermediate_size": 11008,
163
+ "num_hidden_layers": 32,
164
+ "num_attention_heads": 32,
165
+ "num_key_value_heads": 32,
166
+ "hidden_act": "silu",
167
+ "initializer_range": 0.02,
168
+ "rms_norm_eps": 1e-06,
169
+ "pretraining_tp": 1,
170
+ "use_cache": true,
171
+ "rope_theta": 10000.0,
172
+ "rope_scaling": null,
173
+ "attention_bias": false,
174
+ "attention_dropout": 0.0,
175
+ "return_dict": true,
176
+ "output_hidden_states": false,
177
+ "output_attentions": false,
178
+ "torchscript": false,
179
+ "torch_dtype": "bfloat16",
180
+ "use_bfloat16": false,
181
+ "tf_legacy_loss": false,
182
+ "pruned_heads": {},
183
+ "tie_word_embeddings": false,
184
+ "chunk_size_feed_forward": 0,
185
+ "is_encoder_decoder": false,
186
+ "is_decoder": false,
187
+ "cross_attention_hidden_size": null,
188
+ "add_cross_attention": false,
189
+ "tie_encoder_decoder": false,
190
+ "max_length": 20,
191
+ "min_length": 0,
192
+ "do_sample": false,
193
+ "early_stopping": false,
194
+ "num_beams": 1,
195
+ "num_beam_groups": 1,
196
+ "diversity_penalty": 0.0,
197
+ "temperature": 1.0,
198
+ "top_k": 50,
199
+ "top_p": 1.0,
200
+ "typical_p": 1.0,
201
+ "repetition_penalty": 1.0,
202
+ "length_penalty": 1.0,
203
+ "no_repeat_ngram_size": 0,
204
+ "encoder_no_repeat_ngram_size": 0,
205
+ "bad_words_ids": null,
206
+ "num_return_sequences": 1,
207
+ "output_scores": false,
208
+ "return_dict_in_generate": false,
209
+ "forced_bos_token_id": null,
210
+ "forced_eos_token_id": null,
211
+ "remove_invalid_values": false,
212
+ "exponential_decay_length_penalty": null,
213
+ "suppress_tokens": null,
214
+ "begin_suppress_tokens": null,
215
+ "architectures": null,
216
+ "finetuning_task": null,
217
+ "id2label": {
218
+ "0": "LABEL_0",
219
+ "1": "LABEL_1"
220
+ },
221
+ "label2id": {
222
+ "LABEL_0": 0,
223
+ "LABEL_1": 1
224
+ },
225
+ "tokenizer_class": null,
226
+ "prefix": null,
227
+ "bos_token_id": 1,
228
+ "pad_token_id": 32000,
229
+ "eos_token_id": 2,
230
+ "sep_token_id": null,
231
+ "decoder_start_token_id": null,
232
+ "task_specific_params": null,
233
+ "problem_type": null,
234
+ "_name_or_path": "",
235
+ "model_type": "llama"
236
+ },
237
+ "return_dict": true,
238
+ "output_hidden_states": false,
239
+ "output_attentions": false,
240
+ "torchscript": false,
241
+ "torch_dtype": "bfloat16",
242
+ "use_bfloat16": false,
243
+ "tf_legacy_loss": false,
244
+ "pruned_heads": {},
245
+ "tie_word_embeddings": true,
246
+ "chunk_size_feed_forward": 0,
247
+ "is_encoder_decoder": false,
248
+ "is_decoder": false,
249
+ "cross_attention_hidden_size": null,
250
+ "add_cross_attention": false,
251
+ "tie_encoder_decoder": false,
252
+ "max_length": 20,
253
+ "min_length": 0,
254
+ "do_sample": false,
255
+ "early_stopping": false,
256
+ "num_beams": 1,
257
+ "num_beam_groups": 1,
258
+ "diversity_penalty": 0.0,
259
+ "temperature": 1.0,
260
+ "top_k": 50,
261
+ "top_p": 1.0,
262
+ "typical_p": 1.0,
263
+ "repetition_penalty": 1.0,
264
+ "length_penalty": 1.0,
265
+ "no_repeat_ngram_size": 0,
266
+ "encoder_no_repeat_ngram_size": 0,
267
+ "bad_words_ids": null,
268
+ "num_return_sequences": 1,
269
+ "output_scores": false,
270
+ "return_dict_in_generate": false,
271
+ "forced_bos_token_id": null,
272
+ "forced_eos_token_id": null,
273
+ "remove_invalid_values": false,
274
+ "exponential_decay_length_penalty": null,
275
+ "suppress_tokens": null,
276
+ "begin_suppress_tokens": null,
277
+ "architectures": [
278
+ "OpenVLAForActionPrediction"
279
+ ],
280
+ "finetuning_task": null,
281
+ "id2label": {
282
+ "0": "LABEL_0",
283
+ "1": "LABEL_1"
284
+ },
285
+ "label2id": {
286
+ "LABEL_0": 0,
287
+ "LABEL_1": 1
288
+ },
289
+ "tokenizer_class": null,
290
+ "prefix": null,
291
+ "bos_token_id": null,
292
+ "eos_token_id": null,
293
+ "sep_token_id": null,
294
+ "decoder_start_token_id": null,
295
+ "task_specific_params": null,
296
+ "problem_type": null,
297
+ "_name_or_path": "/home/user1/.cache/huggingface/hub/models--openvla--openvla-7b/snapshots/31f090d05236101ebfc381b61c674dd4746d4ce0",
298
+ "transformers_version": "4.40.1",
299
+ "auto_map": {
300
+ "AutoConfig": "configuration_prismatic.OpenVLAConfig",
301
+ "AutoModelForVision2Seq": "modeling_prismatic.OpenVLAForActionPrediction"
302
+ },
303
+ "model_type": "openvla"
304
+ }
config.json.back.20250405_141300 ADDED
@@ -0,0 +1,304 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "norm_stats": {
3
+ "libero_10_no_noops": {
4
+ "action": {
5
+ "mean": [
6
+ 0.01820324920117855,
7
+ 0.05858374014496803,
8
+ -0.05592384561896324,
9
+ 0.004626928828656673,
10
+ 0.00289608770981431,
11
+ -0.007673131301999092,
12
+ 0.5457824468612671
13
+ ],
14
+ "std": [
15
+ 0.2825464606285095,
16
+ 0.35904666781425476,
17
+ 0.3673802614212036,
18
+ 0.03770702704787254,
19
+ 0.05429719388484955,
20
+ 0.08725254982709885,
21
+ 0.49815231561660767
22
+ ],
23
+ "max": [
24
+ 0.9375,
25
+ 0.9375,
26
+ 0.9375,
27
+ 0.30000001192092896,
28
+ 0.29357144236564636,
29
+ 0.375,
30
+ 1.0
31
+ ],
32
+ "min": [
33
+ -0.9375,
34
+ -0.9375,
35
+ -0.9375,
36
+ -0.23642857372760773,
37
+ -0.3053571283817291,
38
+ -0.3675000071525574,
39
+ 0.0
40
+ ],
41
+ "q01": [
42
+ -0.6348214149475098,
43
+ -0.7741071581840515,
44
+ -0.7633928656578064,
45
+ -0.09749999642372131,
46
+ -0.14819999992847435,
47
+ -0.2742857038974762,
48
+ 0.0
49
+ ],
50
+ "q99": [
51
+ 0.7714285850524902,
52
+ 0.8464285731315613,
53
+ 0.9375,
54
+ 0.13928571343421936,
55
+ 0.15964286029338837,
56
+ 0.3246428668498993,
57
+ 1.0
58
+ ],
59
+ "mask": [
60
+ true,
61
+ true,
62
+ true,
63
+ true,
64
+ true,
65
+ true,
66
+ false
67
+ ]
68
+ },
69
+ "proprio": {
70
+ "mean": [
71
+ -0.04190658777952194,
72
+ 0.03539430722594261,
73
+ 0.8257141709327698,
74
+ 2.908308267593384,
75
+ -0.5562185049057007,
76
+ -0.16649018228054047,
77
+ 0.028316624462604523,
78
+ -0.028561657294631004
79
+ ],
80
+ "std": [
81
+ 0.10743364691734314,
82
+ 0.14424669742584229,
83
+ 0.2572328448295593,
84
+ 0.3441362977027893,
85
+ 1.234421730041504,
86
+ 0.3579835891723633,
87
+ 0.013308707624673843,
88
+ 0.013174631632864475
89
+ ],
90
+ "max": [
91
+ 0.21031762659549713,
92
+ 0.39128610491752625,
93
+ 1.3332009315490723,
94
+ 3.6714255809783936,
95
+ 3.560650587081909,
96
+ 1.386339545249939,
97
+ 0.04160946607589722,
98
+ 0.0013633022317662835
99
+ ],
100
+ "min": [
101
+ -0.4828203022480011,
102
+ -0.3255046010017395,
103
+ 0.445506751537323,
104
+ 1.1321442127227783,
105
+ -3.641430377960205,
106
+ -1.842738389968872,
107
+ -0.0010040868073701859,
108
+ -0.04111652821302414
109
+ ],
110
+ "q01": [
111
+ -0.3899900782108307,
112
+ -0.2838300323486328,
113
+ 0.44795057058334353,
114
+ 1.8810229921340942,
115
+ -2.886677579879761,
116
+ -1.1599004411697387,
117
+ 0.002066459748893976,
118
+ -0.04001387819647789
119
+ ],
120
+ "q99": [
121
+ 0.1530261474847791,
122
+ 0.32915401458740223,
123
+ 1.2546923208236693,
124
+ 3.303542451858519,
125
+ 2.7496529006957933,
126
+ 0.6893712210655194,
127
+ 0.040048558115959164,
128
+ -0.0017598449345678235
129
+ ]
130
+ },
131
+ "num_transitions": 101469,
132
+ "num_trajectories": 379
133
+ }
134
+ },
135
+ "n_action_bins": 256,
136
+ "vision_backbone_id": "dinosiglip-vit-so-224px",
137
+ "llm_backbone_id": "llama2-7b-pure",
138
+ "arch_specifier": "no-align+fused-gelu-mlp",
139
+ "output_projector_states": false,
140
+ "use_fused_vision_backbone": true,
141
+ "timm_model_ids": [
142
+ "vit_large_patch14_reg4_dinov2.lvd142m",
143
+ "vit_so400m_patch14_siglip_224"
144
+ ],
145
+ "timm_override_act_layers": [
146
+ null,
147
+ null
148
+ ],
149
+ "image_sizes": [
150
+ 224,
151
+ 224
152
+ ],
153
+ "image_resize_strategy": "resize-naive",
154
+ "hf_llm_id": "meta-llama/Llama-2-7b-hf",
155
+ "llm_max_length": 2048,
156
+ "pad_token_id": 32000,
157
+ "pad_to_multiple_of": 64,
158
+ "text_config": {
159
+ "vocab_size": 32064,
160
+ "max_position_embeddings": 2048,
161
+ "hidden_size": 4096,
162
+ "intermediate_size": 11008,
163
+ "num_hidden_layers": 32,
164
+ "num_attention_heads": 32,
165
+ "num_key_value_heads": 32,
166
+ "hidden_act": "silu",
167
+ "initializer_range": 0.02,
168
+ "rms_norm_eps": 1e-06,
169
+ "pretraining_tp": 1,
170
+ "use_cache": true,
171
+ "rope_theta": 10000.0,
172
+ "rope_scaling": null,
173
+ "attention_bias": false,
174
+ "attention_dropout": 0.0,
175
+ "return_dict": true,
176
+ "output_hidden_states": false,
177
+ "output_attentions": false,
178
+ "torchscript": false,
179
+ "torch_dtype": "bfloat16",
180
+ "use_bfloat16": false,
181
+ "tf_legacy_loss": false,
182
+ "pruned_heads": {},
183
+ "tie_word_embeddings": false,
184
+ "chunk_size_feed_forward": 0,
185
+ "is_encoder_decoder": false,
186
+ "is_decoder": false,
187
+ "cross_attention_hidden_size": null,
188
+ "add_cross_attention": false,
189
+ "tie_encoder_decoder": false,
190
+ "max_length": 20,
191
+ "min_length": 0,
192
+ "do_sample": false,
193
+ "early_stopping": false,
194
+ "num_beams": 1,
195
+ "num_beam_groups": 1,
196
+ "diversity_penalty": 0.0,
197
+ "temperature": 1.0,
198
+ "top_k": 50,
199
+ "top_p": 1.0,
200
+ "typical_p": 1.0,
201
+ "repetition_penalty": 1.0,
202
+ "length_penalty": 1.0,
203
+ "no_repeat_ngram_size": 0,
204
+ "encoder_no_repeat_ngram_size": 0,
205
+ "bad_words_ids": null,
206
+ "num_return_sequences": 1,
207
+ "output_scores": false,
208
+ "return_dict_in_generate": false,
209
+ "forced_bos_token_id": null,
210
+ "forced_eos_token_id": null,
211
+ "remove_invalid_values": false,
212
+ "exponential_decay_length_penalty": null,
213
+ "suppress_tokens": null,
214
+ "begin_suppress_tokens": null,
215
+ "architectures": null,
216
+ "finetuning_task": null,
217
+ "id2label": {
218
+ "0": "LABEL_0",
219
+ "1": "LABEL_1"
220
+ },
221
+ "label2id": {
222
+ "LABEL_0": 0,
223
+ "LABEL_1": 1
224
+ },
225
+ "tokenizer_class": null,
226
+ "prefix": null,
227
+ "bos_token_id": 1,
228
+ "pad_token_id": 32000,
229
+ "eos_token_id": 2,
230
+ "sep_token_id": null,
231
+ "decoder_start_token_id": null,
232
+ "task_specific_params": null,
233
+ "problem_type": null,
234
+ "_name_or_path": "",
235
+ "model_type": "llama"
236
+ },
237
+ "return_dict": true,
238
+ "output_hidden_states": false,
239
+ "output_attentions": false,
240
+ "torchscript": false,
241
+ "torch_dtype": "bfloat16",
242
+ "use_bfloat16": false,
243
+ "tf_legacy_loss": false,
244
+ "pruned_heads": {},
245
+ "tie_word_embeddings": true,
246
+ "chunk_size_feed_forward": 0,
247
+ "is_encoder_decoder": false,
248
+ "is_decoder": false,
249
+ "cross_attention_hidden_size": null,
250
+ "add_cross_attention": false,
251
+ "tie_encoder_decoder": false,
252
+ "max_length": 20,
253
+ "min_length": 0,
254
+ "do_sample": false,
255
+ "early_stopping": false,
256
+ "num_beams": 1,
257
+ "num_beam_groups": 1,
258
+ "diversity_penalty": 0.0,
259
+ "temperature": 1.0,
260
+ "top_k": 50,
261
+ "top_p": 1.0,
262
+ "typical_p": 1.0,
263
+ "repetition_penalty": 1.0,
264
+ "length_penalty": 1.0,
265
+ "no_repeat_ngram_size": 0,
266
+ "encoder_no_repeat_ngram_size": 0,
267
+ "bad_words_ids": null,
268
+ "num_return_sequences": 1,
269
+ "output_scores": false,
270
+ "return_dict_in_generate": false,
271
+ "forced_bos_token_id": null,
272
+ "forced_eos_token_id": null,
273
+ "remove_invalid_values": false,
274
+ "exponential_decay_length_penalty": null,
275
+ "suppress_tokens": null,
276
+ "begin_suppress_tokens": null,
277
+ "architectures": [
278
+ "OpenVLAForActionPrediction"
279
+ ],
280
+ "finetuning_task": null,
281
+ "id2label": {
282
+ "0": "LABEL_0",
283
+ "1": "LABEL_1"
284
+ },
285
+ "label2id": {
286
+ "LABEL_0": 0,
287
+ "LABEL_1": 1
288
+ },
289
+ "tokenizer_class": null,
290
+ "prefix": null,
291
+ "bos_token_id": null,
292
+ "eos_token_id": null,
293
+ "sep_token_id": null,
294
+ "decoder_start_token_id": null,
295
+ "task_specific_params": null,
296
+ "problem_type": null,
297
+ "_name_or_path": "/home/user1/.cache/huggingface/hub/models--openvla--openvla-7b/snapshots/31f090d05236101ebfc381b61c674dd4746d4ce0",
298
+ "transformers_version": "4.40.1",
299
+ "auto_map": {
300
+ "AutoConfig": "configuration_prismatic.OpenVLAConfig",
301
+ "AutoModelForVision2Seq": "modeling_prismatic.OpenVLAForActionPrediction"
302
+ },
303
+ "model_type": "openvla"
304
+ }
config.json.back.20250405_163457 ADDED
@@ -0,0 +1,304 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "norm_stats": {
3
+ "libero_10_no_noops": {
4
+ "action": {
5
+ "mean": [
6
+ 0.01820324920117855,
7
+ 0.05858374014496803,
8
+ -0.05592384561896324,
9
+ 0.004626928828656673,
10
+ 0.00289608770981431,
11
+ -0.007673131301999092,
12
+ 0.5457824468612671
13
+ ],
14
+ "std": [
15
+ 0.2825464606285095,
16
+ 0.35904666781425476,
17
+ 0.3673802614212036,
18
+ 0.03770702704787254,
19
+ 0.05429719388484955,
20
+ 0.08725254982709885,
21
+ 0.49815231561660767
22
+ ],
23
+ "max": [
24
+ 0.9375,
25
+ 0.9375,
26
+ 0.9375,
27
+ 0.30000001192092896,
28
+ 0.29357144236564636,
29
+ 0.375,
30
+ 1.0
31
+ ],
32
+ "min": [
33
+ -0.9375,
34
+ -0.9375,
35
+ -0.9375,
36
+ -0.23642857372760773,
37
+ -0.3053571283817291,
38
+ -0.3675000071525574,
39
+ 0.0
40
+ ],
41
+ "q01": [
42
+ -0.6348214149475098,
43
+ -0.7741071581840515,
44
+ -0.7633928656578064,
45
+ -0.09749999642372131,
46
+ -0.14819999992847435,
47
+ -0.2742857038974762,
48
+ 0.0
49
+ ],
50
+ "q99": [
51
+ 0.7714285850524902,
52
+ 0.8464285731315613,
53
+ 0.9375,
54
+ 0.13928571343421936,
55
+ 0.15964286029338837,
56
+ 0.3246428668498993,
57
+ 1.0
58
+ ],
59
+ "mask": [
60
+ true,
61
+ true,
62
+ true,
63
+ true,
64
+ true,
65
+ true,
66
+ false
67
+ ]
68
+ },
69
+ "proprio": {
70
+ "mean": [
71
+ -0.04190658777952194,
72
+ 0.03539430722594261,
73
+ 0.8257141709327698,
74
+ 2.908308267593384,
75
+ -0.5562185049057007,
76
+ -0.16649018228054047,
77
+ 0.028316624462604523,
78
+ -0.028561657294631004
79
+ ],
80
+ "std": [
81
+ 0.10743364691734314,
82
+ 0.14424669742584229,
83
+ 0.2572328448295593,
84
+ 0.3441362977027893,
85
+ 1.234421730041504,
86
+ 0.3579835891723633,
87
+ 0.013308707624673843,
88
+ 0.013174631632864475
89
+ ],
90
+ "max": [
91
+ 0.21031762659549713,
92
+ 0.39128610491752625,
93
+ 1.3332009315490723,
94
+ 3.6714255809783936,
95
+ 3.560650587081909,
96
+ 1.386339545249939,
97
+ 0.04160946607589722,
98
+ 0.0013633022317662835
99
+ ],
100
+ "min": [
101
+ -0.4828203022480011,
102
+ -0.3255046010017395,
103
+ 0.445506751537323,
104
+ 1.1321442127227783,
105
+ -3.641430377960205,
106
+ -1.842738389968872,
107
+ -0.0010040868073701859,
108
+ -0.04111652821302414
109
+ ],
110
+ "q01": [
111
+ -0.3899900782108307,
112
+ -0.2838300323486328,
113
+ 0.44795057058334353,
114
+ 1.8810229921340942,
115
+ -2.886677579879761,
116
+ -1.1599004411697387,
117
+ 0.002066459748893976,
118
+ -0.04001387819647789
119
+ ],
120
+ "q99": [
121
+ 0.1530261474847791,
122
+ 0.32915401458740223,
123
+ 1.2546923208236693,
124
+ 3.303542451858519,
125
+ 2.7496529006957933,
126
+ 0.6893712210655194,
127
+ 0.040048558115959164,
128
+ -0.0017598449345678235
129
+ ]
130
+ },
131
+ "num_transitions": 101469,
132
+ "num_trajectories": 379
133
+ }
134
+ },
135
+ "n_action_bins": 256,
136
+ "vision_backbone_id": "dinosiglip-vit-so-224px",
137
+ "llm_backbone_id": "llama2-7b-pure",
138
+ "arch_specifier": "no-align+fused-gelu-mlp",
139
+ "output_projector_states": false,
140
+ "use_fused_vision_backbone": true,
141
+ "timm_model_ids": [
142
+ "vit_large_patch14_reg4_dinov2.lvd142m",
143
+ "vit_so400m_patch14_siglip_224"
144
+ ],
145
+ "timm_override_act_layers": [
146
+ null,
147
+ null
148
+ ],
149
+ "image_sizes": [
150
+ 224,
151
+ 224
152
+ ],
153
+ "image_resize_strategy": "resize-naive",
154
+ "hf_llm_id": "meta-llama/Llama-2-7b-hf",
155
+ "llm_max_length": 2048,
156
+ "pad_token_id": 32000,
157
+ "pad_to_multiple_of": 64,
158
+ "text_config": {
159
+ "vocab_size": 32064,
160
+ "max_position_embeddings": 2048,
161
+ "hidden_size": 4096,
162
+ "intermediate_size": 11008,
163
+ "num_hidden_layers": 32,
164
+ "num_attention_heads": 32,
165
+ "num_key_value_heads": 32,
166
+ "hidden_act": "silu",
167
+ "initializer_range": 0.02,
168
+ "rms_norm_eps": 1e-06,
169
+ "pretraining_tp": 1,
170
+ "use_cache": true,
171
+ "rope_theta": 10000.0,
172
+ "rope_scaling": null,
173
+ "attention_bias": false,
174
+ "attention_dropout": 0.0,
175
+ "return_dict": true,
176
+ "output_hidden_states": false,
177
+ "output_attentions": false,
178
+ "torchscript": false,
179
+ "torch_dtype": "bfloat16",
180
+ "use_bfloat16": false,
181
+ "tf_legacy_loss": false,
182
+ "pruned_heads": {},
183
+ "tie_word_embeddings": false,
184
+ "chunk_size_feed_forward": 0,
185
+ "is_encoder_decoder": false,
186
+ "is_decoder": false,
187
+ "cross_attention_hidden_size": null,
188
+ "add_cross_attention": false,
189
+ "tie_encoder_decoder": false,
190
+ "max_length": 20,
191
+ "min_length": 0,
192
+ "do_sample": false,
193
+ "early_stopping": false,
194
+ "num_beams": 1,
195
+ "num_beam_groups": 1,
196
+ "diversity_penalty": 0.0,
197
+ "temperature": 1.0,
198
+ "top_k": 50,
199
+ "top_p": 1.0,
200
+ "typical_p": 1.0,
201
+ "repetition_penalty": 1.0,
202
+ "length_penalty": 1.0,
203
+ "no_repeat_ngram_size": 0,
204
+ "encoder_no_repeat_ngram_size": 0,
205
+ "bad_words_ids": null,
206
+ "num_return_sequences": 1,
207
+ "output_scores": false,
208
+ "return_dict_in_generate": false,
209
+ "forced_bos_token_id": null,
210
+ "forced_eos_token_id": null,
211
+ "remove_invalid_values": false,
212
+ "exponential_decay_length_penalty": null,
213
+ "suppress_tokens": null,
214
+ "begin_suppress_tokens": null,
215
+ "architectures": null,
216
+ "finetuning_task": null,
217
+ "id2label": {
218
+ "0": "LABEL_0",
219
+ "1": "LABEL_1"
220
+ },
221
+ "label2id": {
222
+ "LABEL_0": 0,
223
+ "LABEL_1": 1
224
+ },
225
+ "tokenizer_class": null,
226
+ "prefix": null,
227
+ "bos_token_id": 1,
228
+ "pad_token_id": 32000,
229
+ "eos_token_id": 2,
230
+ "sep_token_id": null,
231
+ "decoder_start_token_id": null,
232
+ "task_specific_params": null,
233
+ "problem_type": null,
234
+ "_name_or_path": "",
235
+ "model_type": "llama"
236
+ },
237
+ "return_dict": true,
238
+ "output_hidden_states": false,
239
+ "output_attentions": false,
240
+ "torchscript": false,
241
+ "torch_dtype": "bfloat16",
242
+ "use_bfloat16": false,
243
+ "tf_legacy_loss": false,
244
+ "pruned_heads": {},
245
+ "tie_word_embeddings": true,
246
+ "chunk_size_feed_forward": 0,
247
+ "is_encoder_decoder": false,
248
+ "is_decoder": false,
249
+ "cross_attention_hidden_size": null,
250
+ "add_cross_attention": false,
251
+ "tie_encoder_decoder": false,
252
+ "max_length": 20,
253
+ "min_length": 0,
254
+ "do_sample": false,
255
+ "early_stopping": false,
256
+ "num_beams": 1,
257
+ "num_beam_groups": 1,
258
+ "diversity_penalty": 0.0,
259
+ "temperature": 1.0,
260
+ "top_k": 50,
261
+ "top_p": 1.0,
262
+ "typical_p": 1.0,
263
+ "repetition_penalty": 1.0,
264
+ "length_penalty": 1.0,
265
+ "no_repeat_ngram_size": 0,
266
+ "encoder_no_repeat_ngram_size": 0,
267
+ "bad_words_ids": null,
268
+ "num_return_sequences": 1,
269
+ "output_scores": false,
270
+ "return_dict_in_generate": false,
271
+ "forced_bos_token_id": null,
272
+ "forced_eos_token_id": null,
273
+ "remove_invalid_values": false,
274
+ "exponential_decay_length_penalty": null,
275
+ "suppress_tokens": null,
276
+ "begin_suppress_tokens": null,
277
+ "architectures": [
278
+ "OpenVLAForActionPrediction"
279
+ ],
280
+ "finetuning_task": null,
281
+ "id2label": {
282
+ "0": "LABEL_0",
283
+ "1": "LABEL_1"
284
+ },
285
+ "label2id": {
286
+ "LABEL_0": 0,
287
+ "LABEL_1": 1
288
+ },
289
+ "tokenizer_class": null,
290
+ "prefix": null,
291
+ "bos_token_id": null,
292
+ "eos_token_id": null,
293
+ "sep_token_id": null,
294
+ "decoder_start_token_id": null,
295
+ "task_specific_params": null,
296
+ "problem_type": null,
297
+ "_name_or_path": "/home/user1/.cache/huggingface/hub/models--openvla--openvla-7b/snapshots/31f090d05236101ebfc381b61c674dd4746d4ce0",
298
+ "transformers_version": "4.40.1",
299
+ "auto_map": {
300
+ "AutoConfig": "configuration_prismatic.OpenVLAConfig",
301
+ "AutoModelForVision2Seq": "modeling_prismatic.OpenVLAForActionPrediction"
302
+ },
303
+ "model_type": "openvla"
304
+ }
config.json.back.20250406_084655 ADDED
@@ -0,0 +1,304 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "norm_stats": {
3
+ "libero_10_no_noops": {
4
+ "action": {
5
+ "mean": [
6
+ 0.01820324920117855,
7
+ 0.05858374014496803,
8
+ -0.05592384561896324,
9
+ 0.004626928828656673,
10
+ 0.00289608770981431,
11
+ -0.007673131301999092,
12
+ 0.5457824468612671
13
+ ],
14
+ "std": [
15
+ 0.2825464606285095,
16
+ 0.35904666781425476,
17
+ 0.3673802614212036,
18
+ 0.03770702704787254,
19
+ 0.05429719388484955,
20
+ 0.08725254982709885,
21
+ 0.49815231561660767
22
+ ],
23
+ "max": [
24
+ 0.9375,
25
+ 0.9375,
26
+ 0.9375,
27
+ 0.30000001192092896,
28
+ 0.29357144236564636,
29
+ 0.375,
30
+ 1.0
31
+ ],
32
+ "min": [
33
+ -0.9375,
34
+ -0.9375,
35
+ -0.9375,
36
+ -0.23642857372760773,
37
+ -0.3053571283817291,
38
+ -0.3675000071525574,
39
+ 0.0
40
+ ],
41
+ "q01": [
42
+ -0.6348214149475098,
43
+ -0.7741071581840515,
44
+ -0.7633928656578064,
45
+ -0.09749999642372131,
46
+ -0.14819999992847435,
47
+ -0.2742857038974762,
48
+ 0.0
49
+ ],
50
+ "q99": [
51
+ 0.7714285850524902,
52
+ 0.8464285731315613,
53
+ 0.9375,
54
+ 0.13928571343421936,
55
+ 0.15964286029338837,
56
+ 0.3246428668498993,
57
+ 1.0
58
+ ],
59
+ "mask": [
60
+ true,
61
+ true,
62
+ true,
63
+ true,
64
+ true,
65
+ true,
66
+ false
67
+ ]
68
+ },
69
+ "proprio": {
70
+ "mean": [
71
+ -0.04190658777952194,
72
+ 0.03539430722594261,
73
+ 0.8257141709327698,
74
+ 2.908308267593384,
75
+ -0.5562185049057007,
76
+ -0.16649018228054047,
77
+ 0.028316624462604523,
78
+ -0.028561657294631004
79
+ ],
80
+ "std": [
81
+ 0.10743364691734314,
82
+ 0.14424669742584229,
83
+ 0.2572328448295593,
84
+ 0.3441362977027893,
85
+ 1.234421730041504,
86
+ 0.3579835891723633,
87
+ 0.013308707624673843,
88
+ 0.013174631632864475
89
+ ],
90
+ "max": [
91
+ 0.21031762659549713,
92
+ 0.39128610491752625,
93
+ 1.3332009315490723,
94
+ 3.6714255809783936,
95
+ 3.560650587081909,
96
+ 1.386339545249939,
97
+ 0.04160946607589722,
98
+ 0.0013633022317662835
99
+ ],
100
+ "min": [
101
+ -0.4828203022480011,
102
+ -0.3255046010017395,
103
+ 0.445506751537323,
104
+ 1.1321442127227783,
105
+ -3.641430377960205,
106
+ -1.842738389968872,
107
+ -0.0010040868073701859,
108
+ -0.04111652821302414
109
+ ],
110
+ "q01": [
111
+ -0.3899900782108307,
112
+ -0.2838300323486328,
113
+ 0.44795057058334353,
114
+ 1.8810229921340942,
115
+ -2.886677579879761,
116
+ -1.1599004411697387,
117
+ 0.002066459748893976,
118
+ -0.04001387819647789
119
+ ],
120
+ "q99": [
121
+ 0.1530261474847791,
122
+ 0.32915401458740223,
123
+ 1.2546923208236693,
124
+ 3.303542451858519,
125
+ 2.7496529006957933,
126
+ 0.6893712210655194,
127
+ 0.040048558115959164,
128
+ -0.0017598449345678235
129
+ ]
130
+ },
131
+ "num_transitions": 101469,
132
+ "num_trajectories": 379
133
+ }
134
+ },
135
+ "n_action_bins": 256,
136
+ "vision_backbone_id": "dinosiglip-vit-so-224px",
137
+ "llm_backbone_id": "llama2-7b-pure",
138
+ "arch_specifier": "no-align+fused-gelu-mlp",
139
+ "output_projector_states": false,
140
+ "use_fused_vision_backbone": true,
141
+ "timm_model_ids": [
142
+ "vit_large_patch14_reg4_dinov2.lvd142m",
143
+ "vit_so400m_patch14_siglip_224"
144
+ ],
145
+ "timm_override_act_layers": [
146
+ null,
147
+ null
148
+ ],
149
+ "image_sizes": [
150
+ 224,
151
+ 224
152
+ ],
153
+ "image_resize_strategy": "resize-naive",
154
+ "hf_llm_id": "meta-llama/Llama-2-7b-hf",
155
+ "llm_max_length": 2048,
156
+ "pad_token_id": 32000,
157
+ "pad_to_multiple_of": 64,
158
+ "text_config": {
159
+ "vocab_size": 32064,
160
+ "max_position_embeddings": 2048,
161
+ "hidden_size": 4096,
162
+ "intermediate_size": 11008,
163
+ "num_hidden_layers": 32,
164
+ "num_attention_heads": 32,
165
+ "num_key_value_heads": 32,
166
+ "hidden_act": "silu",
167
+ "initializer_range": 0.02,
168
+ "rms_norm_eps": 1e-06,
169
+ "pretraining_tp": 1,
170
+ "use_cache": true,
171
+ "rope_theta": 10000.0,
172
+ "rope_scaling": null,
173
+ "attention_bias": false,
174
+ "attention_dropout": 0.0,
175
+ "return_dict": true,
176
+ "output_hidden_states": false,
177
+ "output_attentions": false,
178
+ "torchscript": false,
179
+ "torch_dtype": "bfloat16",
180
+ "use_bfloat16": false,
181
+ "tf_legacy_loss": false,
182
+ "pruned_heads": {},
183
+ "tie_word_embeddings": false,
184
+ "chunk_size_feed_forward": 0,
185
+ "is_encoder_decoder": false,
186
+ "is_decoder": false,
187
+ "cross_attention_hidden_size": null,
188
+ "add_cross_attention": false,
189
+ "tie_encoder_decoder": false,
190
+ "max_length": 20,
191
+ "min_length": 0,
192
+ "do_sample": false,
193
+ "early_stopping": false,
194
+ "num_beams": 1,
195
+ "num_beam_groups": 1,
196
+ "diversity_penalty": 0.0,
197
+ "temperature": 1.0,
198
+ "top_k": 50,
199
+ "top_p": 1.0,
200
+ "typical_p": 1.0,
201
+ "repetition_penalty": 1.0,
202
+ "length_penalty": 1.0,
203
+ "no_repeat_ngram_size": 0,
204
+ "encoder_no_repeat_ngram_size": 0,
205
+ "bad_words_ids": null,
206
+ "num_return_sequences": 1,
207
+ "output_scores": false,
208
+ "return_dict_in_generate": false,
209
+ "forced_bos_token_id": null,
210
+ "forced_eos_token_id": null,
211
+ "remove_invalid_values": false,
212
+ "exponential_decay_length_penalty": null,
213
+ "suppress_tokens": null,
214
+ "begin_suppress_tokens": null,
215
+ "architectures": null,
216
+ "finetuning_task": null,
217
+ "id2label": {
218
+ "0": "LABEL_0",
219
+ "1": "LABEL_1"
220
+ },
221
+ "label2id": {
222
+ "LABEL_0": 0,
223
+ "LABEL_1": 1
224
+ },
225
+ "tokenizer_class": null,
226
+ "prefix": null,
227
+ "bos_token_id": 1,
228
+ "pad_token_id": 32000,
229
+ "eos_token_id": 2,
230
+ "sep_token_id": null,
231
+ "decoder_start_token_id": null,
232
+ "task_specific_params": null,
233
+ "problem_type": null,
234
+ "_name_or_path": "",
235
+ "model_type": "llama"
236
+ },
237
+ "return_dict": true,
238
+ "output_hidden_states": false,
239
+ "output_attentions": false,
240
+ "torchscript": false,
241
+ "torch_dtype": "bfloat16",
242
+ "use_bfloat16": false,
243
+ "tf_legacy_loss": false,
244
+ "pruned_heads": {},
245
+ "tie_word_embeddings": true,
246
+ "chunk_size_feed_forward": 0,
247
+ "is_encoder_decoder": false,
248
+ "is_decoder": false,
249
+ "cross_attention_hidden_size": null,
250
+ "add_cross_attention": false,
251
+ "tie_encoder_decoder": false,
252
+ "max_length": 20,
253
+ "min_length": 0,
254
+ "do_sample": false,
255
+ "early_stopping": false,
256
+ "num_beams": 1,
257
+ "num_beam_groups": 1,
258
+ "diversity_penalty": 0.0,
259
+ "temperature": 1.0,
260
+ "top_k": 50,
261
+ "top_p": 1.0,
262
+ "typical_p": 1.0,
263
+ "repetition_penalty": 1.0,
264
+ "length_penalty": 1.0,
265
+ "no_repeat_ngram_size": 0,
266
+ "encoder_no_repeat_ngram_size": 0,
267
+ "bad_words_ids": null,
268
+ "num_return_sequences": 1,
269
+ "output_scores": false,
270
+ "return_dict_in_generate": false,
271
+ "forced_bos_token_id": null,
272
+ "forced_eos_token_id": null,
273
+ "remove_invalid_values": false,
274
+ "exponential_decay_length_penalty": null,
275
+ "suppress_tokens": null,
276
+ "begin_suppress_tokens": null,
277
+ "architectures": [
278
+ "OpenVLAForActionPrediction"
279
+ ],
280
+ "finetuning_task": null,
281
+ "id2label": {
282
+ "0": "LABEL_0",
283
+ "1": "LABEL_1"
284
+ },
285
+ "label2id": {
286
+ "LABEL_0": 0,
287
+ "LABEL_1": 1
288
+ },
289
+ "tokenizer_class": null,
290
+ "prefix": null,
291
+ "bos_token_id": null,
292
+ "eos_token_id": null,
293
+ "sep_token_id": null,
294
+ "decoder_start_token_id": null,
295
+ "task_specific_params": null,
296
+ "problem_type": null,
297
+ "_name_or_path": "/home/user1/.cache/huggingface/hub/models--openvla--openvla-7b/snapshots/31f090d05236101ebfc381b61c674dd4746d4ce0",
298
+ "transformers_version": "4.40.1",
299
+ "auto_map": {
300
+ "AutoConfig": "configuration_prismatic.OpenVLAConfig",
301
+ "AutoModelForVision2Seq": "modeling_prismatic.OpenVLAForActionPrediction"
302
+ },
303
+ "model_type": "openvla"
304
+ }
configuration_prismatic.py ADDED
@@ -0,0 +1,140 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ configuration_prismatic.py
3
+
4
+ HuggingFace-style configuration definition for Prismatic VLMs, inheriting from `transformers.PretrainedConfig`.
5
+ Default configuration specifies `siglip-224px+7b`.
6
+ """
7
+
8
+ from typing import Any, Dict, List, Optional
9
+
10
+ from transformers import PretrainedConfig
11
+ from transformers.models.auto import CONFIG_MAPPING
12
+
13
+ # === Utilities for Mapping Prismatic names to HF names ===
14
+ # fmt: off
15
+ VISION_BACKBONE_TO_RESOLUTION: Dict[str, List[int]] = {
16
+ "clip-vit-l": [224], "siglip-vit-so400m": [224], "dinov2-vit-l": [224], "in1k-vit-l": [224],
17
+
18
+ "clip-vit-l-336px": [336],
19
+ "siglip-vit-so400m-384px": [384],
20
+
21
+ "dinoclip-vit-l-336px": [336, 336],
22
+ "dinosiglip-vit-so-224px": [224, 224],
23
+ "dinosiglip-vit-so-384px": [384, 384],
24
+ }
25
+ VISION_BACKBONE_TO_TIMM_ID: Dict[str, List[str]] = {
26
+ "clip-vit-l": ["vit_large_patch14_clip_224.openai"],
27
+ "clip-vit-l-336px": ["vit_large_patch14_clip_336.openai"],
28
+
29
+ "dinov2-vit-l": ["vit_large_patch14_reg4_dinov2.lvd142m"],
30
+ "in1k-vit-l": ["vit_large_patch16_224.augreg_in21k_ft_in1k"],
31
+
32
+ "siglip-vit-so400m": ["vit_so400m_patch14_siglip_224"],
33
+ "siglip-vit-so400m-384px": ["vit_so400m_patch14_siglip_384"],
34
+
35
+ "dinoclip-vit-l-336px": ["vit_large_patch14_reg4_dinov2.lvd142m", "vit_large_patch14_clip_336.openai"],
36
+ "dinosiglip-vit-so-224px": ["vit_large_patch14_reg4_dinov2.lvd142m", "vit_so400m_patch14_siglip_224"],
37
+ "dinosiglip-vit-so-384px": ["vit_large_patch14_reg4_dinov2.lvd142m", "vit_so400m_patch14_siglip_384"],
38
+ }
39
+ TIMM_OVERRIDE_ACT_LAYER: Dict[str, List[Optional[str]]] = {
40
+ "clip-vit-l": ["quick_gelu"], "clip-vit-l-336px": ["quick_gelu"],
41
+ "dinov2-vit-l": [None], "in1k-vit-l": [None],
42
+ "siglip-vit-so400m": [None], "siglip-vit-so400m-384px": [None],
43
+ "dinoclip-vit-l-336px": [None, "quick_gelu"],
44
+ "dinosiglip-vit-so-224px": [None, None], "dinosiglip-vit-so-384px": [None, None]
45
+ }
46
+
47
+ LLM_BACKBONE_TO_HF_PATH = {
48
+ "llama2-7b-pure": "meta-llama/Llama-2-7b-hf", "llama2-13b-pure": "meta-llama/Llama-2-13b-hf",
49
+ "llama2-7b-chat": "meta-llama/Llama-2-7b-chat-hf", "llama2-13b-chat": "meta-llama/Llama-2-13b-chat-hf",
50
+
51
+ "vicuna-v15-7b": "lmsys/vicuna-7b-v1.5", "vicuna-v15-13b": "lmsys/vicuna-13b-v1.5",
52
+
53
+ "mistral-v0.1-7b-pure": "mistralai/Mistral-7B-v0.1",
54
+ "mistral-v0.1-7b-instruct": "mistralai/Mistral-7B-Instruct-v0.1",
55
+
56
+ "phi-2-3b": "microsoft/phi-2",
57
+ }
58
+ LLM_BACKBONE_TO_HF_METACLASS = {
59
+ "llama2-7b-pure": "llama", "llama2-13b-pure": "llama", "llama2-7b-chat": "llama", "llama2-13b-chat": "llama",
60
+ "vicuna-v15-7b": "llama", "vicuna-v15-13b": "llama",
61
+
62
+ "mistral-v0.1-7b-pure": "mistral", "mistral-v0.1-7b-instruct": "mistral",
63
+
64
+ "phi-2-3b": "phi",
65
+ }
66
+
67
+ VALID_VISION_BACKBONES = set(VISION_BACKBONE_TO_RESOLUTION.keys())
68
+ VALID_LLM_BACKBONES = set(LLM_BACKBONE_TO_HF_PATH)
69
+ # fmt: on
70
+
71
+
72
+ class PrismaticConfig(PretrainedConfig):
73
+ model_type: str = "prismatic"
74
+ is_composition: bool = False
75
+
76
+ def __init__(
77
+ self,
78
+ vision_backbone_id: str = "siglip-vit-so400m",
79
+ llm_backbone_id: str = "vicuna-v15-7b",
80
+ arch_specifier: str = "no-align+gelu-mlp",
81
+ use_fused_vision_backbone: Optional[bool] = None,
82
+ image_resize_strategy: str = "letterbox",
83
+ text_config: Optional[Dict[str, Any]] = None,
84
+ llm_max_length: int = 2048,
85
+ pad_token_id: int = 32000,
86
+ pad_to_multiple_of: int = 64,
87
+ output_projector_states: bool = False,
88
+ **kwargs: str,
89
+ ) -> None:
90
+ if vision_backbone_id not in VALID_VISION_BACKBONES:
91
+ raise ValueError(f"Vision backbone `{vision_backbone_id}` not in {VALID_VISION_BACKBONES = }")
92
+
93
+ if llm_backbone_id not in VALID_LLM_BACKBONES:
94
+ raise ValueError(f"LLM backbone `{llm_backbone_id}` not in {VALID_LLM_BACKBONES = }")
95
+
96
+ # Set Prismatic Configuration Fields
97
+ self.vision_backbone_id = vision_backbone_id
98
+ self.llm_backbone_id = llm_backbone_id
99
+ self.arch_specifier = arch_specifier
100
+ self.output_projector_states = output_projector_states
101
+
102
+ # [Contract] All vision backbone parameters are lists =>> supports fused backbones with different preprocessing
103
+ self.use_fused_vision_backbone = (
104
+ use_fused_vision_backbone
105
+ if use_fused_vision_backbone is not None
106
+ else any(self.vision_backbone_id.startswith(v) for v in ["dinoclip", "dinosiglip"])
107
+ )
108
+
109
+ self.timm_model_ids = VISION_BACKBONE_TO_TIMM_ID[self.vision_backbone_id]
110
+ self.timm_override_act_layers = TIMM_OVERRIDE_ACT_LAYER[self.vision_backbone_id]
111
+ self.image_sizes = VISION_BACKBONE_TO_RESOLUTION[self.vision_backbone_id]
112
+ self.image_resize_strategy = image_resize_strategy
113
+
114
+ self.hf_llm_id = LLM_BACKBONE_TO_HF_PATH[self.llm_backbone_id]
115
+ self.llm_max_length = llm_max_length
116
+ self.pad_token_id, self.pad_to_multiple_of = pad_token_id, pad_to_multiple_of
117
+
118
+ # [IMPORTANT] HF Utilities actually look for a `text_config` field... we need to use that specific naming!
119
+ self.text_config = (
120
+ CONFIG_MAPPING[LLM_BACKBONE_TO_HF_METACLASS[self.llm_backbone_id]](**text_config)
121
+ if text_config is not None
122
+ else CONFIG_MAPPING[LLM_BACKBONE_TO_HF_METACLASS[self.llm_backbone_id]]()
123
+ )
124
+
125
+ # Dispatch **kwargs to super() =>> note that `pad_token_id` collides, so we pass it in here as well...
126
+ super().__init__(pad_token_id=pad_token_id, **kwargs)
127
+
128
+
129
+ class OpenVLAConfig(PrismaticConfig):
130
+ model_type: str = "openvla"
131
+
132
+ def __init__(
133
+ self,
134
+ norm_stats: Optional[Dict[str, Dict[str, Dict[str, Dict[str, List[float]]]]]] = None,
135
+ n_action_bins: int = 256,
136
+ **kwargs: str,
137
+ ) -> None:
138
+ self.norm_stats, self.n_action_bins = norm_stats, n_action_bins
139
+
140
+ super().__init__(**kwargs)
dataset_statistics.json ADDED
@@ -0,0 +1,133 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "libero_10_no_noops": {
3
+ "action": {
4
+ "mean": [
5
+ 0.01820324920117855,
6
+ 0.05858374014496803,
7
+ -0.05592384561896324,
8
+ 0.004626928828656673,
9
+ 0.00289608770981431,
10
+ -0.007673131301999092,
11
+ 0.5457824468612671
12
+ ],
13
+ "std": [
14
+ 0.2825464606285095,
15
+ 0.35904666781425476,
16
+ 0.3673802614212036,
17
+ 0.03770702704787254,
18
+ 0.05429719388484955,
19
+ 0.08725254982709885,
20
+ 0.49815231561660767
21
+ ],
22
+ "max": [
23
+ 0.9375,
24
+ 0.9375,
25
+ 0.9375,
26
+ 0.30000001192092896,
27
+ 0.29357144236564636,
28
+ 0.375,
29
+ 1.0
30
+ ],
31
+ "min": [
32
+ -0.9375,
33
+ -0.9375,
34
+ -0.9375,
35
+ -0.23642857372760773,
36
+ -0.3053571283817291,
37
+ -0.3675000071525574,
38
+ 0.0
39
+ ],
40
+ "q01": [
41
+ -0.6348214149475098,
42
+ -0.7741071581840515,
43
+ -0.7633928656578064,
44
+ -0.09749999642372131,
45
+ -0.14819999992847435,
46
+ -0.2742857038974762,
47
+ 0.0
48
+ ],
49
+ "q99": [
50
+ 0.7714285850524902,
51
+ 0.8464285731315613,
52
+ 0.9375,
53
+ 0.13928571343421936,
54
+ 0.15964286029338837,
55
+ 0.3246428668498993,
56
+ 1.0
57
+ ],
58
+ "mask": [
59
+ true,
60
+ true,
61
+ true,
62
+ true,
63
+ true,
64
+ true,
65
+ false
66
+ ]
67
+ },
68
+ "proprio": {
69
+ "mean": [
70
+ -0.04190658777952194,
71
+ 0.03539430722594261,
72
+ 0.8257141709327698,
73
+ 2.908308267593384,
74
+ -0.5562185049057007,
75
+ -0.16649018228054047,
76
+ 0.028316624462604523,
77
+ -0.028561657294631004
78
+ ],
79
+ "std": [
80
+ 0.10743364691734314,
81
+ 0.14424669742584229,
82
+ 0.2572328448295593,
83
+ 0.3441362977027893,
84
+ 1.234421730041504,
85
+ 0.3579835891723633,
86
+ 0.013308707624673843,
87
+ 0.013174631632864475
88
+ ],
89
+ "max": [
90
+ 0.21031762659549713,
91
+ 0.39128610491752625,
92
+ 1.3332009315490723,
93
+ 3.6714255809783936,
94
+ 3.560650587081909,
95
+ 1.386339545249939,
96
+ 0.04160946607589722,
97
+ 0.0013633022317662835
98
+ ],
99
+ "min": [
100
+ -0.4828203022480011,
101
+ -0.3255046010017395,
102
+ 0.445506751537323,
103
+ 1.1321442127227783,
104
+ -3.641430377960205,
105
+ -1.842738389968872,
106
+ -0.0010040868073701859,
107
+ -0.04111652821302414
108
+ ],
109
+ "q01": [
110
+ -0.3899900782108307,
111
+ -0.2838300323486328,
112
+ 0.44795057058334353,
113
+ 1.8810229921340942,
114
+ -2.886677579879761,
115
+ -1.1599004411697387,
116
+ 0.002066459748893976,
117
+ -0.04001387819647789
118
+ ],
119
+ "q99": [
120
+ 0.1530261474847791,
121
+ 0.32915401458740223,
122
+ 1.2546923208236693,
123
+ 3.303542451858519,
124
+ 2.7496529006957933,
125
+ 0.6893712210655194,
126
+ 0.040048558115959164,
127
+ -0.0017598449345678235
128
+ ]
129
+ },
130
+ "num_transitions": 101469,
131
+ "num_trajectories": 379
132
+ }
133
+ }
lora_adapter/README.md ADDED
@@ -0,0 +1,202 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ base_model: /home/user1/.cache/huggingface/hub/models--openvla--openvla-7b/snapshots/31f090d05236101ebfc381b61c674dd4746d4ce0
3
+ library_name: peft
4
+ ---
5
+
6
+ # Model Card for Model ID
7
+
8
+ <!-- Provide a quick summary of what the model is/does. -->
9
+
10
+
11
+
12
+ ## Model Details
13
+
14
+ ### Model Description
15
+
16
+ <!-- Provide a longer summary of what this model is. -->
17
+
18
+
19
+
20
+ - **Developed by:** [More Information Needed]
21
+ - **Funded by [optional]:** [More Information Needed]
22
+ - **Shared by [optional]:** [More Information Needed]
23
+ - **Model type:** [More Information Needed]
24
+ - **Language(s) (NLP):** [More Information Needed]
25
+ - **License:** [More Information Needed]
26
+ - **Finetuned from model [optional]:** [More Information Needed]
27
+
28
+ ### Model Sources [optional]
29
+
30
+ <!-- Provide the basic links for the model. -->
31
+
32
+ - **Repository:** [More Information Needed]
33
+ - **Paper [optional]:** [More Information Needed]
34
+ - **Demo [optional]:** [More Information Needed]
35
+
36
+ ## Uses
37
+
38
+ <!-- Address questions around how the model is intended to be used, including the foreseeable users of the model and those affected by the model. -->
39
+
40
+ ### Direct Use
41
+
42
+ <!-- This section is for the model use without fine-tuning or plugging into a larger ecosystem/app. -->
43
+
44
+ [More Information Needed]
45
+
46
+ ### Downstream Use [optional]
47
+
48
+ <!-- This section is for the model use when fine-tuned for a task, or when plugged into a larger ecosystem/app -->
49
+
50
+ [More Information Needed]
51
+
52
+ ### Out-of-Scope Use
53
+
54
+ <!-- This section addresses misuse, malicious use, and uses that the model will not work well for. -->
55
+
56
+ [More Information Needed]
57
+
58
+ ## Bias, Risks, and Limitations
59
+
60
+ <!-- This section is meant to convey both technical and sociotechnical limitations. -->
61
+
62
+ [More Information Needed]
63
+
64
+ ### Recommendations
65
+
66
+ <!-- This section is meant to convey recommendations with respect to the bias, risk, and technical limitations. -->
67
+
68
+ Users (both direct and downstream) should be made aware of the risks, biases and limitations of the model. More information needed for further recommendations.
69
+
70
+ ## How to Get Started with the Model
71
+
72
+ Use the code below to get started with the model.
73
+
74
+ [More Information Needed]
75
+
76
+ ## Training Details
77
+
78
+ ### Training Data
79
+
80
+ <!-- This should link to a Dataset Card, perhaps with a short stub of information on what the training data is all about as well as documentation related to data pre-processing or additional filtering. -->
81
+
82
+ [More Information Needed]
83
+
84
+ ### Training Procedure
85
+
86
+ <!-- This relates heavily to the Technical Specifications. Content here should link to that section when it is relevant to the training procedure. -->
87
+
88
+ #### Preprocessing [optional]
89
+
90
+ [More Information Needed]
91
+
92
+
93
+ #### Training Hyperparameters
94
+
95
+ - **Training regime:** [More Information Needed] <!--fp32, fp16 mixed precision, bf16 mixed precision, bf16 non-mixed precision, fp16 non-mixed precision, fp8 mixed precision -->
96
+
97
+ #### Speeds, Sizes, Times [optional]
98
+
99
+ <!-- This section provides information about throughput, start/end time, checkpoint size if relevant, etc. -->
100
+
101
+ [More Information Needed]
102
+
103
+ ## Evaluation
104
+
105
+ <!-- This section describes the evaluation protocols and provides the results. -->
106
+
107
+ ### Testing Data, Factors & Metrics
108
+
109
+ #### Testing Data
110
+
111
+ <!-- This should link to a Dataset Card if possible. -->
112
+
113
+ [More Information Needed]
114
+
115
+ #### Factors
116
+
117
+ <!-- These are the things the evaluation is disaggregating by, e.g., subpopulations or domains. -->
118
+
119
+ [More Information Needed]
120
+
121
+ #### Metrics
122
+
123
+ <!-- These are the evaluation metrics being used, ideally with a description of why. -->
124
+
125
+ [More Information Needed]
126
+
127
+ ### Results
128
+
129
+ [More Information Needed]
130
+
131
+ #### Summary
132
+
133
+
134
+
135
+ ## Model Examination [optional]
136
+
137
+ <!-- Relevant interpretability work for the model goes here -->
138
+
139
+ [More Information Needed]
140
+
141
+ ## Environmental Impact
142
+
143
+ <!-- Total emissions (in grams of CO2eq) and additional considerations, such as electricity usage, go here. Edit the suggested text below accordingly -->
144
+
145
+ Carbon emissions can be estimated using the [Machine Learning Impact calculator](https://mlco2.github.io/impact#compute) presented in [Lacoste et al. (2019)](https://arxiv.org/abs/1910.09700).
146
+
147
+ - **Hardware Type:** [More Information Needed]
148
+ - **Hours used:** [More Information Needed]
149
+ - **Cloud Provider:** [More Information Needed]
150
+ - **Compute Region:** [More Information Needed]
151
+ - **Carbon Emitted:** [More Information Needed]
152
+
153
+ ## Technical Specifications [optional]
154
+
155
+ ### Model Architecture and Objective
156
+
157
+ [More Information Needed]
158
+
159
+ ### Compute Infrastructure
160
+
161
+ [More Information Needed]
162
+
163
+ #### Hardware
164
+
165
+ [More Information Needed]
166
+
167
+ #### Software
168
+
169
+ [More Information Needed]
170
+
171
+ ## Citation [optional]
172
+
173
+ <!-- If there is a paper or blog post introducing the model, the APA and Bibtex information for that should go in this section. -->
174
+
175
+ **BibTeX:**
176
+
177
+ [More Information Needed]
178
+
179
+ **APA:**
180
+
181
+ [More Information Needed]
182
+
183
+ ## Glossary [optional]
184
+
185
+ <!-- If relevant, include terms and calculations in this section that can help readers understand the model or model card. -->
186
+
187
+ [More Information Needed]
188
+
189
+ ## More Information [optional]
190
+
191
+ [More Information Needed]
192
+
193
+ ## Model Card Authors [optional]
194
+
195
+ [More Information Needed]
196
+
197
+ ## Model Card Contact
198
+
199
+ [More Information Needed]
200
+ ### Framework versions
201
+
202
+ - PEFT 0.11.1
lora_adapter/adapter_config.json ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "alpha_pattern": {},
3
+ "auto_mapping": {
4
+ "base_model_class": "OpenVLAForActionPrediction",
5
+ "parent_library": "transformers_modules.31f090d05236101ebfc381b61c674dd4746d4ce0.modeling_prismatic"
6
+ },
7
+ "base_model_name_or_path": "/home/user1/.cache/huggingface/hub/models--openvla--openvla-7b/snapshots/31f090d05236101ebfc381b61c674dd4746d4ce0",
8
+ "bias": "none",
9
+ "fan_in_fan_out": false,
10
+ "inference_mode": true,
11
+ "init_lora_weights": "gaussian",
12
+ "layer_replication": null,
13
+ "layers_pattern": null,
14
+ "layers_to_transform": null,
15
+ "loftq_config": {},
16
+ "lora_alpha": 16,
17
+ "lora_dropout": 0.0,
18
+ "megatron_config": null,
19
+ "megatron_core": "megatron.core",
20
+ "modules_to_save": null,
21
+ "peft_type": "LORA",
22
+ "r": 32,
23
+ "rank_pattern": {},
24
+ "revision": null,
25
+ "target_modules": [
26
+ "proj",
27
+ "q_proj",
28
+ "fc2",
29
+ "fc1",
30
+ "q",
31
+ "k_proj",
32
+ "qkv",
33
+ "v_proj",
34
+ "lm_head",
35
+ "up_proj",
36
+ "o_proj",
37
+ "kv",
38
+ "gate_proj",
39
+ "fc3",
40
+ "down_proj"
41
+ ],
42
+ "task_type": null,
43
+ "use_dora": false,
44
+ "use_rslora": false
45
+ }
lora_adapter/adapter_model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:cf86f7d0d63ee04b53f473cc822c3e09d5776b01f70c1c6b2292296717349320
3
+ size 484458600
modeling_prismatic.py ADDED
@@ -0,0 +1,1338 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ modeling_prismatic.py
3
+
4
+ Core HuggingFace-style PrismaticPreTrainedModel and PrismaticForConditionalGeneration class definitions.
5
+ Inherits from the default `transformers.PretrainedModel`. Meant to be standalone and self-contained,
6
+ but exactly replicate the logic in `prismatic.models.vlms.prismatic.py`.
7
+ """
8
+
9
+ import logging
10
+ from dataclasses import dataclass
11
+ from functools import partial
12
+ from typing import Any, Callable, ClassVar, Dict, List, Optional, Tuple, Union
13
+
14
+ import numpy as np
15
+ import timm
16
+ import tokenizers
17
+ import torch
18
+ import torch.nn as nn
19
+ import transformers
20
+ from timm.models.vision_transformer import LayerScale
21
+ from transformers import AutoModelForCausalLM, PretrainedConfig, PreTrainedModel
22
+ from transformers.modeling_outputs import ModelOutput
23
+ from prismatic.models.action_heads import L1RegressionActionHead
24
+ import time
25
+ from prismatic.training.train_utils import (
26
+ get_current_action_mask,
27
+ get_next_actions_mask,
28
+ )
29
+ from prismatic.vla.constants import (
30
+ ACTION_DIM,
31
+ ACTION_PROPRIO_NORMALIZATION_TYPE,
32
+ ACTION_TOKEN_BEGIN_IDX,
33
+ IGNORE_INDEX,
34
+ NUM_ACTIONS_CHUNK,
35
+ STOP_INDEX,
36
+ ACTION_TOKEN_IDX,
37
+ NormalizationType,
38
+ )
39
+
40
+ from .configuration_prismatic import OpenVLAConfig, PrismaticConfig
41
+
42
+ # Set up logger
43
+ logger = logging.getLogger(__name__)
44
+
45
+
46
+ # === Utility Functions for Monkey-Patching ===
47
+ def unpack_tuple(fn: Callable[[Any], Tuple[Any]]) -> Callable[[Any], Any]:
48
+ def wrapper(*args: Any, **kwargs: Any) -> Any:
49
+ result = fn(*args, **kwargs)
50
+ return result[0] if isinstance(result, tuple) else result
51
+
52
+ return wrapper
53
+
54
+
55
+ # HF Transformers overwrites parameters with names containing `gamma`; we're going to patch VisionBackbone.LayerScale.
56
+ # =>> TIMM :: https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/vision_transformer.py#L109
57
+ # =>> Transformers :: https://github.com/huggingface/transformers/blob/main/src/transformers/modeling_utils.py#L3960
58
+ def _ls_new_forward(self, x: torch.Tensor) -> torch.Tensor:
59
+ return x.mul_(self.scale_factor) if self.inplace else x * self.scale_factor
60
+
61
+
62
+ def ls_apply_patch(ls_module: LayerScale):
63
+ ls_module.scale_factor = nn.Parameter(ls_module.gamma.clone())
64
+ ls_module.forward = _ls_new_forward.__get__(ls_module, LayerScale)
65
+ del ls_module.gamma
66
+
67
+
68
+ # === Prismatic Vision Backbone (nn.Module) Definitions (w/ Fused Backbone Support) ===
69
+ class PrismaticVisionBackbone(nn.Module):
70
+ """
71
+ Vision backbone for Prismatic models that handles image feature extraction.
72
+
73
+ Supports both single backbone (e.g., SigLIP) and fused backbone (e.g., SigLIP + DINOv2) configurations.
74
+ For fused backbones, features from both models are concatenated along the feature dimension.
75
+ """
76
+
77
+ def __init__(
78
+ self,
79
+ use_fused_vision_backbone: bool,
80
+ image_sizes: List[int],
81
+ timm_model_ids: List[str],
82
+ timm_override_act_layers: List[Optional[str]],
83
+ ) -> None:
84
+ """
85
+ Initialize the vision backbone.
86
+
87
+ Args:
88
+ use_fused_vision_backbone: Whether to use two backbones and fuse their features
89
+ image_sizes: List of image sizes for each backbone
90
+ timm_model_ids: List of TIMM model IDs to use for each backbone
91
+ timm_override_act_layers: List of activation layer overrides for each backbone
92
+ """
93
+ super().__init__()
94
+ self.use_fused_vision_backbone = use_fused_vision_backbone
95
+ self.num_images_in_input = 1 # Default value, can be overridden later
96
+
97
+ # Validate number of (fused) vision backbones
98
+ if len(timm_model_ids) > 2:
99
+ raise ValueError("Prismatic models only support up to 2 (fused) vision backbones!")
100
+
101
+ # Create primary featurizer
102
+ self.featurizer = self._create_featurizer(
103
+ model_id=timm_model_ids[0], img_size=image_sizes[0], act_layer=timm_override_act_layers[0]
104
+ )
105
+ self.embed_dim = self.featurizer.embed_dim
106
+
107
+ # Create secondary featurizer if using fused backbone
108
+ if self.use_fused_vision_backbone:
109
+ self.fused_featurizer = self._create_featurizer(
110
+ model_id=timm_model_ids[1], img_size=image_sizes[1], act_layer=timm_override_act_layers[1]
111
+ )
112
+ self.embed_dim += self.fused_featurizer.embed_dim
113
+
114
+ # Patch LayerScale modules for HF compatibility
115
+ self._patch_layer_scales()
116
+
117
+ def _create_featurizer(self, model_id: str, img_size: int, act_layer: Optional[str]) -> nn.Module:
118
+ """
119
+ Create a TIMM-based featurizer model with appropriate configurations.
120
+
121
+ Args:
122
+ model_id: The TIMM model ID to load
123
+ img_size: Input image size for the model
124
+ act_layer: Override for the activation layer type
125
+
126
+ Returns:
127
+ A configured featurizer model
128
+ """
129
+ featurizer = timm.create_model(
130
+ model_id,
131
+ pretrained=False,
132
+ num_classes=0,
133
+ img_size=img_size,
134
+ act_layer=act_layer,
135
+ )
136
+
137
+ # Monkey-patch the forward function to extract the second-to-last layer features
138
+ num_blocks = len(featurizer.blocks)
139
+ featurizer.forward = unpack_tuple(partial(featurizer.get_intermediate_layers, n={num_blocks - 2}))
140
+
141
+ return featurizer
142
+
143
+ def _patch_layer_scales(self) -> None:
144
+ """
145
+ Patch all LayerScale modules to be compatible with HF's parameter naming.
146
+
147
+ HF Transformers overwrites parameters with names containing 'gamma',
148
+ so we need to rename and modify the forward method.
149
+ """
150
+ # Patch primary featurizer
151
+ for module in self.featurizer.modules():
152
+ if isinstance(module, LayerScale):
153
+ ls_apply_patch(module)
154
+
155
+ # Patch secondary featurizer if it exists
156
+ if self.use_fused_vision_backbone:
157
+ for module in self.fused_featurizer.modules():
158
+ if isinstance(module, LayerScale):
159
+ ls_apply_patch(module)
160
+
161
+ def get_num_patches(self) -> int:
162
+ """
163
+ Returns the number of vision patches output by the vision backbone.
164
+
165
+ Returns:
166
+ Number of patches per image
167
+ """
168
+ return self.featurizer.patch_embed.num_patches
169
+
170
+ def get_num_images_in_input(self) -> int:
171
+ """
172
+ Returns the number of input images for the vision backbone.
173
+
174
+ Returns:
175
+ Number of images expected in the input
176
+ """
177
+ return self.num_images_in_input
178
+
179
+ def set_num_images_in_input(self, num_images_in_input: int) -> None:
180
+ """
181
+ Sets the number of input images for the vision backbone.
182
+
183
+ Args:
184
+ num_images_in_input: Number of images to expect in the input
185
+ """
186
+ self.num_images_in_input = num_images_in_input
187
+
188
+ def forward(self, pixel_values: torch.Tensor) -> torch.Tensor:
189
+ """
190
+ Implements the forward pass for the vision backbone.
191
+
192
+ If `self.use_fused_vision_backbone == True`, uses both SigLIP and DINOv2 transformers to extract visual features
193
+ (otherwise uses SigLIP only). Allows multi-image inputs (but only for fused vision backbone).
194
+
195
+ Args:
196
+ pixel_values (torch.Tensor): Pixels for input image(s), (B, C, H, W).
197
+ """
198
+ if self.num_images_in_input == 1:
199
+ if not self.use_fused_vision_backbone:
200
+ return self.featurizer(pixel_values)
201
+
202
+ # Split `pixel_values :: [bsz, 2 * 3, resolution, resolution]` =>> featurize =>> channel stack
203
+ img, img_fused = torch.split(pixel_values, [3, 3], dim=1)
204
+ patches, patches_fused = self.featurizer(img), self.fused_featurizer(img_fused)
205
+
206
+ return torch.cat([patches, patches_fused], dim=2)
207
+
208
+ else:
209
+ assert self.use_fused_vision_backbone, "Multi-image inputs require using fused backbone!"
210
+
211
+ # Split `pixel_values` into individual images (each with 6 channels: 3 for SigLIP + 3 for DINOv2)
212
+ images = torch.split(pixel_values, [6] * self.num_images_in_input, dim=1)
213
+
214
+ # Process each image and collect patches
215
+ all_patches = []
216
+ for img in images:
217
+ # Split each image further into two stacks of channels (each with 3 channels)
218
+ img_regular, img_fused = torch.split(img, [3, 3], dim=1)
219
+
220
+ # Get patches from both SigLIP and DINOv2 vision transformers
221
+ patches = self.featurizer(img_regular)
222
+ patches_fused = self.fused_featurizer(img_fused)
223
+
224
+ # Concatenate SigLIP and DINOv2 patches along the hidden dimension
225
+ combined_patches = torch.cat([patches, patches_fused], dim=2)
226
+ all_patches.append(combined_patches)
227
+
228
+ # Concatenate all patches along the patch dimension
229
+ return torch.cat(all_patches, dim=1)
230
+
231
+
232
+ # === Prismatic Projector (nn.Module) Definitions ===
233
+ class PrismaticProjector(nn.Module):
234
+ def __init__(self, use_fused_vision_backbone: bool, vision_dim: int, llm_dim: int) -> None:
235
+ super().__init__()
236
+ self.use_fused_vision_backbone = use_fused_vision_backbone
237
+ self.vision_dim, self.llm_dim = vision_dim, llm_dim
238
+
239
+ # Switch on `use_fused_vision_backbone` =>> use slightly different MLPs and projection factors!
240
+ if not self.use_fused_vision_backbone:
241
+ self.fc1 = nn.Linear(self.vision_dim, self.llm_dim, bias=True)
242
+ self.fc2 = nn.Linear(self.llm_dim, self.llm_dim, bias=True)
243
+ self.act_fn1 = nn.GELU()
244
+ else:
245
+ initial_projection_dim = 4 * vision_dim
246
+ self.fc1 = nn.Linear(self.vision_dim, initial_projection_dim, bias=True)
247
+ self.fc2 = nn.Linear(initial_projection_dim, self.llm_dim, bias=True)
248
+ self.fc3 = nn.Linear(self.llm_dim, self.llm_dim, bias=True)
249
+ self.act_fn1 = nn.GELU()
250
+ self.act_fn2 = nn.GELU()
251
+
252
+ def forward(self, img_patches: torch.Tensor) -> torch.Tensor:
253
+ if not self.use_fused_vision_backbone:
254
+ projected_features = self.fc1(img_patches)
255
+ projected_features = self.act_fn1(projected_features)
256
+ projected_features = self.fc2(projected_features)
257
+ else:
258
+ projected_features = self.fc1(img_patches)
259
+ projected_features = self.act_fn1(projected_features)
260
+ projected_features = self.fc2(projected_features)
261
+ projected_features = self.act_fn2(projected_features)
262
+ projected_features = self.fc3(projected_features)
263
+
264
+ return projected_features
265
+
266
+
267
+ # === Main HF Class Definitions ===
268
+ @dataclass
269
+ class PrismaticCausalLMOutputWithPast(ModelOutput):
270
+ """Base class for Prismatic casual (visually-conditioned) language model outputs; also exposes visual features."""
271
+
272
+ loss: Optional[torch.FloatTensor] = None
273
+ logits: torch.FloatTensor = None
274
+ past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
275
+ hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
276
+ attentions: Optional[Tuple[torch.FloatTensor]] = None
277
+
278
+ # Additions for VLMs
279
+ projector_features: Optional[torch.FloatTensor] = None
280
+
281
+
282
+ class PrismaticPreTrainedModel(PreTrainedModel):
283
+ config_class: PretrainedConfig = PrismaticConfig
284
+ base_model_prefix: str = "model"
285
+ supports_gradient_checkpointing: bool = True
286
+
287
+ _no_split_modules: ClassVar[List[str]] = ["PrismaticProjector"]
288
+ _skip_keys_device_placement: str = "past_key_values"
289
+ _supports_flash_attn_2: bool = True
290
+
291
+ def _init_weights(self, module: nn.Module) -> None:
292
+ # Important :: this HF ported version is *not* meant for training from scratch; only inference and fine-tuning!
293
+ # => As such, this init_weights code is not correct; if training VLMs from scratch, use the main codebase at
294
+ # https://github.com/TRI-ML/prismatic-vlms
295
+ std = (
296
+ self.config.initializer_range
297
+ if hasattr(self.config, "initializer_range")
298
+ else self.config.text_config.initializer_range
299
+ )
300
+
301
+ if hasattr(module, "class_embedding"):
302
+ module.class_embedding.data.normal_(mean=0.0, std=std)
303
+
304
+ if isinstance(module, (nn.Linear, nn.Conv2d)):
305
+ module.weight.data.normal_(mean=0.0, std=std)
306
+ if module.bias is not None:
307
+ module.bias.data.zero_()
308
+ elif isinstance(module, nn.Embedding):
309
+ module.weight.data.normal_(mean=0.0, std=std)
310
+ if module.padding_idx is not None:
311
+ module.weight.data[module.padding_idx].zero_()
312
+
313
+ @property
314
+ def _supports_sdpa(self) -> bool:
315
+ """Check LLM supports SDPA Attention"""
316
+ return self.language_model._supports_sdpa
317
+
318
+
319
+ class PrismaticForConditionalGeneration(PrismaticPreTrainedModel):
320
+ def __init__(self, config: PrismaticConfig) -> None:
321
+ super().__init__(config)
322
+
323
+ # [Validation] Lightweight Validate on `config` Fields + Dependency Versions
324
+ if config.use_fused_vision_backbone is None:
325
+ raise ValueError("Missing config field `use_fused_vision_backbone`")
326
+
327
+ if timm.__version__ not in {"0.9.10", "0.9.11", "0.9.12", "0.9.16"}:
328
+ raise NotImplementedError(
329
+ "TIMM Version must be >= 0.9.10 and < 1.0.0 (breaking); please raise a GitHub Issue "
330
+ "if you urgently need support for latest TIMM versions."
331
+ )
332
+
333
+ if (transformers.__version__ != "4.40.1") or (tokenizers.__version__ != "0.19.1"):
334
+ logger.warning(
335
+ f"Expected `transformers==4.40.1` and `tokenizers==0.19.1` but got "
336
+ f"`transformers=={transformers.__version__}` and `tokenizers=={tokenizers.__version__}`; "
337
+ f"there might be inference-time regressions due to dependency changes. If in doubt, please"
338
+ f"use the above versions."
339
+ )
340
+
341
+ # Instantiate PrismaticVisionBackbone (w/ Potential Fused Backbone)
342
+ self.vision_backbone = PrismaticVisionBackbone(
343
+ config.use_fused_vision_backbone, config.image_sizes, config.timm_model_ids, config.timm_override_act_layers
344
+ )
345
+
346
+ # Create Multimodal Projector
347
+ self.projector = PrismaticProjector(
348
+ config.use_fused_vision_backbone,
349
+ vision_dim=self.vision_backbone.embed_dim,
350
+ llm_dim=config.text_config.hidden_size,
351
+ )
352
+
353
+ # Instantiate LLM Backbone
354
+ self.language_model = AutoModelForCausalLM.from_config(
355
+ config.text_config, attn_implementation=config._attn_implementation
356
+ )
357
+ self.vocab_size = config.text_config.vocab_size
358
+ self.pad_token_id = config.pad_token_id
359
+ self.llm_dim = config.text_config.hidden_size
360
+
361
+ # HF Boilerplate =>> initializes weights via `_init_weights()` and sets gradient checkpointing
362
+ self.post_init()
363
+
364
+ # === `PreTrainedModel` Boilerplate ===
365
+ def get_input_embeddings(self) -> nn.Module:
366
+ return self.language_model.get_input_embeddings()
367
+
368
+ def set_input_embeddings(self, value: nn.Module) -> None:
369
+ self.language_model.set_input_embeddings(value)
370
+
371
+ def get_output_embeddings(self) -> nn.Module:
372
+ return self.language_model.get_output_embeddings()
373
+
374
+ def set_output_embeddings(self, new_embeddings: nn.Module) -> None:
375
+ self.language_model.set_output_embeddings(new_embeddings)
376
+
377
+ def get_decoder(self) -> nn.Module:
378
+ return self.language_model.get_decoder()
379
+
380
+ def set_decoder(self, decoder: nn.Module) -> None:
381
+ self.language_model.set_decoder(decoder)
382
+
383
+ def tie_weights(self) -> None:
384
+ self.language_model.tie_weights() # Note: `Llama-2` and `Mistral` don't tie weights (no-op)
385
+
386
+ def resize_token_embeddings(
387
+ self, new_num_tokens: Optional[int] = None, pad_to_multiple_of: Optional[int] = None
388
+ ) -> nn.Embedding:
389
+ updated_embeddings = self.language_model.resize_token_embeddings(new_num_tokens, pad_to_multiple_of)
390
+
391
+ # Update config/instance variables
392
+ self.config.text_config.vocab_size = updated_embeddings.num_embeddings
393
+ self.vocab_size = updated_embeddings.num_embeddings
394
+
395
+ return updated_embeddings
396
+
397
+ def _replace_input_embeddings(self, input_embeddings, all_actions_mask, noisy_action_features):
398
+ """
399
+ Replace embeddings in input_embeddings at positions where all_actions_mask is True
400
+ with embeddings from noisy_action_features, using vectorized operations.
401
+
402
+ Args:
403
+ input_embeddings: Tensor of shape (B, S, D)
404
+ all_actions_mask: Boolean tensor of shape (B, S)
405
+ noisy_action_features: Tensor of shape (B, K, D) where K is the number of True values in mask per sample
406
+
407
+ Returns:
408
+ Modified input_embeddings tensor
409
+ """
410
+ # Clone input to avoid modifying the original tensor
411
+ new_input_embeddings = input_embeddings.clone()
412
+
413
+ # Create a tensor with the same shape of input_embeddings to hold the noisy action features
414
+ repositioned_noisy_action_features = torch.zeros_like(input_embeddings)
415
+
416
+ # Create batch indices for splicing
417
+ batch_indices = torch.arange(input_embeddings.shape[0], device=input_embeddings.device)
418
+ batch_indices = batch_indices.unsqueeze(1).expand(-1, noisy_action_features.shape[1])
419
+
420
+ # Get indices where mask is True for each sample
421
+ masked_indices = torch.stack([torch.where(mask)[0] for mask in all_actions_mask])
422
+
423
+ # Move the noisy action features into their correct positions
424
+ repositioned_noisy_action_features[batch_indices, masked_indices] = noisy_action_features
425
+
426
+ # Combine original input embeddings and noisy action embeddings using the mask
427
+ new_input_embeddings = torch.where(
428
+ all_actions_mask.unsqueeze(-1), repositioned_noisy_action_features, new_input_embeddings
429
+ )
430
+
431
+ return new_input_embeddings
432
+
433
+ def _process_action_masks(self, labels):
434
+ """Helper to get action masks from labels"""
435
+ current_action_mask = get_current_action_mask(labels) # (B, seq_len)
436
+ next_actions_mask = get_next_actions_mask(labels) # (B, seq_len)
437
+ all_actions_mask = current_action_mask | next_actions_mask # (B, seq_len)
438
+ return all_actions_mask
439
+
440
+ def _process_vision_features(self, pixel_values):
441
+ """Process vision features with optional FiLM conditioning"""
442
+ patch_features = self.vision_backbone(pixel_values) # (bsz, 256 * num_images, D)
443
+
444
+ # Project patch embeddings into language embedding space
445
+ return self.projector(patch_features)
446
+
447
+ def _process_proprio_features(self, projected_patch_embeddings, proprio, proprio_projector):
448
+ """Process proprioceptive features and append to vision features"""
449
+ if proprio_projector is not None and proprio is not None:
450
+ # projected_patch_embeddings: (bsz, num_patches * num_images, llm_dim)
451
+ # proprio: (bsz, proprio_dim) or (propro_dim,)
452
+ proprio = proprio.reshape(projected_patch_embeddings.shape[0], -1) # (bsz, proprio_dim)
453
+ proprio_features = proprio_projector(proprio) # (bsz, llm_dim)
454
+ proprio_features = proprio_features.unsqueeze(dim=1) # (bsz, 1, llm_dim)
455
+ # For simplicity, just append proprio token to the end of projected vision patch tokens
456
+ return torch.cat((projected_patch_embeddings, proprio_features), dim=1)
457
+ return projected_patch_embeddings
458
+
459
+ def _build_multimodal_attention(self, input_embeddings, projected_patch_embeddings, attention_mask):
460
+ """Build multimodal embeddings and attention mask"""
461
+ # juyi: Update attention mask 是不是要改成下三角? 不用, 因为generate会自动屏蔽
462
+ projected_patch_attention_mask = None
463
+ if attention_mask is not None:
464
+ projected_patch_attention_mask = torch.full(
465
+ (projected_patch_embeddings.shape[0], projected_patch_embeddings.shape[1]),
466
+ fill_value=True,
467
+ dtype=attention_mask.dtype,
468
+ device=attention_mask.device,
469
+ )
470
+
471
+ # Build multimodal embeddings & attention mask; insert embeddings after <BOS> token (1:)
472
+ multimodal_embeddings = torch.cat(
473
+ [input_embeddings[:, :1, :], projected_patch_embeddings, input_embeddings[:, 1:, :]], dim=1
474
+ )
475
+
476
+ multimodal_attention_mask = None
477
+ if attention_mask is not None:
478
+ multimodal_attention_mask = torch.cat(
479
+ [attention_mask[:, :1], projected_patch_attention_mask, attention_mask[:, 1:]], dim=1
480
+ )
481
+
482
+ return multimodal_embeddings, multimodal_attention_mask
483
+
484
+ def _build_multimodal_labels(self, labels, projected_patch_embeddings):
485
+ """Build multimodal labels with IGNORE_INDEX for patch embeddings"""
486
+ if labels is not None:
487
+ projected_patch_labels = torch.full(
488
+ (projected_patch_embeddings.shape[0], projected_patch_embeddings.shape[1]),
489
+ fill_value=IGNORE_INDEX, # 这些位置不需要计算损失。
490
+ dtype=labels.dtype,
491
+ device=labels.device,
492
+ )
493
+ return torch.cat([labels[:, :1], projected_patch_labels, labels[:, 1:]], dim=1) # 第一个token是<BOS>
494
+ return None
495
+
496
+ # === Core Prismatic VLM `forward()` Logic ===
497
+ def forward(
498
+ self,
499
+ input_ids: Optional[torch.LongTensor] = None,
500
+ attention_mask: Optional[torch.Tensor] = None,
501
+ pixel_values: Optional[torch.FloatTensor] = None,
502
+ labels: Optional[torch.LongTensor] = None,
503
+ inputs_embeds: Optional[torch.FloatTensor] = None,
504
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
505
+ use_cache: Optional[bool] = None,
506
+ output_attentions: Optional[bool] = None,
507
+ output_hidden_states: Optional[bool] = None,
508
+ output_projector_features: Optional[bool] = None,
509
+ return_dict: Optional[bool] = None,
510
+ proprio=None,
511
+ proprio_projector=None,
512
+ noisy_actions=None,
513
+ noisy_action_projector=None,
514
+ diffusion_timestep_embeddings=None,
515
+ use_film: bool = False,
516
+ ) -> Union[Tuple, PrismaticCausalLMOutputWithPast]:
517
+ """Run a forward pass through the VLM, returning a PrismaticCausalLMOutputWithPast instance."""
518
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
519
+ output_hidden_states = (
520
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
521
+ )
522
+ output_projector_features = output_projector_features if output_projector_features is not None else False
523
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
524
+
525
+ # Respect `use_cache` only if not training (even if `gradient_checkpointing` is off)
526
+ use_cache = use_cache and not self.training
527
+
528
+ # Instantiate Placeholder for Projector Features
529
+ projected_patch_embeddings = None
530
+
531
+ # === Handle Generation with Cache (`input_ids.shape[1] == 1`) =>> requires `past_keys_values` ===
532
+ if input_ids.shape[1] == 1:
533
+ assert input_ids.shape[0] == 1, "Generation is only currently supported for batch size of 1!"
534
+ assert past_key_values is not None, "You must provide `past_key_values` during cached generation!"
535
+ assert labels is None, "Unexpected key `labels` provided during cached generation!"
536
+
537
+ language_model_output = self.language_model(
538
+ input_ids=input_ids,
539
+ attention_mask=None,
540
+ position_ids=None,
541
+ past_key_values=past_key_values,
542
+ inputs_embeds=None,
543
+ labels=None,
544
+ use_cache=use_cache,
545
+ output_attentions=output_attentions,
546
+ output_hidden_states=output_hidden_states,
547
+ return_dict=return_dict,
548
+ )
549
+
550
+ # === Handle Unimodal Forward ===
551
+ elif pixel_values is None:
552
+ assert (input_ids is not None) and (inputs_embeds is None), "Missing `input_ids` in language-only forward!"
553
+ assert past_key_values is None, "Unexpected key `past_key_values` provided during language-only forward!"
554
+
555
+ language_model_output = self.language_model(
556
+ input_ids=input_ids,
557
+ attention_mask=attention_mask,
558
+ position_ids=None,
559
+ past_key_values=None,
560
+ inputs_embeds=None,
561
+ labels=labels,
562
+ use_cache=use_cache,
563
+ output_attentions=output_attentions,
564
+ output_hidden_states=output_hidden_states,
565
+ return_dict=return_dict,
566
+ )
567
+
568
+ # === Handle Multimodal Forward ===
569
+ elif (input_ids.shape[0] == pixel_values.shape[0]) or (inputs_embeds.shape[0] == pixel_values.shape[0]):
570
+ assert past_key_values is None, "Unexpected key `past_key_values` provided during multimodal forward!"
571
+
572
+ # Get input embeddings (from language model embeddings)
573
+ input_embeddings = self.get_input_embeddings()(input_ids) # (B, seq_len, D)
574
+
575
+ # Extract action masks
576
+ all_actions_mask = self._process_action_masks(labels)
577
+
578
+ # Extract the language portion of the input embeddings (i.e. remove the action tokens portion)
579
+ language_embeddings = input_embeddings[~all_actions_mask].reshape(
580
+ input_embeddings.shape[0], -1, input_embeddings.shape[2]
581
+ ) # (B, lang_seq_len, llm_dim)
582
+
583
+ # Get visual features
584
+ projected_patch_embeddings = self._process_vision_features(pixel_values, language_embeddings, use_film)
585
+
586
+ # Add proprioceptive state if provided
587
+ projected_patch_embeddings = self._process_proprio_features(
588
+ projected_patch_embeddings, proprio, proprio_projector
589
+ )
590
+
591
+ # [Diffusion] Add diffusion timestep embedding if provided
592
+ if diffusion_timestep_embeddings is not None:
593
+ # For simplicity, just append diffusion timestep embedding to the end of projected vision patch tokens
594
+ projected_patch_embeddings = torch.cat(
595
+ (projected_patch_embeddings, diffusion_timestep_embeddings), dim=1
596
+ )
597
+
598
+ # Process action embeddings
599
+ if noisy_actions is not None:
600
+ # Get mask corresponding to all action tokens
601
+ all_actions_mask = self._process_action_masks(labels)
602
+
603
+ # Reshape noisy actions into individual action tokens
604
+ # noisy_actions: (B, chunk_len, action_dim) -> (B, chunk_len * action_dim, 1)
605
+ B = noisy_actions.shape[0]
606
+ noisy_actions = noisy_actions.reshape(B, -1).unsqueeze(-1)
607
+
608
+ # Project noisy action tokens into language model embedding space
609
+ noisy_action_features = noisy_action_projector(noisy_actions) # (B, chunk_len * action_dim, llm_dim)
610
+
611
+ # Replace embeddings of the action tokens with noisy action embeddings
612
+ input_embeddings = self._replace_input_embeddings(
613
+ input_embeddings, all_actions_mask, noisy_action_features
614
+ )
615
+ else:
616
+ # Replace the embeddings of the action tokens with zeros
617
+ # (Later on, the positional embeddings will be added to them)
618
+ all_actions_mask = all_actions_mask.unsqueeze(-1) # (B, seq_len, 1)
619
+ input_embeddings = input_embeddings * ~all_actions_mask
620
+
621
+ # Build multimodal embeddings & attention mask
622
+ multimodal_embeddings, multimodal_attention_mask = self._build_multimodal_attention(
623
+ input_embeddings, projected_patch_embeddings, attention_mask
624
+ )
625
+
626
+ # Build labels for multimodal sequence if needed
627
+ multimodal_labels = self._build_multimodal_labels(labels, projected_patch_embeddings)
628
+
629
+ # Dispatch to language model
630
+ language_model_output = self.language_model(
631
+ input_ids=None,
632
+ attention_mask=multimodal_attention_mask,
633
+ position_ids=None,
634
+ past_key_values=None,
635
+ inputs_embeds=multimodal_embeddings,
636
+ labels=multimodal_labels,
637
+ use_cache=use_cache,
638
+ output_attentions=output_attentions,
639
+ output_hidden_states=output_hidden_states,
640
+ return_dict=return_dict,
641
+ )
642
+
643
+ # === Otherwise =>> Assume Invalid! ===
644
+ elif (input_ids.shape[0] != pixel_values.shape[0]) or (inputs_embeds.shape[0] != pixel_values.shape[0]):
645
+ raise ValueError("Non-homogenous batch of (text, image) input -- forward() does not support mixed batches!")
646
+
647
+ else:
648
+ raise ValueError(
649
+ "Invalid PrismaticForConditionalGeneration `forward()` call with provided arguments:\n"
650
+ f"=> `input_ids` = {input_ids is not None}\n"
651
+ f"=> `attention_mask` = {attention_mask is not None}\n"
652
+ f"=> `pixel_values` = {pixel_values is not None}\n"
653
+ f"=> `labels` = {labels is not None}\n"
654
+ f"=> `input_embeds` = {inputs_embeds is not None}\n"
655
+ f"=> `past_key_values` = {past_key_values is not None}\n"
656
+ f"=> `use_cache` = {use_cache}"
657
+ )
658
+
659
+ # Unpack `language_model_output` and return PrismaticCausalLMOutputWithPast (or tuple if not `return_dict`)
660
+ if not return_dict:
661
+ if output_projector_features and (projected_patch_embeddings is not None):
662
+ return *language_model_output, projected_patch_embeddings
663
+
664
+ return language_model_output
665
+
666
+ return PrismaticCausalLMOutputWithPast(
667
+ loss=language_model_output.loss,
668
+ logits=language_model_output.logits,
669
+ past_key_values=language_model_output.past_key_values,
670
+ hidden_states=language_model_output.hidden_states,
671
+ attentions=language_model_output.attentions,
672
+ projector_features=projected_patch_embeddings,
673
+ )
674
+
675
+ # === GenerationMixin Methods ===
676
+ def prepare_inputs_for_generation(
677
+ self,
678
+ input_ids: Optional[torch.Tensor] = None,
679
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
680
+ inputs_embeds: Optional[torch.FloatTensor] = None,
681
+ pixel_values: Optional[torch.FloatTensor] = None,
682
+ attention_mask: Optional[torch.Tensor] = None,
683
+ **kwargs: str,
684
+ ) -> Dict[str, torch.Tensor]:
685
+ """Borrowed from `LlamaForCausalLM` and simplified for batch size = 1; mirrors original PrismaticVLM logic."""
686
+ if ((input_ids is not None) and (input_ids.shape[0] > 1)) or (
687
+ (inputs_embeds is not None) and (inputs_embeds.shape[0] > 1)
688
+ ):
689
+ raise ValueError("Generation with batch size > 1 is not currently supported!")
690
+
691
+ # Handle `past_key_values` (cache) =>> assume `input_ids` just has unprocessed tokens
692
+ if past_key_values is not None:
693
+ input_ids = input_ids[:, -1:]
694
+
695
+ # If `input_embeds` are passed, we only want to use them in the 1st generation step
696
+ if inputs_embeds is not None and past_key_values is None:
697
+ model_inputs = {"input_embeds": inputs_embeds}
698
+ else:
699
+ model_inputs = {"input_ids": input_ids}
700
+
701
+ # Make sure `pixel_values` are preserved in `model_inputs`
702
+ model_inputs.update(
703
+ {
704
+ "attention_mask": attention_mask,
705
+ "pixel_values": pixel_values,
706
+ "past_key_values": past_key_values,
707
+ "use_cache": kwargs.get("use_cache"),
708
+ }
709
+ )
710
+
711
+ return model_inputs
712
+
713
+ # Defer to Language Model (all handle this differently, with different return types)
714
+ def _reorder_cache(self, *args, **kwargs) -> Any:
715
+ return self.language_model._reorder_cache(*args, **kwargs)
716
+
717
+
718
+ class OpenVLAForActionPrediction(PrismaticForConditionalGeneration):
719
+ config_class: PretrainedConfig = OpenVLAConfig
720
+
721
+ def __init__(self, config: OpenVLAConfig) -> None:
722
+ super().__init__(config)
723
+ self.norm_stats = config.norm_stats
724
+
725
+ # Compute action bins
726
+ self.bins = np.linspace(-1, 1, config.n_action_bins)
727
+ self.bin_centers = (self.bins[:-1] + self.bins[1:]) / 2.0
728
+
729
+ # Compute vocab size for de-tokenization -- revert added "multiple of"
730
+ self.vocab_size = self.config.text_config.vocab_size - self.config.pad_to_multiple_of
731
+
732
+ def _prepare_input_for_action_prediction(self, input_ids, attention_mask):
733
+ # eval 会用到这里
734
+ """Prepares input for action prediction by adding necessary tokens"""
735
+ # Add (ACTION_DIM * NUM_ACTIONS_CHUNK) placeholder tokens to input_ids to simulate action tokens
736
+ placeholder_action_token_ids = (
737
+ torch.ones((input_ids.shape[0], ACTION_DIM * NUM_ACTIONS_CHUNK)).to(input_ids.device).to(input_ids.dtype)
738
+ )
739
+ input_ids = torch.cat([input_ids, placeholder_action_token_ids], dim=-1) # torch.Size([1, 35 + 56= 91])
740
+
741
+ # Extend the attention mask to fit the new shape of input
742
+ # Note: Only batch size == 1 supported right now
743
+ mask_extension = (
744
+ torch.ones((attention_mask.shape[0], input_ids.shape[-1] - attention_mask.shape[-1]))
745
+ .to(attention_mask.device)
746
+ .to(attention_mask.dtype)
747
+ )
748
+ attention_mask = torch.cat([attention_mask, mask_extension], dim=-1)
749
+
750
+ return input_ids, attention_mask
751
+
752
+ def _prepare_labels_for_action_prediction(self, labels, input_ids):
753
+ """Creates labels tensor for action prediction if not provided"""
754
+ # eval 会用到这里 ,
755
+ # Extends label tensors with fake action labels
756
+ # Adds stop tokens at the end of sequences
757
+ # Handles label preparation for action prediction tasks
758
+ # 他为啥可以随便一个? xuan说 你自定义一个值 ,然后一直指定这个 , PAD token可以吗?
759
+ #TODO: 这里是否要改? 感觉不需要改. 随便写就行了因为labels不重要只是要一个mask. 为什么需要这个函数? 确保 action 预测任务的标签(labels)符合模型的输入长度,并正确地处理序列终止
760
+ # Extend labels tensor with fake action labels
761
+ ARBITRARY_ACTION_TOKEN_IDX = ACTION_TOKEN_IDX # = 为了mask正确生成, action_tokens_only_mask = (labels == ACTION_TOKEN_IDX ), 所以这里也填上ACTION_TOKEN_IDX
762
+ labels_extension = (
763
+ torch.ones((labels.shape[0], input_ids.shape[-1] - labels.shape[-1])).to(labels.device).to(labels.dtype)
764
+ * ARBITRARY_ACTION_TOKEN_IDX
765
+ ) #torch.Size([1, 57]),全是 ARBITRARY_ACTION_TOKEN_IDX
766
+ labels = torch.cat([labels, labels_extension], dim=-1)
767
+
768
+ return labels
769
+
770
+ def _unnormalize_actions(self, normalized_actions, unnorm_key=None):
771
+ """Unnormalize actions using dataset statistics"""
772
+ action_norm_stats = self.get_action_stats(unnorm_key)
773
+
774
+ if ACTION_PROPRIO_NORMALIZATION_TYPE == NormalizationType.BOUNDS:
775
+ mask = action_norm_stats.get("mask", np.ones_like(action_norm_stats["min"], dtype=bool))
776
+ action_high, action_low = np.array(action_norm_stats["max"]), np.array(action_norm_stats["min"])
777
+ elif ACTION_PROPRIO_NORMALIZATION_TYPE == NormalizationType.BOUNDS_Q99:
778
+ mask = action_norm_stats.get("mask", np.ones_like(action_norm_stats["q01"], dtype=bool))
779
+ action_high, action_low = np.array(action_norm_stats["q99"]), np.array(action_norm_stats["q01"])
780
+ else:
781
+ raise ValueError("Unsupported action/proprio normalization type detected!")
782
+
783
+ actions = np.where(
784
+ mask,
785
+ 0.5 * (normalized_actions + 1) * (action_high - action_low + 1e-8) + action_low,
786
+ normalized_actions,
787
+ )
788
+
789
+ return actions
790
+
791
+ def _run_diffusion_prediction(
792
+ self,
793
+ input_embeddings,
794
+ all_actions_mask,
795
+ noise,
796
+ action_head,
797
+ projected_patch_embeddings,
798
+ labels,
799
+ attention_mask,
800
+ NUM_PATCHES,
801
+ NUM_PROMPT_TOKENS,
802
+ noisy_action_projector,
803
+ ):
804
+ """Run diffusion-based action prediction"""
805
+ # Set diffusion timestep values
806
+ action_head.noise_scheduler.set_timesteps(action_head.num_diffusion_steps)
807
+ # Clone embedding for reuse in each timestep
808
+ orig_projected_patch_embeddings = projected_patch_embeddings.clone()
809
+ curr_noisy_actions = noise
810
+
811
+ # Reverse diffusion: Iteratively denoise to generate action prediction
812
+ for t in action_head.noise_scheduler.timesteps:
813
+ # Get diffusion model's noise prediction (conditioned on VLA latent embedding, current noisy action
814
+ # embedding, and diffusion timestep embedding)
815
+ timesteps = torch.Tensor([t]).to(labels.device)
816
+ diffusion_timestep_embeddings = (
817
+ action_head.time_encoder(timesteps).to(curr_noisy_actions.dtype).to(curr_noisy_actions.device)
818
+ ) # (B, llm_dim)
819
+ diffusion_timestep_embeddings = diffusion_timestep_embeddings.unsqueeze(1) # (B, 1, llm_dim)
820
+
821
+ # [Diffusion] Replace the embeddings of the action tokens with noisy actions
822
+ # (Later on, the positional embeddings will be added to them)
823
+
824
+ # For simplicity, append diffusion timestep embedding to the end of projected vision tokens
825
+ projected_patch_embeddings = torch.cat(
826
+ (orig_projected_patch_embeddings, diffusion_timestep_embeddings), dim=1
827
+ )
828
+
829
+ # Reshape and project noisy actions into language embedding space
830
+ B = curr_noisy_actions.shape[0]
831
+ orig_curr_noisy_actions_shape = curr_noisy_actions.shape
832
+ curr_noisy_actions = curr_noisy_actions.reshape(B, -1).unsqueeze(-1)
833
+ noisy_action_features = noisy_action_projector(curr_noisy_actions)
834
+ curr_noisy_actions = curr_noisy_actions.reshape(orig_curr_noisy_actions_shape)
835
+
836
+ # Replace action token embeddings with noisy action embeddings
837
+ input_embeddings = self._replace_input_embeddings(
838
+ input_embeddings.clone(), all_actions_mask, noisy_action_features
839
+ )
840
+
841
+ # Build multimodal embeddings and attention mask
842
+ multimodal_embeddings, multimodal_attention_mask = self._build_multimodal_attention(
843
+ input_embeddings, projected_patch_embeddings, attention_mask
844
+ )
845
+
846
+ # Forward pass through language model
847
+ language_model_output = self.language_model(
848
+ input_ids=None,
849
+ attention_mask=multimodal_attention_mask,
850
+ position_ids=None,
851
+ past_key_values=None,
852
+ inputs_embeds=multimodal_embeddings,
853
+ labels=None,
854
+ use_cache=None,
855
+ output_attentions=False,
856
+ output_hidden_states=True,
857
+ return_dict=True,
858
+ )
859
+
860
+ # Extract hidden states for action portion of response
861
+ last_hidden_states = language_model_output.hidden_states[-1] # (B, seq_len, D)
862
+ actions_hidden_states = last_hidden_states[
863
+ :,
864
+ NUM_PATCHES + NUM_PROMPT_TOKENS : NUM_PATCHES + NUM_PROMPT_TOKENS + ACTION_DIM * NUM_ACTIONS_CHUNK,
865
+ :,
866
+ ] # (B, act_chunk_len, D)
867
+
868
+ # Predict noise and update noisy actions: x_t -> x_{t-1}
869
+ noise_pred = action_head.predict_noise(actions_hidden_states)
870
+ curr_noisy_actions = action_head.noise_scheduler.step(noise_pred, t, curr_noisy_actions).prev_sample
871
+
872
+ curr_noisy_actions = curr_noisy_actions.reshape(NUM_ACTIONS_CHUNK, ACTION_DIM)
873
+
874
+ # Return final actions
875
+ return curr_noisy_actions.float().cpu().detach().numpy(), actions_hidden_states
876
+
877
+ def _regression_or_discrete_prediction(
878
+ self,
879
+ input_embeddings: torch.FloatTensor, #lanage instruction 的embedding.
880
+ all_actions_mask : Optional[torch.BoolTensor], #有啥用? 就是为了提取前面的embedding用. 去掉action .
881
+ projected_patch_embeddings: torch.FloatTensor,
882
+ attention_mask: torch.BoolTensor,
883
+ labels: torch.LongTensor,
884
+ NUM_PATCHES: int,
885
+ NUM_PROMPT_TOKENS: int,
886
+ action_head: L1RegressionActionHead,
887
+ ):
888
+ """Run L1 regression-based continuous action prediction or discrete action tokens prediction."""
889
+ # Extract hidden states for action tokens
890
+ # last_hidden_states = language_model_output.hidden_states[-1] # (B, seq_len, D)
891
+
892
+ # from transformers import AutoProcessor
893
+ # processor = AutoProcessor.from_pretrained("/data/juyi/openvla-7b+fractal20220817_data+b32+lr-5e-05+lora-r32+dropout-0.0--image_aug--test")
894
+ # tokenizer=processor.tokenizer
895
+ # tokenizer.decode(language_model_output.sequences[0])
896
+
897
+ # actions_hidden_states = last_hidden_states[:, NUM_PATCHES + NUM_PROMPT_TOKENS : NUM_PATCHES + NUM_PROMPT_TOKENS + NUM_ACTIONS_CHUNK * tokennum, :]# (B, act_chunk_len, D)
898
+ # 都不需要取了, 直接就给 token对应的hidden state了 ,太方便了.
899
+ # 为什么第一个是torch.Size([1, 535, 4096])? 我应该选哪个? https://discuss.huggingface.co/t/get-each-generated-token-last-layer-hidden-state/145921
900
+ # language_model_output.sequences tensor([[29871, 32001, 32001, 32001, 32001, 32001, 32001, 32001, 32001, 2]], device='cuda:0')
901
+
902
+ multimodal_embeddings, multimodal_attention_mask = self._build_multimodal_attention(
903
+ input_embeddings, projected_patch_embeddings, attention_mask
904
+ )
905
+ # multimodal_embeddings 例子'<s> <512 image token> <pripor token> In: What action should the robot take to open the middle drawer of the cabinet?\nOut:'
906
+ if self.preset:
907
+ # start_prefill = torch.cuda.Event(enable_timing=True)
908
+ # end_prefill = torch.cuda.Event(enable_timing=True)
909
+ # start_prefill.record()
910
+ language_model_output = self.language_model.generate(inputs_embeds=multimodal_embeddings,max_new_tokens=1,output_hidden_states=True,return_dict_in_generate=True)
911
+ # is tuple (1 token, 33 layers, torch.Size([1, 314, 4096]))
912
+ hidden_states = language_model_output.hidden_states[0][-1]
913
+ actions_hidden_states = hidden_states[:, -NUM_ACTIONS_CHUNK:]
914
+ # end_prefill.record()
915
+ # torch.cuda.synchronize()
916
+ # prefill_time = start_prefill.elapsed_time(end_prefill) / 1000
917
+ # print(f"Prefill time: {prefill_time:.4f} seconds")
918
+ else:
919
+ # start_generate = torch.cuda.Event(enable_timing=True)
920
+ # end_generate = torch.cuda.Event(enable_timing=True)
921
+ # start_generate.record()
922
+ language_model_output = self.language_model.generate(inputs_embeds=multimodal_embeddings,max_new_tokens=2048,output_hidden_states=True,return_dict_in_generate=True,use_cache=True)
923
+ # end_generate.record()
924
+ # torch.cuda.synchronize()
925
+ # generate_time = start_generate.elapsed_time(end_generate) / 1000
926
+ # print(f"prefill + Generate time: {generate_time:.4f} seconds")
927
+ actions_hidden_states = torch.stack([language_model_output.hidden_states[i][-1] for i in range(1,NUM_ACTIONS_CHUNK+1)], dim=0) # (action_chunk, batch size, seqence length, hidden_dim)
928
+ actions_hidden_states = actions_hidden_states.transpose(0, 1).squeeze(2) #torch.Size([batch size, action_chunk, hidden_dim])
929
+
930
+ normalized_actions = action_head.predict_action(actions_hidden_states)
931
+ normalized_actions = normalized_actions.reshape(NUM_ACTIONS_CHUNK, ACTION_DIM)
932
+ normalized_actions = normalized_actions.float().cpu().detach().numpy()
933
+
934
+ return normalized_actions, actions_hidden_states
935
+
936
+
937
+
938
+ def mul_regression_or_discrete_prediction(
939
+ self,
940
+ input_embeddings: torch.FloatTensor, #lanage instruction 的embedding.
941
+ all_actions_mask : Optional[torch.BoolTensor], #有啥用? 就是为了提取前面的embedding用. 去掉action .
942
+ projected_patch_embeddings: torch.FloatTensor,
943
+ attention_mask: torch.BoolTensor,
944
+ labels: torch.LongTensor,
945
+ NUM_PATCHES: int,
946
+ NUM_PROMPT_TOKENS: int,
947
+ action_head: L1RegressionActionHead,
948
+ **kwargs,
949
+ ):
950
+ cfg = kwargs.get("cfg", None)
951
+ multimodal_embeddings, multimodal_attention_mask = self._build_multimodal_attention(
952
+ input_embeddings, projected_patch_embeddings, attention_mask
953
+ )
954
+ # multimodal_embeddings 例子'<s> <512 image token> <pripor token> In: What action should the robot take to open the middle drawer of the cabinet?\nOut:'
955
+ # first language_model_output.hidden_states , is tuple (1 token, 33 layers, torch.Size([1, 314, 4096]))
956
+ if self.preset:
957
+ language_model_output = self.language_model.generate(inputs_embeds=multimodal_embeddings,max_new_tokens=1,output_hidden_states=True,return_dict_in_generate=True)
958
+ # assert language_model_output.sequences == torch.tensor([[32001]], device=multimodal_embeddings.device)
959
+ actions_hidden_states = language_model_output.hidden_states[0][-1]
960
+ actions_hidden_states = actions_hidden_states[:, -1]
961
+ else:
962
+ language_model_output = self.language_model.generate(inputs_embeds=multimodal_embeddings,max_new_tokens=2,output_hidden_states=True,return_dict_in_generate=True)
963
+ actions_hidden_states = language_model_output.hidden_states[1][-1]
964
+ actions_hidden_states = actions_hidden_states[:, -1]
965
+
966
+ normalized_actions = action_head.predict_action(actions_hidden_states)
967
+ normalized_actions = normalized_actions.reshape(cfg.num_actions_chunk, ACTION_DIM)
968
+ normalized_actions = normalized_actions.float().cpu().detach().numpy()
969
+
970
+ return normalized_actions, actions_hidden_states
971
+
972
+ def predict_action(
973
+ self,
974
+ input_ids: Optional[torch.LongTensor] = None,
975
+ unnorm_key: Optional[str] = None,
976
+ proprio=None,
977
+ proprio_projector=None,
978
+ action_head=None,
979
+ noisy_action_projector=None,
980
+ use_film: bool = False,
981
+ **kwargs: str,
982
+ ) -> np.ndarray:
983
+ """Predict actions from input sequence, with options for different prediction methods.
984
+
985
+ Args:
986
+ input_ids: Input token ids
987
+ unnorm_key: Key for unnormalization statistics
988
+ proprio: Proprioceptive features
989
+ proprio_projector: Projector for proprioceptive features
990
+ action_head: Optional head for L1 regression or diffusion-based prediction
991
+ noisy_action_projector: Projector for noisy actions in diffusion-based prediction
992
+ use_film: Whether to use FiLM conditioning
993
+ **kwargs: Additional arguments including pixel_values and attention_mask
994
+
995
+ Returns:
996
+ Tuple of (unnormalized_actions, action_hidden_states)
997
+ """
998
+ # If the special empty token ('') does not already appear after the colon (':') token in the prompt
999
+ # (after "OUT:" or "ASSISTANT:"), insert it to match the inputs seen at training time
1000
+ if not torch.all(input_ids[:, -1] == 29871):
1001
+ input_ids = torch.cat(
1002
+ (input_ids, torch.unsqueeze(torch.Tensor([29871]).long(), dim=0).to(input_ids.device)), dim=1
1003
+ )
1004
+
1005
+ pixel_values = kwargs["pixel_values"]
1006
+ attention_mask = kwargs["attention_mask"]
1007
+
1008
+ # Create fake labels tensor (needed for action mask)
1009
+ labels = input_ids.clone()
1010
+ labels[:] = IGNORE_INDEX # 输入都ignore IGNORE_INDEX = -100
1011
+
1012
+ # Get number of tokens in prompt (excluding the start token)
1013
+ NUM_PROMPT_TOKENS = input_ids.shape[-1] - 1 # Subtract action tokens and stop token
1014
+
1015
+ # Prepare inputs by adding necessary tokens
1016
+ input_ids, attention_mask = self._prepare_input_for_action_prediction(input_ids, attention_mask)
1017
+
1018
+ # Update labels tensor for action mask computation later
1019
+ labels = self._prepare_labels_for_action_prediction(labels, input_ids)
1020
+
1021
+ # Get input embeddings and action masks
1022
+ input_embeddings = self.get_input_embeddings()(input_ids)
1023
+ all_actions_mask = self._process_action_masks(labels)
1024
+
1025
+ # Extract language embeddings
1026
+ language_embeddings = input_embeddings[~all_actions_mask].reshape(
1027
+ input_embeddings.shape[0], -1, input_embeddings.shape[2]
1028
+ )
1029
+
1030
+ # Process vision features
1031
+ projected_patch_embeddings = self._process_vision_features(pixel_values, language_embeddings, use_film)
1032
+
1033
+ # Add proprioceptive features if provided
1034
+ use_proprio = proprio_projector is not None and proprio is not None
1035
+ if use_proprio:
1036
+ proprio = torch.Tensor(proprio).to(projected_patch_embeddings.device, dtype=projected_patch_embeddings.dtype)
1037
+ projected_patch_embeddings = self._process_proprio_features(
1038
+ projected_patch_embeddings, proprio, proprio_projector
1039
+ )
1040
+
1041
+ # Use diffusion if provided, otherwise use regression or discrete prediction
1042
+ use_diffusion = noisy_action_projector is not None and hasattr(action_head, "noise_scheduler")
1043
+
1044
+ # Calculate number of patches (including proprio token and/or diffusion timestep embedding if present)
1045
+ NUM_PATCHES = self.vision_backbone.get_num_patches() * self.vision_backbone.get_num_images_in_input()
1046
+ if use_proprio:
1047
+ NUM_PATCHES += 1
1048
+
1049
+ normalized_actions, actions_hidden_states = self._regression_or_discrete_prediction(
1050
+ input_embeddings,
1051
+ all_actions_mask,
1052
+ projected_patch_embeddings,
1053
+ attention_mask,
1054
+ labels,
1055
+ NUM_PATCHES,
1056
+ NUM_PROMPT_TOKENS,
1057
+ action_head,
1058
+ )
1059
+
1060
+ # Unnormalize predicted actions
1061
+ actions = self._unnormalize_actions(normalized_actions, unnorm_key)
1062
+
1063
+ return actions, actions_hidden_states
1064
+
1065
+ @staticmethod
1066
+ def _check_unnorm_key(norm_stats: Dict[str, Dict[str, Any]], unnorm_key: Optional[str]) -> str:
1067
+ """Validate and resolve the unnormalization key for action statistics"""
1068
+ if unnorm_key is None:
1069
+ assert len(norm_stats) == 1, (
1070
+ f"Your model was trained on more than one dataset, "
1071
+ f"please pass a `unnorm_key` from the following options to choose the statistics "
1072
+ f"used for un-normalizing actions: {norm_stats.keys()}"
1073
+ )
1074
+ unnorm_key = next(iter(norm_stats.keys()))
1075
+ # norm states没有加载libero, 为什么?
1076
+ assert unnorm_key in norm_stats, (
1077
+ f"The `unnorm_key` you chose is not in the set of available dataset statistics, "
1078
+ f"please choose from: {norm_stats.keys()}"
1079
+ )
1080
+ return unnorm_key
1081
+
1082
+ def get_action_dim(self, unnorm_key: Optional[str] = None) -> int:
1083
+ """Get the dimensionality of the policy's action space."""
1084
+ unnorm_key = self._check_unnorm_key(self.norm_stats, unnorm_key)
1085
+ return len(self.norm_stats[unnorm_key]["action"]["min"])
1086
+
1087
+ def get_action_stats(self, unnorm_key: Optional[str] = None) -> Dict[str, Any]:
1088
+ """Get all the logged statistics for the given dataset."""
1089
+ unnorm_key = self._check_unnorm_key(self.norm_stats, unnorm_key)
1090
+ return self.norm_stats[unnorm_key]["action"]
1091
+
1092
+
1093
+ def lisa_forward(
1094
+ self,
1095
+ input_ids: Optional[torch.LongTensor] = None,
1096
+ attention_mask: Optional[torch.Tensor] = None,
1097
+ pixel_values: Optional[torch.FloatTensor] = None,
1098
+ labels: Optional[torch.LongTensor] = None,
1099
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1100
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
1101
+ use_cache: Optional[bool] = None,
1102
+ output_attentions: Optional[bool] = None,
1103
+ output_hidden_states: Optional[bool] = None,
1104
+ output_projector_features: Optional[bool] = None,
1105
+ return_dict: Optional[bool] = None,
1106
+ proprio=None,
1107
+ proprio_projector=None,
1108
+ **kwargs
1109
+ ) -> Union[Tuple, PrismaticCausalLMOutputWithPast]:
1110
+ """Run a forward pass through the VLM, returning a PrismaticCausalLMOutputWithPast instance."""
1111
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
1112
+ output_hidden_states = (
1113
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
1114
+ )
1115
+ output_projector_features = output_projector_features if output_projector_features is not None else False
1116
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1117
+
1118
+ # Respect `use_cache` only if not training (even if `gradient_checkpointing` is off)
1119
+ use_cache = use_cache and not self.training
1120
+
1121
+ # Instantiate Placeholder for Projector Features
1122
+ projected_patch_embeddings = None
1123
+
1124
+ # === Handle Generation with Cache (`input_ids.shape[1] == 1`) =>> requires `past_keys_values` ===
1125
+ if input_ids.shape[1] == 1:
1126
+ assert input_ids.shape[0] == 1, "Generation is only currently supported for batch size of 1!"
1127
+ assert past_key_values is not None, "You must provide `past_key_values` during cached generation!"
1128
+ assert labels is None, "Unexpected key `labels` provided during cached generation!"
1129
+
1130
+ language_model_output = self.language_model(
1131
+ input_ids=input_ids,
1132
+ attention_mask=None,
1133
+ position_ids=None,
1134
+ past_key_values=past_key_values,
1135
+ inputs_embeds=None,
1136
+ labels=None,
1137
+ use_cache=use_cache,
1138
+ output_attentions=output_attentions,
1139
+ output_hidden_states=output_hidden_states,
1140
+ return_dict=return_dict,
1141
+ )
1142
+
1143
+ # === Handle Unimodal Forward ===
1144
+ elif pixel_values is None:
1145
+ raise NotImplementedError
1146
+
1147
+ # === Handle Multimodal Forward ===
1148
+ elif (input_ids.shape[0] == pixel_values.shape[0]) or (inputs_embeds.shape[0] == pixel_values.shape[0]):
1149
+ assert past_key_values is None, "Unexpected key `past_key_values` provided during multimodal forward!"
1150
+
1151
+ # Get input embeddings (from language model embeddings)
1152
+ input_embeddings = self.get_input_embeddings()(input_ids) # (B, seq_len, D)
1153
+ # Extract the language portion of the input embeddings (i.e. remove the action tokens portion)
1154
+ # language_embeddings = input_embeddings[~all_actions_mask].reshape(
1155
+ # input_embeddings.shape[0], -1, input_embeddings.shape[2]
1156
+ # ) # (B, lang_seq_len, llm_dim) 这里就会把结尾的 stop index和padding 也算进去. 没问题吗? 没问题因为ignore了 我直接删了因为不用film
1157
+ # Get visual features
1158
+ projected_patch_embeddings = self._process_vision_features(pixel_values)
1159
+
1160
+ # Add proprioceptive state if provided
1161
+ projected_patch_embeddings = self._process_proprio_features(
1162
+ projected_patch_embeddings, proprio, proprio_projector
1163
+ )
1164
+
1165
+ all_actions_mask = (labels == ACTION_TOKEN_IDX) #和run forward pass不一样, run forward pass要手动算token number来找偏移.
1166
+ all_actions_mask = all_actions_mask.unsqueeze(-1) # (B, seq_len, 1)
1167
+ input_embeddings = input_embeddings * ~all_actions_mask
1168
+
1169
+ # Build multimodal embeddings & attention mask
1170
+ multimodal_embeddings, multimodal_attention_mask = self._build_multimodal_attention(
1171
+ input_embeddings, projected_patch_embeddings, attention_mask
1172
+ )
1173
+
1174
+ # Build labels for multimodal sequence if needed
1175
+ multimodal_labels = self._build_multimodal_labels(labels, projected_patch_embeddings)
1176
+
1177
+ # Dispatch to language model
1178
+ language_model_output = self.language_model(
1179
+ input_ids=None,
1180
+ attention_mask=multimodal_attention_mask,
1181
+ position_ids=None,
1182
+ past_key_values=None,
1183
+ inputs_embeds=multimodal_embeddings,
1184
+ labels=multimodal_labels,
1185
+ use_cache=use_cache,
1186
+ output_attentions=output_attentions,
1187
+ output_hidden_states=output_hidden_states,
1188
+ return_dict=return_dict,
1189
+ )
1190
+
1191
+
1192
+ # === Otherwise =>> Assume Invalid! ===
1193
+ elif (input_ids.shape[0] != pixel_values.shape[0]) or (inputs_embeds.shape[0] != pixel_values.shape[0]):
1194
+ raise ValueError("Non-homogenous batch of (text, image) input -- forward() does not support mixed batches!")
1195
+
1196
+ else:
1197
+ raise ValueError(
1198
+ "Invalid PrismaticForConditionalGeneration `forward()` call with provided arguments:\n"
1199
+ f"=> `input_ids` = {input_ids is not None}\n"
1200
+ f"=> `attention_mask` = {attention_mask is not None}\n"
1201
+ f"=> `pixel_values` = {pixel_values is not None}\n"
1202
+ f"=> `labels` = {labels is not None}\n"
1203
+ f"=> `input_embeds` = {inputs_embeds is not None}\n"
1204
+ f"=> `past_key_values` = {past_key_values is not None}\n"
1205
+ f"=> `use_cache` = {use_cache}"
1206
+ )
1207
+
1208
+ # Unpack `language_model_output` and return PrismaticCausalLMOutputWithPast (or tuple if not `return_dict`)
1209
+ if not return_dict:
1210
+ if output_projector_features and (projected_patch_embeddings is not None):
1211
+ return *language_model_output, projected_patch_embeddings
1212
+
1213
+ return language_model_output
1214
+
1215
+ return PrismaticCausalLMOutputWithPast(
1216
+ loss=language_model_output.loss,
1217
+ logits=language_model_output.logits,
1218
+ past_key_values=language_model_output.past_key_values,
1219
+ hidden_states=language_model_output.hidden_states,
1220
+ attentions=language_model_output.attentions,
1221
+ projector_features=projected_patch_embeddings,
1222
+ )
1223
+
1224
+ def lisa_predict_action(
1225
+ self,
1226
+ input_ids: Optional[torch.LongTensor] = None, #就是 language instruction.
1227
+ unnorm_key: Optional[str] = None,
1228
+ proprio=None,
1229
+ proprio_projector=None,
1230
+ action_head:L1RegressionActionHead=None,
1231
+ noisy_action_projector=None,
1232
+ use_film: bool = False,
1233
+ **kwargs: str,
1234
+ ) -> np.ndarray:
1235
+
1236
+ pixel_values = kwargs["pixel_values"]
1237
+ attention_mask = kwargs["attention_mask"]
1238
+
1239
+ # Get number of tokens in prompt (excluding the start token)
1240
+ # NUM_PROMPT_TOKENS = input_ids.shape[-1] - 1 # Subtract action tokens and stop token
1241
+
1242
+ # input id '<s> In: What action should the robot take to open the middle drawer of the cabinet?\nOut:'
1243
+ #预测的时候labels 有啥用? 只是用来设置mask 我们自回归就不用
1244
+ cfg = kwargs.get("cfg", None) # Extract cfg from kwargs
1245
+ if cfg.preset:
1246
+ special_tensor = torch.tensor([[29871]], device=input_ids.device, dtype=input_ids.dtype)
1247
+ output_tensor = torch.tensor([[32001] * NUM_ACTIONS_CHUNK], device=input_ids.device, dtype=input_ids.dtype)
1248
+ input_ids = torch.cat([input_ids, special_tensor, output_tensor], dim=1) # preset action tokens, only forward once.
1249
+ self.preset = True
1250
+ else:
1251
+ self.preset = False
1252
+ input_embeddings = self.get_input_embeddings()(input_ids)
1253
+
1254
+ projected_patch_embeddings = self._process_vision_features(pixel_values)
1255
+
1256
+ use_proprio = proprio_projector is not None and proprio is not None
1257
+ if use_proprio:
1258
+ proprio = torch.Tensor(proprio).to(projected_patch_embeddings.device, dtype=projected_patch_embeddings.dtype)
1259
+ projected_patch_embeddings = self._process_proprio_features(
1260
+ projected_patch_embeddings, proprio, proprio_projector
1261
+ )
1262
+
1263
+ # Calculate number of patches (including proprio token and/or diffusion timestep embedding if present)
1264
+ # NUM_PATCHES = self.vision_backbone.get_num_patches() * self.vision_backbone.get_num_images_in_input()
1265
+ # if use_proprio:
1266
+ # NUM_PATCHES += 1
1267
+ normalized_actions, actions_hidden_states = self._regression_or_discrete_prediction(
1268
+ input_embeddings,
1269
+ None,
1270
+ projected_patch_embeddings,
1271
+ attention_mask,
1272
+ None, #推理不需要labels
1273
+ None, #推理不需要NUM_PATCHES
1274
+ None, #推理不需要NUM_PROMPT_TOKENS
1275
+ action_head,
1276
+ )
1277
+
1278
+ # Unnormalize predicted actions
1279
+ actions = self._unnormalize_actions(normalized_actions, unnorm_key)
1280
+
1281
+ return actions, actions_hidden_states
1282
+
1283
+
1284
+ def mul_predict_action(
1285
+ self,
1286
+ input_ids: Optional[torch.LongTensor] = None, #就是 language instruction.
1287
+ unnorm_key: Optional[str] = None,
1288
+ proprio=None,
1289
+ proprio_projector=None,
1290
+ action_head:L1RegressionActionHead=None,
1291
+ noisy_action_projector=None,
1292
+ use_film: bool = False,
1293
+ **kwargs: str,
1294
+ ) -> np.ndarray:
1295
+ cfg = kwargs.get("cfg", None) # Extract cfg from kwargs
1296
+ if cfg.preset:
1297
+ self.preset = True
1298
+ if not torch.all(input_ids[:, -1] == 29871):
1299
+ input_ids = torch.cat(
1300
+ (input_ids, torch.unsqueeze(torch.Tensor([29871]).long(), dim=0).to(input_ids.device)), dim=1
1301
+ )
1302
+ else:
1303
+ self.preset = False
1304
+
1305
+
1306
+ pixel_values = kwargs["pixel_values"]
1307
+ attention_mask = kwargs["attention_mask"]
1308
+
1309
+ # input id '<s> In: What action should the robot take to open the middle drawer of the cabinet?\nOut:'
1310
+
1311
+ input_embeddings = self.get_input_embeddings()(input_ids)
1312
+
1313
+ projected_patch_embeddings = self._process_vision_features(pixel_values)
1314
+
1315
+ use_proprio = proprio_projector is not None and proprio is not None
1316
+ if use_proprio:
1317
+ proprio = torch.Tensor(proprio).to(projected_patch_embeddings.device, dtype=projected_patch_embeddings.dtype)
1318
+ projected_patch_embeddings = self._process_proprio_features(
1319
+ projected_patch_embeddings, proprio, proprio_projector
1320
+ )
1321
+
1322
+ normalized_actions, actions_hidden_states = self.mul_regression_or_discrete_prediction(
1323
+ input_embeddings,
1324
+ None,
1325
+ projected_patch_embeddings,
1326
+ attention_mask,
1327
+ None, #推理不需要labels
1328
+ None, #推理不需要NUM_PATCHES
1329
+ None, #推理不需要NUM_PROMPT_TOKENS
1330
+ action_head,
1331
+ cfg=cfg,
1332
+ )
1333
+
1334
+ # Unnormalize predicted actions
1335
+ actions = self._unnormalize_actions(normalized_actions, unnorm_key)
1336
+
1337
+ return actions, actions_hidden_states
1338
+
modeling_prismatic.py.back.20250404_145345 ADDED
@@ -0,0 +1,1337 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ modeling_prismatic.py
3
+
4
+ Core HuggingFace-style PrismaticPreTrainedModel and PrismaticForConditionalGeneration class definitions.
5
+ Inherits from the default `transformers.PretrainedModel`. Meant to be standalone and self-contained,
6
+ but exactly replicate the logic in `prismatic.models.vlms.prismatic.py`.
7
+ """
8
+
9
+ import logging
10
+ from dataclasses import dataclass
11
+ from functools import partial
12
+ from typing import Any, Callable, ClassVar, Dict, List, Optional, Tuple, Union
13
+
14
+ import numpy as np
15
+ import timm
16
+ import tokenizers
17
+ import torch
18
+ import torch.nn as nn
19
+ import transformers
20
+ from timm.models.vision_transformer import LayerScale
21
+ from transformers import AutoModelForCausalLM, PretrainedConfig, PreTrainedModel
22
+ from transformers.modeling_outputs import ModelOutput
23
+ from prismatic.models.action_heads import L1RegressionActionHead
24
+ import time
25
+ from prismatic.training.train_utils import (
26
+ get_current_action_mask,
27
+ get_next_actions_mask,
28
+ )
29
+ from prismatic.vla.constants import (
30
+ ACTION_DIM,
31
+ ACTION_PROPRIO_NORMALIZATION_TYPE,
32
+ ACTION_TOKEN_BEGIN_IDX,
33
+ IGNORE_INDEX,
34
+ NUM_ACTIONS_CHUNK,
35
+ STOP_INDEX,
36
+ ACTION_TOKEN_IDX,
37
+ NormalizationType,
38
+ )
39
+
40
+ from .configuration_prismatic import OpenVLAConfig, PrismaticConfig
41
+
42
+ # Set up logger
43
+ logger = logging.getLogger(__name__)
44
+
45
+
46
+ # === Utility Functions for Monkey-Patching ===
47
+ def unpack_tuple(fn: Callable[[Any], Tuple[Any]]) -> Callable[[Any], Any]:
48
+ def wrapper(*args: Any, **kwargs: Any) -> Any:
49
+ result = fn(*args, **kwargs)
50
+ return result[0] if isinstance(result, tuple) else result
51
+
52
+ return wrapper
53
+
54
+
55
+ # HF Transformers overwrites parameters with names containing `gamma`; we're going to patch VisionBackbone.LayerScale.
56
+ # =>> TIMM :: https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/vision_transformer.py#L109
57
+ # =>> Transformers :: https://github.com/huggingface/transformers/blob/main/src/transformers/modeling_utils.py#L3960
58
+ def _ls_new_forward(self, x: torch.Tensor) -> torch.Tensor:
59
+ return x.mul_(self.scale_factor) if self.inplace else x * self.scale_factor
60
+
61
+
62
+ def ls_apply_patch(ls_module: LayerScale):
63
+ ls_module.scale_factor = nn.Parameter(ls_module.gamma.clone())
64
+ ls_module.forward = _ls_new_forward.__get__(ls_module, LayerScale)
65
+ del ls_module.gamma
66
+
67
+
68
+ # === Prismatic Vision Backbone (nn.Module) Definitions (w/ Fused Backbone Support) ===
69
+ class PrismaticVisionBackbone(nn.Module):
70
+ """
71
+ Vision backbone for Prismatic models that handles image feature extraction.
72
+
73
+ Supports both single backbone (e.g., SigLIP) and fused backbone (e.g., SigLIP + DINOv2) configurations.
74
+ For fused backbones, features from both models are concatenated along the feature dimension.
75
+ """
76
+
77
+ def __init__(
78
+ self,
79
+ use_fused_vision_backbone: bool,
80
+ image_sizes: List[int],
81
+ timm_model_ids: List[str],
82
+ timm_override_act_layers: List[Optional[str]],
83
+ ) -> None:
84
+ """
85
+ Initialize the vision backbone.
86
+
87
+ Args:
88
+ use_fused_vision_backbone: Whether to use two backbones and fuse their features
89
+ image_sizes: List of image sizes for each backbone
90
+ timm_model_ids: List of TIMM model IDs to use for each backbone
91
+ timm_override_act_layers: List of activation layer overrides for each backbone
92
+ """
93
+ super().__init__()
94
+ self.use_fused_vision_backbone = use_fused_vision_backbone
95
+ self.num_images_in_input = 1 # Default value, can be overridden later
96
+
97
+ # Validate number of (fused) vision backbones
98
+ if len(timm_model_ids) > 2:
99
+ raise ValueError("Prismatic models only support up to 2 (fused) vision backbones!")
100
+
101
+ # Create primary featurizer
102
+ self.featurizer = self._create_featurizer(
103
+ model_id=timm_model_ids[0], img_size=image_sizes[0], act_layer=timm_override_act_layers[0]
104
+ )
105
+ self.embed_dim = self.featurizer.embed_dim
106
+
107
+ # Create secondary featurizer if using fused backbone
108
+ if self.use_fused_vision_backbone:
109
+ self.fused_featurizer = self._create_featurizer(
110
+ model_id=timm_model_ids[1], img_size=image_sizes[1], act_layer=timm_override_act_layers[1]
111
+ )
112
+ self.embed_dim += self.fused_featurizer.embed_dim
113
+
114
+ # Patch LayerScale modules for HF compatibility
115
+ self._patch_layer_scales()
116
+
117
+ def _create_featurizer(self, model_id: str, img_size: int, act_layer: Optional[str]) -> nn.Module:
118
+ """
119
+ Create a TIMM-based featurizer model with appropriate configurations.
120
+
121
+ Args:
122
+ model_id: The TIMM model ID to load
123
+ img_size: Input image size for the model
124
+ act_layer: Override for the activation layer type
125
+
126
+ Returns:
127
+ A configured featurizer model
128
+ """
129
+ featurizer = timm.create_model(
130
+ model_id,
131
+ pretrained=False,
132
+ num_classes=0,
133
+ img_size=img_size,
134
+ act_layer=act_layer,
135
+ )
136
+
137
+ # Monkey-patch the forward function to extract the second-to-last layer features
138
+ num_blocks = len(featurizer.blocks)
139
+ featurizer.forward = unpack_tuple(partial(featurizer.get_intermediate_layers, n={num_blocks - 2}))
140
+
141
+ return featurizer
142
+
143
+ def _patch_layer_scales(self) -> None:
144
+ """
145
+ Patch all LayerScale modules to be compatible with HF's parameter naming.
146
+
147
+ HF Transformers overwrites parameters with names containing 'gamma',
148
+ so we need to rename and modify the forward method.
149
+ """
150
+ # Patch primary featurizer
151
+ for module in self.featurizer.modules():
152
+ if isinstance(module, LayerScale):
153
+ ls_apply_patch(module)
154
+
155
+ # Patch secondary featurizer if it exists
156
+ if self.use_fused_vision_backbone:
157
+ for module in self.fused_featurizer.modules():
158
+ if isinstance(module, LayerScale):
159
+ ls_apply_patch(module)
160
+
161
+ def get_num_patches(self) -> int:
162
+ """
163
+ Returns the number of vision patches output by the vision backbone.
164
+
165
+ Returns:
166
+ Number of patches per image
167
+ """
168
+ return self.featurizer.patch_embed.num_patches
169
+
170
+ def get_num_images_in_input(self) -> int:
171
+ """
172
+ Returns the number of input images for the vision backbone.
173
+
174
+ Returns:
175
+ Number of images expected in the input
176
+ """
177
+ return self.num_images_in_input
178
+
179
+ def set_num_images_in_input(self, num_images_in_input: int) -> None:
180
+ """
181
+ Sets the number of input images for the vision backbone.
182
+
183
+ Args:
184
+ num_images_in_input: Number of images to expect in the input
185
+ """
186
+ self.num_images_in_input = num_images_in_input
187
+
188
+ def forward(self, pixel_values: torch.Tensor) -> torch.Tensor:
189
+ """
190
+ Implements the forward pass for the vision backbone.
191
+
192
+ If `self.use_fused_vision_backbone == True`, uses both SigLIP and DINOv2 transformers to extract visual features
193
+ (otherwise uses SigLIP only). Allows multi-image inputs (but only for fused vision backbone).
194
+
195
+ Args:
196
+ pixel_values (torch.Tensor): Pixels for input image(s), (B, C, H, W).
197
+ """
198
+ if self.num_images_in_input == 1:
199
+ if not self.use_fused_vision_backbone:
200
+ return self.featurizer(pixel_values)
201
+
202
+ # Split `pixel_values :: [bsz, 2 * 3, resolution, resolution]` =>> featurize =>> channel stack
203
+ img, img_fused = torch.split(pixel_values, [3, 3], dim=1)
204
+ patches, patches_fused = self.featurizer(img), self.fused_featurizer(img_fused)
205
+
206
+ return torch.cat([patches, patches_fused], dim=2)
207
+
208
+ else:
209
+ assert self.use_fused_vision_backbone, "Multi-image inputs require using fused backbone!"
210
+
211
+ # Split `pixel_values` into individual images (each with 6 channels: 3 for SigLIP + 3 for DINOv2)
212
+ images = torch.split(pixel_values, [6] * self.num_images_in_input, dim=1)
213
+
214
+ # Process each image and collect patches
215
+ all_patches = []
216
+ for img in images:
217
+ # Split each image further into two stacks of channels (each with 3 channels)
218
+ img_regular, img_fused = torch.split(img, [3, 3], dim=1)
219
+
220
+ # Get patches from both SigLIP and DINOv2 vision transformers
221
+ patches = self.featurizer(img_regular)
222
+ patches_fused = self.fused_featurizer(img_fused)
223
+
224
+ # Concatenate SigLIP and DINOv2 patches along the hidden dimension
225
+ combined_patches = torch.cat([patches, patches_fused], dim=2)
226
+ all_patches.append(combined_patches)
227
+
228
+ # Concatenate all patches along the patch dimension
229
+ return torch.cat(all_patches, dim=1)
230
+
231
+
232
+ # === Prismatic Projector (nn.Module) Definitions ===
233
+ class PrismaticProjector(nn.Module):
234
+ def __init__(self, use_fused_vision_backbone: bool, vision_dim: int, llm_dim: int) -> None:
235
+ super().__init__()
236
+ self.use_fused_vision_backbone = use_fused_vision_backbone
237
+ self.vision_dim, self.llm_dim = vision_dim, llm_dim
238
+
239
+ # Switch on `use_fused_vision_backbone` =>> use slightly different MLPs and projection factors!
240
+ if not self.use_fused_vision_backbone:
241
+ self.fc1 = nn.Linear(self.vision_dim, self.llm_dim, bias=True)
242
+ self.fc2 = nn.Linear(self.llm_dim, self.llm_dim, bias=True)
243
+ self.act_fn1 = nn.GELU()
244
+ else:
245
+ initial_projection_dim = 4 * vision_dim
246
+ self.fc1 = nn.Linear(self.vision_dim, initial_projection_dim, bias=True)
247
+ self.fc2 = nn.Linear(initial_projection_dim, self.llm_dim, bias=True)
248
+ self.fc3 = nn.Linear(self.llm_dim, self.llm_dim, bias=True)
249
+ self.act_fn1 = nn.GELU()
250
+ self.act_fn2 = nn.GELU()
251
+
252
+ def forward(self, img_patches: torch.Tensor) -> torch.Tensor:
253
+ if not self.use_fused_vision_backbone:
254
+ projected_features = self.fc1(img_patches)
255
+ projected_features = self.act_fn1(projected_features)
256
+ projected_features = self.fc2(projected_features)
257
+ else:
258
+ projected_features = self.fc1(img_patches)
259
+ projected_features = self.act_fn1(projected_features)
260
+ projected_features = self.fc2(projected_features)
261
+ projected_features = self.act_fn2(projected_features)
262
+ projected_features = self.fc3(projected_features)
263
+
264
+ return projected_features
265
+
266
+
267
+ # === Main HF Class Definitions ===
268
+ @dataclass
269
+ class PrismaticCausalLMOutputWithPast(ModelOutput):
270
+ """Base class for Prismatic casual (visually-conditioned) language model outputs; also exposes visual features."""
271
+
272
+ loss: Optional[torch.FloatTensor] = None
273
+ logits: torch.FloatTensor = None
274
+ past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
275
+ hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
276
+ attentions: Optional[Tuple[torch.FloatTensor]] = None
277
+
278
+ # Additions for VLMs
279
+ projector_features: Optional[torch.FloatTensor] = None
280
+
281
+
282
+ class PrismaticPreTrainedModel(PreTrainedModel):
283
+ config_class: PretrainedConfig = PrismaticConfig
284
+ base_model_prefix: str = "model"
285
+ supports_gradient_checkpointing: bool = True
286
+
287
+ _no_split_modules: ClassVar[List[str]] = ["PrismaticProjector"]
288
+ _skip_keys_device_placement: str = "past_key_values"
289
+ _supports_flash_attn_2: bool = True
290
+
291
+ def _init_weights(self, module: nn.Module) -> None:
292
+ # Important :: this HF ported version is *not* meant for training from scratch; only inference and fine-tuning!
293
+ # => As such, this init_weights code is not correct; if training VLMs from scratch, use the main codebase at
294
+ # https://github.com/TRI-ML/prismatic-vlms
295
+ std = (
296
+ self.config.initializer_range
297
+ if hasattr(self.config, "initializer_range")
298
+ else self.config.text_config.initializer_range
299
+ )
300
+
301
+ if hasattr(module, "class_embedding"):
302
+ module.class_embedding.data.normal_(mean=0.0, std=std)
303
+
304
+ if isinstance(module, (nn.Linear, nn.Conv2d)):
305
+ module.weight.data.normal_(mean=0.0, std=std)
306
+ if module.bias is not None:
307
+ module.bias.data.zero_()
308
+ elif isinstance(module, nn.Embedding):
309
+ module.weight.data.normal_(mean=0.0, std=std)
310
+ if module.padding_idx is not None:
311
+ module.weight.data[module.padding_idx].zero_()
312
+
313
+ @property
314
+ def _supports_sdpa(self) -> bool:
315
+ """Check LLM supports SDPA Attention"""
316
+ return self.language_model._supports_sdpa
317
+
318
+
319
+ class PrismaticForConditionalGeneration(PrismaticPreTrainedModel):
320
+ def __init__(self, config: PrismaticConfig) -> None:
321
+ super().__init__(config)
322
+
323
+ # [Validation] Lightweight Validate on `config` Fields + Dependency Versions
324
+ if config.use_fused_vision_backbone is None:
325
+ raise ValueError("Missing config field `use_fused_vision_backbone`")
326
+
327
+ if timm.__version__ not in {"0.9.10", "0.9.11", "0.9.12", "0.9.16"}:
328
+ raise NotImplementedError(
329
+ "TIMM Version must be >= 0.9.10 and < 1.0.0 (breaking); please raise a GitHub Issue "
330
+ "if you urgently need support for latest TIMM versions."
331
+ )
332
+
333
+ if (transformers.__version__ != "4.40.1") or (tokenizers.__version__ != "0.19.1"):
334
+ logger.warning(
335
+ f"Expected `transformers==4.40.1` and `tokenizers==0.19.1` but got "
336
+ f"`transformers=={transformers.__version__}` and `tokenizers=={tokenizers.__version__}`; "
337
+ f"there might be inference-time regressions due to dependency changes. If in doubt, please"
338
+ f"use the above versions."
339
+ )
340
+
341
+ # Instantiate PrismaticVisionBackbone (w/ Potential Fused Backbone)
342
+ self.vision_backbone = PrismaticVisionBackbone(
343
+ config.use_fused_vision_backbone, config.image_sizes, config.timm_model_ids, config.timm_override_act_layers
344
+ )
345
+
346
+ # Create Multimodal Projector
347
+ self.projector = PrismaticProjector(
348
+ config.use_fused_vision_backbone,
349
+ vision_dim=self.vision_backbone.embed_dim,
350
+ llm_dim=config.text_config.hidden_size,
351
+ )
352
+
353
+ # Instantiate LLM Backbone
354
+ self.language_model = AutoModelForCausalLM.from_config(
355
+ config.text_config, attn_implementation=config._attn_implementation
356
+ )
357
+ self.vocab_size = config.text_config.vocab_size
358
+ self.pad_token_id = config.pad_token_id
359
+ self.llm_dim = config.text_config.hidden_size
360
+
361
+ # HF Boilerplate =>> initializes weights via `_init_weights()` and sets gradient checkpointing
362
+ self.post_init()
363
+
364
+ # === `PreTrainedModel` Boilerplate ===
365
+ def get_input_embeddings(self) -> nn.Module:
366
+ return self.language_model.get_input_embeddings()
367
+
368
+ def set_input_embeddings(self, value: nn.Module) -> None:
369
+ self.language_model.set_input_embeddings(value)
370
+
371
+ def get_output_embeddings(self) -> nn.Module:
372
+ return self.language_model.get_output_embeddings()
373
+
374
+ def set_output_embeddings(self, new_embeddings: nn.Module) -> None:
375
+ self.language_model.set_output_embeddings(new_embeddings)
376
+
377
+ def get_decoder(self) -> nn.Module:
378
+ return self.language_model.get_decoder()
379
+
380
+ def set_decoder(self, decoder: nn.Module) -> None:
381
+ self.language_model.set_decoder(decoder)
382
+
383
+ def tie_weights(self) -> None:
384
+ self.language_model.tie_weights() # Note: `Llama-2` and `Mistral` don't tie weights (no-op)
385
+
386
+ def resize_token_embeddings(
387
+ self, new_num_tokens: Optional[int] = None, pad_to_multiple_of: Optional[int] = None
388
+ ) -> nn.Embedding:
389
+ updated_embeddings = self.language_model.resize_token_embeddings(new_num_tokens, pad_to_multiple_of)
390
+
391
+ # Update config/instance variables
392
+ self.config.text_config.vocab_size = updated_embeddings.num_embeddings
393
+ self.vocab_size = updated_embeddings.num_embeddings
394
+
395
+ return updated_embeddings
396
+
397
+ def _replace_input_embeddings(self, input_embeddings, all_actions_mask, noisy_action_features):
398
+ """
399
+ Replace embeddings in input_embeddings at positions where all_actions_mask is True
400
+ with embeddings from noisy_action_features, using vectorized operations.
401
+
402
+ Args:
403
+ input_embeddings: Tensor of shape (B, S, D)
404
+ all_actions_mask: Boolean tensor of shape (B, S)
405
+ noisy_action_features: Tensor of shape (B, K, D) where K is the number of True values in mask per sample
406
+
407
+ Returns:
408
+ Modified input_embeddings tensor
409
+ """
410
+ # Clone input to avoid modifying the original tensor
411
+ new_input_embeddings = input_embeddings.clone()
412
+
413
+ # Create a tensor with the same shape of input_embeddings to hold the noisy action features
414
+ repositioned_noisy_action_features = torch.zeros_like(input_embeddings)
415
+
416
+ # Create batch indices for splicing
417
+ batch_indices = torch.arange(input_embeddings.shape[0], device=input_embeddings.device)
418
+ batch_indices = batch_indices.unsqueeze(1).expand(-1, noisy_action_features.shape[1])
419
+
420
+ # Get indices where mask is True for each sample
421
+ masked_indices = torch.stack([torch.where(mask)[0] for mask in all_actions_mask])
422
+
423
+ # Move the noisy action features into their correct positions
424
+ repositioned_noisy_action_features[batch_indices, masked_indices] = noisy_action_features
425
+
426
+ # Combine original input embeddings and noisy action embeddings using the mask
427
+ new_input_embeddings = torch.where(
428
+ all_actions_mask.unsqueeze(-1), repositioned_noisy_action_features, new_input_embeddings
429
+ )
430
+
431
+ return new_input_embeddings
432
+
433
+ def _process_action_masks(self, labels):
434
+ """Helper to get action masks from labels"""
435
+ current_action_mask = get_current_action_mask(labels) # (B, seq_len)
436
+ next_actions_mask = get_next_actions_mask(labels) # (B, seq_len)
437
+ all_actions_mask = current_action_mask | next_actions_mask # (B, seq_len)
438
+ return all_actions_mask
439
+
440
+ def _process_vision_features(self, pixel_values):
441
+ """Process vision features with optional FiLM conditioning"""
442
+ patch_features = self.vision_backbone(pixel_values) # (bsz, 256 * num_images, D)
443
+
444
+ # Project patch embeddings into language embedding space
445
+ return self.projector(patch_features)
446
+
447
+ def _process_proprio_features(self, projected_patch_embeddings, proprio, proprio_projector):
448
+ """Process proprioceptive features and append to vision features"""
449
+ if proprio_projector is not None and proprio is not None:
450
+ # projected_patch_embeddings: (bsz, num_patches * num_images, llm_dim)
451
+ # proprio: (bsz, proprio_dim) or (propro_dim,)
452
+ proprio = proprio.reshape(projected_patch_embeddings.shape[0], -1) # (bsz, proprio_dim)
453
+ proprio_features = proprio_projector(proprio) # (bsz, llm_dim)
454
+ proprio_features = proprio_features.unsqueeze(dim=1) # (bsz, 1, llm_dim)
455
+ # For simplicity, just append proprio token to the end of projected vision patch tokens
456
+ return torch.cat((projected_patch_embeddings, proprio_features), dim=1)
457
+ return projected_patch_embeddings
458
+
459
+ def _build_multimodal_attention(self, input_embeddings, projected_patch_embeddings, attention_mask):
460
+ """Build multimodal embeddings and attention mask"""
461
+ # juyi: Update attention mask 是不是要改成下三角? 不用, 因为generate会自动屏蔽
462
+ projected_patch_attention_mask = None
463
+ if attention_mask is not None:
464
+ projected_patch_attention_mask = torch.full(
465
+ (projected_patch_embeddings.shape[0], projected_patch_embeddings.shape[1]),
466
+ fill_value=True,
467
+ dtype=attention_mask.dtype,
468
+ device=attention_mask.device,
469
+ )
470
+
471
+ # Build multimodal embeddings & attention mask; insert embeddings after <BOS> token (1:)
472
+ multimodal_embeddings = torch.cat(
473
+ [input_embeddings[:, :1, :], projected_patch_embeddings, input_embeddings[:, 1:, :]], dim=1
474
+ )
475
+
476
+ multimodal_attention_mask = None
477
+ if attention_mask is not None:
478
+ multimodal_attention_mask = torch.cat(
479
+ [attention_mask[:, :1], projected_patch_attention_mask, attention_mask[:, 1:]], dim=1
480
+ )
481
+
482
+ return multimodal_embeddings, multimodal_attention_mask
483
+
484
+ def _build_multimodal_labels(self, labels, projected_patch_embeddings):
485
+ """Build multimodal labels with IGNORE_INDEX for patch embeddings"""
486
+ if labels is not None:
487
+ projected_patch_labels = torch.full(
488
+ (projected_patch_embeddings.shape[0], projected_patch_embeddings.shape[1]),
489
+ fill_value=IGNORE_INDEX, # 这些位置不需要计算损失。
490
+ dtype=labels.dtype,
491
+ device=labels.device,
492
+ )
493
+ return torch.cat([labels[:, :1], projected_patch_labels, labels[:, 1:]], dim=1) # 第一个token是<BOS>
494
+ return None
495
+
496
+ # === Core Prismatic VLM `forward()` Logic ===
497
+ def forward(
498
+ self,
499
+ input_ids: Optional[torch.LongTensor] = None,
500
+ attention_mask: Optional[torch.Tensor] = None,
501
+ pixel_values: Optional[torch.FloatTensor] = None,
502
+ labels: Optional[torch.LongTensor] = None,
503
+ inputs_embeds: Optional[torch.FloatTensor] = None,
504
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
505
+ use_cache: Optional[bool] = None,
506
+ output_attentions: Optional[bool] = None,
507
+ output_hidden_states: Optional[bool] = None,
508
+ output_projector_features: Optional[bool] = None,
509
+ return_dict: Optional[bool] = None,
510
+ proprio=None,
511
+ proprio_projector=None,
512
+ noisy_actions=None,
513
+ noisy_action_projector=None,
514
+ diffusion_timestep_embeddings=None,
515
+ use_film: bool = False,
516
+ ) -> Union[Tuple, PrismaticCausalLMOutputWithPast]:
517
+ """Run a forward pass through the VLM, returning a PrismaticCausalLMOutputWithPast instance."""
518
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
519
+ output_hidden_states = (
520
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
521
+ )
522
+ output_projector_features = output_projector_features if output_projector_features is not None else False
523
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
524
+
525
+ # Respect `use_cache` only if not training (even if `gradient_checkpointing` is off)
526
+ use_cache = use_cache and not self.training
527
+
528
+ # Instantiate Placeholder for Projector Features
529
+ projected_patch_embeddings = None
530
+
531
+ # === Handle Generation with Cache (`input_ids.shape[1] == 1`) =>> requires `past_keys_values` ===
532
+ if input_ids.shape[1] == 1:
533
+ assert input_ids.shape[0] == 1, "Generation is only currently supported for batch size of 1!"
534
+ assert past_key_values is not None, "You must provide `past_key_values` during cached generation!"
535
+ assert labels is None, "Unexpected key `labels` provided during cached generation!"
536
+
537
+ language_model_output = self.language_model(
538
+ input_ids=input_ids,
539
+ attention_mask=None,
540
+ position_ids=None,
541
+ past_key_values=past_key_values,
542
+ inputs_embeds=None,
543
+ labels=None,
544
+ use_cache=use_cache,
545
+ output_attentions=output_attentions,
546
+ output_hidden_states=output_hidden_states,
547
+ return_dict=return_dict,
548
+ )
549
+
550
+ # === Handle Unimodal Forward ===
551
+ elif pixel_values is None:
552
+ assert (input_ids is not None) and (inputs_embeds is None), "Missing `input_ids` in language-only forward!"
553
+ assert past_key_values is None, "Unexpected key `past_key_values` provided during language-only forward!"
554
+
555
+ language_model_output = self.language_model(
556
+ input_ids=input_ids,
557
+ attention_mask=attention_mask,
558
+ position_ids=None,
559
+ past_key_values=None,
560
+ inputs_embeds=None,
561
+ labels=labels,
562
+ use_cache=use_cache,
563
+ output_attentions=output_attentions,
564
+ output_hidden_states=output_hidden_states,
565
+ return_dict=return_dict,
566
+ )
567
+
568
+ # === Handle Multimodal Forward ===
569
+ elif (input_ids.shape[0] == pixel_values.shape[0]) or (inputs_embeds.shape[0] == pixel_values.shape[0]):
570
+ assert past_key_values is None, "Unexpected key `past_key_values` provided during multimodal forward!"
571
+
572
+ # Get input embeddings (from language model embeddings)
573
+ input_embeddings = self.get_input_embeddings()(input_ids) # (B, seq_len, D)
574
+
575
+ # Extract action masks
576
+ all_actions_mask = self._process_action_masks(labels)
577
+
578
+ # Extract the language portion of the input embeddings (i.e. remove the action tokens portion)
579
+ language_embeddings = input_embeddings[~all_actions_mask].reshape(
580
+ input_embeddings.shape[0], -1, input_embeddings.shape[2]
581
+ ) # (B, lang_seq_len, llm_dim)
582
+
583
+ # Get visual features
584
+ projected_patch_embeddings = self._process_vision_features(pixel_values, language_embeddings, use_film)
585
+
586
+ # Add proprioceptive state if provided
587
+ projected_patch_embeddings = self._process_proprio_features(
588
+ projected_patch_embeddings, proprio, proprio_projector
589
+ )
590
+
591
+ # [Diffusion] Add diffusion timestep embedding if provided
592
+ if diffusion_timestep_embeddings is not None:
593
+ # For simplicity, just append diffusion timestep embedding to the end of projected vision patch tokens
594
+ projected_patch_embeddings = torch.cat(
595
+ (projected_patch_embeddings, diffusion_timestep_embeddings), dim=1
596
+ )
597
+
598
+ # Process action embeddings
599
+ if noisy_actions is not None:
600
+ # Get mask corresponding to all action tokens
601
+ all_actions_mask = self._process_action_masks(labels)
602
+
603
+ # Reshape noisy actions into individual action tokens
604
+ # noisy_actions: (B, chunk_len, action_dim) -> (B, chunk_len * action_dim, 1)
605
+ B = noisy_actions.shape[0]
606
+ noisy_actions = noisy_actions.reshape(B, -1).unsqueeze(-1)
607
+
608
+ # Project noisy action tokens into language model embedding space
609
+ noisy_action_features = noisy_action_projector(noisy_actions) # (B, chunk_len * action_dim, llm_dim)
610
+
611
+ # Replace embeddings of the action tokens with noisy action embeddings
612
+ input_embeddings = self._replace_input_embeddings(
613
+ input_embeddings, all_actions_mask, noisy_action_features
614
+ )
615
+ else:
616
+ # Replace the embeddings of the action tokens with zeros
617
+ # (Later on, the positional embeddings will be added to them)
618
+ all_actions_mask = all_actions_mask.unsqueeze(-1) # (B, seq_len, 1)
619
+ input_embeddings = input_embeddings * ~all_actions_mask
620
+
621
+ # Build multimodal embeddings & attention mask
622
+ multimodal_embeddings, multimodal_attention_mask = self._build_multimodal_attention(
623
+ input_embeddings, projected_patch_embeddings, attention_mask
624
+ )
625
+
626
+ # Build labels for multimodal sequence if needed
627
+ multimodal_labels = self._build_multimodal_labels(labels, projected_patch_embeddings)
628
+
629
+ # Dispatch to language model
630
+ language_model_output = self.language_model(
631
+ input_ids=None,
632
+ attention_mask=multimodal_attention_mask,
633
+ position_ids=None,
634
+ past_key_values=None,
635
+ inputs_embeds=multimodal_embeddings,
636
+ labels=multimodal_labels,
637
+ use_cache=use_cache,
638
+ output_attentions=output_attentions,
639
+ output_hidden_states=output_hidden_states,
640
+ return_dict=return_dict,
641
+ )
642
+
643
+ # === Otherwise =>> Assume Invalid! ===
644
+ elif (input_ids.shape[0] != pixel_values.shape[0]) or (inputs_embeds.shape[0] != pixel_values.shape[0]):
645
+ raise ValueError("Non-homogenous batch of (text, image) input -- forward() does not support mixed batches!")
646
+
647
+ else:
648
+ raise ValueError(
649
+ "Invalid PrismaticForConditionalGeneration `forward()` call with provided arguments:\n"
650
+ f"=> `input_ids` = {input_ids is not None}\n"
651
+ f"=> `attention_mask` = {attention_mask is not None}\n"
652
+ f"=> `pixel_values` = {pixel_values is not None}\n"
653
+ f"=> `labels` = {labels is not None}\n"
654
+ f"=> `input_embeds` = {inputs_embeds is not None}\n"
655
+ f"=> `past_key_values` = {past_key_values is not None}\n"
656
+ f"=> `use_cache` = {use_cache}"
657
+ )
658
+
659
+ # Unpack `language_model_output` and return PrismaticCausalLMOutputWithPast (or tuple if not `return_dict`)
660
+ if not return_dict:
661
+ if output_projector_features and (projected_patch_embeddings is not None):
662
+ return *language_model_output, projected_patch_embeddings
663
+
664
+ return language_model_output
665
+
666
+ return PrismaticCausalLMOutputWithPast(
667
+ loss=language_model_output.loss,
668
+ logits=language_model_output.logits,
669
+ past_key_values=language_model_output.past_key_values,
670
+ hidden_states=language_model_output.hidden_states,
671
+ attentions=language_model_output.attentions,
672
+ projector_features=projected_patch_embeddings,
673
+ )
674
+
675
+ # === GenerationMixin Methods ===
676
+ def prepare_inputs_for_generation(
677
+ self,
678
+ input_ids: Optional[torch.Tensor] = None,
679
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
680
+ inputs_embeds: Optional[torch.FloatTensor] = None,
681
+ pixel_values: Optional[torch.FloatTensor] = None,
682
+ attention_mask: Optional[torch.Tensor] = None,
683
+ **kwargs: str,
684
+ ) -> Dict[str, torch.Tensor]:
685
+ """Borrowed from `LlamaForCausalLM` and simplified for batch size = 1; mirrors original PrismaticVLM logic."""
686
+ if ((input_ids is not None) and (input_ids.shape[0] > 1)) or (
687
+ (inputs_embeds is not None) and (inputs_embeds.shape[0] > 1)
688
+ ):
689
+ raise ValueError("Generation with batch size > 1 is not currently supported!")
690
+
691
+ # Handle `past_key_values` (cache) =>> assume `input_ids` just has unprocessed tokens
692
+ if past_key_values is not None:
693
+ input_ids = input_ids[:, -1:]
694
+
695
+ # If `input_embeds` are passed, we only want to use them in the 1st generation step
696
+ if inputs_embeds is not None and past_key_values is None:
697
+ model_inputs = {"input_embeds": inputs_embeds}
698
+ else:
699
+ model_inputs = {"input_ids": input_ids}
700
+
701
+ # Make sure `pixel_values` are preserved in `model_inputs`
702
+ model_inputs.update(
703
+ {
704
+ "attention_mask": attention_mask,
705
+ "pixel_values": pixel_values,
706
+ "past_key_values": past_key_values,
707
+ "use_cache": kwargs.get("use_cache"),
708
+ }
709
+ )
710
+
711
+ return model_inputs
712
+
713
+ # Defer to Language Model (all handle this differently, with different return types)
714
+ def _reorder_cache(self, *args, **kwargs) -> Any:
715
+ return self.language_model._reorder_cache(*args, **kwargs)
716
+
717
+
718
+ class OpenVLAForActionPrediction(PrismaticForConditionalGeneration):
719
+ config_class: PretrainedConfig = OpenVLAConfig
720
+
721
+ def __init__(self, config: OpenVLAConfig) -> None:
722
+ super().__init__(config)
723
+ self.norm_stats = config.norm_stats
724
+
725
+ # Compute action bins
726
+ self.bins = np.linspace(-1, 1, config.n_action_bins)
727
+ self.bin_centers = (self.bins[:-1] + self.bins[1:]) / 2.0
728
+
729
+ # Compute vocab size for de-tokenization -- revert added "multiple of"
730
+ self.vocab_size = self.config.text_config.vocab_size - self.config.pad_to_multiple_of
731
+
732
+ def _prepare_input_for_action_prediction(self, input_ids, attention_mask):
733
+ # eval 会用到这里
734
+ """Prepares input for action prediction by adding necessary tokens"""
735
+ # Add (ACTION_DIM * NUM_ACTIONS_CHUNK) placeholder tokens to input_ids to simulate action tokens
736
+ placeholder_action_token_ids = (
737
+ torch.ones((input_ids.shape[0], ACTION_DIM * NUM_ACTIONS_CHUNK)).to(input_ids.device).to(input_ids.dtype)
738
+ )
739
+ input_ids = torch.cat([input_ids, placeholder_action_token_ids], dim=-1) # torch.Size([1, 35 + 56= 91])
740
+
741
+ # Extend the attention mask to fit the new shape of input
742
+ # Note: Only batch size == 1 supported right now
743
+ mask_extension = (
744
+ torch.ones((attention_mask.shape[0], input_ids.shape[-1] - attention_mask.shape[-1]))
745
+ .to(attention_mask.device)
746
+ .to(attention_mask.dtype)
747
+ )
748
+ attention_mask = torch.cat([attention_mask, mask_extension], dim=-1)
749
+
750
+ return input_ids, attention_mask
751
+
752
+ def _prepare_labels_for_action_prediction(self, labels, input_ids):
753
+ """Creates labels tensor for action prediction if not provided"""
754
+ # eval 会用到这里 ,
755
+ # Extends label tensors with fake action labels
756
+ # Adds stop tokens at the end of sequences
757
+ # Handles label preparation for action prediction tasks
758
+ # 他为啥可以随便一个? xuan说 你自定义一个值 ,然后一直指定这个 , PAD token可以吗?
759
+ #TODO: 这里是否要改? 感觉不需要改. 随便写就行了因为labels不重要只是要一个mask. 为什么需要这个函数? 确保 action 预测任务的标签(labels)符合模型的输入长度,并正确地处理序列终止
760
+ # Extend labels tensor with fake action labels
761
+ ARBITRARY_ACTION_TOKEN_IDX = ACTION_TOKEN_IDX # = 为了mask正确生成, action_tokens_only_mask = (labels == ACTION_TOKEN_IDX ), 所以这里也填上ACTION_TOKEN_IDX
762
+ labels_extension = (
763
+ torch.ones((labels.shape[0], input_ids.shape[-1] - labels.shape[-1])).to(labels.device).to(labels.dtype)
764
+ * ARBITRARY_ACTION_TOKEN_IDX
765
+ ) #torch.Size([1, 57]),全是 ARBITRARY_ACTION_TOKEN_IDX
766
+ labels = torch.cat([labels, labels_extension], dim=-1)
767
+
768
+ return labels
769
+
770
+ def _unnormalize_actions(self, normalized_actions, unnorm_key=None):
771
+ """Unnormalize actions using dataset statistics"""
772
+ action_norm_stats = self.get_action_stats(unnorm_key)
773
+
774
+ if ACTION_PROPRIO_NORMALIZATION_TYPE == NormalizationType.BOUNDS:
775
+ mask = action_norm_stats.get("mask", np.ones_like(action_norm_stats["min"], dtype=bool))
776
+ action_high, action_low = np.array(action_norm_stats["max"]), np.array(action_norm_stats["min"])
777
+ elif ACTION_PROPRIO_NORMALIZATION_TYPE == NormalizationType.BOUNDS_Q99:
778
+ mask = action_norm_stats.get("mask", np.ones_like(action_norm_stats["q01"], dtype=bool))
779
+ action_high, action_low = np.array(action_norm_stats["q99"]), np.array(action_norm_stats["q01"])
780
+ else:
781
+ raise ValueError("Unsupported action/proprio normalization type detected!")
782
+
783
+ actions = np.where(
784
+ mask,
785
+ 0.5 * (normalized_actions + 1) * (action_high - action_low + 1e-8) + action_low,
786
+ normalized_actions,
787
+ )
788
+
789
+ return actions
790
+
791
+ def _run_diffusion_prediction(
792
+ self,
793
+ input_embeddings,
794
+ all_actions_mask,
795
+ noise,
796
+ action_head,
797
+ projected_patch_embeddings,
798
+ labels,
799
+ attention_mask,
800
+ NUM_PATCHES,
801
+ NUM_PROMPT_TOKENS,
802
+ noisy_action_projector,
803
+ ):
804
+ """Run diffusion-based action prediction"""
805
+ # Set diffusion timestep values
806
+ action_head.noise_scheduler.set_timesteps(action_head.num_diffusion_steps)
807
+ # Clone embedding for reuse in each timestep
808
+ orig_projected_patch_embeddings = projected_patch_embeddings.clone()
809
+ curr_noisy_actions = noise
810
+
811
+ # Reverse diffusion: Iteratively denoise to generate action prediction
812
+ for t in action_head.noise_scheduler.timesteps:
813
+ # Get diffusion model's noise prediction (conditioned on VLA latent embedding, current noisy action
814
+ # embedding, and diffusion timestep embedding)
815
+ timesteps = torch.Tensor([t]).to(labels.device)
816
+ diffusion_timestep_embeddings = (
817
+ action_head.time_encoder(timesteps).to(curr_noisy_actions.dtype).to(curr_noisy_actions.device)
818
+ ) # (B, llm_dim)
819
+ diffusion_timestep_embeddings = diffusion_timestep_embeddings.unsqueeze(1) # (B, 1, llm_dim)
820
+
821
+ # [Diffusion] Replace the embeddings of the action tokens with noisy actions
822
+ # (Later on, the positional embeddings will be added to them)
823
+
824
+ # For simplicity, append diffusion timestep embedding to the end of projected vision tokens
825
+ projected_patch_embeddings = torch.cat(
826
+ (orig_projected_patch_embeddings, diffusion_timestep_embeddings), dim=1
827
+ )
828
+
829
+ # Reshape and project noisy actions into language embedding space
830
+ B = curr_noisy_actions.shape[0]
831
+ orig_curr_noisy_actions_shape = curr_noisy_actions.shape
832
+ curr_noisy_actions = curr_noisy_actions.reshape(B, -1).unsqueeze(-1)
833
+ noisy_action_features = noisy_action_projector(curr_noisy_actions)
834
+ curr_noisy_actions = curr_noisy_actions.reshape(orig_curr_noisy_actions_shape)
835
+
836
+ # Replace action token embeddings with noisy action embeddings
837
+ input_embeddings = self._replace_input_embeddings(
838
+ input_embeddings.clone(), all_actions_mask, noisy_action_features
839
+ )
840
+
841
+ # Build multimodal embeddings and attention mask
842
+ multimodal_embeddings, multimodal_attention_mask = self._build_multimodal_attention(
843
+ input_embeddings, projected_patch_embeddings, attention_mask
844
+ )
845
+
846
+ # Forward pass through language model
847
+ language_model_output = self.language_model(
848
+ input_ids=None,
849
+ attention_mask=multimodal_attention_mask,
850
+ position_ids=None,
851
+ past_key_values=None,
852
+ inputs_embeds=multimodal_embeddings,
853
+ labels=None,
854
+ use_cache=None,
855
+ output_attentions=False,
856
+ output_hidden_states=True,
857
+ return_dict=True,
858
+ )
859
+
860
+ # Extract hidden states for action portion of response
861
+ last_hidden_states = language_model_output.hidden_states[-1] # (B, seq_len, D)
862
+ actions_hidden_states = last_hidden_states[
863
+ :,
864
+ NUM_PATCHES + NUM_PROMPT_TOKENS : NUM_PATCHES + NUM_PROMPT_TOKENS + ACTION_DIM * NUM_ACTIONS_CHUNK,
865
+ :,
866
+ ] # (B, act_chunk_len, D)
867
+
868
+ # Predict noise and update noisy actions: x_t -> x_{t-1}
869
+ noise_pred = action_head.predict_noise(actions_hidden_states)
870
+ curr_noisy_actions = action_head.noise_scheduler.step(noise_pred, t, curr_noisy_actions).prev_sample
871
+
872
+ curr_noisy_actions = curr_noisy_actions.reshape(NUM_ACTIONS_CHUNK, ACTION_DIM)
873
+
874
+ # Return final actions
875
+ return curr_noisy_actions.float().cpu().detach().numpy(), actions_hidden_states
876
+
877
+ def _regression_or_discrete_prediction(
878
+ self,
879
+ input_embeddings: torch.FloatTensor, #lanage instruction 的embedding.
880
+ all_actions_mask : Optional[torch.BoolTensor], #有啥用? 就是为了提取前面的embedding用. 去掉action .
881
+ projected_patch_embeddings: torch.FloatTensor,
882
+ attention_mask: torch.BoolTensor,
883
+ labels: torch.LongTensor,
884
+ NUM_PATCHES: int,
885
+ NUM_PROMPT_TOKENS: int,
886
+ action_head: L1RegressionActionHead,
887
+ ):
888
+ """Run L1 regression-based continuous action prediction or discrete action tokens prediction."""
889
+ # Extract hidden states for action tokens
890
+ # last_hidden_states = language_model_output.hidden_states[-1] # (B, seq_len, D)
891
+
892
+ # from transformers import AutoProcessor
893
+ # processor = AutoProcessor.from_pretrained("/data/juyi/openvla-7b+fractal20220817_data+b32+lr-5e-05+lora-r32+dropout-0.0--image_aug--test")
894
+ # tokenizer=processor.tokenizer
895
+ # tokenizer.decode(language_model_output.sequences[0])
896
+
897
+ # actions_hidden_states = last_hidden_states[:, NUM_PATCHES + NUM_PROMPT_TOKENS : NUM_PATCHES + NUM_PROMPT_TOKENS + NUM_ACTIONS_CHUNK * tokennum, :]# (B, act_chunk_len, D)
898
+ # 都不需要取了, 直接就给 token对应的hidden state了 ,太方便了.
899
+ # 为什么第一个是torch.Size([1, 535, 4096])? 我应该选哪个? https://discuss.huggingface.co/t/get-each-generated-token-last-layer-hidden-state/145921
900
+ # language_model_output.sequences tensor([[29871, 32001, 32001, 32001, 32001, 32001, 32001, 32001, 32001, 2]], device='cuda:0')
901
+
902
+ multimodal_embeddings, multimodal_attention_mask = self._build_multimodal_attention(
903
+ input_embeddings, projected_patch_embeddings, attention_mask
904
+ )
905
+ # multimodal_embeddings 例子'<s> <512 image token> <pripor token> In: What action should the robot take to open the middle drawer of the cabinet?\nOut:'
906
+ if self.preset:
907
+ # start_prefill = torch.cuda.Event(enable_timing=True)
908
+ # end_prefill = torch.cuda.Event(enable_timing=True)
909
+ # start_prefill.record()
910
+ language_model_output = self.language_model.generate(inputs_embeds=multimodal_embeddings,max_new_tokens=1,output_hidden_states=True,return_dict_in_generate=True)
911
+ # is tuple (1 token, 33 layers, torch.Size([1, 314, 4096]))
912
+ hidden_states = language_model_output.hidden_states[0][-1]
913
+ actions_hidden_states = hidden_states[:, -NUM_ACTIONS_CHUNK:]
914
+ # end_prefill.record()
915
+ # torch.cuda.synchronize()
916
+ # prefill_time = start_prefill.elapsed_time(end_prefill) / 1000
917
+ # print(f"Prefill time: {prefill_time:.4f} seconds")
918
+ else:
919
+ # start_generate = torch.cuda.Event(enable_timing=True)
920
+ # end_generate = torch.cuda.Event(enable_timing=True)
921
+ # start_generate.record()
922
+ language_model_output = self.language_model.generate(inputs_embeds=multimodal_embeddings,max_new_tokens=2048,output_hidden_states=True,return_dict_in_generate=True,use_cache=True)
923
+ # end_generate.record()
924
+ # torch.cuda.synchronize()
925
+ # generate_time = start_generate.elapsed_time(end_generate) / 1000
926
+ # print(f"prefill + Generate time: {generate_time:.4f} seconds")
927
+ actions_hidden_states = torch.stack([language_model_output.hidden_states[i][-1] for i in range(1,NUM_ACTIONS_CHUNK+1)], dim=0) # (action_chunk, batch size, seqence length, hidden_dim)
928
+ actions_hidden_states = actions_hidden_states.transpose(0, 1).squeeze(2) #torch.Size([batch size, action_chunk, hidden_dim])
929
+
930
+ normalized_actions = action_head.predict_action(actions_hidden_states)
931
+ normalized_actions = normalized_actions.reshape(NUM_ACTIONS_CHUNK, ACTION_DIM)
932
+ normalized_actions = normalized_actions.float().cpu().detach().numpy()
933
+
934
+ return normalized_actions, actions_hidden_states
935
+
936
+
937
+
938
+ def mul_regression_or_discrete_prediction(
939
+ self,
940
+ input_embeddings: torch.FloatTensor, #lanage instruction 的embedding.
941
+ all_actions_mask : Optional[torch.BoolTensor], #有啥用? 就是为了提取前面的embedding用. 去掉action .
942
+ projected_patch_embeddings: torch.FloatTensor,
943
+ attention_mask: torch.BoolTensor,
944
+ labels: torch.LongTensor,
945
+ NUM_PATCHES: int,
946
+ NUM_PROMPT_TOKENS: int,
947
+ action_head: L1RegressionActionHead,
948
+ **kwargs,
949
+ ):
950
+ cfg = kwargs.get("cfg", None)
951
+ multimodal_embeddings, multimodal_attention_mask = self._build_multimodal_attention(
952
+ input_embeddings, projected_patch_embeddings, attention_mask
953
+ )
954
+ # multimodal_embeddings 例子'<s> <512 image token> <pripor token> In: What action should the robot take to open the middle drawer of the cabinet?\nOut:'
955
+ # first language_model_output.hidden_states , is tuple (1 token, 33 layers, torch.Size([1, 314, 4096]))
956
+ if self.preset:
957
+ language_model_output = self.language_model.generate(inputs_embeds=multimodal_embeddings,max_new_tokens=1,output_hidden_states=True,return_dict_in_generate=True)
958
+ actions_hidden_states = language_model_output.hidden_states[0][-1]
959
+ actions_hidden_states = actions_hidden_states[:, -1]
960
+ else:
961
+ language_model_output = self.language_model.generate(inputs_embeds=multimodal_embeddings,max_new_tokens=2,output_hidden_states=True,return_dict_in_generate=True)
962
+ actions_hidden_states = language_model_output.hidden_states[1][-1]
963
+ actions_hidden_states = actions_hidden_states[:, -1]
964
+
965
+ normalized_actions = action_head.predict_action(actions_hidden_states)
966
+ normalized_actions = normalized_actions.reshape(cfg.num_actions_chunk, ACTION_DIM)
967
+ normalized_actions = normalized_actions.float().cpu().detach().numpy()
968
+
969
+ return normalized_actions, actions_hidden_states
970
+
971
+ def predict_action(
972
+ self,
973
+ input_ids: Optional[torch.LongTensor] = None,
974
+ unnorm_key: Optional[str] = None,
975
+ proprio=None,
976
+ proprio_projector=None,
977
+ action_head=None,
978
+ noisy_action_projector=None,
979
+ use_film: bool = False,
980
+ **kwargs: str,
981
+ ) -> np.ndarray:
982
+ """Predict actions from input sequence, with options for different prediction methods.
983
+
984
+ Args:
985
+ input_ids: Input token ids
986
+ unnorm_key: Key for unnormalization statistics
987
+ proprio: Proprioceptive features
988
+ proprio_projector: Projector for proprioceptive features
989
+ action_head: Optional head for L1 regression or diffusion-based prediction
990
+ noisy_action_projector: Projector for noisy actions in diffusion-based prediction
991
+ use_film: Whether to use FiLM conditioning
992
+ **kwargs: Additional arguments including pixel_values and attention_mask
993
+
994
+ Returns:
995
+ Tuple of (unnormalized_actions, action_hidden_states)
996
+ """
997
+ # If the special empty token ('') does not already appear after the colon (':') token in the prompt
998
+ # (after "OUT:" or "ASSISTANT:"), insert it to match the inputs seen at training time
999
+ if not torch.all(input_ids[:, -1] == 29871):
1000
+ input_ids = torch.cat(
1001
+ (input_ids, torch.unsqueeze(torch.Tensor([29871]).long(), dim=0).to(input_ids.device)), dim=1
1002
+ )
1003
+
1004
+ pixel_values = kwargs["pixel_values"]
1005
+ attention_mask = kwargs["attention_mask"]
1006
+
1007
+ # Create fake labels tensor (needed for action mask)
1008
+ labels = input_ids.clone()
1009
+ labels[:] = IGNORE_INDEX # 输入都ignore IGNORE_INDEX = -100
1010
+
1011
+ # Get number of tokens in prompt (excluding the start token)
1012
+ NUM_PROMPT_TOKENS = input_ids.shape[-1] - 1 # Subtract action tokens and stop token
1013
+
1014
+ # Prepare inputs by adding necessary tokens
1015
+ input_ids, attention_mask = self._prepare_input_for_action_prediction(input_ids, attention_mask)
1016
+
1017
+ # Update labels tensor for action mask computation later
1018
+ labels = self._prepare_labels_for_action_prediction(labels, input_ids)
1019
+
1020
+ # Get input embeddings and action masks
1021
+ input_embeddings = self.get_input_embeddings()(input_ids)
1022
+ all_actions_mask = self._process_action_masks(labels)
1023
+
1024
+ # Extract language embeddings
1025
+ language_embeddings = input_embeddings[~all_actions_mask].reshape(
1026
+ input_embeddings.shape[0], -1, input_embeddings.shape[2]
1027
+ )
1028
+
1029
+ # Process vision features
1030
+ projected_patch_embeddings = self._process_vision_features(pixel_values, language_embeddings, use_film)
1031
+
1032
+ # Add proprioceptive features if provided
1033
+ use_proprio = proprio_projector is not None and proprio is not None
1034
+ if use_proprio:
1035
+ proprio = torch.Tensor(proprio).to(projected_patch_embeddings.device, dtype=projected_patch_embeddings.dtype)
1036
+ projected_patch_embeddings = self._process_proprio_features(
1037
+ projected_patch_embeddings, proprio, proprio_projector
1038
+ )
1039
+
1040
+ # Use diffusion if provided, otherwise use regression or discrete prediction
1041
+ use_diffusion = noisy_action_projector is not None and hasattr(action_head, "noise_scheduler")
1042
+
1043
+ # Calculate number of patches (including proprio token and/or diffusion timestep embedding if present)
1044
+ NUM_PATCHES = self.vision_backbone.get_num_patches() * self.vision_backbone.get_num_images_in_input()
1045
+ if use_proprio:
1046
+ NUM_PATCHES += 1
1047
+
1048
+ normalized_actions, actions_hidden_states = self._regression_or_discrete_prediction(
1049
+ input_embeddings,
1050
+ all_actions_mask,
1051
+ projected_patch_embeddings,
1052
+ attention_mask,
1053
+ labels,
1054
+ NUM_PATCHES,
1055
+ NUM_PROMPT_TOKENS,
1056
+ action_head,
1057
+ )
1058
+
1059
+ # Unnormalize predicted actions
1060
+ actions = self._unnormalize_actions(normalized_actions, unnorm_key)
1061
+
1062
+ return actions, actions_hidden_states
1063
+
1064
+ @staticmethod
1065
+ def _check_unnorm_key(norm_stats: Dict[str, Dict[str, Any]], unnorm_key: Optional[str]) -> str:
1066
+ """Validate and resolve the unnormalization key for action statistics"""
1067
+ if unnorm_key is None:
1068
+ assert len(norm_stats) == 1, (
1069
+ f"Your model was trained on more than one dataset, "
1070
+ f"please pass a `unnorm_key` from the following options to choose the statistics "
1071
+ f"used for un-normalizing actions: {norm_stats.keys()}"
1072
+ )
1073
+ unnorm_key = next(iter(norm_stats.keys()))
1074
+ # norm states没有加载libero, 为什么?
1075
+ assert unnorm_key in norm_stats, (
1076
+ f"The `unnorm_key` you chose is not in the set of available dataset statistics, "
1077
+ f"please choose from: {norm_stats.keys()}"
1078
+ )
1079
+ return unnorm_key
1080
+
1081
+ def get_action_dim(self, unnorm_key: Optional[str] = None) -> int:
1082
+ """Get the dimensionality of the policy's action space."""
1083
+ unnorm_key = self._check_unnorm_key(self.norm_stats, unnorm_key)
1084
+ return len(self.norm_stats[unnorm_key]["action"]["min"])
1085
+
1086
+ def get_action_stats(self, unnorm_key: Optional[str] = None) -> Dict[str, Any]:
1087
+ """Get all the logged statistics for the given dataset."""
1088
+ unnorm_key = self._check_unnorm_key(self.norm_stats, unnorm_key)
1089
+ return self.norm_stats[unnorm_key]["action"]
1090
+
1091
+
1092
+ def lisa_forward(
1093
+ self,
1094
+ input_ids: Optional[torch.LongTensor] = None,
1095
+ attention_mask: Optional[torch.Tensor] = None,
1096
+ pixel_values: Optional[torch.FloatTensor] = None,
1097
+ labels: Optional[torch.LongTensor] = None,
1098
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1099
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
1100
+ use_cache: Optional[bool] = None,
1101
+ output_attentions: Optional[bool] = None,
1102
+ output_hidden_states: Optional[bool] = None,
1103
+ output_projector_features: Optional[bool] = None,
1104
+ return_dict: Optional[bool] = None,
1105
+ proprio=None,
1106
+ proprio_projector=None,
1107
+ **kwargs
1108
+ ) -> Union[Tuple, PrismaticCausalLMOutputWithPast]:
1109
+ """Run a forward pass through the VLM, returning a PrismaticCausalLMOutputWithPast instance."""
1110
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
1111
+ output_hidden_states = (
1112
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
1113
+ )
1114
+ output_projector_features = output_projector_features if output_projector_features is not None else False
1115
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1116
+
1117
+ # Respect `use_cache` only if not training (even if `gradient_checkpointing` is off)
1118
+ use_cache = use_cache and not self.training
1119
+
1120
+ # Instantiate Placeholder for Projector Features
1121
+ projected_patch_embeddings = None
1122
+
1123
+ # === Handle Generation with Cache (`input_ids.shape[1] == 1`) =>> requires `past_keys_values` ===
1124
+ if input_ids.shape[1] == 1:
1125
+ assert input_ids.shape[0] == 1, "Generation is only currently supported for batch size of 1!"
1126
+ assert past_key_values is not None, "You must provide `past_key_values` during cached generation!"
1127
+ assert labels is None, "Unexpected key `labels` provided during cached generation!"
1128
+
1129
+ language_model_output = self.language_model(
1130
+ input_ids=input_ids,
1131
+ attention_mask=None,
1132
+ position_ids=None,
1133
+ past_key_values=past_key_values,
1134
+ inputs_embeds=None,
1135
+ labels=None,
1136
+ use_cache=use_cache,
1137
+ output_attentions=output_attentions,
1138
+ output_hidden_states=output_hidden_states,
1139
+ return_dict=return_dict,
1140
+ )
1141
+
1142
+ # === Handle Unimodal Forward ===
1143
+ elif pixel_values is None:
1144
+ raise NotImplementedError
1145
+
1146
+ # === Handle Multimodal Forward ===
1147
+ elif (input_ids.shape[0] == pixel_values.shape[0]) or (inputs_embeds.shape[0] == pixel_values.shape[0]):
1148
+ assert past_key_values is None, "Unexpected key `past_key_values` provided during multimodal forward!"
1149
+
1150
+ # Get input embeddings (from language model embeddings)
1151
+ input_embeddings = self.get_input_embeddings()(input_ids) # (B, seq_len, D)
1152
+ # Extract the language portion of the input embeddings (i.e. remove the action tokens portion)
1153
+ # language_embeddings = input_embeddings[~all_actions_mask].reshape(
1154
+ # input_embeddings.shape[0], -1, input_embeddings.shape[2]
1155
+ # ) # (B, lang_seq_len, llm_dim) 这里就会把结尾的 stop index和padding 也算进去. 没问题吗? 没问题因为ignore了 我直接删了因为不用film
1156
+ # Get visual features
1157
+ projected_patch_embeddings = self._process_vision_features(pixel_values)
1158
+
1159
+ # Add proprioceptive state if provided
1160
+ projected_patch_embeddings = self._process_proprio_features(
1161
+ projected_patch_embeddings, proprio, proprio_projector
1162
+ )
1163
+
1164
+ all_actions_mask = (labels == ACTION_TOKEN_IDX) #和run forward pass不一样, run forward pass要手动算token number来找偏移.
1165
+ all_actions_mask = all_actions_mask.unsqueeze(-1) # (B, seq_len, 1)
1166
+ input_embeddings = input_embeddings * ~all_actions_mask
1167
+
1168
+ # Build multimodal embeddings & attention mask
1169
+ multimodal_embeddings, multimodal_attention_mask = self._build_multimodal_attention(
1170
+ input_embeddings, projected_patch_embeddings, attention_mask
1171
+ )
1172
+
1173
+ # Build labels for multimodal sequence if needed
1174
+ multimodal_labels = self._build_multimodal_labels(labels, projected_patch_embeddings)
1175
+
1176
+ # Dispatch to language model
1177
+ language_model_output = self.language_model(
1178
+ input_ids=None,
1179
+ attention_mask=multimodal_attention_mask,
1180
+ position_ids=None,
1181
+ past_key_values=None,
1182
+ inputs_embeds=multimodal_embeddings,
1183
+ labels=multimodal_labels,
1184
+ use_cache=use_cache,
1185
+ output_attentions=output_attentions,
1186
+ output_hidden_states=output_hidden_states,
1187
+ return_dict=return_dict,
1188
+ )
1189
+
1190
+
1191
+ # === Otherwise =>> Assume Invalid! ===
1192
+ elif (input_ids.shape[0] != pixel_values.shape[0]) or (inputs_embeds.shape[0] != pixel_values.shape[0]):
1193
+ raise ValueError("Non-homogenous batch of (text, image) input -- forward() does not support mixed batches!")
1194
+
1195
+ else:
1196
+ raise ValueError(
1197
+ "Invalid PrismaticForConditionalGeneration `forward()` call with provided arguments:\n"
1198
+ f"=> `input_ids` = {input_ids is not None}\n"
1199
+ f"=> `attention_mask` = {attention_mask is not None}\n"
1200
+ f"=> `pixel_values` = {pixel_values is not None}\n"
1201
+ f"=> `labels` = {labels is not None}\n"
1202
+ f"=> `input_embeds` = {inputs_embeds is not None}\n"
1203
+ f"=> `past_key_values` = {past_key_values is not None}\n"
1204
+ f"=> `use_cache` = {use_cache}"
1205
+ )
1206
+
1207
+ # Unpack `language_model_output` and return PrismaticCausalLMOutputWithPast (or tuple if not `return_dict`)
1208
+ if not return_dict:
1209
+ if output_projector_features and (projected_patch_embeddings is not None):
1210
+ return *language_model_output, projected_patch_embeddings
1211
+
1212
+ return language_model_output
1213
+
1214
+ return PrismaticCausalLMOutputWithPast(
1215
+ loss=language_model_output.loss,
1216
+ logits=language_model_output.logits,
1217
+ past_key_values=language_model_output.past_key_values,
1218
+ hidden_states=language_model_output.hidden_states,
1219
+ attentions=language_model_output.attentions,
1220
+ projector_features=projected_patch_embeddings,
1221
+ )
1222
+
1223
+ def lisa_predict_action(
1224
+ self,
1225
+ input_ids: Optional[torch.LongTensor] = None, #就是 language instruction.
1226
+ unnorm_key: Optional[str] = None,
1227
+ proprio=None,
1228
+ proprio_projector=None,
1229
+ action_head:L1RegressionActionHead=None,
1230
+ noisy_action_projector=None,
1231
+ use_film: bool = False,
1232
+ **kwargs: str,
1233
+ ) -> np.ndarray:
1234
+
1235
+ pixel_values = kwargs["pixel_values"]
1236
+ attention_mask = kwargs["attention_mask"]
1237
+
1238
+ # Get number of tokens in prompt (excluding the start token)
1239
+ # NUM_PROMPT_TOKENS = input_ids.shape[-1] - 1 # Subtract action tokens and stop token
1240
+
1241
+ # input id '<s> In: What action should the robot take to open the middle drawer of the cabinet?\nOut:'
1242
+ #预测的时候labels 有啥用? 只是用来设置mask 我们自回归就不用
1243
+ cfg = kwargs.get("cfg", None) # Extract cfg from kwargs
1244
+ if cfg.preset:
1245
+ special_tensor = torch.tensor([[29871]], device=input_ids.device, dtype=input_ids.dtype)
1246
+ output_tensor = torch.tensor([[32001] * NUM_ACTIONS_CHUNK], device=input_ids.device, dtype=input_ids.dtype)
1247
+ input_ids = torch.cat([input_ids, special_tensor, output_tensor], dim=1) # preset action tokens, only forward once.
1248
+ self.preset = True
1249
+ else:
1250
+ self.preset = False
1251
+ input_embeddings = self.get_input_embeddings()(input_ids)
1252
+
1253
+ projected_patch_embeddings = self._process_vision_features(pixel_values)
1254
+
1255
+ use_proprio = proprio_projector is not None and proprio is not None
1256
+ if use_proprio:
1257
+ proprio = torch.Tensor(proprio).to(projected_patch_embeddings.device, dtype=projected_patch_embeddings.dtype)
1258
+ projected_patch_embeddings = self._process_proprio_features(
1259
+ projected_patch_embeddings, proprio, proprio_projector
1260
+ )
1261
+
1262
+ # Calculate number of patches (including proprio token and/or diffusion timestep embedding if present)
1263
+ # NUM_PATCHES = self.vision_backbone.get_num_patches() * self.vision_backbone.get_num_images_in_input()
1264
+ # if use_proprio:
1265
+ # NUM_PATCHES += 1
1266
+ normalized_actions, actions_hidden_states = self._regression_or_discrete_prediction(
1267
+ input_embeddings,
1268
+ None,
1269
+ projected_patch_embeddings,
1270
+ attention_mask,
1271
+ None, #推理不需要labels
1272
+ None, #推理不需要NUM_PATCHES
1273
+ None, #推理不需要NUM_PROMPT_TOKENS
1274
+ action_head,
1275
+ )
1276
+
1277
+ # Unnormalize predicted actions
1278
+ actions = self._unnormalize_actions(normalized_actions, unnorm_key)
1279
+
1280
+ return actions, actions_hidden_states
1281
+
1282
+
1283
+ def mul_predict_action(
1284
+ self,
1285
+ input_ids: Optional[torch.LongTensor] = None, #就是 language instruction.
1286
+ unnorm_key: Optional[str] = None,
1287
+ proprio=None,
1288
+ proprio_projector=None,
1289
+ action_head:L1RegressionActionHead=None,
1290
+ noisy_action_projector=None,
1291
+ use_film: bool = False,
1292
+ **kwargs: str,
1293
+ ) -> np.ndarray:
1294
+ cfg = kwargs.get("cfg", None) # Extract cfg from kwargs
1295
+ if cfg.preset:
1296
+ self.preset = True
1297
+ if not torch.all(input_ids[:, -1] == 29871):
1298
+ input_ids = torch.cat(
1299
+ (input_ids, torch.unsqueeze(torch.Tensor([29871]).long(), dim=0).to(input_ids.device)), dim=1
1300
+ )
1301
+ else:
1302
+ self.preset = False
1303
+
1304
+
1305
+ pixel_values = kwargs["pixel_values"]
1306
+ attention_mask = kwargs["attention_mask"]
1307
+
1308
+ # input id '<s> In: What action should the robot take to open the middle drawer of the cabinet?\nOut:'
1309
+
1310
+ input_embeddings = self.get_input_embeddings()(input_ids)
1311
+
1312
+ projected_patch_embeddings = self._process_vision_features(pixel_values)
1313
+
1314
+ use_proprio = proprio_projector is not None and proprio is not None
1315
+ if use_proprio:
1316
+ proprio = torch.Tensor(proprio).to(projected_patch_embeddings.device, dtype=projected_patch_embeddings.dtype)
1317
+ projected_patch_embeddings = self._process_proprio_features(
1318
+ projected_patch_embeddings, proprio, proprio_projector
1319
+ )
1320
+
1321
+ normalized_actions, actions_hidden_states = self.mul_regression_or_discrete_prediction(
1322
+ input_embeddings,
1323
+ None,
1324
+ projected_patch_embeddings,
1325
+ attention_mask,
1326
+ None, #推理不需要labels
1327
+ None, #推理不需要NUM_PATCHES
1328
+ None, #推理不需要NUM_PROMPT_TOKENS
1329
+ action_head,
1330
+ cfg=cfg,
1331
+ )
1332
+
1333
+ # Unnormalize predicted actions
1334
+ actions = self._unnormalize_actions(normalized_actions, unnorm_key)
1335
+
1336
+ return actions, actions_hidden_states
1337
+
modeling_prismatic.py.back.20250405_141300 ADDED
@@ -0,0 +1,1338 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ modeling_prismatic.py
3
+
4
+ Core HuggingFace-style PrismaticPreTrainedModel and PrismaticForConditionalGeneration class definitions.
5
+ Inherits from the default `transformers.PretrainedModel`. Meant to be standalone and self-contained,
6
+ but exactly replicate the logic in `prismatic.models.vlms.prismatic.py`.
7
+ """
8
+
9
+ import logging
10
+ from dataclasses import dataclass
11
+ from functools import partial
12
+ from typing import Any, Callable, ClassVar, Dict, List, Optional, Tuple, Union
13
+
14
+ import numpy as np
15
+ import timm
16
+ import tokenizers
17
+ import torch
18
+ import torch.nn as nn
19
+ import transformers
20
+ from timm.models.vision_transformer import LayerScale
21
+ from transformers import AutoModelForCausalLM, PretrainedConfig, PreTrainedModel
22
+ from transformers.modeling_outputs import ModelOutput
23
+ from prismatic.models.action_heads import L1RegressionActionHead
24
+ import time
25
+ from prismatic.training.train_utils import (
26
+ get_current_action_mask,
27
+ get_next_actions_mask,
28
+ )
29
+ from prismatic.vla.constants import (
30
+ ACTION_DIM,
31
+ ACTION_PROPRIO_NORMALIZATION_TYPE,
32
+ ACTION_TOKEN_BEGIN_IDX,
33
+ IGNORE_INDEX,
34
+ NUM_ACTIONS_CHUNK,
35
+ STOP_INDEX,
36
+ ACTION_TOKEN_IDX,
37
+ NormalizationType,
38
+ )
39
+
40
+ from .configuration_prismatic import OpenVLAConfig, PrismaticConfig
41
+
42
+ # Set up logger
43
+ logger = logging.getLogger(__name__)
44
+
45
+
46
+ # === Utility Functions for Monkey-Patching ===
47
+ def unpack_tuple(fn: Callable[[Any], Tuple[Any]]) -> Callable[[Any], Any]:
48
+ def wrapper(*args: Any, **kwargs: Any) -> Any:
49
+ result = fn(*args, **kwargs)
50
+ return result[0] if isinstance(result, tuple) else result
51
+
52
+ return wrapper
53
+
54
+
55
+ # HF Transformers overwrites parameters with names containing `gamma`; we're going to patch VisionBackbone.LayerScale.
56
+ # =>> TIMM :: https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/vision_transformer.py#L109
57
+ # =>> Transformers :: https://github.com/huggingface/transformers/blob/main/src/transformers/modeling_utils.py#L3960
58
+ def _ls_new_forward(self, x: torch.Tensor) -> torch.Tensor:
59
+ return x.mul_(self.scale_factor) if self.inplace else x * self.scale_factor
60
+
61
+
62
+ def ls_apply_patch(ls_module: LayerScale):
63
+ ls_module.scale_factor = nn.Parameter(ls_module.gamma.clone())
64
+ ls_module.forward = _ls_new_forward.__get__(ls_module, LayerScale)
65
+ del ls_module.gamma
66
+
67
+
68
+ # === Prismatic Vision Backbone (nn.Module) Definitions (w/ Fused Backbone Support) ===
69
+ class PrismaticVisionBackbone(nn.Module):
70
+ """
71
+ Vision backbone for Prismatic models that handles image feature extraction.
72
+
73
+ Supports both single backbone (e.g., SigLIP) and fused backbone (e.g., SigLIP + DINOv2) configurations.
74
+ For fused backbones, features from both models are concatenated along the feature dimension.
75
+ """
76
+
77
+ def __init__(
78
+ self,
79
+ use_fused_vision_backbone: bool,
80
+ image_sizes: List[int],
81
+ timm_model_ids: List[str],
82
+ timm_override_act_layers: List[Optional[str]],
83
+ ) -> None:
84
+ """
85
+ Initialize the vision backbone.
86
+
87
+ Args:
88
+ use_fused_vision_backbone: Whether to use two backbones and fuse their features
89
+ image_sizes: List of image sizes for each backbone
90
+ timm_model_ids: List of TIMM model IDs to use for each backbone
91
+ timm_override_act_layers: List of activation layer overrides for each backbone
92
+ """
93
+ super().__init__()
94
+ self.use_fused_vision_backbone = use_fused_vision_backbone
95
+ self.num_images_in_input = 1 # Default value, can be overridden later
96
+
97
+ # Validate number of (fused) vision backbones
98
+ if len(timm_model_ids) > 2:
99
+ raise ValueError("Prismatic models only support up to 2 (fused) vision backbones!")
100
+
101
+ # Create primary featurizer
102
+ self.featurizer = self._create_featurizer(
103
+ model_id=timm_model_ids[0], img_size=image_sizes[0], act_layer=timm_override_act_layers[0]
104
+ )
105
+ self.embed_dim = self.featurizer.embed_dim
106
+
107
+ # Create secondary featurizer if using fused backbone
108
+ if self.use_fused_vision_backbone:
109
+ self.fused_featurizer = self._create_featurizer(
110
+ model_id=timm_model_ids[1], img_size=image_sizes[1], act_layer=timm_override_act_layers[1]
111
+ )
112
+ self.embed_dim += self.fused_featurizer.embed_dim
113
+
114
+ # Patch LayerScale modules for HF compatibility
115
+ self._patch_layer_scales()
116
+
117
+ def _create_featurizer(self, model_id: str, img_size: int, act_layer: Optional[str]) -> nn.Module:
118
+ """
119
+ Create a TIMM-based featurizer model with appropriate configurations.
120
+
121
+ Args:
122
+ model_id: The TIMM model ID to load
123
+ img_size: Input image size for the model
124
+ act_layer: Override for the activation layer type
125
+
126
+ Returns:
127
+ A configured featurizer model
128
+ """
129
+ featurizer = timm.create_model(
130
+ model_id,
131
+ pretrained=False,
132
+ num_classes=0,
133
+ img_size=img_size,
134
+ act_layer=act_layer,
135
+ )
136
+
137
+ # Monkey-patch the forward function to extract the second-to-last layer features
138
+ num_blocks = len(featurizer.blocks)
139
+ featurizer.forward = unpack_tuple(partial(featurizer.get_intermediate_layers, n={num_blocks - 2}))
140
+
141
+ return featurizer
142
+
143
+ def _patch_layer_scales(self) -> None:
144
+ """
145
+ Patch all LayerScale modules to be compatible with HF's parameter naming.
146
+
147
+ HF Transformers overwrites parameters with names containing 'gamma',
148
+ so we need to rename and modify the forward method.
149
+ """
150
+ # Patch primary featurizer
151
+ for module in self.featurizer.modules():
152
+ if isinstance(module, LayerScale):
153
+ ls_apply_patch(module)
154
+
155
+ # Patch secondary featurizer if it exists
156
+ if self.use_fused_vision_backbone:
157
+ for module in self.fused_featurizer.modules():
158
+ if isinstance(module, LayerScale):
159
+ ls_apply_patch(module)
160
+
161
+ def get_num_patches(self) -> int:
162
+ """
163
+ Returns the number of vision patches output by the vision backbone.
164
+
165
+ Returns:
166
+ Number of patches per image
167
+ """
168
+ return self.featurizer.patch_embed.num_patches
169
+
170
+ def get_num_images_in_input(self) -> int:
171
+ """
172
+ Returns the number of input images for the vision backbone.
173
+
174
+ Returns:
175
+ Number of images expected in the input
176
+ """
177
+ return self.num_images_in_input
178
+
179
+ def set_num_images_in_input(self, num_images_in_input: int) -> None:
180
+ """
181
+ Sets the number of input images for the vision backbone.
182
+
183
+ Args:
184
+ num_images_in_input: Number of images to expect in the input
185
+ """
186
+ self.num_images_in_input = num_images_in_input
187
+
188
+ def forward(self, pixel_values: torch.Tensor) -> torch.Tensor:
189
+ """
190
+ Implements the forward pass for the vision backbone.
191
+
192
+ If `self.use_fused_vision_backbone == True`, uses both SigLIP and DINOv2 transformers to extract visual features
193
+ (otherwise uses SigLIP only). Allows multi-image inputs (but only for fused vision backbone).
194
+
195
+ Args:
196
+ pixel_values (torch.Tensor): Pixels for input image(s), (B, C, H, W).
197
+ """
198
+ if self.num_images_in_input == 1:
199
+ if not self.use_fused_vision_backbone:
200
+ return self.featurizer(pixel_values)
201
+
202
+ # Split `pixel_values :: [bsz, 2 * 3, resolution, resolution]` =>> featurize =>> channel stack
203
+ img, img_fused = torch.split(pixel_values, [3, 3], dim=1)
204
+ patches, patches_fused = self.featurizer(img), self.fused_featurizer(img_fused)
205
+
206
+ return torch.cat([patches, patches_fused], dim=2)
207
+
208
+ else:
209
+ assert self.use_fused_vision_backbone, "Multi-image inputs require using fused backbone!"
210
+
211
+ # Split `pixel_values` into individual images (each with 6 channels: 3 for SigLIP + 3 for DINOv2)
212
+ images = torch.split(pixel_values, [6] * self.num_images_in_input, dim=1)
213
+
214
+ # Process each image and collect patches
215
+ all_patches = []
216
+ for img in images:
217
+ # Split each image further into two stacks of channels (each with 3 channels)
218
+ img_regular, img_fused = torch.split(img, [3, 3], dim=1)
219
+
220
+ # Get patches from both SigLIP and DINOv2 vision transformers
221
+ patches = self.featurizer(img_regular)
222
+ patches_fused = self.fused_featurizer(img_fused)
223
+
224
+ # Concatenate SigLIP and DINOv2 patches along the hidden dimension
225
+ combined_patches = torch.cat([patches, patches_fused], dim=2)
226
+ all_patches.append(combined_patches)
227
+
228
+ # Concatenate all patches along the patch dimension
229
+ return torch.cat(all_patches, dim=1)
230
+
231
+
232
+ # === Prismatic Projector (nn.Module) Definitions ===
233
+ class PrismaticProjector(nn.Module):
234
+ def __init__(self, use_fused_vision_backbone: bool, vision_dim: int, llm_dim: int) -> None:
235
+ super().__init__()
236
+ self.use_fused_vision_backbone = use_fused_vision_backbone
237
+ self.vision_dim, self.llm_dim = vision_dim, llm_dim
238
+
239
+ # Switch on `use_fused_vision_backbone` =>> use slightly different MLPs and projection factors!
240
+ if not self.use_fused_vision_backbone:
241
+ self.fc1 = nn.Linear(self.vision_dim, self.llm_dim, bias=True)
242
+ self.fc2 = nn.Linear(self.llm_dim, self.llm_dim, bias=True)
243
+ self.act_fn1 = nn.GELU()
244
+ else:
245
+ initial_projection_dim = 4 * vision_dim
246
+ self.fc1 = nn.Linear(self.vision_dim, initial_projection_dim, bias=True)
247
+ self.fc2 = nn.Linear(initial_projection_dim, self.llm_dim, bias=True)
248
+ self.fc3 = nn.Linear(self.llm_dim, self.llm_dim, bias=True)
249
+ self.act_fn1 = nn.GELU()
250
+ self.act_fn2 = nn.GELU()
251
+
252
+ def forward(self, img_patches: torch.Tensor) -> torch.Tensor:
253
+ if not self.use_fused_vision_backbone:
254
+ projected_features = self.fc1(img_patches)
255
+ projected_features = self.act_fn1(projected_features)
256
+ projected_features = self.fc2(projected_features)
257
+ else:
258
+ projected_features = self.fc1(img_patches)
259
+ projected_features = self.act_fn1(projected_features)
260
+ projected_features = self.fc2(projected_features)
261
+ projected_features = self.act_fn2(projected_features)
262
+ projected_features = self.fc3(projected_features)
263
+
264
+ return projected_features
265
+
266
+
267
+ # === Main HF Class Definitions ===
268
+ @dataclass
269
+ class PrismaticCausalLMOutputWithPast(ModelOutput):
270
+ """Base class for Prismatic casual (visually-conditioned) language model outputs; also exposes visual features."""
271
+
272
+ loss: Optional[torch.FloatTensor] = None
273
+ logits: torch.FloatTensor = None
274
+ past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
275
+ hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
276
+ attentions: Optional[Tuple[torch.FloatTensor]] = None
277
+
278
+ # Additions for VLMs
279
+ projector_features: Optional[torch.FloatTensor] = None
280
+
281
+
282
+ class PrismaticPreTrainedModel(PreTrainedModel):
283
+ config_class: PretrainedConfig = PrismaticConfig
284
+ base_model_prefix: str = "model"
285
+ supports_gradient_checkpointing: bool = True
286
+
287
+ _no_split_modules: ClassVar[List[str]] = ["PrismaticProjector"]
288
+ _skip_keys_device_placement: str = "past_key_values"
289
+ _supports_flash_attn_2: bool = True
290
+
291
+ def _init_weights(self, module: nn.Module) -> None:
292
+ # Important :: this HF ported version is *not* meant for training from scratch; only inference and fine-tuning!
293
+ # => As such, this init_weights code is not correct; if training VLMs from scratch, use the main codebase at
294
+ # https://github.com/TRI-ML/prismatic-vlms
295
+ std = (
296
+ self.config.initializer_range
297
+ if hasattr(self.config, "initializer_range")
298
+ else self.config.text_config.initializer_range
299
+ )
300
+
301
+ if hasattr(module, "class_embedding"):
302
+ module.class_embedding.data.normal_(mean=0.0, std=std)
303
+
304
+ if isinstance(module, (nn.Linear, nn.Conv2d)):
305
+ module.weight.data.normal_(mean=0.0, std=std)
306
+ if module.bias is not None:
307
+ module.bias.data.zero_()
308
+ elif isinstance(module, nn.Embedding):
309
+ module.weight.data.normal_(mean=0.0, std=std)
310
+ if module.padding_idx is not None:
311
+ module.weight.data[module.padding_idx].zero_()
312
+
313
+ @property
314
+ def _supports_sdpa(self) -> bool:
315
+ """Check LLM supports SDPA Attention"""
316
+ return self.language_model._supports_sdpa
317
+
318
+
319
+ class PrismaticForConditionalGeneration(PrismaticPreTrainedModel):
320
+ def __init__(self, config: PrismaticConfig) -> None:
321
+ super().__init__(config)
322
+
323
+ # [Validation] Lightweight Validate on `config` Fields + Dependency Versions
324
+ if config.use_fused_vision_backbone is None:
325
+ raise ValueError("Missing config field `use_fused_vision_backbone`")
326
+
327
+ if timm.__version__ not in {"0.9.10", "0.9.11", "0.9.12", "0.9.16"}:
328
+ raise NotImplementedError(
329
+ "TIMM Version must be >= 0.9.10 and < 1.0.0 (breaking); please raise a GitHub Issue "
330
+ "if you urgently need support for latest TIMM versions."
331
+ )
332
+
333
+ if (transformers.__version__ != "4.40.1") or (tokenizers.__version__ != "0.19.1"):
334
+ logger.warning(
335
+ f"Expected `transformers==4.40.1` and `tokenizers==0.19.1` but got "
336
+ f"`transformers=={transformers.__version__}` and `tokenizers=={tokenizers.__version__}`; "
337
+ f"there might be inference-time regressions due to dependency changes. If in doubt, please"
338
+ f"use the above versions."
339
+ )
340
+
341
+ # Instantiate PrismaticVisionBackbone (w/ Potential Fused Backbone)
342
+ self.vision_backbone = PrismaticVisionBackbone(
343
+ config.use_fused_vision_backbone, config.image_sizes, config.timm_model_ids, config.timm_override_act_layers
344
+ )
345
+
346
+ # Create Multimodal Projector
347
+ self.projector = PrismaticProjector(
348
+ config.use_fused_vision_backbone,
349
+ vision_dim=self.vision_backbone.embed_dim,
350
+ llm_dim=config.text_config.hidden_size,
351
+ )
352
+
353
+ # Instantiate LLM Backbone
354
+ self.language_model = AutoModelForCausalLM.from_config(
355
+ config.text_config, attn_implementation=config._attn_implementation
356
+ )
357
+ self.vocab_size = config.text_config.vocab_size
358
+ self.pad_token_id = config.pad_token_id
359
+ self.llm_dim = config.text_config.hidden_size
360
+
361
+ # HF Boilerplate =>> initializes weights via `_init_weights()` and sets gradient checkpointing
362
+ self.post_init()
363
+
364
+ # === `PreTrainedModel` Boilerplate ===
365
+ def get_input_embeddings(self) -> nn.Module:
366
+ return self.language_model.get_input_embeddings()
367
+
368
+ def set_input_embeddings(self, value: nn.Module) -> None:
369
+ self.language_model.set_input_embeddings(value)
370
+
371
+ def get_output_embeddings(self) -> nn.Module:
372
+ return self.language_model.get_output_embeddings()
373
+
374
+ def set_output_embeddings(self, new_embeddings: nn.Module) -> None:
375
+ self.language_model.set_output_embeddings(new_embeddings)
376
+
377
+ def get_decoder(self) -> nn.Module:
378
+ return self.language_model.get_decoder()
379
+
380
+ def set_decoder(self, decoder: nn.Module) -> None:
381
+ self.language_model.set_decoder(decoder)
382
+
383
+ def tie_weights(self) -> None:
384
+ self.language_model.tie_weights() # Note: `Llama-2` and `Mistral` don't tie weights (no-op)
385
+
386
+ def resize_token_embeddings(
387
+ self, new_num_tokens: Optional[int] = None, pad_to_multiple_of: Optional[int] = None
388
+ ) -> nn.Embedding:
389
+ updated_embeddings = self.language_model.resize_token_embeddings(new_num_tokens, pad_to_multiple_of)
390
+
391
+ # Update config/instance variables
392
+ self.config.text_config.vocab_size = updated_embeddings.num_embeddings
393
+ self.vocab_size = updated_embeddings.num_embeddings
394
+
395
+ return updated_embeddings
396
+
397
+ def _replace_input_embeddings(self, input_embeddings, all_actions_mask, noisy_action_features):
398
+ """
399
+ Replace embeddings in input_embeddings at positions where all_actions_mask is True
400
+ with embeddings from noisy_action_features, using vectorized operations.
401
+
402
+ Args:
403
+ input_embeddings: Tensor of shape (B, S, D)
404
+ all_actions_mask: Boolean tensor of shape (B, S)
405
+ noisy_action_features: Tensor of shape (B, K, D) where K is the number of True values in mask per sample
406
+
407
+ Returns:
408
+ Modified input_embeddings tensor
409
+ """
410
+ # Clone input to avoid modifying the original tensor
411
+ new_input_embeddings = input_embeddings.clone()
412
+
413
+ # Create a tensor with the same shape of input_embeddings to hold the noisy action features
414
+ repositioned_noisy_action_features = torch.zeros_like(input_embeddings)
415
+
416
+ # Create batch indices for splicing
417
+ batch_indices = torch.arange(input_embeddings.shape[0], device=input_embeddings.device)
418
+ batch_indices = batch_indices.unsqueeze(1).expand(-1, noisy_action_features.shape[1])
419
+
420
+ # Get indices where mask is True for each sample
421
+ masked_indices = torch.stack([torch.where(mask)[0] for mask in all_actions_mask])
422
+
423
+ # Move the noisy action features into their correct positions
424
+ repositioned_noisy_action_features[batch_indices, masked_indices] = noisy_action_features
425
+
426
+ # Combine original input embeddings and noisy action embeddings using the mask
427
+ new_input_embeddings = torch.where(
428
+ all_actions_mask.unsqueeze(-1), repositioned_noisy_action_features, new_input_embeddings
429
+ )
430
+
431
+ return new_input_embeddings
432
+
433
+ def _process_action_masks(self, labels):
434
+ """Helper to get action masks from labels"""
435
+ current_action_mask = get_current_action_mask(labels) # (B, seq_len)
436
+ next_actions_mask = get_next_actions_mask(labels) # (B, seq_len)
437
+ all_actions_mask = current_action_mask | next_actions_mask # (B, seq_len)
438
+ return all_actions_mask
439
+
440
+ def _process_vision_features(self, pixel_values):
441
+ """Process vision features with optional FiLM conditioning"""
442
+ patch_features = self.vision_backbone(pixel_values) # (bsz, 256 * num_images, D)
443
+
444
+ # Project patch embeddings into language embedding space
445
+ return self.projector(patch_features)
446
+
447
+ def _process_proprio_features(self, projected_patch_embeddings, proprio, proprio_projector):
448
+ """Process proprioceptive features and append to vision features"""
449
+ if proprio_projector is not None and proprio is not None:
450
+ # projected_patch_embeddings: (bsz, num_patches * num_images, llm_dim)
451
+ # proprio: (bsz, proprio_dim) or (propro_dim,)
452
+ proprio = proprio.reshape(projected_patch_embeddings.shape[0], -1) # (bsz, proprio_dim)
453
+ proprio_features = proprio_projector(proprio) # (bsz, llm_dim)
454
+ proprio_features = proprio_features.unsqueeze(dim=1) # (bsz, 1, llm_dim)
455
+ # For simplicity, just append proprio token to the end of projected vision patch tokens
456
+ return torch.cat((projected_patch_embeddings, proprio_features), dim=1)
457
+ return projected_patch_embeddings
458
+
459
+ def _build_multimodal_attention(self, input_embeddings, projected_patch_embeddings, attention_mask):
460
+ """Build multimodal embeddings and attention mask"""
461
+ # juyi: Update attention mask 是不是要改成下三角? 不用, 因为generate会自动屏蔽
462
+ projected_patch_attention_mask = None
463
+ if attention_mask is not None:
464
+ projected_patch_attention_mask = torch.full(
465
+ (projected_patch_embeddings.shape[0], projected_patch_embeddings.shape[1]),
466
+ fill_value=True,
467
+ dtype=attention_mask.dtype,
468
+ device=attention_mask.device,
469
+ )
470
+
471
+ # Build multimodal embeddings & attention mask; insert embeddings after <BOS> token (1:)
472
+ multimodal_embeddings = torch.cat(
473
+ [input_embeddings[:, :1, :], projected_patch_embeddings, input_embeddings[:, 1:, :]], dim=1
474
+ )
475
+
476
+ multimodal_attention_mask = None
477
+ if attention_mask is not None:
478
+ multimodal_attention_mask = torch.cat(
479
+ [attention_mask[:, :1], projected_patch_attention_mask, attention_mask[:, 1:]], dim=1
480
+ )
481
+
482
+ return multimodal_embeddings, multimodal_attention_mask
483
+
484
+ def _build_multimodal_labels(self, labels, projected_patch_embeddings):
485
+ """Build multimodal labels with IGNORE_INDEX for patch embeddings"""
486
+ if labels is not None:
487
+ projected_patch_labels = torch.full(
488
+ (projected_patch_embeddings.shape[0], projected_patch_embeddings.shape[1]),
489
+ fill_value=IGNORE_INDEX, # 这些位置不需要计算损失。
490
+ dtype=labels.dtype,
491
+ device=labels.device,
492
+ )
493
+ return torch.cat([labels[:, :1], projected_patch_labels, labels[:, 1:]], dim=1) # 第一个token是<BOS>
494
+ return None
495
+
496
+ # === Core Prismatic VLM `forward()` Logic ===
497
+ def forward(
498
+ self,
499
+ input_ids: Optional[torch.LongTensor] = None,
500
+ attention_mask: Optional[torch.Tensor] = None,
501
+ pixel_values: Optional[torch.FloatTensor] = None,
502
+ labels: Optional[torch.LongTensor] = None,
503
+ inputs_embeds: Optional[torch.FloatTensor] = None,
504
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
505
+ use_cache: Optional[bool] = None,
506
+ output_attentions: Optional[bool] = None,
507
+ output_hidden_states: Optional[bool] = None,
508
+ output_projector_features: Optional[bool] = None,
509
+ return_dict: Optional[bool] = None,
510
+ proprio=None,
511
+ proprio_projector=None,
512
+ noisy_actions=None,
513
+ noisy_action_projector=None,
514
+ diffusion_timestep_embeddings=None,
515
+ use_film: bool = False,
516
+ ) -> Union[Tuple, PrismaticCausalLMOutputWithPast]:
517
+ """Run a forward pass through the VLM, returning a PrismaticCausalLMOutputWithPast instance."""
518
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
519
+ output_hidden_states = (
520
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
521
+ )
522
+ output_projector_features = output_projector_features if output_projector_features is not None else False
523
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
524
+
525
+ # Respect `use_cache` only if not training (even if `gradient_checkpointing` is off)
526
+ use_cache = use_cache and not self.training
527
+
528
+ # Instantiate Placeholder for Projector Features
529
+ projected_patch_embeddings = None
530
+
531
+ # === Handle Generation with Cache (`input_ids.shape[1] == 1`) =>> requires `past_keys_values` ===
532
+ if input_ids.shape[1] == 1:
533
+ assert input_ids.shape[0] == 1, "Generation is only currently supported for batch size of 1!"
534
+ assert past_key_values is not None, "You must provide `past_key_values` during cached generation!"
535
+ assert labels is None, "Unexpected key `labels` provided during cached generation!"
536
+
537
+ language_model_output = self.language_model(
538
+ input_ids=input_ids,
539
+ attention_mask=None,
540
+ position_ids=None,
541
+ past_key_values=past_key_values,
542
+ inputs_embeds=None,
543
+ labels=None,
544
+ use_cache=use_cache,
545
+ output_attentions=output_attentions,
546
+ output_hidden_states=output_hidden_states,
547
+ return_dict=return_dict,
548
+ )
549
+
550
+ # === Handle Unimodal Forward ===
551
+ elif pixel_values is None:
552
+ assert (input_ids is not None) and (inputs_embeds is None), "Missing `input_ids` in language-only forward!"
553
+ assert past_key_values is None, "Unexpected key `past_key_values` provided during language-only forward!"
554
+
555
+ language_model_output = self.language_model(
556
+ input_ids=input_ids,
557
+ attention_mask=attention_mask,
558
+ position_ids=None,
559
+ past_key_values=None,
560
+ inputs_embeds=None,
561
+ labels=labels,
562
+ use_cache=use_cache,
563
+ output_attentions=output_attentions,
564
+ output_hidden_states=output_hidden_states,
565
+ return_dict=return_dict,
566
+ )
567
+
568
+ # === Handle Multimodal Forward ===
569
+ elif (input_ids.shape[0] == pixel_values.shape[0]) or (inputs_embeds.shape[0] == pixel_values.shape[0]):
570
+ assert past_key_values is None, "Unexpected key `past_key_values` provided during multimodal forward!"
571
+
572
+ # Get input embeddings (from language model embeddings)
573
+ input_embeddings = self.get_input_embeddings()(input_ids) # (B, seq_len, D)
574
+
575
+ # Extract action masks
576
+ all_actions_mask = self._process_action_masks(labels)
577
+
578
+ # Extract the language portion of the input embeddings (i.e. remove the action tokens portion)
579
+ language_embeddings = input_embeddings[~all_actions_mask].reshape(
580
+ input_embeddings.shape[0], -1, input_embeddings.shape[2]
581
+ ) # (B, lang_seq_len, llm_dim)
582
+
583
+ # Get visual features
584
+ projected_patch_embeddings = self._process_vision_features(pixel_values, language_embeddings, use_film)
585
+
586
+ # Add proprioceptive state if provided
587
+ projected_patch_embeddings = self._process_proprio_features(
588
+ projected_patch_embeddings, proprio, proprio_projector
589
+ )
590
+
591
+ # [Diffusion] Add diffusion timestep embedding if provided
592
+ if diffusion_timestep_embeddings is not None:
593
+ # For simplicity, just append diffusion timestep embedding to the end of projected vision patch tokens
594
+ projected_patch_embeddings = torch.cat(
595
+ (projected_patch_embeddings, diffusion_timestep_embeddings), dim=1
596
+ )
597
+
598
+ # Process action embeddings
599
+ if noisy_actions is not None:
600
+ # Get mask corresponding to all action tokens
601
+ all_actions_mask = self._process_action_masks(labels)
602
+
603
+ # Reshape noisy actions into individual action tokens
604
+ # noisy_actions: (B, chunk_len, action_dim) -> (B, chunk_len * action_dim, 1)
605
+ B = noisy_actions.shape[0]
606
+ noisy_actions = noisy_actions.reshape(B, -1).unsqueeze(-1)
607
+
608
+ # Project noisy action tokens into language model embedding space
609
+ noisy_action_features = noisy_action_projector(noisy_actions) # (B, chunk_len * action_dim, llm_dim)
610
+
611
+ # Replace embeddings of the action tokens with noisy action embeddings
612
+ input_embeddings = self._replace_input_embeddings(
613
+ input_embeddings, all_actions_mask, noisy_action_features
614
+ )
615
+ else:
616
+ # Replace the embeddings of the action tokens with zeros
617
+ # (Later on, the positional embeddings will be added to them)
618
+ all_actions_mask = all_actions_mask.unsqueeze(-1) # (B, seq_len, 1)
619
+ input_embeddings = input_embeddings * ~all_actions_mask
620
+
621
+ # Build multimodal embeddings & attention mask
622
+ multimodal_embeddings, multimodal_attention_mask = self._build_multimodal_attention(
623
+ input_embeddings, projected_patch_embeddings, attention_mask
624
+ )
625
+
626
+ # Build labels for multimodal sequence if needed
627
+ multimodal_labels = self._build_multimodal_labels(labels, projected_patch_embeddings)
628
+
629
+ # Dispatch to language model
630
+ language_model_output = self.language_model(
631
+ input_ids=None,
632
+ attention_mask=multimodal_attention_mask,
633
+ position_ids=None,
634
+ past_key_values=None,
635
+ inputs_embeds=multimodal_embeddings,
636
+ labels=multimodal_labels,
637
+ use_cache=use_cache,
638
+ output_attentions=output_attentions,
639
+ output_hidden_states=output_hidden_states,
640
+ return_dict=return_dict,
641
+ )
642
+
643
+ # === Otherwise =>> Assume Invalid! ===
644
+ elif (input_ids.shape[0] != pixel_values.shape[0]) or (inputs_embeds.shape[0] != pixel_values.shape[0]):
645
+ raise ValueError("Non-homogenous batch of (text, image) input -- forward() does not support mixed batches!")
646
+
647
+ else:
648
+ raise ValueError(
649
+ "Invalid PrismaticForConditionalGeneration `forward()` call with provided arguments:\n"
650
+ f"=> `input_ids` = {input_ids is not None}\n"
651
+ f"=> `attention_mask` = {attention_mask is not None}\n"
652
+ f"=> `pixel_values` = {pixel_values is not None}\n"
653
+ f"=> `labels` = {labels is not None}\n"
654
+ f"=> `input_embeds` = {inputs_embeds is not None}\n"
655
+ f"=> `past_key_values` = {past_key_values is not None}\n"
656
+ f"=> `use_cache` = {use_cache}"
657
+ )
658
+
659
+ # Unpack `language_model_output` and return PrismaticCausalLMOutputWithPast (or tuple if not `return_dict`)
660
+ if not return_dict:
661
+ if output_projector_features and (projected_patch_embeddings is not None):
662
+ return *language_model_output, projected_patch_embeddings
663
+
664
+ return language_model_output
665
+
666
+ return PrismaticCausalLMOutputWithPast(
667
+ loss=language_model_output.loss,
668
+ logits=language_model_output.logits,
669
+ past_key_values=language_model_output.past_key_values,
670
+ hidden_states=language_model_output.hidden_states,
671
+ attentions=language_model_output.attentions,
672
+ projector_features=projected_patch_embeddings,
673
+ )
674
+
675
+ # === GenerationMixin Methods ===
676
+ def prepare_inputs_for_generation(
677
+ self,
678
+ input_ids: Optional[torch.Tensor] = None,
679
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
680
+ inputs_embeds: Optional[torch.FloatTensor] = None,
681
+ pixel_values: Optional[torch.FloatTensor] = None,
682
+ attention_mask: Optional[torch.Tensor] = None,
683
+ **kwargs: str,
684
+ ) -> Dict[str, torch.Tensor]:
685
+ """Borrowed from `LlamaForCausalLM` and simplified for batch size = 1; mirrors original PrismaticVLM logic."""
686
+ if ((input_ids is not None) and (input_ids.shape[0] > 1)) or (
687
+ (inputs_embeds is not None) and (inputs_embeds.shape[0] > 1)
688
+ ):
689
+ raise ValueError("Generation with batch size > 1 is not currently supported!")
690
+
691
+ # Handle `past_key_values` (cache) =>> assume `input_ids` just has unprocessed tokens
692
+ if past_key_values is not None:
693
+ input_ids = input_ids[:, -1:]
694
+
695
+ # If `input_embeds` are passed, we only want to use them in the 1st generation step
696
+ if inputs_embeds is not None and past_key_values is None:
697
+ model_inputs = {"input_embeds": inputs_embeds}
698
+ else:
699
+ model_inputs = {"input_ids": input_ids}
700
+
701
+ # Make sure `pixel_values` are preserved in `model_inputs`
702
+ model_inputs.update(
703
+ {
704
+ "attention_mask": attention_mask,
705
+ "pixel_values": pixel_values,
706
+ "past_key_values": past_key_values,
707
+ "use_cache": kwargs.get("use_cache"),
708
+ }
709
+ )
710
+
711
+ return model_inputs
712
+
713
+ # Defer to Language Model (all handle this differently, with different return types)
714
+ def _reorder_cache(self, *args, **kwargs) -> Any:
715
+ return self.language_model._reorder_cache(*args, **kwargs)
716
+
717
+
718
+ class OpenVLAForActionPrediction(PrismaticForConditionalGeneration):
719
+ config_class: PretrainedConfig = OpenVLAConfig
720
+
721
+ def __init__(self, config: OpenVLAConfig) -> None:
722
+ super().__init__(config)
723
+ self.norm_stats = config.norm_stats
724
+
725
+ # Compute action bins
726
+ self.bins = np.linspace(-1, 1, config.n_action_bins)
727
+ self.bin_centers = (self.bins[:-1] + self.bins[1:]) / 2.0
728
+
729
+ # Compute vocab size for de-tokenization -- revert added "multiple of"
730
+ self.vocab_size = self.config.text_config.vocab_size - self.config.pad_to_multiple_of
731
+
732
+ def _prepare_input_for_action_prediction(self, input_ids, attention_mask):
733
+ # eval 会用到这里
734
+ """Prepares input for action prediction by adding necessary tokens"""
735
+ # Add (ACTION_DIM * NUM_ACTIONS_CHUNK) placeholder tokens to input_ids to simulate action tokens
736
+ placeholder_action_token_ids = (
737
+ torch.ones((input_ids.shape[0], ACTION_DIM * NUM_ACTIONS_CHUNK)).to(input_ids.device).to(input_ids.dtype)
738
+ )
739
+ input_ids = torch.cat([input_ids, placeholder_action_token_ids], dim=-1) # torch.Size([1, 35 + 56= 91])
740
+
741
+ # Extend the attention mask to fit the new shape of input
742
+ # Note: Only batch size == 1 supported right now
743
+ mask_extension = (
744
+ torch.ones((attention_mask.shape[0], input_ids.shape[-1] - attention_mask.shape[-1]))
745
+ .to(attention_mask.device)
746
+ .to(attention_mask.dtype)
747
+ )
748
+ attention_mask = torch.cat([attention_mask, mask_extension], dim=-1)
749
+
750
+ return input_ids, attention_mask
751
+
752
+ def _prepare_labels_for_action_prediction(self, labels, input_ids):
753
+ """Creates labels tensor for action prediction if not provided"""
754
+ # eval 会用到这里 ,
755
+ # Extends label tensors with fake action labels
756
+ # Adds stop tokens at the end of sequences
757
+ # Handles label preparation for action prediction tasks
758
+ # 他为啥可以随便一个? xuan说 你自定义一个值 ,然后一直指定这个 , PAD token可以吗?
759
+ #TODO: 这里是否要改? 感觉不需要改. 随便写就行了因为labels不重要只是要一个mask. 为什么需要这个函数? 确保 action 预测任务的标签(labels)符合模型的输入长度,并正确地处理序列终止
760
+ # Extend labels tensor with fake action labels
761
+ ARBITRARY_ACTION_TOKEN_IDX = ACTION_TOKEN_IDX # = 为了mask正确生成, action_tokens_only_mask = (labels == ACTION_TOKEN_IDX ), 所以这里也填上ACTION_TOKEN_IDX
762
+ labels_extension = (
763
+ torch.ones((labels.shape[0], input_ids.shape[-1] - labels.shape[-1])).to(labels.device).to(labels.dtype)
764
+ * ARBITRARY_ACTION_TOKEN_IDX
765
+ ) #torch.Size([1, 57]),全是 ARBITRARY_ACTION_TOKEN_IDX
766
+ labels = torch.cat([labels, labels_extension], dim=-1)
767
+
768
+ return labels
769
+
770
+ def _unnormalize_actions(self, normalized_actions, unnorm_key=None):
771
+ """Unnormalize actions using dataset statistics"""
772
+ action_norm_stats = self.get_action_stats(unnorm_key)
773
+
774
+ if ACTION_PROPRIO_NORMALIZATION_TYPE == NormalizationType.BOUNDS:
775
+ mask = action_norm_stats.get("mask", np.ones_like(action_norm_stats["min"], dtype=bool))
776
+ action_high, action_low = np.array(action_norm_stats["max"]), np.array(action_norm_stats["min"])
777
+ elif ACTION_PROPRIO_NORMALIZATION_TYPE == NormalizationType.BOUNDS_Q99:
778
+ mask = action_norm_stats.get("mask", np.ones_like(action_norm_stats["q01"], dtype=bool))
779
+ action_high, action_low = np.array(action_norm_stats["q99"]), np.array(action_norm_stats["q01"])
780
+ else:
781
+ raise ValueError("Unsupported action/proprio normalization type detected!")
782
+
783
+ actions = np.where(
784
+ mask,
785
+ 0.5 * (normalized_actions + 1) * (action_high - action_low + 1e-8) + action_low,
786
+ normalized_actions,
787
+ )
788
+
789
+ return actions
790
+
791
+ def _run_diffusion_prediction(
792
+ self,
793
+ input_embeddings,
794
+ all_actions_mask,
795
+ noise,
796
+ action_head,
797
+ projected_patch_embeddings,
798
+ labels,
799
+ attention_mask,
800
+ NUM_PATCHES,
801
+ NUM_PROMPT_TOKENS,
802
+ noisy_action_projector,
803
+ ):
804
+ """Run diffusion-based action prediction"""
805
+ # Set diffusion timestep values
806
+ action_head.noise_scheduler.set_timesteps(action_head.num_diffusion_steps)
807
+ # Clone embedding for reuse in each timestep
808
+ orig_projected_patch_embeddings = projected_patch_embeddings.clone()
809
+ curr_noisy_actions = noise
810
+
811
+ # Reverse diffusion: Iteratively denoise to generate action prediction
812
+ for t in action_head.noise_scheduler.timesteps:
813
+ # Get diffusion model's noise prediction (conditioned on VLA latent embedding, current noisy action
814
+ # embedding, and diffusion timestep embedding)
815
+ timesteps = torch.Tensor([t]).to(labels.device)
816
+ diffusion_timestep_embeddings = (
817
+ action_head.time_encoder(timesteps).to(curr_noisy_actions.dtype).to(curr_noisy_actions.device)
818
+ ) # (B, llm_dim)
819
+ diffusion_timestep_embeddings = diffusion_timestep_embeddings.unsqueeze(1) # (B, 1, llm_dim)
820
+
821
+ # [Diffusion] Replace the embeddings of the action tokens with noisy actions
822
+ # (Later on, the positional embeddings will be added to them)
823
+
824
+ # For simplicity, append diffusion timestep embedding to the end of projected vision tokens
825
+ projected_patch_embeddings = torch.cat(
826
+ (orig_projected_patch_embeddings, diffusion_timestep_embeddings), dim=1
827
+ )
828
+
829
+ # Reshape and project noisy actions into language embedding space
830
+ B = curr_noisy_actions.shape[0]
831
+ orig_curr_noisy_actions_shape = curr_noisy_actions.shape
832
+ curr_noisy_actions = curr_noisy_actions.reshape(B, -1).unsqueeze(-1)
833
+ noisy_action_features = noisy_action_projector(curr_noisy_actions)
834
+ curr_noisy_actions = curr_noisy_actions.reshape(orig_curr_noisy_actions_shape)
835
+
836
+ # Replace action token embeddings with noisy action embeddings
837
+ input_embeddings = self._replace_input_embeddings(
838
+ input_embeddings.clone(), all_actions_mask, noisy_action_features
839
+ )
840
+
841
+ # Build multimodal embeddings and attention mask
842
+ multimodal_embeddings, multimodal_attention_mask = self._build_multimodal_attention(
843
+ input_embeddings, projected_patch_embeddings, attention_mask
844
+ )
845
+
846
+ # Forward pass through language model
847
+ language_model_output = self.language_model(
848
+ input_ids=None,
849
+ attention_mask=multimodal_attention_mask,
850
+ position_ids=None,
851
+ past_key_values=None,
852
+ inputs_embeds=multimodal_embeddings,
853
+ labels=None,
854
+ use_cache=None,
855
+ output_attentions=False,
856
+ output_hidden_states=True,
857
+ return_dict=True,
858
+ )
859
+
860
+ # Extract hidden states for action portion of response
861
+ last_hidden_states = language_model_output.hidden_states[-1] # (B, seq_len, D)
862
+ actions_hidden_states = last_hidden_states[
863
+ :,
864
+ NUM_PATCHES + NUM_PROMPT_TOKENS : NUM_PATCHES + NUM_PROMPT_TOKENS + ACTION_DIM * NUM_ACTIONS_CHUNK,
865
+ :,
866
+ ] # (B, act_chunk_len, D)
867
+
868
+ # Predict noise and update noisy actions: x_t -> x_{t-1}
869
+ noise_pred = action_head.predict_noise(actions_hidden_states)
870
+ curr_noisy_actions = action_head.noise_scheduler.step(noise_pred, t, curr_noisy_actions).prev_sample
871
+
872
+ curr_noisy_actions = curr_noisy_actions.reshape(NUM_ACTIONS_CHUNK, ACTION_DIM)
873
+
874
+ # Return final actions
875
+ return curr_noisy_actions.float().cpu().detach().numpy(), actions_hidden_states
876
+
877
+ def _regression_or_discrete_prediction(
878
+ self,
879
+ input_embeddings: torch.FloatTensor, #lanage instruction 的embedding.
880
+ all_actions_mask : Optional[torch.BoolTensor], #有啥用? 就是为了提取前面的embedding用. 去掉action .
881
+ projected_patch_embeddings: torch.FloatTensor,
882
+ attention_mask: torch.BoolTensor,
883
+ labels: torch.LongTensor,
884
+ NUM_PATCHES: int,
885
+ NUM_PROMPT_TOKENS: int,
886
+ action_head: L1RegressionActionHead,
887
+ ):
888
+ """Run L1 regression-based continuous action prediction or discrete action tokens prediction."""
889
+ # Extract hidden states for action tokens
890
+ # last_hidden_states = language_model_output.hidden_states[-1] # (B, seq_len, D)
891
+
892
+ # from transformers import AutoProcessor
893
+ # processor = AutoProcessor.from_pretrained("/data/juyi/openvla-7b+fractal20220817_data+b32+lr-5e-05+lora-r32+dropout-0.0--image_aug--test")
894
+ # tokenizer=processor.tokenizer
895
+ # tokenizer.decode(language_model_output.sequences[0])
896
+
897
+ # actions_hidden_states = last_hidden_states[:, NUM_PATCHES + NUM_PROMPT_TOKENS : NUM_PATCHES + NUM_PROMPT_TOKENS + NUM_ACTIONS_CHUNK * tokennum, :]# (B, act_chunk_len, D)
898
+ # 都不需要取了, 直接就给 token对应的hidden state了 ,太方便了.
899
+ # 为什么第一个是torch.Size([1, 535, 4096])? 我应该选哪个? https://discuss.huggingface.co/t/get-each-generated-token-last-layer-hidden-state/145921
900
+ # language_model_output.sequences tensor([[29871, 32001, 32001, 32001, 32001, 32001, 32001, 32001, 32001, 2]], device='cuda:0')
901
+
902
+ multimodal_embeddings, multimodal_attention_mask = self._build_multimodal_attention(
903
+ input_embeddings, projected_patch_embeddings, attention_mask
904
+ )
905
+ # multimodal_embeddings 例子'<s> <512 image token> <pripor token> In: What action should the robot take to open the middle drawer of the cabinet?\nOut:'
906
+ if self.preset:
907
+ # start_prefill = torch.cuda.Event(enable_timing=True)
908
+ # end_prefill = torch.cuda.Event(enable_timing=True)
909
+ # start_prefill.record()
910
+ language_model_output = self.language_model.generate(inputs_embeds=multimodal_embeddings,max_new_tokens=1,output_hidden_states=True,return_dict_in_generate=True)
911
+ # is tuple (1 token, 33 layers, torch.Size([1, 314, 4096]))
912
+ hidden_states = language_model_output.hidden_states[0][-1]
913
+ actions_hidden_states = hidden_states[:, -NUM_ACTIONS_CHUNK:]
914
+ # end_prefill.record()
915
+ # torch.cuda.synchronize()
916
+ # prefill_time = start_prefill.elapsed_time(end_prefill) / 1000
917
+ # print(f"Prefill time: {prefill_time:.4f} seconds")
918
+ else:
919
+ # start_generate = torch.cuda.Event(enable_timing=True)
920
+ # end_generate = torch.cuda.Event(enable_timing=True)
921
+ # start_generate.record()
922
+ language_model_output = self.language_model.generate(inputs_embeds=multimodal_embeddings,max_new_tokens=2048,output_hidden_states=True,return_dict_in_generate=True,use_cache=True)
923
+ # end_generate.record()
924
+ # torch.cuda.synchronize()
925
+ # generate_time = start_generate.elapsed_time(end_generate) / 1000
926
+ # print(f"prefill + Generate time: {generate_time:.4f} seconds")
927
+ actions_hidden_states = torch.stack([language_model_output.hidden_states[i][-1] for i in range(1,NUM_ACTIONS_CHUNK+1)], dim=0) # (action_chunk, batch size, seqence length, hidden_dim)
928
+ actions_hidden_states = actions_hidden_states.transpose(0, 1).squeeze(2) #torch.Size([batch size, action_chunk, hidden_dim])
929
+
930
+ normalized_actions = action_head.predict_action(actions_hidden_states)
931
+ normalized_actions = normalized_actions.reshape(NUM_ACTIONS_CHUNK, ACTION_DIM)
932
+ normalized_actions = normalized_actions.float().cpu().detach().numpy()
933
+
934
+ return normalized_actions, actions_hidden_states
935
+
936
+
937
+
938
+ def mul_regression_or_discrete_prediction(
939
+ self,
940
+ input_embeddings: torch.FloatTensor, #lanage instruction 的embedding.
941
+ all_actions_mask : Optional[torch.BoolTensor], #有啥用? 就是为了提取前面的embedding用. 去掉action .
942
+ projected_patch_embeddings: torch.FloatTensor,
943
+ attention_mask: torch.BoolTensor,
944
+ labels: torch.LongTensor,
945
+ NUM_PATCHES: int,
946
+ NUM_PROMPT_TOKENS: int,
947
+ action_head: L1RegressionActionHead,
948
+ **kwargs,
949
+ ):
950
+ cfg = kwargs.get("cfg", None)
951
+ multimodal_embeddings, multimodal_attention_mask = self._build_multimodal_attention(
952
+ input_embeddings, projected_patch_embeddings, attention_mask
953
+ )
954
+ # multimodal_embeddings 例子'<s> <512 image token> <pripor token> In: What action should the robot take to open the middle drawer of the cabinet?\nOut:'
955
+ # first language_model_output.hidden_states , is tuple (1 token, 33 layers, torch.Size([1, 314, 4096]))
956
+ if self.preset:
957
+ language_model_output = self.language_model.generate(inputs_embeds=multimodal_embeddings,max_new_tokens=1,output_hidden_states=True,return_dict_in_generate=True)
958
+ assert language_model_output.sequences == torch.tensor([[32001]], device=multimodal_embeddings.device)
959
+ actions_hidden_states = language_model_output.hidden_states[0][-1]
960
+ actions_hidden_states = actions_hidden_states[:, -1]
961
+ else:
962
+ language_model_output = self.language_model.generate(inputs_embeds=multimodal_embeddings,max_new_tokens=2,output_hidden_states=True,return_dict_in_generate=True)
963
+ actions_hidden_states = language_model_output.hidden_states[1][-1]
964
+ actions_hidden_states = actions_hidden_states[:, -1]
965
+
966
+ normalized_actions = action_head.predict_action(actions_hidden_states)
967
+ normalized_actions = normalized_actions.reshape(cfg.num_actions_chunk, ACTION_DIM)
968
+ normalized_actions = normalized_actions.float().cpu().detach().numpy()
969
+
970
+ return normalized_actions, actions_hidden_states
971
+
972
+ def predict_action(
973
+ self,
974
+ input_ids: Optional[torch.LongTensor] = None,
975
+ unnorm_key: Optional[str] = None,
976
+ proprio=None,
977
+ proprio_projector=None,
978
+ action_head=None,
979
+ noisy_action_projector=None,
980
+ use_film: bool = False,
981
+ **kwargs: str,
982
+ ) -> np.ndarray:
983
+ """Predict actions from input sequence, with options for different prediction methods.
984
+
985
+ Args:
986
+ input_ids: Input token ids
987
+ unnorm_key: Key for unnormalization statistics
988
+ proprio: Proprioceptive features
989
+ proprio_projector: Projector for proprioceptive features
990
+ action_head: Optional head for L1 regression or diffusion-based prediction
991
+ noisy_action_projector: Projector for noisy actions in diffusion-based prediction
992
+ use_film: Whether to use FiLM conditioning
993
+ **kwargs: Additional arguments including pixel_values and attention_mask
994
+
995
+ Returns:
996
+ Tuple of (unnormalized_actions, action_hidden_states)
997
+ """
998
+ # If the special empty token ('') does not already appear after the colon (':') token in the prompt
999
+ # (after "OUT:" or "ASSISTANT:"), insert it to match the inputs seen at training time
1000
+ if not torch.all(input_ids[:, -1] == 29871):
1001
+ input_ids = torch.cat(
1002
+ (input_ids, torch.unsqueeze(torch.Tensor([29871]).long(), dim=0).to(input_ids.device)), dim=1
1003
+ )
1004
+
1005
+ pixel_values = kwargs["pixel_values"]
1006
+ attention_mask = kwargs["attention_mask"]
1007
+
1008
+ # Create fake labels tensor (needed for action mask)
1009
+ labels = input_ids.clone()
1010
+ labels[:] = IGNORE_INDEX # 输入都ignore IGNORE_INDEX = -100
1011
+
1012
+ # Get number of tokens in prompt (excluding the start token)
1013
+ NUM_PROMPT_TOKENS = input_ids.shape[-1] - 1 # Subtract action tokens and stop token
1014
+
1015
+ # Prepare inputs by adding necessary tokens
1016
+ input_ids, attention_mask = self._prepare_input_for_action_prediction(input_ids, attention_mask)
1017
+
1018
+ # Update labels tensor for action mask computation later
1019
+ labels = self._prepare_labels_for_action_prediction(labels, input_ids)
1020
+
1021
+ # Get input embeddings and action masks
1022
+ input_embeddings = self.get_input_embeddings()(input_ids)
1023
+ all_actions_mask = self._process_action_masks(labels)
1024
+
1025
+ # Extract language embeddings
1026
+ language_embeddings = input_embeddings[~all_actions_mask].reshape(
1027
+ input_embeddings.shape[0], -1, input_embeddings.shape[2]
1028
+ )
1029
+
1030
+ # Process vision features
1031
+ projected_patch_embeddings = self._process_vision_features(pixel_values, language_embeddings, use_film)
1032
+
1033
+ # Add proprioceptive features if provided
1034
+ use_proprio = proprio_projector is not None and proprio is not None
1035
+ if use_proprio:
1036
+ proprio = torch.Tensor(proprio).to(projected_patch_embeddings.device, dtype=projected_patch_embeddings.dtype)
1037
+ projected_patch_embeddings = self._process_proprio_features(
1038
+ projected_patch_embeddings, proprio, proprio_projector
1039
+ )
1040
+
1041
+ # Use diffusion if provided, otherwise use regression or discrete prediction
1042
+ use_diffusion = noisy_action_projector is not None and hasattr(action_head, "noise_scheduler")
1043
+
1044
+ # Calculate number of patches (including proprio token and/or diffusion timestep embedding if present)
1045
+ NUM_PATCHES = self.vision_backbone.get_num_patches() * self.vision_backbone.get_num_images_in_input()
1046
+ if use_proprio:
1047
+ NUM_PATCHES += 1
1048
+
1049
+ normalized_actions, actions_hidden_states = self._regression_or_discrete_prediction(
1050
+ input_embeddings,
1051
+ all_actions_mask,
1052
+ projected_patch_embeddings,
1053
+ attention_mask,
1054
+ labels,
1055
+ NUM_PATCHES,
1056
+ NUM_PROMPT_TOKENS,
1057
+ action_head,
1058
+ )
1059
+
1060
+ # Unnormalize predicted actions
1061
+ actions = self._unnormalize_actions(normalized_actions, unnorm_key)
1062
+
1063
+ return actions, actions_hidden_states
1064
+
1065
+ @staticmethod
1066
+ def _check_unnorm_key(norm_stats: Dict[str, Dict[str, Any]], unnorm_key: Optional[str]) -> str:
1067
+ """Validate and resolve the unnormalization key for action statistics"""
1068
+ if unnorm_key is None:
1069
+ assert len(norm_stats) == 1, (
1070
+ f"Your model was trained on more than one dataset, "
1071
+ f"please pass a `unnorm_key` from the following options to choose the statistics "
1072
+ f"used for un-normalizing actions: {norm_stats.keys()}"
1073
+ )
1074
+ unnorm_key = next(iter(norm_stats.keys()))
1075
+ # norm states没有加载libero, 为什么?
1076
+ assert unnorm_key in norm_stats, (
1077
+ f"The `unnorm_key` you chose is not in the set of available dataset statistics, "
1078
+ f"please choose from: {norm_stats.keys()}"
1079
+ )
1080
+ return unnorm_key
1081
+
1082
+ def get_action_dim(self, unnorm_key: Optional[str] = None) -> int:
1083
+ """Get the dimensionality of the policy's action space."""
1084
+ unnorm_key = self._check_unnorm_key(self.norm_stats, unnorm_key)
1085
+ return len(self.norm_stats[unnorm_key]["action"]["min"])
1086
+
1087
+ def get_action_stats(self, unnorm_key: Optional[str] = None) -> Dict[str, Any]:
1088
+ """Get all the logged statistics for the given dataset."""
1089
+ unnorm_key = self._check_unnorm_key(self.norm_stats, unnorm_key)
1090
+ return self.norm_stats[unnorm_key]["action"]
1091
+
1092
+
1093
+ def lisa_forward(
1094
+ self,
1095
+ input_ids: Optional[torch.LongTensor] = None,
1096
+ attention_mask: Optional[torch.Tensor] = None,
1097
+ pixel_values: Optional[torch.FloatTensor] = None,
1098
+ labels: Optional[torch.LongTensor] = None,
1099
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1100
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
1101
+ use_cache: Optional[bool] = None,
1102
+ output_attentions: Optional[bool] = None,
1103
+ output_hidden_states: Optional[bool] = None,
1104
+ output_projector_features: Optional[bool] = None,
1105
+ return_dict: Optional[bool] = None,
1106
+ proprio=None,
1107
+ proprio_projector=None,
1108
+ **kwargs
1109
+ ) -> Union[Tuple, PrismaticCausalLMOutputWithPast]:
1110
+ """Run a forward pass through the VLM, returning a PrismaticCausalLMOutputWithPast instance."""
1111
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
1112
+ output_hidden_states = (
1113
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
1114
+ )
1115
+ output_projector_features = output_projector_features if output_projector_features is not None else False
1116
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1117
+
1118
+ # Respect `use_cache` only if not training (even if `gradient_checkpointing` is off)
1119
+ use_cache = use_cache and not self.training
1120
+
1121
+ # Instantiate Placeholder for Projector Features
1122
+ projected_patch_embeddings = None
1123
+
1124
+ # === Handle Generation with Cache (`input_ids.shape[1] == 1`) =>> requires `past_keys_values` ===
1125
+ if input_ids.shape[1] == 1:
1126
+ assert input_ids.shape[0] == 1, "Generation is only currently supported for batch size of 1!"
1127
+ assert past_key_values is not None, "You must provide `past_key_values` during cached generation!"
1128
+ assert labels is None, "Unexpected key `labels` provided during cached generation!"
1129
+
1130
+ language_model_output = self.language_model(
1131
+ input_ids=input_ids,
1132
+ attention_mask=None,
1133
+ position_ids=None,
1134
+ past_key_values=past_key_values,
1135
+ inputs_embeds=None,
1136
+ labels=None,
1137
+ use_cache=use_cache,
1138
+ output_attentions=output_attentions,
1139
+ output_hidden_states=output_hidden_states,
1140
+ return_dict=return_dict,
1141
+ )
1142
+
1143
+ # === Handle Unimodal Forward ===
1144
+ elif pixel_values is None:
1145
+ raise NotImplementedError
1146
+
1147
+ # === Handle Multimodal Forward ===
1148
+ elif (input_ids.shape[0] == pixel_values.shape[0]) or (inputs_embeds.shape[0] == pixel_values.shape[0]):
1149
+ assert past_key_values is None, "Unexpected key `past_key_values` provided during multimodal forward!"
1150
+
1151
+ # Get input embeddings (from language model embeddings)
1152
+ input_embeddings = self.get_input_embeddings()(input_ids) # (B, seq_len, D)
1153
+ # Extract the language portion of the input embeddings (i.e. remove the action tokens portion)
1154
+ # language_embeddings = input_embeddings[~all_actions_mask].reshape(
1155
+ # input_embeddings.shape[0], -1, input_embeddings.shape[2]
1156
+ # ) # (B, lang_seq_len, llm_dim) 这里就会把结尾的 stop index和padding 也算进去. 没问题吗? 没问题因为ignore了 我直接删了因为不用film
1157
+ # Get visual features
1158
+ projected_patch_embeddings = self._process_vision_features(pixel_values)
1159
+
1160
+ # Add proprioceptive state if provided
1161
+ projected_patch_embeddings = self._process_proprio_features(
1162
+ projected_patch_embeddings, proprio, proprio_projector
1163
+ )
1164
+
1165
+ all_actions_mask = (labels == ACTION_TOKEN_IDX) #和run forward pass不一样, run forward pass要手动算token number来找偏移.
1166
+ all_actions_mask = all_actions_mask.unsqueeze(-1) # (B, seq_len, 1)
1167
+ input_embeddings = input_embeddings * ~all_actions_mask
1168
+
1169
+ # Build multimodal embeddings & attention mask
1170
+ multimodal_embeddings, multimodal_attention_mask = self._build_multimodal_attention(
1171
+ input_embeddings, projected_patch_embeddings, attention_mask
1172
+ )
1173
+
1174
+ # Build labels for multimodal sequence if needed
1175
+ multimodal_labels = self._build_multimodal_labels(labels, projected_patch_embeddings)
1176
+
1177
+ # Dispatch to language model
1178
+ language_model_output = self.language_model(
1179
+ input_ids=None,
1180
+ attention_mask=multimodal_attention_mask,
1181
+ position_ids=None,
1182
+ past_key_values=None,
1183
+ inputs_embeds=multimodal_embeddings,
1184
+ labels=multimodal_labels,
1185
+ use_cache=use_cache,
1186
+ output_attentions=output_attentions,
1187
+ output_hidden_states=output_hidden_states,
1188
+ return_dict=return_dict,
1189
+ )
1190
+
1191
+
1192
+ # === Otherwise =>> Assume Invalid! ===
1193
+ elif (input_ids.shape[0] != pixel_values.shape[0]) or (inputs_embeds.shape[0] != pixel_values.shape[0]):
1194
+ raise ValueError("Non-homogenous batch of (text, image) input -- forward() does not support mixed batches!")
1195
+
1196
+ else:
1197
+ raise ValueError(
1198
+ "Invalid PrismaticForConditionalGeneration `forward()` call with provided arguments:\n"
1199
+ f"=> `input_ids` = {input_ids is not None}\n"
1200
+ f"=> `attention_mask` = {attention_mask is not None}\n"
1201
+ f"=> `pixel_values` = {pixel_values is not None}\n"
1202
+ f"=> `labels` = {labels is not None}\n"
1203
+ f"=> `input_embeds` = {inputs_embeds is not None}\n"
1204
+ f"=> `past_key_values` = {past_key_values is not None}\n"
1205
+ f"=> `use_cache` = {use_cache}"
1206
+ )
1207
+
1208
+ # Unpack `language_model_output` and return PrismaticCausalLMOutputWithPast (or tuple if not `return_dict`)
1209
+ if not return_dict:
1210
+ if output_projector_features and (projected_patch_embeddings is not None):
1211
+ return *language_model_output, projected_patch_embeddings
1212
+
1213
+ return language_model_output
1214
+
1215
+ return PrismaticCausalLMOutputWithPast(
1216
+ loss=language_model_output.loss,
1217
+ logits=language_model_output.logits,
1218
+ past_key_values=language_model_output.past_key_values,
1219
+ hidden_states=language_model_output.hidden_states,
1220
+ attentions=language_model_output.attentions,
1221
+ projector_features=projected_patch_embeddings,
1222
+ )
1223
+
1224
+ def lisa_predict_action(
1225
+ self,
1226
+ input_ids: Optional[torch.LongTensor] = None, #就是 language instruction.
1227
+ unnorm_key: Optional[str] = None,
1228
+ proprio=None,
1229
+ proprio_projector=None,
1230
+ action_head:L1RegressionActionHead=None,
1231
+ noisy_action_projector=None,
1232
+ use_film: bool = False,
1233
+ **kwargs: str,
1234
+ ) -> np.ndarray:
1235
+
1236
+ pixel_values = kwargs["pixel_values"]
1237
+ attention_mask = kwargs["attention_mask"]
1238
+
1239
+ # Get number of tokens in prompt (excluding the start token)
1240
+ # NUM_PROMPT_TOKENS = input_ids.shape[-1] - 1 # Subtract action tokens and stop token
1241
+
1242
+ # input id '<s> In: What action should the robot take to open the middle drawer of the cabinet?\nOut:'
1243
+ #预测的时候labels 有啥用? 只是用来设置mask 我们自回归就不用
1244
+ cfg = kwargs.get("cfg", None) # Extract cfg from kwargs
1245
+ if cfg.preset:
1246
+ special_tensor = torch.tensor([[29871]], device=input_ids.device, dtype=input_ids.dtype)
1247
+ output_tensor = torch.tensor([[32001] * NUM_ACTIONS_CHUNK], device=input_ids.device, dtype=input_ids.dtype)
1248
+ input_ids = torch.cat([input_ids, special_tensor, output_tensor], dim=1) # preset action tokens, only forward once.
1249
+ self.preset = True
1250
+ else:
1251
+ self.preset = False
1252
+ input_embeddings = self.get_input_embeddings()(input_ids)
1253
+
1254
+ projected_patch_embeddings = self._process_vision_features(pixel_values)
1255
+
1256
+ use_proprio = proprio_projector is not None and proprio is not None
1257
+ if use_proprio:
1258
+ proprio = torch.Tensor(proprio).to(projected_patch_embeddings.device, dtype=projected_patch_embeddings.dtype)
1259
+ projected_patch_embeddings = self._process_proprio_features(
1260
+ projected_patch_embeddings, proprio, proprio_projector
1261
+ )
1262
+
1263
+ # Calculate number of patches (including proprio token and/or diffusion timestep embedding if present)
1264
+ # NUM_PATCHES = self.vision_backbone.get_num_patches() * self.vision_backbone.get_num_images_in_input()
1265
+ # if use_proprio:
1266
+ # NUM_PATCHES += 1
1267
+ normalized_actions, actions_hidden_states = self._regression_or_discrete_prediction(
1268
+ input_embeddings,
1269
+ None,
1270
+ projected_patch_embeddings,
1271
+ attention_mask,
1272
+ None, #推理不需要labels
1273
+ None, #推理不需要NUM_PATCHES
1274
+ None, #推理不需要NUM_PROMPT_TOKENS
1275
+ action_head,
1276
+ )
1277
+
1278
+ # Unnormalize predicted actions
1279
+ actions = self._unnormalize_actions(normalized_actions, unnorm_key)
1280
+
1281
+ return actions, actions_hidden_states
1282
+
1283
+
1284
+ def mul_predict_action(
1285
+ self,
1286
+ input_ids: Optional[torch.LongTensor] = None, #就是 language instruction.
1287
+ unnorm_key: Optional[str] = None,
1288
+ proprio=None,
1289
+ proprio_projector=None,
1290
+ action_head:L1RegressionActionHead=None,
1291
+ noisy_action_projector=None,
1292
+ use_film: bool = False,
1293
+ **kwargs: str,
1294
+ ) -> np.ndarray:
1295
+ cfg = kwargs.get("cfg", None) # Extract cfg from kwargs
1296
+ if cfg.preset:
1297
+ self.preset = True
1298
+ if not torch.all(input_ids[:, -1] == 29871):
1299
+ input_ids = torch.cat(
1300
+ (input_ids, torch.unsqueeze(torch.Tensor([29871]).long(), dim=0).to(input_ids.device)), dim=1
1301
+ )
1302
+ else:
1303
+ self.preset = False
1304
+
1305
+
1306
+ pixel_values = kwargs["pixel_values"]
1307
+ attention_mask = kwargs["attention_mask"]
1308
+
1309
+ # input id '<s> In: What action should the robot take to open the middle drawer of the cabinet?\nOut:'
1310
+
1311
+ input_embeddings = self.get_input_embeddings()(input_ids)
1312
+
1313
+ projected_patch_embeddings = self._process_vision_features(pixel_values)
1314
+
1315
+ use_proprio = proprio_projector is not None and proprio is not None
1316
+ if use_proprio:
1317
+ proprio = torch.Tensor(proprio).to(projected_patch_embeddings.device, dtype=projected_patch_embeddings.dtype)
1318
+ projected_patch_embeddings = self._process_proprio_features(
1319
+ projected_patch_embeddings, proprio, proprio_projector
1320
+ )
1321
+
1322
+ normalized_actions, actions_hidden_states = self.mul_regression_or_discrete_prediction(
1323
+ input_embeddings,
1324
+ None,
1325
+ projected_patch_embeddings,
1326
+ attention_mask,
1327
+ None, #推理不需要labels
1328
+ None, #推理不需要NUM_PATCHES
1329
+ None, #推理不需要NUM_PROMPT_TOKENS
1330
+ action_head,
1331
+ cfg=cfg,
1332
+ )
1333
+
1334
+ # Unnormalize predicted actions
1335
+ actions = self._unnormalize_actions(normalized_actions, unnorm_key)
1336
+
1337
+ return actions, actions_hidden_states
1338
+
preprocessor_config.json ADDED
@@ -0,0 +1,114 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "auto_map": {
3
+ "AutoImageProcessor": "processing_prismatic.PrismaticImageProcessor",
4
+ "AutoProcessor": "processing_prismatic.PrismaticProcessor"
5
+ },
6
+ "image_processor_type": "PrismaticImageProcessor",
7
+ "image_resize_strategy": "resize-naive",
8
+ "input_sizes": [
9
+ [
10
+ 3,
11
+ 224,
12
+ 224
13
+ ],
14
+ [
15
+ 3,
16
+ 224,
17
+ 224
18
+ ]
19
+ ],
20
+ "interpolations": [
21
+ "bicubic",
22
+ "bicubic"
23
+ ],
24
+ "means": [
25
+ [
26
+ 0.485,
27
+ 0.456,
28
+ 0.406
29
+ ],
30
+ [
31
+ 0.5,
32
+ 0.5,
33
+ 0.5
34
+ ]
35
+ ],
36
+ "processor_class": "PrismaticProcessor",
37
+ "stds": [
38
+ [
39
+ 0.229,
40
+ 0.224,
41
+ 0.225
42
+ ],
43
+ [
44
+ 0.5,
45
+ 0.5,
46
+ 0.5
47
+ ]
48
+ ],
49
+ "tvf_crop_params": [
50
+ {
51
+ "output_size": [
52
+ 224,
53
+ 224
54
+ ]
55
+ },
56
+ {
57
+ "output_size": [
58
+ 224,
59
+ 224
60
+ ]
61
+ }
62
+ ],
63
+ "tvf_do_letterbox": false,
64
+ "tvf_letterbox_fill": null,
65
+ "tvf_normalize_params": [
66
+ {
67
+ "inplace": false,
68
+ "mean": [
69
+ 0.484375,
70
+ 0.455078125,
71
+ 0.40625
72
+ ],
73
+ "std": [
74
+ 0.228515625,
75
+ 0.2236328125,
76
+ 0.224609375
77
+ ]
78
+ },
79
+ {
80
+ "inplace": false,
81
+ "mean": [
82
+ 0.5,
83
+ 0.5,
84
+ 0.5
85
+ ],
86
+ "std": [
87
+ 0.5,
88
+ 0.5,
89
+ 0.5
90
+ ]
91
+ }
92
+ ],
93
+ "tvf_resize_params": [
94
+ {
95
+ "antialias": true,
96
+ "interpolation": 3,
97
+ "max_size": null,
98
+ "size": [
99
+ 224,
100
+ 224
101
+ ]
102
+ },
103
+ {
104
+ "antialias": true,
105
+ "interpolation": 3,
106
+ "max_size": null,
107
+ "size": [
108
+ 224,
109
+ 224
110
+ ]
111
+ }
112
+ ],
113
+ "use_fused_vision_backbone": true
114
+ }
processing_prismatic.py ADDED
@@ -0,0 +1,257 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ processing_prismatic.py
3
+
4
+ HuggingFace-style preprocessor definitions for Prismatic VLMs, inheriting from `ProcessorMixin`. Default configuration
5
+ specifies `siglip-224px+7b`.
6
+ """
7
+
8
+ from typing import Any, ClassVar, List, Optional, Tuple, Union
9
+
10
+ import timm.data
11
+ import torch
12
+ import torchvision.transforms.functional as TVF
13
+ from PIL import Image
14
+ from torchvision.transforms import CenterCrop, Compose, Normalize, Resize, ToTensor
15
+ from transformers import PreTrainedTokenizerBase
16
+ from transformers.image_processing_utils import BatchFeature, ImageProcessingMixin
17
+ from transformers.processing_utils import ProcessorMixin
18
+ from transformers.tokenization_utils import PaddingStrategy, PreTokenizedInput, TextInput, TruncationStrategy
19
+ from transformers.utils import TensorType
20
+
21
+
22
+ # === Image Processing ===
23
+ def letterbox_pad_transform(image: Image.Image, padding_fill_value: Tuple[int, int, int]) -> Image.Image:
24
+ """Given a PIL.Image, pad to square by adding a symmetric border around the height/width."""
25
+ (w, h), max_wh = image.size, max(image.size)
26
+ horizontal_pad, vertical_pad = int((max_wh - w) / 2), int((max_wh - h) / 2)
27
+ padding = (horizontal_pad, vertical_pad, horizontal_pad, vertical_pad)
28
+
29
+ return TVF.pad(image, padding, fill=padding_fill_value, padding_mode="constant")
30
+
31
+
32
+ class PrismaticImageProcessor(ImageProcessingMixin):
33
+ model_input_names: ClassVar[List[str]] = ["pixel_values"]
34
+
35
+ def __init__(
36
+ self,
37
+ use_fused_vision_backbone: bool = False,
38
+ image_resize_strategy: str = "letterbox",
39
+ input_sizes: Optional[List[Tuple[int, int, int]]] = None,
40
+ interpolations: Optional[List[str]] = None,
41
+ means: Optional[List[Tuple[float, float, float]]] = None,
42
+ stds: Optional[List[Tuple[float, float, float]]] = None,
43
+ **kwargs: str,
44
+ ) -> None:
45
+ """
46
+ Initialize a PrismaticImageProcessor as a wrapper around a torchvision transform; this transform will be
47
+ created by TIMM, and edited to follow our custom `image_resize_strategy` logic.
48
+
49
+ @param use_fused_vision_backbone: Boolean indicating single or fused (dual) vision backbone
50
+ @param image_resize_strategy: Prismatic image resize strategy in < resize-naive | resize-crop | letterbox >
51
+ @param input_size: [TIMM :: `data_cfg`] Input image size as tuple (channels, width, height)
52
+ @param interpolation: [TIMM :: `data_cfg`] Interpolation as string (default: "bicubic")
53
+ @param mean: [TIMM :: `data_cfg`] Normalization mean as float tuple (or two-tuple if `fused_backbone`)
54
+ @param std: [TIMM :: `data_cfg`] Normalization std as float tuple (or two-tuple if `fused_backbone`)
55
+ """
56
+ self.use_fused_vision_backbone = use_fused_vision_backbone
57
+ self.image_resize_strategy = image_resize_strategy
58
+
59
+ # Handle `None` default values
60
+ input_sizes = [(3, 224, 224)] if input_sizes is None else input_sizes
61
+ means = [(0.5, 0.5, 0.5)] if means is None else means
62
+ stds = [(0.5, 0.5, 0.5)] if stds is None else stds
63
+
64
+ # TIMM `data_cfg` Parameters
65
+ self.input_sizes, self.interpolations, self.means, self.stds = input_sizes, interpolations, means, stds
66
+
67
+ # Grab torchvision transforms via TIMM =>> need to parse for specific "functional" transform values!
68
+ self.tvf_resize_params, self.tvf_crop_params, self.tvf_normalize_params = [], [], []
69
+ self.tvf_do_letterbox, self.tvf_letterbox_fill = False, None
70
+
71
+ for idx in range(len(input_sizes)):
72
+ transform = timm.data.create_transform(
73
+ input_size=self.input_sizes[idx],
74
+ interpolation=self.interpolations[idx],
75
+ mean=self.means[idx],
76
+ std=self.stds[idx],
77
+ crop_pct=1.0, # Set to 1.0 to ignore cropping (initial Resize sets `input_size`)
78
+ crop_mode="center", # Default crop mode -- no-op when `crop_pct == 1.0`
79
+ is_training=False, # No image augmentations when loading the transform!
80
+ )
81
+
82
+ # [Validation] Ensure appropriate transform structure, expected sizes
83
+ if not (
84
+ isinstance(transform, Compose)
85
+ and (len(transform.transforms) == 4)
86
+ and isinstance(transform.transforms[0], Resize)
87
+ and isinstance(transform.transforms[1], CenterCrop)
88
+ and isinstance(transform.transforms[2], ToTensor)
89
+ and isinstance(transform.transforms[3], Normalize)
90
+ and (transform.transforms[0].size == self.input_sizes[idx][-1])
91
+ and (transform.transforms[1].size == self.input_sizes[idx][-2:])
92
+ ):
93
+ raise ValueError(f"Unexpected TIMM image transformation structure/sizes: `{transform}`")
94
+
95
+ # HF Image Processors *must* be JSON-serializable; as such, cannot have torchvision. as an attribute.
96
+ # => Instead, we're going to parse the transform and call "torchvision.transforms.functional" (`tvf`)
97
+ resize_t, crop_t, norm_t = transform.transforms[0], transform.transforms[1], transform.transforms[3]
98
+ self.tvf_resize_params.append(
99
+ {
100
+ "size": resize_t.size,
101
+ "interpolation": TVF.pil_modes_mapping[resize_t.interpolation],
102
+ "max_size": None,
103
+ "antialias": True,
104
+ }
105
+ )
106
+ self.tvf_crop_params.append({"output_size": crop_t.size})
107
+ self.tvf_normalize_params.append(
108
+ {
109
+ "mean": norm_t.mean.float().numpy().tolist(),
110
+ "std": norm_t.std.float().numpy().tolist(),
111
+ "inplace": False,
112
+ }
113
+ )
114
+ self.tvf_do_letterbox, self.tvf_letterbox_fill = False, None
115
+
116
+ # Handle Prismatic `image_resize_strategy`
117
+ if self.image_resize_strategy == "resize-naive":
118
+ self.tvf_resize_params[idx]["size"] = (resize_t.size, resize_t.size)
119
+ elif self.image_resize_strategy == "letterbox":
120
+ self.tvf_do_letterbox, self.tvf_letterbox_fill = True, tuple([int(x * 255) for x in self.means[idx]])
121
+ elif self.image_resize_strategy == "resize-crop":
122
+ pass
123
+ else:
124
+ raise ValueError(f"Image resize strategy `{self.image_resize_strategy}` is not supported!")
125
+
126
+ # Dispatch **kwargs to super()
127
+ super().__init__(**kwargs)
128
+
129
+ def apply_transform(self, img: Image.Image) -> torch.Tensor:
130
+ """Apply `functional` variant of TIMM's Transform = Compose([Resize -> CenterCrop -> ToTensor -> Normalize])"""
131
+ if self.tvf_do_letterbox:
132
+ img = letterbox_pad_transform(img, self.tvf_letterbox_fill)
133
+
134
+ # [Contract] Fused Backbones expect "channel-stacked" inputs; we'll unpack on the model side!
135
+ imgs_t = []
136
+ for idx in range(len(self.input_sizes)):
137
+ img_idx = TVF.resize(img, **self.tvf_resize_params[idx])
138
+ img_idx = TVF.center_crop(img_idx, **self.tvf_crop_params[idx])
139
+ img_idx_t = TVF.to_tensor(img_idx)
140
+ img_idx_t = TVF.normalize(img_idx_t, **self.tvf_normalize_params[idx])
141
+ imgs_t.append(img_idx_t)
142
+
143
+ # [Contract] `imgs_t` is a list of Tensors of shape [3, input_size, input_size]; stack along dim = 0
144
+ img_t = torch.vstack(imgs_t)
145
+
146
+ return img_t
147
+
148
+ def preprocess(
149
+ self,
150
+ images: Union[Image.Image, List[Image.Image]],
151
+ return_tensors: Optional[Union[str, TensorType]] = None,
152
+ **_: str,
153
+ ) -> BatchFeature:
154
+ """
155
+ Preprocess an image (or batch of images); note that unlike the `transformers :: BaseImageProcessor` we
156
+ explicitly only handle PIL.Image.Image instances for simplicity.
157
+
158
+ @param images: A (batch of) PIL.Image.Image instance(s) to preprocess.
159
+ @param return_tensors: BatchFeature default Tensor format (e.g., "pt" for torch); if None, returns np.ndarray
160
+
161
+ @return: Instance of `transformers :: BatchFeature` with a single key "pixel_values"
162
+ """
163
+ if not isinstance(images, list):
164
+ images = [images]
165
+
166
+ # Apply `self.img_transform` to each image (will return list of torch.Tensors); stack into "batched" Tensor
167
+ pixel_values = torch.stack([self.apply_transform(img.convert("RGB")) for img in images])
168
+
169
+ # Return BatchFeature =>> note that for compatibility, constructor expects Dict[str, np.ndarray], so we convert
170
+ return BatchFeature(data={"pixel_values": pixel_values.float().numpy()}, tensor_type=return_tensors)
171
+
172
+ def __call__(self, images: Union[Image.Image, List[Image.Image]], **kwargs) -> BatchFeature:
173
+ return self.preprocess(images, **kwargs)
174
+
175
+
176
+ # === PrismaticProcessor =>> Wraps both ImageProcessor and Tokenizer ===
177
+ # =>> https://github.com/huggingface/transformers/blob/main/src/transformers/models/llava/processing_llava.py
178
+ class PrismaticProcessor(ProcessorMixin):
179
+ attributes: ClassVar[List[str]] = ["image_processor", "tokenizer"]
180
+ image_processor_class: str = "AutoImageProcessor"
181
+ tokenizer_class: str = "AutoTokenizer"
182
+
183
+ def __init__(
184
+ self,
185
+ image_processor: Optional[ImageProcessingMixin] = None,
186
+ tokenizer: Optional[PreTrainedTokenizerBase] = None,
187
+ ) -> None:
188
+ super().__init__(image_processor, tokenizer)
189
+
190
+ def __call__(
191
+ self,
192
+ text: Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]],
193
+ images: Union[Image.Image, List[Image.Image]],
194
+ padding: Union[bool, str, PaddingStrategy] = False,
195
+ truncation: Optional[Union[bool, str, TruncationStrategy]] = None,
196
+ max_length: Optional[int] = None,
197
+ return_tensors: Optional[Union[str, TensorType]] = TensorType.PYTORCH,
198
+ ) -> BatchFeature:
199
+ """
200
+ Preprocess a given (batch) of text/images for a Prismatic VLM; forwards text to the underlying LLM's tokenizer,
201
+ forwards images to PrismaticImageProcessor.
202
+
203
+ @param text: The (batch) of text to encode; must be a string or list of strings.
204
+ @param images: A (batch of) PIL.Image.Image instance(s) to preprocess.
205
+ @param padding: Sequence padding strategy (if multiple specified) in < True = "longest" | "max_length" | False >
206
+ @param truncation: Truncation strategy for the output sequences; requires `max_length` to be specified
207
+ @param max_length: Maximum length (in tokens) to truncate
208
+ @param return_tensors: Type of return tensors (usually "pt" or TensorType.PYTORCH)
209
+
210
+ @return: BatchFeature with keys for `input_ids`, `attention_mask` and `pixel_values`.
211
+ """
212
+ pixel_values = self.image_processor(images, return_tensors=return_tensors)["pixel_values"]
213
+ text_inputs = self.tokenizer(
214
+ text, return_tensors=return_tensors, padding=padding, truncation=truncation, max_length=max_length
215
+ )
216
+
217
+ # [Validate] Need same number of images and text inputs!
218
+ if pixel_values.shape[0] != text_inputs.input_ids.shape[0]:
219
+ raise ValueError("Batch is malformed; expected same number of images and text inputs!")
220
+
221
+ return BatchFeature(data={**text_inputs, "pixel_values": pixel_values})
222
+
223
+ # === Tokenizer Dispatch Utilities =>> check `PreTrainedTokenizerBase` for documentation ===
224
+ def batch_decode(
225
+ self,
226
+ sequences: Union[List[int], List[List[int]], torch.Tensor, Any], # `Any` = np.ndarray | tf.Tensor
227
+ skip_special_tokens: bool = False,
228
+ clean_up_tokenization_spaces: Optional[bool] = None,
229
+ **kwargs: str,
230
+ ) -> List[str]:
231
+ return self.tokenizer.batch_decode(
232
+ sequences=sequences,
233
+ skip_special_tokens=skip_special_tokens,
234
+ clean_up_tokenization_spaces=clean_up_tokenization_spaces,
235
+ **kwargs,
236
+ )
237
+
238
+ def decode(
239
+ self,
240
+ token_ids: Union[int, List[int], torch.Tensor, Any], # `Any` = np.ndarray | tf.Tensor
241
+ skip_special_tokens: bool = False,
242
+ clean_up_tokenization_spaces: Optional[bool] = None,
243
+ **kwargs: str,
244
+ ) -> str:
245
+ return self.tokenizer.decode(
246
+ token_ids=token_ids,
247
+ skip_special_tokens=skip_special_tokens,
248
+ clean_up_tokenization_spaces=clean_up_tokenization_spaces,
249
+ **kwargs,
250
+ )
251
+
252
+ @property
253
+ def model_input_names(self) -> List[str]:
254
+ tokenizer_input_names = self.tokenizer.model_input_names
255
+ image_processor_input_names = self.image_processor.model_input_names
256
+
257
+ return list(dict.fromkeys(tokenizer_input_names + image_processor_input_names))
processor_config.json ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ {
2
+ "auto_map": {
3
+ "AutoProcessor": "processing_prismatic.PrismaticProcessor"
4
+ },
5
+ "processor_class": "PrismaticProcessor"
6
+ }
special_tokens_map.json ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "additional_special_tokens": [
3
+ {
4
+ "content": "<ACT>",
5
+ "lstrip": false,
6
+ "normalized": false,
7
+ "rstrip": false,
8
+ "single_word": false
9
+ }
10
+ ],
11
+ "bos_token": {
12
+ "content": "<s>",
13
+ "lstrip": false,
14
+ "normalized": false,
15
+ "rstrip": false,
16
+ "single_word": false
17
+ },
18
+ "eos_token": {
19
+ "content": "</s>",
20
+ "lstrip": false,
21
+ "normalized": false,
22
+ "rstrip": false,
23
+ "single_word": false
24
+ },
25
+ "pad_token": {
26
+ "content": "<PAD>",
27
+ "lstrip": false,
28
+ "normalized": false,
29
+ "rstrip": false,
30
+ "single_word": false
31
+ },
32
+ "unk_token": {
33
+ "content": "<unk>",
34
+ "lstrip": false,
35
+ "normalized": false,
36
+ "rstrip": false,
37
+ "single_word": false
38
+ }
39
+ }
tokenizer.json ADDED
The diff for this file is too large to render. See raw diff
 
tokenizer.model ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:9e556afd44213b6bd1be2b850ebbbd98f5481437a8021afaf58ee7fb1818d347
3
+ size 499723
tokenizer_config.json ADDED
@@ -0,0 +1,64 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "add_bos_token": true,
3
+ "add_eos_token": false,
4
+ "added_tokens_decoder": {
5
+ "0": {
6
+ "content": "<unk>",
7
+ "lstrip": false,
8
+ "normalized": false,
9
+ "rstrip": false,
10
+ "single_word": false,
11
+ "special": true
12
+ },
13
+ "1": {
14
+ "content": "<s>",
15
+ "lstrip": false,
16
+ "normalized": false,
17
+ "rstrip": false,
18
+ "single_word": false,
19
+ "special": true
20
+ },
21
+ "2": {
22
+ "content": "</s>",
23
+ "lstrip": false,
24
+ "normalized": false,
25
+ "rstrip": false,
26
+ "single_word": false,
27
+ "special": true
28
+ },
29
+ "32000": {
30
+ "content": "<PAD>",
31
+ "lstrip": false,
32
+ "normalized": false,
33
+ "rstrip": false,
34
+ "single_word": false,
35
+ "special": true
36
+ },
37
+ "32001": {
38
+ "content": "<ACT>",
39
+ "lstrip": false,
40
+ "normalized": false,
41
+ "rstrip": false,
42
+ "single_word": false,
43
+ "special": true
44
+ }
45
+ },
46
+ "additional_special_tokens": [
47
+ "<ACT>"
48
+ ],
49
+ "auto_map": {
50
+ "AutoProcessor": "processing_prismatic.PrismaticProcessor"
51
+ },
52
+ "bos_token": "<s>",
53
+ "clean_up_tokenization_spaces": false,
54
+ "eos_token": "</s>",
55
+ "legacy": false,
56
+ "model_max_length": 2048,
57
+ "pad_token": "<PAD>",
58
+ "padding_side": "right",
59
+ "processor_class": "PrismaticProcessor",
60
+ "sp_model_kwargs": {},
61
+ "tokenizer_class": "LlamaTokenizer",
62
+ "unk_token": "<unk>",
63
+ "use_default_system_prompt": false
64
+ }