agopalkr commited on
Commit
5331ccc
·
verified ·
1 Parent(s): 65d2b0f

Upload PrismaticForConditionalGeneration

Browse files
Files changed (4) hide show
  1. config.json +2 -1
  2. generation_config.json +7 -0
  3. model.safetensors +3 -0
  4. modelling_pi.py +566 -0
config.json CHANGED
@@ -4,7 +4,8 @@
4
  "PrismaticForConditionalGeneration"
5
  ],
6
  "auto_map": {
7
- "AutoConfig": "configuration_prismatic.PrismaticConfig"
 
8
  },
9
  "hf_llm_id": "agopalkr/gemma-2b",
10
  "image_resize_strategy": "resize-naive",
 
4
  "PrismaticForConditionalGeneration"
5
  ],
6
  "auto_map": {
7
+ "AutoConfig": "configuration_prismatic.PrismaticConfig",
8
+ "AutoModelForVision2Seq": "modelling_pi.PrismaticForConditionalGeneration"
9
  },
10
  "hf_llm_id": "agopalkr/gemma-2b",
11
  "image_resize_strategy": "resize-naive",
generation_config.json ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ {
2
+ "_from_model_config": true,
3
+ "bos_token_id": 2,
4
+ "eos_token_id": 1,
5
+ "pad_token_id": 257153,
6
+ "transformers_version": "4.44.2"
7
+ }
model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e4a537ee0679c94c6aaee04d70782af5c46a36314d322f1b97575c17c8d8d261
3
+ size 6931030928
modelling_pi.py ADDED
@@ -0,0 +1,566 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ modeling_prismatic.py
3
+
4
+ Core HuggingFace-style PrismaticPreTrainedModel and PrismaticForConditionalGeneration class definitions, inheriting
5
+ from the default `transformers.PretrainedModel`. Meant to be standalone and self-contained, but exactly replicate the
6
+ logic in `prismatic.models.vlms.prismatic.py`.
7
+
8
+ Note =>> for the time being, not adding the custom HF "docstring" formatting.
9
+
10
+ References [LLaVa, IDEFICS-2]:
11
+ => https://github.com/huggingface/transformers/blob/main/src/transformers/models/llava/modeling_llava.py
12
+ => https://github.com/huggingface/transformers/blob/main/src/transformers/models/idefics2/modeling_idefics2.py
13
+ """
14
+
15
+ import logging
16
+ from dataclasses import dataclass
17
+ from functools import partial
18
+ from typing import Any, Callable, ClassVar, Dict, List, Optional, Tuple, Union
19
+
20
+ import numpy as np
21
+ import timm
22
+ import tokenizers
23
+ import torch
24
+ import torch.nn as nn
25
+ import transformers
26
+ from timm.models.vision_transformer import LayerScale
27
+ from transformers import AutoModelForCausalLM, PretrainedConfig, PreTrainedModel,GemmaForCausalLM
28
+ from transformers.modeling_outputs import ModelOutput
29
+ from prismatic.util.nn_utils import FusedMLPProjector, LinearProjector, MLPProjector
30
+
31
+ from .configuration_prismatic import OpenVLAConfig, PrismaticConfig
32
+
33
+ # Get Logger
34
+ logger = logging.getLogger(__name__)
35
+
36
+
37
+ # === PyTorch/HuggingFace Default IGNORE_INDEX (for CrossEntropyLoss labels)
38
+ IGNORE_INDEX = -100
39
+
40
+
41
+ # === Utility Functions for Monkey-Patching ===
42
+ def unpack_tuple(fn: Callable[[Any], Tuple[Any]]) -> Callable[[Any], Any]:
43
+ def wrapper(*args: Any, **kwargs: Any) -> Any:
44
+ result = fn(*args, **kwargs)
45
+ return result[0] if isinstance(result, tuple) else result
46
+
47
+ return wrapper
48
+
49
+
50
+ # HF Transformers overwrites parameters with names containing `gamma`; we're going to patch VisionBackbone.LayerScale.
51
+ # =>> TIMM :: https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/vision_transformer.py#L109
52
+ # =>> Transformers :: https://github.com/huggingface/transformers/blob/main/src/transformers/modeling_utils.py#L3960
53
+ def _ls_new_forward(self, x: torch.Tensor) -> torch.Tensor:
54
+ return x.mul_(self.scale_factor) if self.inplace else x * self.scale_factor
55
+
56
+
57
+ def ls_apply_patch(ls_module: LayerScale):
58
+ ls_module.scale_factor = nn.Parameter(ls_module.gamma.clone())
59
+ ls_module.forward = _ls_new_forward.__get__(ls_module, LayerScale)
60
+ del ls_module.gamma
61
+
62
+
63
+ # === Prismatic Vision Backbone (nn.Module) Definitions (w/ Fused Backbone Support) ===
64
+ class PrismaticVisionBackbone(nn.Module):
65
+ def __init__(
66
+ self,
67
+ use_fused_vision_backbone: bool,
68
+ image_sizes: List[int],
69
+ timm_model_ids: List[str],
70
+ timm_override_act_layers: List[Optional[str]],
71
+ ) -> None:
72
+ super().__init__()
73
+ self.use_fused_vision_backbone = use_fused_vision_backbone
74
+
75
+ # [Contract] Validate number of (fused) vision backbones, create "alpha" featurizer and Instantiate
76
+ # =>> Note :: Monkey-Patch the `forward()` function of the backbone to ensure FSDP-compatibility
77
+ # Hardcodes `get_intermediate_layers` to return the **SECOND-TO-LAST** layer patches!
78
+ assert len(timm_model_ids) <= 2, "Prismatic models only support up to 2 (fused) vision backbones!"
79
+ self.featurizer = timm.create_model(
80
+ timm_model_ids[0],
81
+ pretrained=False,
82
+ num_classes=0,
83
+ img_size=image_sizes[0],
84
+ act_layer=timm_override_act_layers[0],
85
+ )
86
+ self.featurizer.forward = unpack_tuple(
87
+ partial(self.featurizer.get_intermediate_layers, n={len(self.featurizer.blocks) - 2})
88
+ )
89
+ self.embed_dim = self.featurizer.embed_dim
90
+
91
+ # If `use_fused_vision_backbone` =>> create "beta" featurizer
92
+ if self.use_fused_vision_backbone:
93
+ self.fused_featurizer = timm.create_model(
94
+ timm_model_ids[1],
95
+ pretrained=False,
96
+ num_classes=0,
97
+ img_size=image_sizes[1],
98
+ act_layer=timm_override_act_layers[1],
99
+ )
100
+ self.fused_featurizer.forward = unpack_tuple(
101
+ partial(self.fused_featurizer.get_intermediate_layers, n={len(self.fused_featurizer.blocks) - 2})
102
+ )
103
+ self.embed_dim += self.fused_featurizer.embed_dim
104
+
105
+ # Patch `vision_backbone.featurizer` and `vision_backbone.fused_featurizer` with HF-Compatible LayerScale
106
+ for module in self.featurizer.modules():
107
+ if isinstance(module, LayerScale):
108
+ ls_apply_patch(module)
109
+
110
+ if self.use_fused_vision_backbone:
111
+ for module in self.fused_featurizer.modules():
112
+ if isinstance(module, LayerScale):
113
+ ls_apply_patch(module)
114
+
115
+ def forward(self, pixel_values: torch.Tensor) -> torch.Tensor:
116
+ """Run image (`pixel_values`) through featurizer; if channel-stacked, then dispatch and sequence stack."""
117
+ if not self.use_fused_vision_backbone:
118
+ return self.featurizer(pixel_values)
119
+
120
+ # Split `pixel_values :: [bsz, 2 * 3, resolution, resolution]` =>> featurize =>> channel stack
121
+ img, img_fused = torch.split(pixel_values, [3, 3], dim=1)
122
+ patches, patches_fused = self.featurizer(img), self.fused_featurizer(img_fused)
123
+
124
+ return torch.cat([patches, patches_fused], dim=2)
125
+
126
+
127
+ # === Prismatic Projector (nn.Module) Definitions ===
128
+ class PrismaticProjector(nn.Module):
129
+ def __init__(self, use_fused_vision_backbone: bool, vision_dim: int, llm_dim: int) -> None:
130
+ super().__init__()
131
+ self.use_fused_vision_backbone = use_fused_vision_backbone
132
+ self.vision_dim, self.llm_dim = vision_dim, llm_dim
133
+
134
+ # Switch on `use_fused_vision_backbone` =>> use slightly different MLPs and projection factors!
135
+ if not self.use_fused_vision_backbone:
136
+ self.fc1 = nn.Linear(self.vision_dim, self.llm_dim, bias=True)
137
+ self.fc2 = nn.Linear(self.llm_dim, self.llm_dim, bias=True)
138
+ self.act_fn1 = nn.GELU()
139
+ else:
140
+ initial_projection_dim = 4 * vision_dim
141
+ self.fc1 = nn.Linear(self.vision_dim, initial_projection_dim, bias=True)
142
+ self.fc2 = nn.Linear(initial_projection_dim, self.llm_dim, bias=True)
143
+ self.fc3 = nn.Linear(self.llm_dim, self.llm_dim, bias=True)
144
+ self.act_fn1 = nn.GELU()
145
+ self.act_fn2 = nn.GELU()
146
+
147
+ def forward(self, img_patches: torch.Tensor) -> torch.Tensor:
148
+ if not self.use_fused_vision_backbone:
149
+ projected_features = self.fc1(img_patches)
150
+ projected_features = self.act_fn1(projected_features)
151
+ projected_features = self.fc2(projected_features)
152
+ else:
153
+ projected_features = self.fc1(img_patches)
154
+ projected_features = self.act_fn1(projected_features)
155
+ projected_features = self.fc2(projected_features)
156
+ projected_features = self.act_fn2(projected_features)
157
+ projected_features = self.fc3(projected_features)
158
+
159
+ return projected_features
160
+
161
+
162
+ # === Main HF Class Definitions ===
163
+ @dataclass
164
+ class PrismaticCausalLMOutputWithPast(ModelOutput):
165
+ """Base class for Prismatic casual (visually-conditioned) language model outputs; also exposes visual features."""
166
+
167
+ loss: Optional[torch.FloatTensor] = None
168
+ logits: torch.FloatTensor = None
169
+ past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
170
+ hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
171
+ attentions: Optional[Tuple[torch.FloatTensor]] = None
172
+
173
+ # Additions for VLMs
174
+ projector_features: Optional[torch.FloatTensor] = None
175
+
176
+
177
+ class PrismaticPreTrainedModel(PreTrainedModel):
178
+ config_class: PretrainedConfig = PrismaticConfig
179
+ base_model_prefix: str = "model"
180
+ supports_gradient_checkpointing: bool = True
181
+
182
+ _no_split_modules: ClassVar[List[str]] = ["PrismaticProjector"]
183
+ _skip_keys_device_placement: str = "past_key_values"
184
+ _supports_flash_attn_2: bool = True
185
+
186
+ def _init_weights(self, module: nn.Module) -> None:
187
+ # Important :: this HF ported version is *not* meant for training from scratch; only inference and fine-tuning!
188
+ # => As such, this init_weights code is not correct; if training VLMs from scratch, use the main codebase at
189
+ # https://github.com/TRI-ML/prismatic-vlms
190
+ std = (
191
+ self.config.initializer_range
192
+ if hasattr(self.config, "initializer_range")
193
+ else self.config.text_config.initializer_range
194
+ )
195
+
196
+ if hasattr(module, "class_embedding"):
197
+ module.class_embedding.data.normal_(mean=0.0, std=std)
198
+
199
+ if isinstance(module, (nn.Linear, nn.Conv2d)):
200
+ module.weight.data.normal_(mean=0.0, std=std)
201
+ if module.bias is not None:
202
+ module.bias.data.zero_()
203
+ elif isinstance(module, nn.Embedding):
204
+ module.weight.data.normal_(mean=0.0, std=std)
205
+ if module.padding_idx is not None:
206
+ module.weight.data[module.padding_idx].zero_()
207
+
208
+ @property
209
+ def _supports_sdpa(self) -> bool:
210
+ """Check LLM supports SDPA Attention"""
211
+ return self.language_model._supports_sdpa
212
+
213
+
214
+ class PrismaticForConditionalGeneration(PrismaticPreTrainedModel):
215
+ def __init__(self, config: PrismaticConfig) -> None:
216
+ super().__init__(config)
217
+
218
+ # [Validation] Lightweight Validate on `config` Fields + Dependency Versions
219
+ if config.use_fused_vision_backbone is None:
220
+ raise ValueError("Missing config field `use_fused_vision_backbone`")
221
+
222
+ if timm.__version__ not in {"0.9.10", "0.9.11", "0.9.12", "0.9.16"}:
223
+ raise NotImplementedError(
224
+ "TIMM Version must be >= 0.9.10 and < 1.0.0 (breaking); please raise a GitHub Issue "
225
+ "if you urgently need support for latest TIMM versions."
226
+ )
227
+
228
+ if (transformers.__version__ != "4.40.1") or (tokenizers.__version__ != "0.19.1"):
229
+ logger.warning(
230
+ f"Expected `transformers==4.40.1` and `tokenizers==0.19.1` but got "
231
+ f"`transformers=={transformers.__version__}` and `tokenizers=={tokenizers.__version__}`; "
232
+ f"there might be inference-time regressions due to dependency changes. If in doubt, please"
233
+ f"use the above versions."
234
+ )
235
+
236
+ # Instantiate PrismaticVisionBackbone (w/ Potential Fused Backbone)
237
+ print(config)
238
+ self.vision_backbone = PrismaticVisionBackbone(
239
+ config.use_fused_vision_backbone, config.image_sizes, config.timm_model_ids, config.timm_override_act_layers
240
+ )
241
+ print(config.use_fused_vision_backbone, config.image_sizes, config.timm_model_ids, config.timm_override_act_layers)
242
+ # Create Multimodal Projector
243
+ self.projector = LinearProjector(
244
+
245
+ self.vision_backbone.embed_dim,
246
+ config.text_config.hidden_size,
247
+ )
248
+ print(self.vision_backbone.embed_dim,
249
+ config.text_config.hidden_size)
250
+ # Instantiate LLM Backbone
251
+ self.language_model = AutoModelForCausalLM.from_pretrained("agopalkr/gemma-2b")
252
+ self.vocab_size = config.text_config.vocab_size
253
+ self.pad_token_id = config.pad_token_id
254
+
255
+ # HF Boilerplate =>> initializes weights via `_init_weights()` and sets gradient checkpointing
256
+ self.post_init()
257
+
258
+ # === `PreTrainedModel` Boilerplate ===
259
+ def get_input_embeddings(self) -> nn.Module:
260
+ return self.language_model.get_input_embeddings()
261
+
262
+ def set_input_embeddings(self, value: nn.Module) -> None:
263
+ self.language_model.set_input_embeddings(value)
264
+
265
+ def get_output_embeddings(self) -> nn.Module:
266
+ return self.language_model.get_output_embeddings()
267
+
268
+ def set_output_embeddings(self, new_embeddings: nn.Module) -> None:
269
+ self.language_model.set_output_embeddings(new_embeddings)
270
+
271
+ def get_decoder(self) -> nn.Module:
272
+ return self.language_model.get_decoder()
273
+
274
+ def set_decoder(self, decoder: nn.Module) -> None:
275
+ self.language_model.set_decoder(decoder)
276
+
277
+ def tie_weights(self) -> None:
278
+ self.language_model.tie_weights() # Note: `Llama-2` and `Mistral` don't tie weights (no-op)
279
+
280
+ def resize_token_embeddings(
281
+ self, new_num_tokens: Optional[int] = None, pad_to_multiple_of: Optional[int] = None
282
+ ) -> nn.Embedding:
283
+ updated_embeddings = self.language_model.resize_token_embeddings(new_num_tokens, pad_to_multiple_of)
284
+
285
+ # Update config/instance variables
286
+ self.config.text_config.vocab_size = updated_embeddings.num_embeddings
287
+ self.vocab_size = updated_embeddings.num_embeddings
288
+
289
+ return updated_embeddings
290
+
291
+ # === Core Prismatic VLM `forward()` Logic ===
292
+ def forward(
293
+ self,
294
+ input_ids: Optional[torch.LongTensor] = None,
295
+ attention_mask: Optional[torch.Tensor] = None,
296
+ pixel_values: Optional[torch.FloatTensor] = None,
297
+ labels: Optional[torch.LongTensor] = None,
298
+ inputs_embeds: Optional[torch.FloatTensor] = None,
299
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
300
+ use_cache: Optional[bool] = None,
301
+ output_attentions: Optional[bool] = None,
302
+ output_hidden_states: Optional[bool] = None,
303
+ output_projector_features: Optional[bool] = None,
304
+ return_dict: Optional[bool] = None,
305
+ ) -> Union[Tuple, PrismaticCausalLMOutputWithPast]:
306
+ """Run a forward pass through the VLM, returning a PrismaticCausalLMOutputWithPast instance."""
307
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
308
+ output_hidden_states = (
309
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
310
+ )
311
+ output_projector_features = output_projector_features if output_projector_features is not None else False
312
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
313
+
314
+ # Respect `use_cache` only if not training (even if `gradient_checkpointing` is off)
315
+ use_cache = use_cache and not self.training
316
+
317
+ # Instantiate Placeholder for Projector Features
318
+ projected_patch_embeddings = None
319
+
320
+ # Note :: We only support forward passes with the following cases:
321
+ # => Cached Generation :: (input_ids.shape[1] == 1) and (past_key_values is not None)
322
+ # => Unimodal Forward :: (pixel_values is None)
323
+ # => Multimodal Forward :: (pixel_values is not None) and (input_ids/embeds.shape[0] == pixel_values.shape[0])
324
+
325
+ # === Handle Generation with Cache (`input_ids.shape[1] == 1`) =>> requires `past_keys_values` ===
326
+ if input_ids.shape[1] == 1:
327
+ assert input_ids.shape[0] == 1, "Generation is only currently supported for batch size of 1!"
328
+ assert past_key_values is not None, "You must provide `past_key_values` during cached generation!"
329
+ assert labels is None, "Unexpected key `labels` provided during cached generation!"
330
+
331
+ language_model_output = self.language_model(
332
+ input_ids=input_ids,
333
+ attention_mask=None,
334
+ position_ids=None,
335
+ past_key_values=past_key_values,
336
+ inputs_embeds=None,
337
+ labels=None,
338
+ use_cache=use_cache,
339
+ output_attentions=output_attentions,
340
+ output_hidden_states=output_hidden_states,
341
+ return_dict=return_dict,
342
+ )
343
+
344
+ # === Handle Unimodal Forward ===
345
+ elif pixel_values is None:
346
+ assert (input_ids is not None) and (inputs_embeds is None), "Missing `input_ids` in language-only forward!"
347
+ assert past_key_values is None, "Unexpected key `past_key_values` provided during language-only forward!"
348
+
349
+ language_model_output = self.language_model(
350
+ input_ids=input_ids,
351
+ attention_mask=attention_mask,
352
+ position_ids=None,
353
+ past_key_values=None,
354
+ inputs_embeds=None,
355
+ labels=labels,
356
+ use_cache=use_cache,
357
+ output_attentions=output_attentions,
358
+ output_hidden_states=output_hidden_states,
359
+ return_dict=return_dict,
360
+ )
361
+
362
+ # === Handle Multimodal Forward ===
363
+ elif (input_ids.shape[0] == pixel_values.shape[0]) or (inputs_embeds.shape[0] == pixel_values.shape[0]):
364
+ assert past_key_values is None, "Unexpected key `past_key_values` provided during language-only forward!"
365
+
366
+ # Visual Feature Extraction
367
+ patch_features = self.vision_backbone(pixel_values)
368
+
369
+ # Projection Logic =>> Update Attention Mask
370
+ projected_patch_embeddings = self.projector(patch_features)
371
+ projected_patch_attention_mask = None
372
+ if attention_mask is not None:
373
+ projected_patch_attention_mask = torch.full(
374
+ (projected_patch_embeddings.shape[0], projected_patch_embeddings.shape[1]),
375
+ fill_value=True,
376
+ dtype=attention_mask.dtype,
377
+ device=attention_mask.device,
378
+ )
379
+
380
+ # Get Input Embeddings (from Language Model Embeddings)
381
+ input_embeddings = self.get_input_embeddings()(input_ids)
382
+
383
+ # Build Multimodal Embeddings & Attention Mask =>> Prismatic defaults to inserting after <BOS> token (1:)
384
+ multimodal_embeddings = torch.cat(
385
+ [input_embeddings[:, :1, :], projected_patch_embeddings, input_embeddings[:, 1:, :]], dim=1
386
+ )
387
+ multimodal_attention_mask = None
388
+ if attention_mask is not None:
389
+ multimodal_attention_mask = torch.cat(
390
+ [attention_mask[:, :1], projected_patch_attention_mask, attention_mask[:, 1:]], dim=1
391
+ )
392
+
393
+ # Build Labels (if specified) =>> Ignore Labels for Patch Embeddings
394
+ multimodal_labels = None
395
+ if labels is not None:
396
+ projected_patch_labels = torch.full(
397
+ (projected_patch_embeddings.shape[0], projected_patch_embeddings.shape[1]),
398
+ fill_value=IGNORE_INDEX,
399
+ dtype=labels.dtype,
400
+ device=labels.device,
401
+ )
402
+ multimodal_labels = torch.cat([labels[:, :1], projected_patch_labels, labels[:, 1:]], dim=1)
403
+
404
+ # Dispatch to Language Model
405
+ language_model_output = self.language_model(
406
+ input_ids=None,
407
+ attention_mask=multimodal_attention_mask,
408
+ position_ids=None,
409
+ past_key_values=None,
410
+ inputs_embeds=multimodal_embeddings,
411
+ labels=multimodal_labels,
412
+ use_cache=use_cache,
413
+ output_attentions=output_attentions,
414
+ output_hidden_states=output_hidden_states,
415
+ return_dict=return_dict,
416
+ )
417
+
418
+ # === Otherwise =>> Assume Invalid! ===
419
+ elif (input_ids.shape[0] != pixel_values.shape[0]) or (inputs_embeds.shape[0] != pixel_values.shape[0]):
420
+ raise ValueError("Non-homogenous batch of (text, image) input -- forward() does not support mixed batches!")
421
+
422
+ else:
423
+ raise ValueError(
424
+ "Invalid PrismaticForConditionalGeneration `forward()` call with provided arguments:\n"
425
+ f"=> `input_ids` = {input_ids is not None}\n"
426
+ f"=> `attention_mask` = {attention_mask is not None}\n"
427
+ f"=> `pixel_values` = {pixel_values is not None}\n"
428
+ f"=> `labels` = {labels is not None}\n"
429
+ f"=> `input_embeds` = {inputs_embeds is not None}\n"
430
+ f"=> `past_key_values` = {past_key_values is not None}\n"
431
+ f"=> `use_cache` = {use_cache}"
432
+ )
433
+
434
+ # Unpack `language_model_output` and return PrismaticCausalLMOutputWithPast (or tuple if not `return_dict`)
435
+ if not return_dict:
436
+ if output_projector_features and (projected_patch_embeddings is not None):
437
+ return *language_model_output, projected_patch_embeddings
438
+
439
+ return language_model_output
440
+
441
+ return PrismaticCausalLMOutputWithPast(
442
+ loss=language_model_output.loss,
443
+ logits=language_model_output.logits,
444
+ past_key_values=language_model_output.past_key_values,
445
+ hidden_states=language_model_output.hidden_states,
446
+ attentions=language_model_output.attentions,
447
+ projector_features=projected_patch_embeddings,
448
+ )
449
+
450
+ # === GenerationMixin Methods ===
451
+ def prepare_inputs_for_generation(
452
+ self,
453
+ input_ids: Optional[torch.Tensor] = None,
454
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
455
+ inputs_embeds: Optional[torch.FloatTensor] = None,
456
+ pixel_values: Optional[torch.FloatTensor] = None,
457
+ attention_mask: Optional[torch.Tensor] = None,
458
+ **kwargs: str,
459
+ ) -> Dict[str, torch.Tensor]:
460
+ """Borrowed from `LlamaForCausalLM` and simplified for batch size = 1; mirrors original PrismaticVLM logic."""
461
+ if ((input_ids is not None) and (input_ids.shape[0] > 1)) or (
462
+ (inputs_embeds is not None) and (inputs_embeds.shape[0] > 1)
463
+ ):
464
+ raise ValueError("Generation with batch size > 1 is not currently supported!")
465
+
466
+ # Handle `past_key_values` (cache) =>> assume `input_ids` just has unprocessed tokens
467
+ if past_key_values is not None:
468
+ input_ids = input_ids[:, -1:]
469
+
470
+ # If `input_embeds` are passed, we only want to use them in the 1st generation step
471
+ if inputs_embeds is not None and past_key_values is None:
472
+ model_inputs = {"input_embeds": inputs_embeds}
473
+ else:
474
+ model_inputs = {"input_ids": input_ids}
475
+
476
+ # Make sure `pixel_values` are preserved in `model_inputs`
477
+ model_inputs.update(
478
+ {
479
+ "attention_mask": attention_mask,
480
+ "pixel_values": pixel_values,
481
+ "past_key_values": past_key_values,
482
+ "use_cache": kwargs.get("use_cache"),
483
+ }
484
+ )
485
+
486
+ return model_inputs
487
+
488
+ # Defer to Language Model (all handle this differently, with different return types)
489
+ def _reorder_cache(self, *args, **kwargs) -> Any:
490
+ return self.language_model._reorder_cache(*args, **kwargs)
491
+
492
+
493
+ class OpenVLAForActionPrediction(PrismaticForConditionalGeneration):
494
+ config_class: PretrainedConfig = OpenVLAConfig
495
+
496
+ def __init__(self, config: OpenVLAConfig) -> None:
497
+ super().__init__(config)
498
+ self.norm_stats = config.norm_stats
499
+
500
+ # Compute action bins
501
+ self.bins = np.linspace(-1, 1, config.n_action_bins)
502
+ self.bin_centers = (self.bins[:-1] + self.bins[1:]) / 2.0
503
+
504
+ # Compute vocab size for de-tokenization -- revert added "multiple of"
505
+ self.vocab_size = self.config.text_config.vocab_size - self.config.pad_to_multiple_of
506
+
507
+ def predict_action(
508
+ self, input_ids: Optional[torch.LongTensor] = None, unnorm_key: Optional[str] = None, **kwargs: str
509
+ ) -> np.ndarray:
510
+ """Thin wrapper around super().generate() that decodes predicted actions and de-normalizes them."""
511
+
512
+ # We need to add this special empty token ('') after the colon (':') token in "ASSISTANT:"
513
+ # in order for the predictions to match the training configuration and be accurate.
514
+ input_ids = torch.cat(
515
+ (input_ids, torch.unsqueeze(torch.Tensor([29871]).long(), dim=0).to(input_ids.device)), dim=1
516
+ )
517
+
518
+ # Run VLA inference
519
+ generated_ids = self.generate(input_ids, max_new_tokens=self.get_action_dim(unnorm_key), **kwargs)
520
+
521
+ # Extract predicted action tokens and translate into (normalized) continuous actions
522
+ predicted_action_token_ids = generated_ids[0, -self.get_action_dim(unnorm_key) :].cpu().numpy()
523
+ discretized_actions = self.vocab_size - predicted_action_token_ids
524
+ discretized_actions = np.clip(discretized_actions - 1, a_min=0, a_max=self.bin_centers.shape[0] - 1)
525
+ normalized_actions = self.bin_centers[discretized_actions]
526
+
527
+ # Unnormalize actions
528
+ action_norm_stats = self.get_action_stats(unnorm_key)
529
+ mask = action_norm_stats.get("mask", np.ones_like(action_norm_stats["q01"], dtype=bool))
530
+ action_high, action_low = np.array(action_norm_stats["q99"]), np.array(action_norm_stats["q01"])
531
+ actions = np.where(
532
+ mask,
533
+ 0.5 * (normalized_actions + 1) * (action_high - action_low) + action_low,
534
+ normalized_actions,
535
+ )
536
+
537
+ return actions
538
+
539
+ @staticmethod
540
+ def _check_unnorm_key(norm_stats: Dict[str, Dict[str, Any]], unnorm_key: Optional[str]) -> str:
541
+ if unnorm_key is None and len(norm_stats) != 1:
542
+ raise ValueError(
543
+ f"Your model was trained on more than one dataset. "
544
+ f"Please pass a `unnorm_key` from the following options to choose the statistics used for "
545
+ f"de-normalizing actions: {norm_stats.keys()}"
546
+ )
547
+
548
+ # If None, grab the (singular) dataset in `norm_stats` to use as `unnorm_key`
549
+ unnorm_key = unnorm_key if unnorm_key is not None else next(iter(norm_stats.keys()))
550
+ if unnorm_key not in norm_stats:
551
+ raise ValueError(
552
+ f"The `unnorm_key` you chose ({unnorm_key = }) is not in the available statistics. "
553
+ f"Please choose from: {norm_stats.keys()}"
554
+ )
555
+
556
+ return unnorm_key
557
+
558
+ def get_action_dim(self, unnorm_key: Optional[str] = None) -> int:
559
+ """Get the dimensionality of the policy's action space."""
560
+ unnorm_key = self._check_unnorm_key(self.norm_stats, unnorm_key)
561
+ return len(self.norm_stats[unnorm_key]["action"]["q01"])
562
+
563
+ def get_action_stats(self, unnorm_key: Optional[str] = None) -> Dict[str, Any]:
564
+ """Get all the logged statistics for the given dataset."""
565
+ unnorm_key = self._check_unnorm_key(self.norm_stats, unnorm_key)
566
+ return self.norm_stats[unnorm_key]["action"]