JT-L commited on
Commit
ce6dcb4
·
verified ·
1 Parent(s): d88d80a

Upload 4 files

Browse files
modeling_prismatic.py ADDED
@@ -0,0 +1,1499 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ modeling_prismatic.py
3
+
4
+ Core HuggingFace-style PrismaticPreTrainedModel and PrismaticForConditionalGeneration class definitions.
5
+ Inherits from the default `transformers.PretrainedModel`. Meant to be standalone and self-contained,
6
+ but exactly replicate the logic in `prismatic.models.vlms.prismatic.py`.
7
+ """
8
+
9
+ import logging
10
+ from dataclasses import dataclass
11
+ from functools import partial
12
+ from typing import Any, Callable, ClassVar, Dict, List, Optional, Tuple, Union
13
+
14
+ import numpy as np
15
+ import timm
16
+ import tokenizers
17
+ import torch
18
+ import torch.nn as nn
19
+ import transformers
20
+ from timm.models.vision_transformer import LayerScale
21
+ from transformers import AutoModelForCausalLM, PretrainedConfig, PreTrainedModel
22
+ from transformers.modeling_outputs import ModelOutput
23
+
24
+ from prismatic.training.train_utils import (
25
+ get_current_action_mask,
26
+ get_next_actions_mask,
27
+ )
28
+ from prismatic.vla.constants import (
29
+ ACTION_DIM,
30
+ ACTION_PROPRIO_NORMALIZATION_TYPE,
31
+ ACTION_TOKEN_BEGIN_IDX,
32
+ IGNORE_INDEX,
33
+ NUM_ACTIONS_CHUNK,
34
+ STOP_INDEX,
35
+ NormalizationType,
36
+ NUM_TOKENS
37
+ )
38
+
39
+ from .configuration_prismatic import OpenVLAConfig, PrismaticConfig
40
+
41
+
42
+
43
+ # Set up logger
44
+ logger = logging.getLogger(__name__)
45
+
46
+
47
+ # === Utility Functions for Monkey-Patching ===
48
+ def unpack_tuple(fn: Callable[[Any], Tuple[Any]]) -> Callable[[Any], Any]:
49
+ def wrapper(*args: Any, **kwargs: Any) -> Any:
50
+ result = fn(*args, **kwargs)
51
+ return result[0] if isinstance(result, tuple) else result
52
+
53
+ return wrapper
54
+
55
+
56
+ # HF Transformers overwrites parameters with names containing `gamma`; we're going to patch VisionBackbone.LayerScale.
57
+ # =>> TIMM :: https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/vision_transformer.py#L109
58
+ # =>> Transformers :: https://github.com/huggingface/transformers/blob/main/src/transformers/modeling_utils.py#L3960
59
+ def _ls_new_forward(self, x: torch.Tensor) -> torch.Tensor:
60
+ return x.mul_(self.scale_factor) if self.inplace else x * self.scale_factor
61
+
62
+
63
+ def ls_apply_patch(ls_module: LayerScale):
64
+ ls_module.scale_factor = nn.Parameter(ls_module.gamma.clone())
65
+ ls_module.forward = _ls_new_forward.__get__(ls_module, LayerScale)
66
+ del ls_module.gamma
67
+
68
+
69
+ # === Prismatic Vision Backbone (nn.Module) Definitions (w/ Fused Backbone Support) ===
70
+ class PrismaticVisionBackbone(nn.Module):
71
+ """
72
+ Vision backbone for Prismatic models that handles image feature extraction.
73
+
74
+ Supports both single backbone (e.g., SigLIP) and fused backbone (e.g., SigLIP + DINOv2) configurations.
75
+ For fused backbones, features from both models are concatenated along the feature dimension.
76
+ """
77
+
78
+ def __init__(
79
+ self,
80
+ use_fused_vision_backbone: bool,
81
+ image_sizes: List[int],
82
+ timm_model_ids: List[str],
83
+ timm_override_act_layers: List[Optional[str]],
84
+ ) -> None:
85
+ """
86
+ Initialize the vision backbone.
87
+
88
+ Args:
89
+ use_fused_vision_backbone: Whether to use two backbones and fuse their features
90
+ image_sizes: List of image sizes for each backbone
91
+ timm_model_ids: List of TIMM model IDs to use for each backbone
92
+ timm_override_act_layers: List of activation layer overrides for each backbone
93
+ """
94
+ super().__init__()
95
+ self.use_fused_vision_backbone = use_fused_vision_backbone
96
+ self.num_images_in_input = 1 # Default value, can be overridden later
97
+
98
+ # Validate number of (fused) vision backbones
99
+ if len(timm_model_ids) > 2:
100
+ raise ValueError("Prismatic models only support up to 2 (fused) vision backbones!")
101
+
102
+ # Create primary featurizer
103
+ self.featurizer = self._create_featurizer(
104
+ model_id=timm_model_ids[0], img_size=image_sizes[0], act_layer=timm_override_act_layers[0]
105
+ )
106
+ self.embed_dim = self.featurizer.embed_dim
107
+
108
+ # Create secondary featurizer if using fused backbone
109
+ if self.use_fused_vision_backbone:
110
+ self.fused_featurizer = self._create_featurizer(
111
+ model_id=timm_model_ids[1], img_size=image_sizes[1], act_layer=timm_override_act_layers[1]
112
+ )
113
+ self.embed_dim += self.fused_featurizer.embed_dim
114
+
115
+ # Patch LayerScale modules for HF compatibility
116
+ self._patch_layer_scales()
117
+
118
+ def _create_featurizer(self, model_id: str, img_size: int, act_layer: Optional[str]) -> nn.Module:
119
+ """
120
+ Create a TIMM-based featurizer model with appropriate configurations.
121
+
122
+ Args:
123
+ model_id: The TIMM model ID to load
124
+ img_size: Input image size for the model
125
+ act_layer: Override for the activation layer type
126
+
127
+ Returns:
128
+ A configured featurizer model
129
+ """
130
+ featurizer = timm.create_model(
131
+ model_id,
132
+ pretrained=False,
133
+ num_classes=0,
134
+ img_size=img_size,
135
+ act_layer=act_layer,
136
+ )
137
+
138
+ # Monkey-patch the forward function to extract the second-to-last layer features
139
+ num_blocks = len(featurizer.blocks)
140
+ featurizer.forward = unpack_tuple(partial(featurizer.get_intermediate_layers, n={num_blocks - 2}))
141
+
142
+ return featurizer
143
+
144
+ def _patch_layer_scales(self) -> None:
145
+ """
146
+ Patch all LayerScale modules to be compatible with HF's parameter naming.
147
+
148
+ HF Transformers overwrites parameters with names containing 'gamma',
149
+ so we need to rename and modify the forward method.
150
+ """
151
+ # Patch primary featurizer
152
+ for module in self.featurizer.modules():
153
+ if isinstance(module, LayerScale):
154
+ ls_apply_patch(module)
155
+
156
+ # Patch secondary featurizer if it exists
157
+ if self.use_fused_vision_backbone:
158
+ for module in self.fused_featurizer.modules():
159
+ if isinstance(module, LayerScale):
160
+ ls_apply_patch(module)
161
+
162
+ def get_num_patches(self) -> int:
163
+ """
164
+ Returns the number of vision patches output by the vision backbone.
165
+
166
+ Returns:
167
+ Number of patches per image
168
+ """
169
+ return self.featurizer.patch_embed.num_patches
170
+
171
+ def get_num_images_in_input(self) -> int:
172
+ """
173
+ Returns the number of input images for the vision backbone.
174
+
175
+ Returns:
176
+ Number of images expected in the input
177
+ """
178
+ return self.num_images_in_input
179
+
180
+ def set_num_images_in_input(self, num_images_in_input: int) -> None:
181
+ """
182
+ Sets the number of input images for the vision backbone.
183
+
184
+ Args:
185
+ num_images_in_input: Number of images to expect in the input
186
+ """
187
+ self.num_images_in_input = num_images_in_input
188
+
189
+ def forward(self, pixel_values: torch.Tensor) -> torch.Tensor:
190
+ """
191
+ Implements the forward pass for the vision backbone.
192
+
193
+ If `self.use_fused_vision_backbone == True`, uses both SigLIP and DINOv2 transformers to extract visual features
194
+ (otherwise uses SigLIP only). Allows multi-image inputs (but only for fused vision backbone).
195
+
196
+ Args:
197
+ pixel_values (torch.Tensor): Pixels for input image(s), (B, C, H, W).
198
+ """
199
+ if self.num_images_in_input == 1:
200
+ if not self.use_fused_vision_backbone:
201
+ return self.featurizer(pixel_values)
202
+
203
+ # Split `pixel_values :: [bsz, 2 * 3, resolution, resolution]` =>> featurize =>> channel stack
204
+ img, img_fused = torch.split(pixel_values, [3, 3], dim=1)
205
+ patches, patches_fused = self.featurizer(img), self.fused_featurizer(img_fused)
206
+
207
+ return torch.cat([patches, patches_fused], dim=2)
208
+
209
+ else:
210
+ assert self.use_fused_vision_backbone, "Multi-image inputs require using fused backbone!"
211
+
212
+ # Split `pixel_values` into individual images (each with 6 channels: 3 for SigLIP + 3 for DINOv2)
213
+ images = torch.split(pixel_values, [6] * self.num_images_in_input, dim=1)
214
+
215
+ # Process each image and collect patches
216
+ all_patches = []
217
+ for img in images:
218
+ # Split each image further into two stacks of channels (each with 3 channels)
219
+ img_regular, img_fused = torch.split(img, [3, 3], dim=1)
220
+
221
+ # Get patches from both SigLIP and DINOv2 vision transformers
222
+ patches = self.featurizer(img_regular)
223
+ patches_fused = self.fused_featurizer(img_fused)
224
+
225
+ # Concatenate SigLIP and DINOv2 patches along the hidden dimension
226
+ combined_patches = torch.cat([patches, patches_fused], dim=2)
227
+ all_patches.append(combined_patches)
228
+
229
+ # Concatenate all patches along the patch dimension
230
+ return torch.cat(all_patches, dim=1)
231
+
232
+
233
+ # === Prismatic Projector (nn.Module) Definitions ===
234
+ class PrismaticProjector(nn.Module):
235
+ def __init__(self, use_fused_vision_backbone: bool, vision_dim: int, llm_dim: int) -> None:
236
+ super().__init__()
237
+ self.use_fused_vision_backbone = use_fused_vision_backbone
238
+ self.vision_dim, self.llm_dim = vision_dim, llm_dim
239
+
240
+ # Switch on `use_fused_vision_backbone` =>> use slightly different MLPs and projection factors!
241
+ if not self.use_fused_vision_backbone:
242
+ self.fc1 = nn.Linear(self.vision_dim, self.llm_dim, bias=True)
243
+ self.fc2 = nn.Linear(self.llm_dim, self.llm_dim, bias=True)
244
+ self.act_fn1 = nn.GELU()
245
+ else:
246
+ initial_projection_dim = 4 * vision_dim
247
+ self.fc1 = nn.Linear(self.vision_dim, initial_projection_dim, bias=True)
248
+ self.fc2 = nn.Linear(initial_projection_dim, self.llm_dim, bias=True)
249
+ self.fc3 = nn.Linear(self.llm_dim, self.llm_dim, bias=True)
250
+ self.act_fn1 = nn.GELU()
251
+ self.act_fn2 = nn.GELU()
252
+
253
+ def forward(self, img_patches: torch.Tensor) -> torch.Tensor:
254
+ if not self.use_fused_vision_backbone:
255
+ projected_features = self.fc1(img_patches)
256
+ projected_features = self.act_fn1(projected_features)
257
+ projected_features = self.fc2(projected_features)
258
+ else:
259
+ projected_features = self.fc1(img_patches)
260
+ projected_features = self.act_fn1(projected_features)
261
+ projected_features = self.fc2(projected_features)
262
+ projected_features = self.act_fn2(projected_features)
263
+ projected_features = self.fc3(projected_features)
264
+
265
+ return projected_features
266
+
267
+
268
+ # === Main HF Class Definitions ===
269
+ @dataclass
270
+ class PrismaticCausalLMOutputWithPast(ModelOutput):
271
+ """Base class for Prismatic casual (visually-conditioned) language model outputs; also exposes visual features."""
272
+
273
+ loss: Optional[torch.FloatTensor] = None
274
+ logits: torch.FloatTensor = None
275
+ past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
276
+ hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
277
+ attentions: Optional[Tuple[torch.FloatTensor]] = None
278
+
279
+ # Additions for VLMs
280
+ projector_features: Optional[torch.FloatTensor] = None
281
+
282
+
283
+ class PrismaticPreTrainedModel(PreTrainedModel):
284
+ config_class: PretrainedConfig = PrismaticConfig
285
+ base_model_prefix: str = "model"
286
+ supports_gradient_checkpointing: bool = True
287
+
288
+ _no_split_modules: ClassVar[List[str]] = ["PrismaticProjector"]
289
+ _skip_keys_device_placement: str = "past_key_values"
290
+ _supports_flash_attn_2: bool = True
291
+
292
+ def _init_weights(self, module: nn.Module) -> None:
293
+ # Important :: this HF ported version is *not* meant for training from scratch; only inference and fine-tuning!
294
+ # => As such, this init_weights code is not correct; if training VLMs from scratch, use the main codebase at
295
+ # https://github.com/TRI-ML/prismatic-vlms
296
+ std = (
297
+ self.config.initializer_range
298
+ if hasattr(self.config, "initializer_range")
299
+ else self.config.text_config.initializer_range
300
+ )
301
+
302
+ if hasattr(module, "class_embedding"):
303
+ module.class_embedding.data.normal_(mean=0.0, std=std)
304
+
305
+ if isinstance(module, (nn.Linear, nn.Conv2d)):
306
+ module.weight.data.normal_(mean=0.0, std=std)
307
+ if module.bias is not None:
308
+ module.bias.data.zero_()
309
+ elif isinstance(module, nn.Embedding):
310
+ module.weight.data.normal_(mean=0.0, std=std)
311
+ if module.padding_idx is not None:
312
+ module.weight.data[module.padding_idx].zero_()
313
+
314
+ @property
315
+ def _supports_sdpa(self) -> bool:
316
+ """Check LLM supports SDPA Attention"""
317
+ return self.language_model._supports_sdpa
318
+
319
+
320
+ class PrismaticForConditionalGeneration(PrismaticPreTrainedModel):
321
+ def __init__(self, config: PrismaticConfig) -> None:
322
+ super().__init__(config)
323
+
324
+ # [Validation] Lightweight Validate on `config` Fields + Dependency Versions
325
+ if config.use_fused_vision_backbone is None:
326
+ raise ValueError("Missing config field `use_fused_vision_backbone`")
327
+
328
+ if timm.__version__ not in {"0.9.10", "0.9.11", "0.9.12", "0.9.16"}:
329
+ raise NotImplementedError(
330
+ "TIMM Version must be >= 0.9.10 and < 1.0.0 (breaking); please raise a GitHub Issue "
331
+ "if you urgently need support for latest TIMM versions."
332
+ )
333
+
334
+ if (transformers.__version__ != "4.40.1") or (tokenizers.__version__ != "0.19.1"):
335
+ logger.warning(
336
+ f"Expected `transformers==4.40.1` and `tokenizers==0.19.1` but got "
337
+ f"`transformers=={transformers.__version__}` and `tokenizers=={tokenizers.__version__}`; "
338
+ f"there might be inference-time regressions due to dependency changes. If in doubt, please"
339
+ f"use the above versions."
340
+ )
341
+ # import pdb; pdb.set_trace()
342
+ # Instantiate PrismaticVisionBackbone (w/ Potential Fused Backbone)
343
+ self.vision_backbone = PrismaticVisionBackbone(
344
+ config.use_fused_vision_backbone, config.image_sizes, config.timm_model_ids, config.timm_override_act_layers
345
+ )
346
+
347
+ # Create Multimodal Projector
348
+ self.projector = PrismaticProjector(
349
+ config.use_fused_vision_backbone,
350
+ vision_dim=self.vision_backbone.embed_dim,
351
+ llm_dim=config.text_config.hidden_size,
352
+ )
353
+
354
+ # Instantiate LLM Backbone
355
+ self.language_model = AutoModelForCausalLM.from_config(
356
+ config.text_config, attn_implementation=config._attn_implementation
357
+ )
358
+
359
+ self.vocab_size = config.text_config.vocab_size
360
+ self.pad_token_id = config.pad_token_id
361
+ self.llm_dim = config.text_config.hidden_size
362
+
363
+ #Action query token
364
+ self.action_queries = nn.Embedding(NUM_TOKENS, self.llm_dim)
365
+ self.action_queries.weight.data.zero_()
366
+
367
+ # HF Boilerplate =>> initializes weights via `_init_weights()` and sets gradient checkpointing
368
+ self.post_init()
369
+
370
+ # === `PreTrainedModel` Boilerplate ===
371
+ def get_input_embeddings(self) -> nn.Module:
372
+ return self.language_model.get_input_embeddings()
373
+ def set_version(self, version: str):
374
+ self.version = version
375
+ return self.version
376
+
377
+
378
+ def set_input_embeddings(self, value: nn.Module) -> None:
379
+ self.language_model.set_input_embeddings(value)
380
+
381
+ def get_output_embeddings(self) -> nn.Module:
382
+ return self.language_model.get_output_embeddings()
383
+
384
+ def set_output_embeddings(self, new_embeddings: nn.Module) -> None:
385
+ self.language_model.set_output_embeddings(new_embeddings)
386
+
387
+ def get_decoder(self) -> nn.Module:
388
+ return self.language_model.get_decoder()
389
+
390
+ def set_decoder(self, decoder: nn.Module) -> None:
391
+ self.language_model.set_decoder(decoder)
392
+
393
+ def tie_weights(self) -> None:
394
+ self.language_model.tie_weights() # Note: `Llama-2` and `Mistral` don't tie weights (no-op)
395
+
396
+ def resize_token_embeddings(
397
+ self, new_num_tokens: Optional[int] = None, pad_to_multiple_of: Optional[int] = None
398
+ ) -> nn.Embedding:
399
+ updated_embeddings = self.language_model.resize_token_embeddings(new_num_tokens, pad_to_multiple_of)
400
+
401
+ # Update config/instance variables
402
+ self.config.text_config.vocab_size = updated_embeddings.num_embeddings
403
+ self.vocab_size = updated_embeddings.num_embeddings
404
+
405
+ return updated_embeddings
406
+
407
+ def _replace_input_embeddings(self, input_embeddings, all_actions_mask, noisy_action_features):
408
+ """
409
+ Replace embeddings in input_embeddings at positions where all_actions_mask is True
410
+ with embeddings from noisy_action_features, using vectorized operations.
411
+
412
+ Args:
413
+ input_embeddings: Tensor of shape (B, S, D)
414
+ all_actions_mask: Boolean tensor of shape (B, S)
415
+ noisy_action_features: Tensor of shape (B, K, D) where K is the number of True values in mask per sample
416
+
417
+ Returns:
418
+ Modified input_embeddings tensor
419
+ """
420
+ # Clone input to avoid modifying the original tensor
421
+ new_input_embeddings = input_embeddings.clone()
422
+
423
+ # Create a tensor with the same shape of input_embeddings to hold the noisy action features
424
+ repositioned_noisy_action_features = torch.zeros_like(input_embeddings)
425
+
426
+ # Create batch indices for splicing
427
+ batch_indices = torch.arange(input_embeddings.shape[0], device=input_embeddings.device)
428
+ batch_indices = batch_indices.unsqueeze(1).expand(-1, noisy_action_features.shape[1])
429
+
430
+ # Get indices where mask is True for each sample
431
+ masked_indices = torch.stack([torch.where(mask)[0] for mask in all_actions_mask])
432
+
433
+ # Move the noisy action features into their correct positions
434
+ # print(noisy_action_features.size())
435
+ # import pdb; pdb.set_trace()
436
+ repositioned_noisy_action_features[batch_indices, masked_indices] = noisy_action_features
437
+
438
+ # Combine original input embeddings and noisy action embeddings using the mask
439
+ new_input_embeddings = torch.where(
440
+ all_actions_mask.unsqueeze(-1), repositioned_noisy_action_features, new_input_embeddings
441
+ )
442
+
443
+ return new_input_embeddings
444
+
445
+ def _process_action_masks(self, labels):
446
+ """Helper to get action masks from labels"""
447
+ current_action_mask = get_current_action_mask(labels)
448
+ next_actions_mask = get_next_actions_mask(labels)
449
+ all_actions_mask = current_action_mask | next_actions_mask # (B, seq_len)
450
+ return all_actions_mask
451
+
452
+ def _process_vision_features(self, pixel_values, language_embeddings=None, use_film=False):
453
+ """Process vision features with optional FiLM conditioning"""
454
+ if use_film:
455
+ # FiLM: Infuse language inputs into visual features
456
+ patch_features = self.vision_backbone(pixel_values, language_embeddings) # (bsz, 256 * num_images, D)
457
+ else:
458
+ patch_features = self.vision_backbone(pixel_values) # (bsz, 256 * num_images, D)
459
+
460
+ # Project patch embeddings into language embedding space
461
+ return self.projector(patch_features)
462
+
463
+ def _process_proprio_features(self, projected_patch_embeddings, proprio, proprio_projector):
464
+ """Process proprioceptive features and append to vision features"""
465
+ if proprio_projector is not None and proprio is not None:
466
+ # projected_patch_embeddings: (bsz, num_patches * num_images, llm_dim)
467
+ # proprio: (bsz, proprio_dim) or (propro_dim,)
468
+ proprio = proprio.reshape(projected_patch_embeddings.shape[0], -1) # (bsz, proprio_dim)
469
+ proprio_features = proprio_projector(proprio) # (bsz, llm_dim)
470
+ proprio_features = proprio_features.unsqueeze(dim=1) # (bsz, 1, llm_dim)
471
+ # For simplicity, just append proprio token to the end of projected vision patch tokens
472
+ return torch.cat((projected_patch_embeddings, proprio_features), dim=1)
473
+ return projected_patch_embeddings
474
+
475
+ def _build_multimodal_attention(self, input_embeddings, projected_patch_embeddings, attention_mask):
476
+ """Build multimodal embeddings and attention mask"""
477
+ # Update attention mask
478
+ # import pdb; pdb.set_trace()
479
+ projected_patch_attention_mask = None
480
+ if attention_mask is not None:
481
+ projected_patch_attention_mask = torch.full(
482
+ (projected_patch_embeddings.shape[0], projected_patch_embeddings.shape[1]),
483
+ fill_value=True,
484
+ dtype=attention_mask.dtype,
485
+ device=attention_mask.device,
486
+ )
487
+
488
+ # Build multimodal embeddings & attention mask; insert embeddings after <BOS> token (1:)
489
+ multimodal_embeddings = torch.cat(
490
+ [input_embeddings[:, :1, :], projected_patch_embeddings, input_embeddings[:, 1:, :]], dim=1
491
+ )
492
+
493
+ multimodal_attention_mask = None
494
+ if attention_mask is not None:
495
+ multimodal_attention_mask = torch.cat(
496
+ [attention_mask[:, :1], projected_patch_attention_mask, attention_mask[:, 1:]], dim=1
497
+ )
498
+
499
+ return multimodal_embeddings, multimodal_attention_mask
500
+
501
+ def _build_multimodal_labels(self, labels, projected_patch_embeddings):
502
+ """Build multimodal labels with IGNORE_INDEX for patch embeddings"""
503
+ if labels is not None:
504
+ projected_patch_labels = torch.full(
505
+ (projected_patch_embeddings.shape[0], projected_patch_embeddings.shape[1]),
506
+ fill_value=IGNORE_INDEX,
507
+ dtype=labels.dtype,
508
+ device=labels.device,
509
+ )
510
+ return torch.cat([labels[:, :1], projected_patch_labels, labels[:, 1:]], dim=1)
511
+ return None
512
+
513
+ # === Core Prismatic VLM `forward()` Logic ===
514
+ def forward(
515
+ self,
516
+ input_ids: Optional[torch.LongTensor] = None,
517
+ attention_mask: Optional[torch.Tensor] = None,
518
+ pixel_values: Optional[torch.FloatTensor] = None,
519
+ labels: Optional[torch.LongTensor] = None,
520
+ inputs_embeds: Optional[torch.FloatTensor] = None,
521
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
522
+ use_cache: Optional[bool] = None,
523
+ output_attentions: Optional[bool] = None,
524
+ output_hidden_states: Optional[bool] = None,
525
+ output_projector_features: Optional[bool] = None,
526
+ return_dict: Optional[bool] = None,
527
+ proprio=None,
528
+ proprio_projector=None,
529
+ noisy_actions=None,
530
+ noisy_action_projector=None,
531
+ diffusion_timestep_embeddings=None,
532
+ use_film: bool = False,
533
+ ) -> Union[Tuple, PrismaticCausalLMOutputWithPast]:
534
+ """Run a forward pass through the VLM, returning a PrismaticCausalLMOutputWithPast instance."""
535
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
536
+ output_hidden_states = (
537
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
538
+ )
539
+ output_projector_features = output_projector_features if output_projector_features is not None else False
540
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
541
+
542
+ # Respect `use_cache` only if not training (even if `gradient_checkpointing` is off)
543
+ use_cache = use_cache and not self.training
544
+
545
+ # Instantiate Placeholder for Projector Features
546
+ projected_patch_embeddings = None
547
+
548
+ # === Handle Generation with Cache (`input_ids.shape[1] == 1`) =>> requires `past_keys_values` ===
549
+ if input_ids.shape[1] == 1:
550
+ assert input_ids.shape[0] == 1, "Generation is only currently supported for batch size of 1!"
551
+ assert past_key_values is not None, "You must provide `past_key_values` during cached generation!"
552
+ assert labels is None, "Unexpected key `labels` provided during cached generation!"
553
+
554
+ language_model_output = self.language_model(
555
+ input_ids=input_ids,
556
+ attention_mask=None,
557
+ position_ids=None,
558
+ past_key_values=past_key_values,
559
+ inputs_embeds=None,
560
+ labels=None,
561
+ use_cache=use_cache,
562
+ output_attentions=output_attentions,
563
+ output_hidden_states=output_hidden_states,
564
+ return_dict=return_dict,
565
+ )
566
+
567
+ # === Handle Unimodal Forward ===
568
+ elif pixel_values is None:
569
+ assert (input_ids is not None) and (inputs_embeds is None), "Missing `input_ids` in language-only forward!"
570
+ assert past_key_values is None, "Unexpected key `past_key_values` provided during language-only forward!"
571
+
572
+ language_model_output = self.language_model(
573
+ input_ids=input_ids,
574
+ attention_mask=attention_mask,
575
+ position_ids=None,
576
+ past_key_values=None,
577
+ inputs_embeds=None,
578
+ labels=labels,
579
+ use_cache=use_cache,
580
+ output_attentions=output_attentions,
581
+ output_hidden_states=output_hidden_states,
582
+ return_dict=return_dict,
583
+ )
584
+
585
+ # === Handle Multimodal Forward ===
586
+ elif (input_ids.shape[0] == pixel_values.shape[0]) or (inputs_embeds.shape[0] == pixel_values.shape[0]):
587
+ assert past_key_values is None, "Unexpected key `past_key_values` provided during multimodal forward!"
588
+
589
+ # Get input embeddings (from language model embeddings)
590
+ input_embeddings = self.get_input_embeddings()(input_ids) # (B, seq_len, D)
591
+
592
+ # import pdb; pdb.set_trace()
593
+ # Extract action masks
594
+ all_actions_mask = self._process_action_masks(labels)
595
+
596
+ # Extract the language portion of the input embeddings (i.e. remove the action tokens portion)
597
+ # import pdb; pdb.set_trace()
598
+ # print(input_embeddings[~all_actions_mask].size())
599
+ language_embeddings = input_embeddings[~all_actions_mask].reshape(
600
+ input_embeddings.shape[0], -1, input_embeddings.shape[2]
601
+ ) # (B, lang_seq_len, llm_dim)
602
+
603
+ # Get visual features
604
+ projected_patch_embeddings = self._process_vision_features(pixel_values, language_embeddings, use_film)
605
+
606
+ # Add proprioceptive state if provided
607
+ if self.version == 'v1':
608
+ pass
609
+ else:
610
+ projected_patch_embeddings = self._process_proprio_features(
611
+ projected_patch_embeddings, proprio, proprio_projector
612
+ )
613
+
614
+ # [Diffusion] Add diffusion timestep embedding if provided
615
+ if diffusion_timestep_embeddings is not None:
616
+ if self.version == 'v1':
617
+ pass
618
+ else:
619
+ # For simplicity, just append diffusion timestep embedding to the end of projected vision patch tokens
620
+ projected_patch_embeddings = torch.cat(
621
+ (projected_patch_embeddings, diffusion_timestep_embeddings), dim=1
622
+ )
623
+
624
+
625
+ # Process action embeddings
626
+ if noisy_actions is not None:
627
+ # import pdb; pdb.set_trace()
628
+ if self.version == 'v1':
629
+ # action_queries = self.action_queries.weight # (1, h)
630
+ # action_queries = action_queries.view(1, 1, action_queries.shape[1]).repeat(input_embeddings.shape[0], 1, 1) # (b, chunk_size, h)
631
+ # input_embeddings = torch.cat((input_embeddings, action_queries), dim=1) # (b, n_tokens+chunk_size, h)
632
+ # action_attention_mask = None
633
+ # action_attention_mask = torch.full(
634
+ # (action_queries.shape[0], action_queries.shape[1]),
635
+ # fill_value=True,
636
+ # dtype=attention_mask.dtype,
637
+ # device=attention_mask.device,)
638
+ # attention_mask = torch.cat([attention_mask, action_attention_mask], dim=1)
639
+
640
+ action_queries = self.action_queries.weight # (1, h)
641
+ action_queries = action_queries.view(1, action_queries.shape[0], action_queries.shape[1]).repeat(input_embeddings.shape[0], 1, 1) # (b, chunk_size, h)
642
+ all_actions_mask = self._process_action_masks(labels)
643
+ input_embeddings = self._replace_input_embeddings(
644
+ input_embeddings, all_actions_mask, action_queries)
645
+ # import pdb; pdb.set_trace()
646
+
647
+ else:
648
+ # Get mask corresponding to all action tokens
649
+ all_actions_mask = self._process_action_masks(labels)
650
+
651
+ # Reshape noisy actions into individual action tokens
652
+ # noisy_actions: (B, chunk_len, action_dim) -> (B, chunk_len * action_dim, 1)
653
+ B = noisy_actions.shape[0]
654
+ noisy_actions = noisy_actions.reshape(B, -1).unsqueeze(-1)
655
+ # Project noisy action tokens into language model embedding space
656
+ noisy_action_features = noisy_action_projector(noisy_actions) # (B, chunk_len * action_dim, llm_dim)
657
+ # Replace embeddings of the action tokens with noisy action embeddings
658
+ input_embeddings = self._replace_input_embeddings(
659
+ input_embeddings, all_actions_mask, noisy_action_features)
660
+
661
+ else:
662
+ if self.version == 'v1':
663
+ action_queries = self.action_queries.weight # (1, h)
664
+ action_queries = action_queries.view(1, action_queries.shape[0], action_queries.shape[1]).repeat(input_embeddings.shape[0], 1, 1) # (b, chunk_size, h)
665
+ all_actions_mask = self._process_action_masks(labels)
666
+ input_embeddings = self._replace_input_embeddings(
667
+ input_embeddings, all_actions_mask, action_queries)
668
+ # import pdb; pdb.set_trace()
669
+ else:
670
+ # Replace the embeddings of the action tokens with zeros
671
+ # (Later on, the positional embeddings will be added to them)
672
+ all_actions_mask = all_actions_mask.unsqueeze(-1) # (B, seq_len, 1)
673
+ input_embeddings = input_embeddings * ~all_actions_mask
674
+
675
+
676
+ # Build multimodal embeddings & attention mask
677
+ multimodal_embeddings, multimodal_attention_mask = self._build_multimodal_attention(
678
+ input_embeddings, projected_patch_embeddings, attention_mask
679
+ )
680
+ # import pdb; pdb.set_trace()
681
+ # Build labels for multimodal sequence if needed
682
+ multimodal_labels = self._build_multimodal_labels(labels, projected_patch_embeddings)
683
+
684
+ # import pdb; pdb.set_trace()
685
+ # Dispatch to language model
686
+ if self.version == 'v1':
687
+ # import pdb; pdb.set_trace()
688
+ language_model_output = self.language_model(
689
+ input_ids=None,
690
+ attention_mask=multimodal_attention_mask,
691
+ position_ids=None,
692
+ past_key_values=None,
693
+ inputs_embeds=multimodal_embeddings,
694
+ labels=None,
695
+ use_cache=use_cache,
696
+ output_attentions=output_attentions,
697
+ output_hidden_states=output_hidden_states,
698
+ return_dict=return_dict,
699
+ )
700
+ # import pdb; pdb.set_trace()
701
+ else:
702
+ language_model_output = self.language_model(
703
+ input_ids=None,
704
+ attention_mask=multimodal_attention_mask,
705
+ position_ids=None,
706
+ past_key_values=None,
707
+ inputs_embeds=multimodal_embeddings,
708
+ labels=multimodal_labels,
709
+ use_cache=use_cache,
710
+ output_attentions=output_attentions,
711
+ output_hidden_states=output_hidden_states,
712
+ return_dict=return_dict,
713
+ )
714
+
715
+ # === Otherwise =>> Assume Invalid! ===
716
+ elif (input_ids.shape[0] != pixel_values.shape[0]) or (inputs_embeds.shape[0] != pixel_values.shape[0]):
717
+ raise ValueError("Non-homogenous batch of (text, image) input -- forward() does not support mixed batches!")
718
+
719
+ else:
720
+ raise ValueError(
721
+ "Invalid PrismaticForConditionalGeneration `forward()` call with provided arguments:\n"
722
+ f"=> `input_ids` = {input_ids is not None}\n"
723
+ f"=> `attention_mask` = {attention_mask is not None}\n"
724
+ f"=> `pixel_values` = {pixel_values is not None}\n"
725
+ f"=> `labels` = {labels is not None}\n"
726
+ f"=> `input_embeds` = {inputs_embeds is not None}\n"
727
+ f"=> `past_key_values` = {past_key_values is not None}\n"
728
+ f"=> `use_cache` = {use_cache}"
729
+ )
730
+
731
+ # Unpack `language_model_output` and return PrismaticCausalLMOutputWithPast (or tuple if not `return_dict`)
732
+ if not return_dict:
733
+ if output_projector_features and (projected_patch_embeddings is not None):
734
+ return *language_model_output, projected_patch_embeddings
735
+
736
+ return language_model_output
737
+
738
+ if self.version == 'v1':
739
+ return PrismaticCausalLMOutputWithPast(
740
+ loss=language_model_output.loss,
741
+ past_key_values=language_model_output.past_key_values,
742
+ hidden_states=language_model_output.hidden_states,
743
+ attentions=language_model_output.attentions,
744
+ projector_features=projected_patch_embeddings,
745
+ )
746
+ else:
747
+ return PrismaticCausalLMOutputWithPast(
748
+ loss=language_model_output.loss,
749
+ logits=language_model_output.logits,
750
+ past_key_values=language_model_output.past_key_values,
751
+ hidden_states=language_model_output.hidden_states,
752
+ attentions=language_model_output.attentions,
753
+ projector_features=projected_patch_embeddings,
754
+ )
755
+
756
+ # === GenerationMixin Methods ===
757
+ def prepare_inputs_for_generation(
758
+ self,
759
+ input_ids: Optional[torch.Tensor] = None,
760
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
761
+ inputs_embeds: Optional[torch.FloatTensor] = None,
762
+ pixel_values: Optional[torch.FloatTensor] = None,
763
+ attention_mask: Optional[torch.Tensor] = None,
764
+ **kwargs: str,
765
+ ) -> Dict[str, torch.Tensor]:
766
+ """Borrowed from `LlamaForCausalLM` and simplified for batch size = 1; mirrors original PrismaticVLM logic."""
767
+ if ((input_ids is not None) and (input_ids.shape[0] > 1)) or (
768
+ (inputs_embeds is not None) and (inputs_embeds.shape[0] > 1)
769
+ ):
770
+ raise ValueError("Generation with batch size > 1 is not currently supported!")
771
+
772
+ # Handle `past_key_values` (cache) =>> assume `input_ids` just has unprocessed tokens
773
+ if past_key_values is not None:
774
+ input_ids = input_ids[:, -1:]
775
+
776
+ # If `input_embeds` are passed, we only want to use them in the 1st generation step
777
+ if inputs_embeds is not None and past_key_values is None:
778
+ model_inputs = {"input_embeds": inputs_embeds}
779
+ else:
780
+ model_inputs = {"input_ids": input_ids}
781
+
782
+ # Make sure `pixel_values` are preserved in `model_inputs`
783
+ model_inputs.update(
784
+ {
785
+ "attention_mask": attention_mask,
786
+ "pixel_values": pixel_values,
787
+ "past_key_values": past_key_values,
788
+ "use_cache": kwargs.get("use_cache"),
789
+ }
790
+ )
791
+
792
+ return model_inputs
793
+
794
+ # Defer to Language Model (all handle this differently, with different return types)
795
+ def _reorder_cache(self, *args, **kwargs) -> Any:
796
+ return self.language_model._reorder_cache(*args, **kwargs)
797
+
798
+
799
+ class OpenVLAForActionPrediction(PrismaticForConditionalGeneration):
800
+ config_class: PretrainedConfig = OpenVLAConfig
801
+
802
+ def __init__(self, config: OpenVLAConfig) -> None:
803
+ super().__init__(config)
804
+ self.norm_stats = config.norm_stats
805
+ # import pdb; pdb.set_trace()
806
+
807
+ # Compute action bins
808
+ self.bins = np.linspace(-1, 1, config.n_action_bins)
809
+ self.bin_centers = (self.bins[:-1] + self.bins[1:]) / 2.0
810
+
811
+ # Compute vocab size for de-tokenization -- revert added "multiple of"
812
+ self.vocab_size = self.config.text_config.vocab_size - self.config.pad_to_multiple_of
813
+
814
+ def _prepare_input_for_action_prediction(self, input_ids, attention_mask):
815
+ """Prepares input for action prediction by adding necessary tokens"""
816
+ # Add (ACTION_DIM * NUM_ACTIONS_CHUNK) placeholder tokens to input_ids to simulate action tokens
817
+ placeholder_action_token_ids = (
818
+ torch.ones((input_ids.shape[0], NUM_TOKENS)).to(input_ids.device).to(input_ids.dtype)
819
+ )
820
+ input_ids = torch.cat([input_ids, placeholder_action_token_ids], dim=-1)
821
+
822
+ # Add stop token to sequence (needed in non-causal bi-directional self-attention, as it appears at train time)
823
+ stop_token_id = torch.ones((input_ids.shape[0], 1)).to(input_ids.device).to(input_ids.dtype) * STOP_INDEX
824
+ input_ids = torch.cat([input_ids, stop_token_id], dim=-1)
825
+
826
+ # Extend the attention mask to fit the new shape of input
827
+ # Note: Only batch size == 1 supported right now
828
+ mask_extension = (
829
+ torch.ones((attention_mask.shape[0], input_ids.shape[-1] - attention_mask.shape[-1]))
830
+ .to(attention_mask.device)
831
+ .to(attention_mask.dtype)
832
+ )
833
+ attention_mask = torch.cat([attention_mask, mask_extension], dim=-1)
834
+
835
+ return input_ids, attention_mask
836
+
837
+ def _prepare_labels_for_action_prediction(self, labels, input_ids):
838
+ """Creates labels tensor for action prediction if not provided"""
839
+ # Extend labels tensor with fake action labels
840
+ ARBITRARY_ACTION_TOKEN_IDX = ACTION_TOKEN_BEGIN_IDX + 1
841
+ labels_extension = (
842
+ torch.ones((labels.shape[0], input_ids.shape[-1] - labels.shape[-1])).to(labels.device).to(labels.dtype)
843
+ * ARBITRARY_ACTION_TOKEN_IDX
844
+ )
845
+ labels = torch.cat([labels, labels_extension], dim=-1)
846
+
847
+ # Replace last label token with stop token
848
+ labels[:, -1] = STOP_INDEX
849
+
850
+ return labels
851
+
852
+ def _unnormalize_actions(self, normalized_actions, unnorm_key=None):
853
+ """Unnormalize actions using dataset statistics"""
854
+ action_norm_stats = self.get_action_stats(unnorm_key)
855
+
856
+ if ACTION_PROPRIO_NORMALIZATION_TYPE == NormalizationType.BOUNDS:
857
+ mask = action_norm_stats.get("mask", np.ones_like(action_norm_stats["min"], dtype=bool))
858
+ action_high, action_low = np.array(action_norm_stats["max"]), np.array(action_norm_stats["min"])
859
+ elif ACTION_PROPRIO_NORMALIZATION_TYPE == NormalizationType.BOUNDS_Q99:
860
+ mask = action_norm_stats.get("mask", np.ones_like(action_norm_stats["q01"], dtype=bool))
861
+ action_high, action_low = np.array(action_norm_stats["q99"]), np.array(action_norm_stats["q01"])
862
+ else:
863
+ raise ValueError("Unsupported action/proprio normalization type detected!")
864
+
865
+ actions = np.where(
866
+ mask,
867
+ 0.5 * (normalized_actions + 1) * (action_high - action_low + 1e-8) + action_low,
868
+ normalized_actions,
869
+ )
870
+
871
+ return actions
872
+
873
+ def _run_flow_matching_prediction(
874
+ self,
875
+ input_embeddings,
876
+ all_actions_mask,
877
+ noise,
878
+ action_head,
879
+ projected_patch_embeddings,
880
+ labels,
881
+ attention_mask,
882
+ NUM_PATCHES,
883
+ NUM_PROMPT_TOKENS,
884
+ noisy_action_projector
885
+ ):
886
+ """Run flow matching-based action prediction"""
887
+ # Clone embedding for reuse in each timestep
888
+ # orig_projected_patch_embeddings = projected_patch_embeddings.clone()
889
+
890
+ dt = -1.0 / action_head.num_flow_steps
891
+ dt = torch.tensor(dt, dtype=torch.bfloat16, device=labels.device)
892
+
893
+ curr_noisy_actions = noise
894
+ time = torch.tensor(1.0, dtype=torch.bfloat16, device=labels.device)
895
+ while time >= -dt / 2:
896
+ B = curr_noisy_actions.shape[0]
897
+ orig_curr_noisy_actions_shape = curr_noisy_actions.shape
898
+ curr_noisy_actions = curr_noisy_actions.reshape(B, -1).unsqueeze(-1)
899
+ noisy_action_features = noisy_action_projector(curr_noisy_actions)
900
+ curr_noisy_actions = curr_noisy_actions.reshape(orig_curr_noisy_actions_shape)
901
+
902
+ # Replace action token embeddings with noisy action embeddings
903
+ input_embeddings = self._replace_input_embeddings(
904
+ input_embeddings.clone(), all_actions_mask, noisy_action_features
905
+ )
906
+
907
+ # Build multimodal embeddings and attention mask
908
+ multimodal_embeddings, multimodal_attention_mask = self._build_multimodal_attention(
909
+ input_embeddings, projected_patch_embeddings, attention_mask
910
+ )
911
+
912
+ # Forward pass through language model
913
+ language_model_output = self.language_model(
914
+ input_ids=None,
915
+ attention_mask=multimodal_attention_mask,
916
+ position_ids=None,
917
+ past_key_values=None,
918
+ inputs_embeds=multimodal_embeddings,
919
+ labels=None,
920
+ use_cache=None,
921
+ output_attentions=False,
922
+ output_hidden_states=True,
923
+ return_dict=True,
924
+ )
925
+
926
+ # Extract hidden states for action portion of response
927
+ last_hidden_states = language_model_output.hidden_states[-1] # (B, seq_len, D)
928
+ actions_hidden_states = last_hidden_states[
929
+ :,
930
+ NUM_PATCHES + NUM_PROMPT_TOKENS : NUM_PATCHES + NUM_PROMPT_TOKENS + ACTION_DIM * NUM_ACTIONS_CHUNK,
931
+ :,
932
+ ] # (B, act_chunk_len, D)
933
+
934
+ # Predict noise and update noisy actions: x_t -> x_{t-1}
935
+ flow_pred = action_head.predict_flow(actions_hidden_states)
936
+ curr_noisy_actions += dt * flow_pred
937
+ time += dt
938
+ curr_noisy_actions = curr_noisy_actions.reshape(NUM_ACTIONS_CHUNK, ACTION_DIM)
939
+
940
+ # Return final actions
941
+ return curr_noisy_actions.float().cpu().detach().numpy(), actions_hidden_states
942
+
943
+
944
+ def _run_diffusion_prediction(
945
+ self,
946
+ input_embeddings,
947
+ all_actions_mask,
948
+ noise,
949
+ action_head,
950
+ projected_patch_embeddings,
951
+ labels,
952
+ attention_mask,
953
+ NUM_PATCHES,
954
+ NUM_PROMPT_TOKENS,
955
+ noisy_action_projector,
956
+ ):
957
+ """Run diffusion-based action prediction"""
958
+ # Set diffusion timestep values
959
+ action_head.noise_scheduler.set_timesteps(action_head.num_diffusion_steps)
960
+ # Clone embedding for reuse in each timestep
961
+ orig_projected_patch_embeddings = projected_patch_embeddings.clone()
962
+ curr_noisy_actions = noise
963
+
964
+ # Reverse diffusion: Iteratively denoise to generate action prediction
965
+ for t in action_head.noise_scheduler.timesteps:
966
+ # Get diffusion model's noise prediction (conditioned on VLA latent embedding, current noisy action
967
+ # embedding, and diffusion timestep embedding)
968
+ timesteps = torch.Tensor([t]).to(labels.device)
969
+ diffusion_timestep_embeddings = (
970
+ action_head.time_encoder(timesteps).to(curr_noisy_actions.dtype).to(curr_noisy_actions.device)
971
+ ) # (B, llm_dim)
972
+ diffusion_timestep_embeddings = diffusion_timestep_embeddings.unsqueeze(1) # (B, 1, llm_dim)
973
+
974
+ # [Diffusion] Replace the embeddings of the action tokens with noisy actions
975
+ # (Later on, the positional embeddings will be added to them)
976
+
977
+ # For simplicity, append diffusion timestep embedding to the end of projected vision tokens
978
+ projected_patch_embeddings = torch.cat(
979
+ (orig_projected_patch_embeddings, diffusion_timestep_embeddings), dim=1
980
+ )
981
+
982
+ # Reshape and project noisy actions into language embedding space
983
+ B = curr_noisy_actions.shape[0]
984
+ orig_curr_noisy_actions_shape = curr_noisy_actions.shape
985
+ curr_noisy_actions = curr_noisy_actions.reshape(B, -1).unsqueeze(-1)
986
+ noisy_action_features = noisy_action_projector(curr_noisy_actions)
987
+ curr_noisy_actions = curr_noisy_actions.reshape(orig_curr_noisy_actions_shape)
988
+
989
+ # Replace action token embeddings with noisy action embeddings
990
+ input_embeddings = self._replace_input_embeddings(
991
+ input_embeddings.clone(), all_actions_mask, noisy_action_features
992
+ )
993
+
994
+ # Build multimodal embeddings and attention mask
995
+ multimodal_embeddings, multimodal_attention_mask = self._build_multimodal_attention(
996
+ input_embeddings, projected_patch_embeddings, attention_mask
997
+ )
998
+
999
+ # Forward pass through language model
1000
+ language_model_output = self.language_model(
1001
+ input_ids=None,
1002
+ attention_mask=multimodal_attention_mask,
1003
+ position_ids=None,
1004
+ past_key_values=None,
1005
+ inputs_embeds=multimodal_embeddings,
1006
+ labels=None,
1007
+ use_cache=None,
1008
+ output_attentions=False,
1009
+ output_hidden_states=True,
1010
+ return_dict=True,
1011
+ )
1012
+
1013
+ # Extract hidden states for action portion of response
1014
+ last_hidden_states = language_model_output.hidden_states[-1] # (B, seq_len, D)
1015
+ actions_hidden_states = last_hidden_states[
1016
+ :,
1017
+ NUM_PATCHES + NUM_PROMPT_TOKENS : NUM_PATCHES + NUM_PROMPT_TOKENS + ACTION_DIM * NUM_ACTIONS_CHUNK,
1018
+ :,
1019
+ ] # (B, act_chunk_len, D)
1020
+
1021
+ # Predict noise and update noisy actions: x_t -> x_{t-1}
1022
+ noise_pred = action_head.predict_noise(actions_hidden_states)
1023
+ curr_noisy_actions = action_head.noise_scheduler.step(noise_pred, t, curr_noisy_actions).prev_sample
1024
+
1025
+ curr_noisy_actions = curr_noisy_actions.reshape(NUM_ACTIONS_CHUNK, ACTION_DIM)
1026
+
1027
+ # Return final actions
1028
+ return curr_noisy_actions.float().cpu().detach().numpy(), actions_hidden_states
1029
+
1030
+ def _run_diffusion_prediction_V1(
1031
+ self,
1032
+ input_embeddings,
1033
+ all_actions_mask,
1034
+ noise,
1035
+ action_head,
1036
+ projected_patch_embeddings,
1037
+ labels,
1038
+ attention_mask,
1039
+ NUM_PATCHES,
1040
+ NUM_PROMPT_TOKENS,
1041
+ noisy_action_projector,
1042
+ proprio,
1043
+ proprio_projector,
1044
+ ):
1045
+ """Run diffusion-based action prediction"""
1046
+ # Set diffusion timestep values
1047
+ action_head.noise_scheduler.set_timesteps(action_head.num_diffusion_steps)
1048
+ # Clone embedding for reuse in each timestep
1049
+ curr_noisy_actions = noise
1050
+
1051
+ # import pdb; pdb.set_trace()
1052
+
1053
+ action_queries = self.action_queries.weight # (1, h)
1054
+ action_queries = action_queries.view(1, action_queries.shape[0], action_queries.shape[1]).repeat(input_embeddings.shape[0], 1, 1) # (b, chunk_size, h)
1055
+ # Replace action token embeddings with noisy action embeddings
1056
+ input_embeddings = self._replace_input_embeddings(input_embeddings.clone(), all_actions_mask, action_queries)
1057
+ # input_embeddings = torch.cat((input_embeddings, action_queries), dim=1) # (b, n_tokens+chunk_size, h)
1058
+ # action_attention_mask = None
1059
+ # action_attention_mask = torch.full(
1060
+ # (action_queries.shape[0], action_queries.shape[1]),
1061
+ # fill_value=True,
1062
+ # dtype=attention_mask.dtype,
1063
+ # device=attention_mask.device,)
1064
+ # attention_mask = torch.cat([attention_mask, action_attention_mask], dim=1)
1065
+
1066
+ # Build multimodal embeddings and attention mask
1067
+ multimodal_embeddings, multimodal_attention_mask = self._build_multimodal_attention(
1068
+ input_embeddings, projected_patch_embeddings, attention_mask
1069
+ )
1070
+
1071
+ # import pdb; pdb.set_trace()
1072
+ # Forward pass through language model
1073
+ language_model_output = self.language_model(
1074
+ input_ids=None,
1075
+ attention_mask=multimodal_attention_mask,
1076
+ position_ids=None,
1077
+ past_key_values=None,
1078
+ inputs_embeds=multimodal_embeddings,
1079
+ labels=None,
1080
+ use_cache=None,
1081
+ output_attentions=False,
1082
+ output_hidden_states=True,
1083
+ return_dict=True,
1084
+ )
1085
+ multi_layer_hidden_states = []
1086
+ # import pdb; pdb.set_trace()
1087
+ for item in language_model_output.hidden_states[0:]:
1088
+ # last_hidden_states = output.hidden_states[-1] # (B, seq_len, D)
1089
+ # Get hidden states for text portion of prompt+response (after the vision patches)
1090
+ text_hidden_states = item
1091
+ # Get hidden states for action portion of response
1092
+ actions_hidden_states = text_hidden_states[:, NUM_PATCHES+ NUM_PROMPT_TOKENS : NUM_PATCHES + NUM_PROMPT_TOKENS + NUM_TOKENS, :,].reshape(1, 1, NUM_TOKENS, -1).to(torch.bfloat16)
1093
+ # import pdb; pdb.set_trace()
1094
+ batch_size = item.shape[0]
1095
+ task_latten_states = item[:, :NUM_PATCHES].reshape(batch_size, 1, NUM_PATCHES , -1)
1096
+ all_hidden_states = torch.cat((task_latten_states, actions_hidden_states),2)
1097
+ multi_layer_hidden_states.append(all_hidden_states)
1098
+ # import pdb; pdb.set_trace()
1099
+ multi_layer_hidden_states = torch.cat(multi_layer_hidden_states, dim = 1)
1100
+ # import pdb; pdb.set_trace()
1101
+
1102
+
1103
+
1104
+ # Reverse diffusion: Iteratively denoise to generate action prediction
1105
+ for t in action_head.noise_scheduler.timesteps:
1106
+ # Get diffusion model's noise prediction (conditioned on VLA latent embedding, current noisy action
1107
+ # embedding, and diffusion timestep embedding)
1108
+ timesteps = torch.Tensor([t]).to(labels.device)
1109
+ diffusion_timestep_embeddings = (
1110
+ action_head.time_encoder(timesteps).to(curr_noisy_actions.dtype).to(curr_noisy_actions.device)
1111
+ ) # (B, llm_dim)
1112
+ diffusion_timestep_embeddings = diffusion_timestep_embeddings.unsqueeze(1) # (B, 1, llm_dim)
1113
+
1114
+ # [Diffusion] Replace the embeddings of the action tokens with noisy actions
1115
+ # (Later on, the positional embeddings will be added to them)
1116
+
1117
+ # Reshape and project noisy actions into language embedding space
1118
+ B = curr_noisy_actions.shape[0]
1119
+ orig_curr_noisy_actions_shape = curr_noisy_actions.shape
1120
+ curr_noisy_actions = curr_noisy_actions.reshape(B, -1).unsqueeze(-1)
1121
+ curr_noisy_actions = curr_noisy_actions.reshape(orig_curr_noisy_actions_shape)
1122
+
1123
+ # Predict noise and update noisy actions: x_t -> x_{t-1}
1124
+ # noise_pred = action_head.predict_noise(actions_hidden_states)
1125
+ noise_pred = action_head.predict_noise(multi_layer_hidden_states,
1126
+ noisy_actions=curr_noisy_actions,
1127
+ timestep_embeddings = diffusion_timestep_embeddings,
1128
+ noisy_action_projector=noisy_action_projector,
1129
+ proprio=proprio ,
1130
+ proprio_projector=proprio_projector)
1131
+
1132
+ curr_noisy_actions = action_head.noise_scheduler.step(noise_pred, t, curr_noisy_actions).prev_sample
1133
+ curr_noisy_actions = curr_noisy_actions.reshape(NUM_ACTIONS_CHUNK, ACTION_DIM)
1134
+
1135
+ # Return final actions
1136
+ return curr_noisy_actions.float().cpu().detach().numpy(), actions_hidden_states
1137
+
1138
+ def _regression_or_discrete_prediction_V1(
1139
+ self,
1140
+ input_embeddings,
1141
+ all_actions_mask,
1142
+ projected_patch_embeddings,
1143
+ attention_mask,
1144
+ labels,
1145
+ NUM_PATCHES,
1146
+ NUM_PROMPT_TOKENS,
1147
+ action_head=None,
1148
+ proprio=None,
1149
+ proprio_projector=None,
1150
+ ):
1151
+ """Run L1 regression-based continuous action prediction or discrete action tokens prediction."""
1152
+
1153
+ action_queries = self.action_queries.weight # (1, h)
1154
+ action_queries = action_queries.view(1, action_queries.shape[0], action_queries.shape[1]).repeat(input_embeddings.shape[0], 1, 1) # (b, chunk_size, h)
1155
+ # Replace action token embeddings with noisy action embeddings
1156
+ input_embeddings = self._replace_input_embeddings(input_embeddings.clone(), all_actions_mask, action_queries)
1157
+
1158
+ # Build multimodal embeddings and attention mask
1159
+ multimodal_embeddings, multimodal_attention_mask = self._build_multimodal_attention(
1160
+ input_embeddings, projected_patch_embeddings, attention_mask
1161
+ )
1162
+
1163
+ # Forward pass through language model
1164
+ language_model_output = self.language_model(
1165
+ input_ids=None,
1166
+ attention_mask=multimodal_attention_mask,
1167
+ position_ids=None,
1168
+ past_key_values=None,
1169
+ inputs_embeds=multimodal_embeddings,
1170
+ labels=None,
1171
+ use_cache=None,
1172
+ output_attentions=False,
1173
+ output_hidden_states=True,
1174
+ return_dict=True,
1175
+ )
1176
+
1177
+ # Extract hidden states for action tokens
1178
+ multi_layer_hidden_states = []
1179
+ # import pdb; pdb.set_trace()
1180
+ for item in language_model_output.hidden_states[0:]:
1181
+ # last_hidden_states = output.hidden_states[-1] # (B, seq_len, D)
1182
+ # Get hidden states for text portion of prompt+response (after the vision patches)
1183
+ text_hidden_states = item
1184
+ # Get hidden states for action portion of response
1185
+ actions_hidden_states = text_hidden_states[:, NUM_PATCHES+ NUM_PROMPT_TOKENS : NUM_PATCHES + NUM_PROMPT_TOKENS + NUM_TOKENS, :,].reshape(1, 1, NUM_TOKENS, -1).to(torch.bfloat16)
1186
+ # import pdb; pdb.set_trace()
1187
+ batch_size = item.shape[0]
1188
+ task_latten_states = item[:, :NUM_PATCHES].reshape(batch_size, 1, NUM_PATCHES , -1)
1189
+ all_hidden_states = torch.cat((task_latten_states, actions_hidden_states),2)
1190
+ multi_layer_hidden_states.append(all_hidden_states)
1191
+ # import pdb; pdb.set_trace()
1192
+ multi_layer_hidden_states = torch.cat(multi_layer_hidden_states, dim = 1)
1193
+ # import pdb; pdb.set_trace()
1194
+
1195
+ # Handle different prediction methods
1196
+ if action_head is not None:
1197
+ # L1 regression prediction
1198
+ normalized_actions = action_head.predict_action(multi_layer_hidden_states,
1199
+ proprio=proprio,
1200
+ proprio_projector=proprio_projector)
1201
+ normalized_actions = normalized_actions.reshape(NUM_ACTIONS_CHUNK, ACTION_DIM)
1202
+ normalized_actions = normalized_actions.float().cpu().detach().numpy()
1203
+ else:
1204
+ # Discrete token-based prediction
1205
+ predicted_action_token_ids = (
1206
+ language_model_output.logits[
1207
+ :,
1208
+ NUM_PATCHES + NUM_PROMPT_TOKENS : NUM_PATCHES + NUM_PROMPT_TOKENS + ACTION_DIM * NUM_ACTIONS_CHUNK,
1209
+ ]
1210
+ .argmax(dim=2)
1211
+ .cpu()
1212
+ .numpy()
1213
+ )
1214
+ discretized_actions = self.vocab_size - predicted_action_token_ids
1215
+ discretized_actions = np.clip(discretized_actions - 1, a_min=0, a_max=self.bin_centers.shape[0] - 1)
1216
+ normalized_actions = self.bin_centers[discretized_actions]
1217
+ normalized_actions = normalized_actions.reshape(NUM_ACTIONS_CHUNK, ACTION_DIM)
1218
+
1219
+ return normalized_actions, actions_hidden_states
1220
+
1221
+ def _regression_or_discrete_prediction(
1222
+ self,
1223
+ input_embeddings,
1224
+ all_actions_mask,
1225
+ projected_patch_embeddings,
1226
+ attention_mask,
1227
+ labels,
1228
+ NUM_PATCHES,
1229
+ NUM_PROMPT_TOKENS,
1230
+ action_head=None,
1231
+ ):
1232
+ """Run L1 regression-based continuous action prediction or discrete action tokens prediction."""
1233
+ # Zero out action token embeddings
1234
+ all_actions_mask = all_actions_mask.unsqueeze(-1) # (B, seq_len, 1)
1235
+ input_embeddings = input_embeddings * ~all_actions_mask
1236
+
1237
+ # Build multimodal embeddings and attention mask
1238
+ multimodal_embeddings, multimodal_attention_mask = self._build_multimodal_attention(
1239
+ input_embeddings, projected_patch_embeddings, attention_mask
1240
+ )
1241
+
1242
+ # Forward pass through language model
1243
+ language_model_output = self.language_model(
1244
+ input_ids=None,
1245
+ attention_mask=multimodal_attention_mask,
1246
+ position_ids=None,
1247
+ past_key_values=None,
1248
+ inputs_embeds=multimodal_embeddings,
1249
+ labels=None,
1250
+ use_cache=None,
1251
+ output_attentions=False,
1252
+ output_hidden_states=True,
1253
+ return_dict=True,
1254
+ )
1255
+
1256
+ # Extract hidden states for action tokens
1257
+ last_hidden_states = language_model_output.hidden_states[-1] # (B, seq_len, D)
1258
+ actions_hidden_states = last_hidden_states[
1259
+ :,
1260
+ NUM_PATCHES + NUM_PROMPT_TOKENS : NUM_PATCHES + NUM_PROMPT_TOKENS + ACTION_DIM * NUM_ACTIONS_CHUNK,
1261
+ :,
1262
+ ] # (B, act_chunk_len, D)
1263
+
1264
+ # Handle different prediction methods
1265
+ if action_head is not None:
1266
+ # L1 regression prediction
1267
+ normalized_actions = action_head.predict_action(actions_hidden_states)
1268
+ normalized_actions = normalized_actions.reshape(NUM_ACTIONS_CHUNK, ACTION_DIM)
1269
+ normalized_actions = normalized_actions.float().cpu().detach().numpy()
1270
+ else:
1271
+ # Discrete token-based prediction
1272
+ predicted_action_token_ids = (
1273
+ language_model_output.logits[
1274
+ :,
1275
+ NUM_PATCHES + NUM_PROMPT_TOKENS : NUM_PATCHES + NUM_PROMPT_TOKENS + ACTION_DIM * NUM_ACTIONS_CHUNK,
1276
+ ]
1277
+ .argmax(dim=2)
1278
+ .cpu()
1279
+ .numpy()
1280
+ )
1281
+ discretized_actions = self.vocab_size - predicted_action_token_ids
1282
+ discretized_actions = np.clip(discretized_actions - 1, a_min=0, a_max=self.bin_centers.shape[0] - 1)
1283
+ normalized_actions = self.bin_centers[discretized_actions]
1284
+ normalized_actions = normalized_actions.reshape(NUM_ACTIONS_CHUNK, ACTION_DIM)
1285
+
1286
+ return normalized_actions, actions_hidden_states
1287
+
1288
+ def predict_action(
1289
+ self,
1290
+ input_ids: Optional[torch.LongTensor] = None,
1291
+ unnorm_key: Optional[str] = None,
1292
+ proprio=None,
1293
+ proprio_projector=None,
1294
+ action_head=None,
1295
+ noisy_action_projector=None,
1296
+ use_film: bool = False,
1297
+ **kwargs: str,
1298
+ ) -> np.ndarray:
1299
+ """Predict actions from input sequence, with options for different prediction methods.
1300
+
1301
+ Args:
1302
+ input_ids: Input token ids
1303
+ unnorm_key: Key for unnormalization statistics
1304
+ proprio: Proprioceptive features
1305
+ proprio_projector: Projector for proprioceptive features
1306
+ action_head: Optional head for L1 regression or diffusion-based prediction
1307
+ noisy_action_projector: Projector for noisy actions in diffusion-based prediction
1308
+ use_film: Whether to use FiLM conditioning
1309
+ **kwargs: Additional arguments including pixel_values and attention_mask
1310
+
1311
+ Returns:
1312
+ Tuple of (unnormalized_actions, action_hidden_states)
1313
+ """
1314
+ # import pdb; pdb.set_trace()
1315
+ # If the special empty token ('') does not already appear after the colon (':') token in the prompt
1316
+ # (after "OUT:" or "ASSISTANT:"), insert it to match the inputs seen at training time
1317
+
1318
+ # 如果是 minivla, 不用加这个判断!!!!!
1319
+ # if not torch.all(input_ids[:, -1] == 29871):
1320
+ # input_ids = torch.cat(
1321
+ # (input_ids, torch.unsqueeze(torch.Tensor([29871]).long(), dim=0).to(input_ids.device)), dim=1
1322
+ # )
1323
+
1324
+
1325
+ pixel_values = kwargs["pixel_values"] # [1, 12, 224, 224]
1326
+ attention_mask = kwargs["attention_mask"] #
1327
+
1328
+ # Create fake labels tensor (needed for action mask)
1329
+ labels = input_ids.clone()
1330
+ labels[:] = IGNORE_INDEX
1331
+
1332
+ # Get number of tokens in prompt (excluding the start token)
1333
+ NUM_PROMPT_TOKENS = input_ids.shape[-1] - 1 # Subtract action tokens and stop token
1334
+
1335
+ # import pdb; pdb.set_trace()
1336
+
1337
+ # Prepare inputs by adding necessary tokens
1338
+ input_ids, attention_mask = self._prepare_input_for_action_prediction(input_ids, attention_mask)
1339
+
1340
+ # Update labels tensor for action mask computation later
1341
+ labels = self._prepare_labels_for_action_prediction(labels, input_ids)
1342
+
1343
+ # Get input embeddings and action masks
1344
+ input_embeddings = self.get_input_embeddings()(input_ids)
1345
+ all_actions_mask = self._process_action_masks(labels)
1346
+
1347
+ # Extract language embeddings
1348
+ language_embeddings = input_embeddings[~all_actions_mask].reshape(
1349
+ input_embeddings.shape[0], -1, input_embeddings.shape[2]
1350
+ )
1351
+
1352
+ # Process vision features
1353
+ projected_patch_embeddings = self._process_vision_features(pixel_values, language_embeddings, use_film)
1354
+
1355
+ # Add proprioceptive features if provided
1356
+ use_proprio = proprio_projector is not None and proprio is not None
1357
+ if use_proprio:
1358
+ proprio = torch.Tensor(proprio).to(projected_patch_embeddings.device, dtype=projected_patch_embeddings.dtype)
1359
+ if self.version == 'v1':
1360
+ pass
1361
+ else:
1362
+ projected_patch_embeddings = self._process_proprio_features(
1363
+ projected_patch_embeddings, proprio, proprio_projector
1364
+ )
1365
+ # import pdb; pdb.set_trace()
1366
+ # Use diffusion if provided, otherwise use regression or discrete prediction
1367
+ use_diffusion = noisy_action_projector is not None and hasattr(action_head, "noise_scheduler")
1368
+ use_flow_matching = noisy_action_projector is not None and hasattr(action_head, "sample_actions")
1369
+
1370
+
1371
+ # Calculate number of patches (including proprio token and/or diffusion timestep embedding if present)
1372
+ NUM_PATCHES = self.vision_backbone.get_num_patches() * self.vision_backbone.get_num_images_in_input()
1373
+ if self.version == 'v1':
1374
+ # if use_diffusion:
1375
+ # NUM_PATCHES += 1
1376
+ pass
1377
+ else:
1378
+ if use_proprio:
1379
+ NUM_PATCHES += 1
1380
+ if use_diffusion:
1381
+ NUM_PATCHES += 1
1382
+
1383
+ # import pdb; pdb.set_trace()
1384
+ if use_flow_matching:
1385
+ # Sample random noise with shape equal to output action, used as the starting state for flow matching
1386
+ noise = action_head.sample_noise((1, NUM_ACTIONS_CHUNK, ACTION_DIM),device=input_embeddings.device, dtype=input_embeddings.dtype)
1387
+
1388
+ # Run flow matching-based prediction
1389
+ normalized_actions, actions_hidden_states = self._run_flow_matching_prediction(
1390
+ input_embeddings,
1391
+ all_actions_mask,
1392
+ noise,
1393
+ action_head,
1394
+ projected_patch_embeddings,
1395
+ labels,
1396
+ attention_mask,
1397
+ NUM_PATCHES,
1398
+ NUM_PROMPT_TOKENS,
1399
+ noisy_action_projector
1400
+ )
1401
+ elif use_diffusion:
1402
+ # Sample random noise with shape equal to output action, used as the starting state for reverse diffusion
1403
+ noise = torch.randn(
1404
+ size=(1, NUM_ACTIONS_CHUNK, ACTION_DIM), device=input_embeddings.device, dtype=input_embeddings.dtype
1405
+ )
1406
+ # import pdb; pdb.set_trace()
1407
+ if self.version == 'v1':
1408
+
1409
+ # import pdb; pdb.set_trace()
1410
+ # Run diffusion-based prediction
1411
+ normalized_actions, actions_hidden_states = self._run_diffusion_prediction_V1(
1412
+ input_embeddings, # [1, 86, 4096]
1413
+ all_actions_mask, # [1, 86]
1414
+ noise, # [1,8, 7]
1415
+ action_head,
1416
+ projected_patch_embeddings, # [1, 512, 4096]
1417
+ labels, # [1, 86]
1418
+ attention_mask, # [1, 86]
1419
+ NUM_PATCHES, # 512
1420
+ NUM_PROMPT_TOKENS, # 28
1421
+ noisy_action_projector,
1422
+ proprio, # [8]
1423
+ proprio_projector,
1424
+ )
1425
+ else:
1426
+ # Run diffusion-based prediction
1427
+ normalized_actions, actions_hidden_states = self._run_diffusion_prediction(
1428
+ input_embeddings,
1429
+ all_actions_mask,
1430
+ noise,
1431
+ action_head,
1432
+ projected_patch_embeddings,
1433
+ labels,
1434
+ attention_mask,
1435
+ NUM_PATCHES,
1436
+ NUM_PROMPT_TOKENS,
1437
+ noisy_action_projector,
1438
+ )
1439
+
1440
+ else:
1441
+ if self.version == 'v1':
1442
+ # Run regression or discrete token-based prediction
1443
+ normalized_actions, actions_hidden_states = self._regression_or_discrete_prediction_V1(
1444
+ input_embeddings,
1445
+ all_actions_mask,
1446
+ projected_patch_embeddings,
1447
+ attention_mask,
1448
+ labels,
1449
+ NUM_PATCHES,
1450
+ NUM_PROMPT_TOKENS,
1451
+ action_head=action_head,
1452
+ proprio=proprio, # [8]
1453
+ proprio_projector=proprio_projector,
1454
+ )
1455
+ else:
1456
+ # Run regression or discrete token-based prediction
1457
+ normalized_actions, actions_hidden_states = self._regression_or_discrete_prediction(
1458
+ input_embeddings,
1459
+ all_actions_mask,
1460
+ projected_patch_embeddings,
1461
+ attention_mask,
1462
+ labels,
1463
+ NUM_PATCHES,
1464
+ NUM_PROMPT_TOKENS,
1465
+ action_head,
1466
+ )
1467
+
1468
+ # import pdb; pdb.set_trace()
1469
+ # Unnormalize predicted actions
1470
+ actions = self._unnormalize_actions(normalized_actions, unnorm_key)
1471
+
1472
+ return actions, actions_hidden_states
1473
+
1474
+ @staticmethod
1475
+ def _check_unnorm_key(norm_stats: Dict[str, Dict[str, Any]], unnorm_key: Optional[str]) -> str:
1476
+ """Validate and resolve the unnormalization key for action statistics"""
1477
+ if unnorm_key is None:
1478
+ assert len(norm_stats) == 1, (
1479
+ f"Your model was trained on more than one dataset, "
1480
+ f"please pass a `unnorm_key` from the following options to choose the statistics "
1481
+ f"used for un-normalizing actions: {norm_stats.keys()}"
1482
+ )
1483
+ unnorm_key = next(iter(norm_stats.keys()))
1484
+
1485
+ assert unnorm_key in norm_stats, (
1486
+ f"The `unnorm_key` you chose is not in the set of available dataset statistics, "
1487
+ f"please choose from: {norm_stats.keys()}"
1488
+ )
1489
+ return unnorm_key
1490
+
1491
+ def get_action_dim(self, unnorm_key: Optional[str] = None) -> int:
1492
+ """Get the dimensionality of the policy's action space."""
1493
+ unnorm_key = self._check_unnorm_key(self.norm_stats, unnorm_key)
1494
+ return len(self.norm_stats[unnorm_key]["action"]["min"])
1495
+
1496
+ def get_action_stats(self, unnorm_key: Optional[str] = None) -> Dict[str, Any]:
1497
+ """Get all the logged statistics for the given dataset."""
1498
+ unnorm_key = self._check_unnorm_key(self.norm_stats, unnorm_key)
1499
+ return self.norm_stats[unnorm_key]["action"]
processing_prismatic.py ADDED
@@ -0,0 +1,257 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ processing_prismatic.py
3
+
4
+ HuggingFace-style preprocessor definitions for Prismatic VLMs, inheriting from `ProcessorMixin`. Default configuration
5
+ specifies `siglip-224px+7b`.
6
+ """
7
+
8
+ from typing import Any, ClassVar, List, Optional, Tuple, Union
9
+
10
+ import timm.data
11
+ import torch
12
+ import torchvision.transforms.functional as TVF
13
+ from PIL import Image
14
+ from torchvision.transforms import CenterCrop, Compose, Normalize, Resize, ToTensor
15
+ from transformers import PreTrainedTokenizerBase
16
+ from transformers.image_processing_utils import BatchFeature, ImageProcessingMixin
17
+ from transformers.processing_utils import ProcessorMixin
18
+ from transformers.tokenization_utils import PaddingStrategy, PreTokenizedInput, TextInput, TruncationStrategy
19
+ from transformers.utils import TensorType
20
+
21
+
22
+ # === Image Processing ===
23
+ def letterbox_pad_transform(image: Image.Image, padding_fill_value: Tuple[int, int, int]) -> Image.Image:
24
+ """Given a PIL.Image, pad to square by adding a symmetric border around the height/width."""
25
+ (w, h), max_wh = image.size, max(image.size)
26
+ horizontal_pad, vertical_pad = int((max_wh - w) / 2), int((max_wh - h) / 2)
27
+ padding = (horizontal_pad, vertical_pad, horizontal_pad, vertical_pad)
28
+
29
+ return TVF.pad(image, padding, fill=padding_fill_value, padding_mode="constant")
30
+
31
+
32
+ class PrismaticImageProcessor(ImageProcessingMixin):
33
+ model_input_names: ClassVar[List[str]] = ["pixel_values"]
34
+
35
+ def __init__(
36
+ self,
37
+ use_fused_vision_backbone: bool = False,
38
+ image_resize_strategy: str = "letterbox",
39
+ input_sizes: Optional[List[Tuple[int, int, int]]] = None,
40
+ interpolations: Optional[List[str]] = None,
41
+ means: Optional[List[Tuple[float, float, float]]] = None,
42
+ stds: Optional[List[Tuple[float, float, float]]] = None,
43
+ **kwargs: str,
44
+ ) -> None:
45
+ """
46
+ Initialize a PrismaticImageProcessor as a wrapper around a torchvision transform; this transform will be
47
+ created by TIMM, and edited to follow our custom `image_resize_strategy` logic.
48
+
49
+ @param use_fused_vision_backbone: Boolean indicating single or fused (dual) vision backbone
50
+ @param image_resize_strategy: Prismatic image resize strategy in < resize-naive | resize-crop | letterbox >
51
+ @param input_size: [TIMM :: `data_cfg`] Input image size as tuple (channels, width, height)
52
+ @param interpolation: [TIMM :: `data_cfg`] Interpolation as string (default: "bicubic")
53
+ @param mean: [TIMM :: `data_cfg`] Normalization mean as float tuple (or two-tuple if `fused_backbone`)
54
+ @param std: [TIMM :: `data_cfg`] Normalization std as float tuple (or two-tuple if `fused_backbone`)
55
+ """
56
+ self.use_fused_vision_backbone = use_fused_vision_backbone
57
+ self.image_resize_strategy = image_resize_strategy
58
+
59
+ # Handle `None` default values
60
+ input_sizes = [(3, 224, 224)] if input_sizes is None else input_sizes
61
+ means = [(0.5, 0.5, 0.5)] if means is None else means
62
+ stds = [(0.5, 0.5, 0.5)] if stds is None else stds
63
+
64
+ # TIMM `data_cfg` Parameters
65
+ self.input_sizes, self.interpolations, self.means, self.stds = input_sizes, interpolations, means, stds
66
+
67
+ # Grab torchvision transforms via TIMM =>> need to parse for specific "functional" transform values!
68
+ self.tvf_resize_params, self.tvf_crop_params, self.tvf_normalize_params = [], [], []
69
+ self.tvf_do_letterbox, self.tvf_letterbox_fill = False, None
70
+
71
+ for idx in range(len(input_sizes)):
72
+ transform = timm.data.create_transform(
73
+ input_size=self.input_sizes[idx],
74
+ interpolation=self.interpolations[idx],
75
+ mean=self.means[idx],
76
+ std=self.stds[idx],
77
+ crop_pct=1.0, # Set to 1.0 to ignore cropping (initial Resize sets `input_size`)
78
+ crop_mode="center", # Default crop mode -- no-op when `crop_pct == 1.0`
79
+ is_training=False, # No image augmentations when loading the transform!
80
+ )
81
+
82
+ # [Validation] Ensure appropriate transform structure, expected sizes
83
+ if not (
84
+ isinstance(transform, Compose)
85
+ and (len(transform.transforms) == 4)
86
+ and isinstance(transform.transforms[0], Resize)
87
+ and isinstance(transform.transforms[1], CenterCrop)
88
+ and isinstance(transform.transforms[2], ToTensor)
89
+ and isinstance(transform.transforms[3], Normalize)
90
+ and (transform.transforms[0].size == self.input_sizes[idx][-1])
91
+ and (transform.transforms[1].size == self.input_sizes[idx][-2:])
92
+ ):
93
+ raise ValueError(f"Unexpected TIMM image transformation structure/sizes: `{transform}`")
94
+
95
+ # HF Image Processors *must* be JSON-serializable; as such, cannot have torchvision. as an attribute.
96
+ # => Instead, we're going to parse the transform and call "torchvision.transforms.functional" (`tvf`)
97
+ resize_t, crop_t, norm_t = transform.transforms[0], transform.transforms[1], transform.transforms[3]
98
+ self.tvf_resize_params.append(
99
+ {
100
+ "size": resize_t.size,
101
+ "interpolation": TVF.pil_modes_mapping[resize_t.interpolation],
102
+ "max_size": None,
103
+ "antialias": True,
104
+ }
105
+ )
106
+ self.tvf_crop_params.append({"output_size": crop_t.size})
107
+ self.tvf_normalize_params.append(
108
+ {
109
+ "mean": norm_t.mean.float().numpy().tolist(),
110
+ "std": norm_t.std.float().numpy().tolist(),
111
+ "inplace": False,
112
+ }
113
+ )
114
+ self.tvf_do_letterbox, self.tvf_letterbox_fill = False, None
115
+
116
+ # Handle Prismatic `image_resize_strategy`
117
+ if self.image_resize_strategy == "resize-naive":
118
+ self.tvf_resize_params[idx]["size"] = (resize_t.size, resize_t.size)
119
+ elif self.image_resize_strategy == "letterbox":
120
+ self.tvf_do_letterbox, self.tvf_letterbox_fill = True, tuple([int(x * 255) for x in self.means[idx]])
121
+ elif self.image_resize_strategy == "resize-crop":
122
+ pass
123
+ else:
124
+ raise ValueError(f"Image resize strategy `{self.image_resize_strategy}` is not supported!")
125
+
126
+ # Dispatch **kwargs to super()
127
+ super().__init__(**kwargs)
128
+
129
+ def apply_transform(self, img: Image.Image) -> torch.Tensor:
130
+ """Apply `functional` variant of TIMM's Transform = Compose([Resize -> CenterCrop -> ToTensor -> Normalize])"""
131
+ if self.tvf_do_letterbox:
132
+ img = letterbox_pad_transform(img, self.tvf_letterbox_fill)
133
+
134
+ # [Contract] Fused Backbones expect "channel-stacked" inputs; we'll unpack on the model side!
135
+ imgs_t = []
136
+ for idx in range(len(self.input_sizes)):
137
+ img_idx = TVF.resize(img, **self.tvf_resize_params[idx])
138
+ img_idx = TVF.center_crop(img_idx, **self.tvf_crop_params[idx])
139
+ img_idx_t = TVF.to_tensor(img_idx)
140
+ img_idx_t = TVF.normalize(img_idx_t, **self.tvf_normalize_params[idx])
141
+ imgs_t.append(img_idx_t)
142
+
143
+ # [Contract] `imgs_t` is a list of Tensors of shape [3, input_size, input_size]; stack along dim = 0
144
+ img_t = torch.vstack(imgs_t)
145
+
146
+ return img_t
147
+
148
+ def preprocess(
149
+ self,
150
+ images: Union[Image.Image, List[Image.Image]],
151
+ return_tensors: Optional[Union[str, TensorType]] = None,
152
+ **_: str,
153
+ ) -> BatchFeature:
154
+ """
155
+ Preprocess an image (or batch of images); note that unlike the `transformers :: BaseImageProcessor` we
156
+ explicitly only handle PIL.Image.Image instances for simplicity.
157
+
158
+ @param images: A (batch of) PIL.Image.Image instance(s) to preprocess.
159
+ @param return_tensors: BatchFeature default Tensor format (e.g., "pt" for torch); if None, returns np.ndarray
160
+
161
+ @return: Instance of `transformers :: BatchFeature` with a single key "pixel_values"
162
+ """
163
+ if not isinstance(images, list):
164
+ images = [images]
165
+
166
+ # Apply `self.img_transform` to each image (will return list of torch.Tensors); stack into "batched" Tensor
167
+ pixel_values = torch.stack([self.apply_transform(img.convert("RGB")) for img in images])
168
+
169
+ # Return BatchFeature =>> note that for compatibility, constructor expects Dict[str, np.ndarray], so we convert
170
+ return BatchFeature(data={"pixel_values": pixel_values.float().numpy()}, tensor_type=return_tensors)
171
+
172
+ def __call__(self, images: Union[Image.Image, List[Image.Image]], **kwargs) -> BatchFeature:
173
+ return self.preprocess(images, **kwargs)
174
+
175
+
176
+ # === PrismaticProcessor =>> Wraps both ImageProcessor and Tokenizer ===
177
+ # =>> https://github.com/huggingface/transformers/blob/main/src/transformers/models/llava/processing_llava.py
178
+ class PrismaticProcessor(ProcessorMixin):
179
+ attributes: ClassVar[List[str]] = ["image_processor", "tokenizer"]
180
+ image_processor_class: str = "AutoImageProcessor"
181
+ tokenizer_class: str = "AutoTokenizer"
182
+
183
+ def __init__(
184
+ self,
185
+ image_processor: Optional[ImageProcessingMixin] = None,
186
+ tokenizer: Optional[PreTrainedTokenizerBase] = None,
187
+ ) -> None:
188
+ super().__init__(image_processor, tokenizer)
189
+
190
+ def __call__(
191
+ self,
192
+ text: Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]],
193
+ images: Union[Image.Image, List[Image.Image]],
194
+ padding: Union[bool, str, PaddingStrategy] = False,
195
+ truncation: Optional[Union[bool, str, TruncationStrategy]] = None,
196
+ max_length: Optional[int] = None,
197
+ return_tensors: Optional[Union[str, TensorType]] = TensorType.PYTORCH,
198
+ ) -> BatchFeature:
199
+ """
200
+ Preprocess a given (batch) of text/images for a Prismatic VLM; forwards text to the underlying LLM's tokenizer,
201
+ forwards images to PrismaticImageProcessor.
202
+
203
+ @param text: The (batch) of text to encode; must be a string or list of strings.
204
+ @param images: A (batch of) PIL.Image.Image instance(s) to preprocess.
205
+ @param padding: Sequence padding strategy (if multiple specified) in < True = "longest" | "max_length" | False >
206
+ @param truncation: Truncation strategy for the output sequences; requires `max_length` to be specified
207
+ @param max_length: Maximum length (in tokens) to truncate
208
+ @param return_tensors: Type of return tensors (usually "pt" or TensorType.PYTORCH)
209
+
210
+ @return: BatchFeature with keys for `input_ids`, `attention_mask` and `pixel_values`.
211
+ """
212
+ pixel_values = self.image_processor(images, return_tensors=return_tensors)["pixel_values"]
213
+ text_inputs = self.tokenizer(
214
+ text, return_tensors=return_tensors, padding=padding, truncation=truncation, max_length=max_length
215
+ )
216
+
217
+ # [Validate] Need same number of images and text inputs!
218
+ if pixel_values.shape[0] != text_inputs.input_ids.shape[0]:
219
+ raise ValueError("Batch is malformed; expected same number of images and text inputs!")
220
+
221
+ return BatchFeature(data={**text_inputs, "pixel_values": pixel_values})
222
+
223
+ # === Tokenizer Dispatch Utilities =>> check `PreTrainedTokenizerBase` for documentation ===
224
+ def batch_decode(
225
+ self,
226
+ sequences: Union[List[int], List[List[int]], torch.Tensor, Any], # `Any` = np.ndarray | tf.Tensor
227
+ skip_special_tokens: bool = False,
228
+ clean_up_tokenization_spaces: Optional[bool] = None,
229
+ **kwargs: str,
230
+ ) -> List[str]:
231
+ return self.tokenizer.batch_decode(
232
+ sequences=sequences,
233
+ skip_special_tokens=skip_special_tokens,
234
+ clean_up_tokenization_spaces=clean_up_tokenization_spaces,
235
+ **kwargs,
236
+ )
237
+
238
+ def decode(
239
+ self,
240
+ token_ids: Union[int, List[int], torch.Tensor, Any], # `Any` = np.ndarray | tf.Tensor
241
+ skip_special_tokens: bool = False,
242
+ clean_up_tokenization_spaces: Optional[bool] = None,
243
+ **kwargs: str,
244
+ ) -> str:
245
+ return self.tokenizer.decode(
246
+ token_ids=token_ids,
247
+ skip_special_tokens=skip_special_tokens,
248
+ clean_up_tokenization_spaces=clean_up_tokenization_spaces,
249
+ **kwargs,
250
+ )
251
+
252
+ @property
253
+ def model_input_names(self) -> List[str]:
254
+ tokenizer_input_names = self.tokenizer.model_input_names
255
+ image_processor_input_names = self.image_processor.model_input_names
256
+
257
+ return list(dict.fromkeys(tokenizer_input_names + image_processor_input_names))
proprio_projector--10000_checkpoint.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:8cc720d02dcd9f72446a597dba5fc883c9b9e8390ec95f114e626d656640aa3e
3
+ size 1626096
proprio_projector--checkpoint.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:17525f46b6df034b12c9fae2619442ff8520ee7f19c240c5e62643a4a3d0b793
3
+ size 1626096