| """ |
| modeling_prismatic.py |
| |
| Core HuggingFace-style PrismaticPreTrainedModel and PrismaticForConditionalGeneration class definitions. |
| Inherits from the default `transformers.PretrainedModel`. Meant to be standalone and self-contained, |
| but exactly replicate the logic in `prismatic.models.vlms.prismatic.py`. |
| """ |
|
|
| import logging |
| import os |
| from dataclasses import dataclass |
| from functools import partial |
| from typing import Any, Callable, ClassVar, Dict, List, Optional, Tuple, Union |
|
|
| import numpy as np |
| import timm |
| import tokenizers |
| import torch |
| import torch.nn as nn |
| import transformers |
| from timm.models.vision_transformer import LayerScale |
| from transformers import AutoModelForCausalLM, PretrainedConfig, PreTrainedModel |
| from transformers.modeling_outputs import ModelOutput |
|
|
| MODEL_LOGIC_REV = "2026-03-03-maskgit-12step-remask-gripper-audit" |
|
|
| from prismatic.discrete_flow import ( |
| dfm_decode, |
| kappa, |
| kappa_dot, |
| mask_schedule, |
| parallel_decode, |
| ) |
| from prismatic.training.train_utils import ( |
| get_current_action_mask, |
| get_next_actions_mask, |
| ) |
| from prismatic.vla.constants import ( |
| ACTION_DIM, |
| ACTION_PROPRIO_NORMALIZATION_TYPE, |
| ACTION_TOKEN_BEGIN_IDX, |
| IGNORE_INDEX, |
| NUM_ACTIONS_CHUNK, |
| STOP_INDEX, |
| NormalizationType, |
| ) |
|
|
| from .configuration_prismatic import OpenVLAConfig, PrismaticConfig |
|
|
| |
| logger = logging.getLogger(__name__) |
|
|
|
|
| |
| def unpack_tuple(fn: Callable[[Any], Tuple[Any]]) -> Callable[[Any], Any]: |
| def wrapper(*args: Any, **kwargs: Any) -> Any: |
| result = fn(*args, **kwargs) |
| return result[0] if isinstance(result, tuple) else result |
|
|
| return wrapper |
|
|
|
|
| |
| |
| |
| def _ls_new_forward(self, x: torch.Tensor) -> torch.Tensor: |
| return x.mul_(self.scale_factor) if self.inplace else x * self.scale_factor |
|
|
|
|
| def ls_apply_patch(ls_module: LayerScale): |
| ls_module.scale_factor = nn.Parameter(ls_module.gamma.clone()) |
| ls_module.forward = _ls_new_forward.__get__(ls_module, LayerScale) |
| del ls_module.gamma |
|
|
|
|
| |
| class PrismaticVisionBackbone(nn.Module): |
| """ |
| Vision backbone for Prismatic models that handles image feature extraction. |
| |
| Supports both single backbone (e.g., SigLIP) and fused backbone (e.g., SigLIP + DINOv2) configurations. |
| For fused backbones, features from both models are concatenated along the feature dimension. |
| """ |
|
|
| def __init__( |
| self, |
| use_fused_vision_backbone: bool, |
| image_sizes: List[int], |
| timm_model_ids: List[str], |
| timm_override_act_layers: List[Optional[str]], |
| ) -> None: |
| """ |
| Initialize the vision backbone. |
| |
| Args: |
| use_fused_vision_backbone: Whether to use two backbones and fuse their features |
| image_sizes: List of image sizes for each backbone |
| timm_model_ids: List of TIMM model IDs to use for each backbone |
| timm_override_act_layers: List of activation layer overrides for each backbone |
| """ |
| super().__init__() |
| self.use_fused_vision_backbone = use_fused_vision_backbone |
| self.num_images_in_input = 1 |
|
|
| |
| if len(timm_model_ids) > 2: |
| raise ValueError("Prismatic models only support up to 2 (fused) vision backbones!") |
|
|
| |
| self.featurizer = self._create_featurizer( |
| model_id=timm_model_ids[0], img_size=image_sizes[0], act_layer=timm_override_act_layers[0] |
| ) |
| self.embed_dim = self.featurizer.embed_dim |
|
|
| |
| if self.use_fused_vision_backbone: |
| self.fused_featurizer = self._create_featurizer( |
| model_id=timm_model_ids[1], img_size=image_sizes[1], act_layer=timm_override_act_layers[1] |
| ) |
| self.embed_dim += self.fused_featurizer.embed_dim |
|
|
| |
| self._patch_layer_scales() |
|
|
| def _create_featurizer(self, model_id: str, img_size: int, act_layer: Optional[str]) -> nn.Module: |
| """ |
| Create a TIMM-based featurizer model with appropriate configurations. |
| |
| Args: |
| model_id: The TIMM model ID to load |
| img_size: Input image size for the model |
| act_layer: Override for the activation layer type |
| |
| Returns: |
| A configured featurizer model |
| """ |
| featurizer = timm.create_model( |
| model_id, |
| pretrained=False, |
| num_classes=0, |
| img_size=img_size, |
| act_layer=act_layer, |
| ) |
|
|
| |
| num_blocks = len(featurizer.blocks) |
| featurizer.forward = unpack_tuple(partial(featurizer.get_intermediate_layers, n={num_blocks - 2})) |
|
|
| return featurizer |
|
|
| def _patch_layer_scales(self) -> None: |
| """ |
| Patch all LayerScale modules to be compatible with HF's parameter naming. |
| |
| HF Transformers overwrites parameters with names containing 'gamma', |
| so we need to rename and modify the forward method. |
| """ |
| |
| for module in self.featurizer.modules(): |
| if isinstance(module, LayerScale): |
| ls_apply_patch(module) |
|
|
| |
| if self.use_fused_vision_backbone: |
| for module in self.fused_featurizer.modules(): |
| if isinstance(module, LayerScale): |
| ls_apply_patch(module) |
|
|
| def get_num_patches(self) -> int: |
| """ |
| Returns the number of vision patches output by the vision backbone. |
| |
| Returns: |
| Number of patches per image |
| """ |
| return self.featurizer.patch_embed.num_patches |
|
|
| def get_num_images_in_input(self) -> int: |
| """ |
| Returns the number of input images for the vision backbone. |
| |
| Returns: |
| Number of images expected in the input |
| """ |
| return self.num_images_in_input |
|
|
| def set_num_images_in_input(self, num_images_in_input: int) -> None: |
| """ |
| Sets the number of input images for the vision backbone. |
| |
| Args: |
| num_images_in_input: Number of images to expect in the input |
| """ |
| self.num_images_in_input = num_images_in_input |
|
|
| def forward(self, pixel_values: torch.Tensor) -> torch.Tensor: |
| """ |
| Implements the forward pass for the vision backbone. |
| |
| If `self.use_fused_vision_backbone == True`, uses both SigLIP and DINOv2 transformers to extract visual features |
| (otherwise uses SigLIP only). Allows multi-image inputs (but only for fused vision backbone). |
| |
| Args: |
| pixel_values (torch.Tensor): Pixels for input image(s), (B, C, H, W). |
| """ |
| if self.num_images_in_input == 1: |
| if not self.use_fused_vision_backbone: |
| return self.featurizer(pixel_values) |
|
|
| |
| img, img_fused = torch.split(pixel_values, [3, 3], dim=1) |
| patches, patches_fused = self.featurizer(img), self.fused_featurizer(img_fused) |
|
|
| return torch.cat([patches, patches_fused], dim=2) |
|
|
| else: |
| assert self.use_fused_vision_backbone, "Multi-image inputs require using fused backbone!" |
|
|
| |
| images = torch.split(pixel_values, [6] * self.num_images_in_input, dim=1) |
|
|
| |
| all_patches = [] |
| for img in images: |
| |
| img_regular, img_fused = torch.split(img, [3, 3], dim=1) |
|
|
| |
| patches = self.featurizer(img_regular) |
| patches_fused = self.fused_featurizer(img_fused) |
|
|
| |
| combined_patches = torch.cat([patches, patches_fused], dim=2) |
| all_patches.append(combined_patches) |
|
|
| |
| return torch.cat(all_patches, dim=1) |
|
|
|
|
| |
| class PrismaticProjector(nn.Module): |
| def __init__(self, use_fused_vision_backbone: bool, vision_dim: int, llm_dim: int) -> None: |
| super().__init__() |
| self.use_fused_vision_backbone = use_fused_vision_backbone |
| self.vision_dim, self.llm_dim = vision_dim, llm_dim |
|
|
| |
| if not self.use_fused_vision_backbone: |
| self.fc1 = nn.Linear(self.vision_dim, self.llm_dim, bias=True) |
| self.fc2 = nn.Linear(self.llm_dim, self.llm_dim, bias=True) |
| self.act_fn1 = nn.GELU() |
| else: |
| initial_projection_dim = 4 * vision_dim |
| self.fc1 = nn.Linear(self.vision_dim, initial_projection_dim, bias=True) |
| self.fc2 = nn.Linear(initial_projection_dim, self.llm_dim, bias=True) |
| self.fc3 = nn.Linear(self.llm_dim, self.llm_dim, bias=True) |
| self.act_fn1 = nn.GELU() |
| self.act_fn2 = nn.GELU() |
|
|
| def forward(self, img_patches: torch.Tensor) -> torch.Tensor: |
| if not self.use_fused_vision_backbone: |
| projected_features = self.fc1(img_patches) |
| projected_features = self.act_fn1(projected_features) |
| projected_features = self.fc2(projected_features) |
| else: |
| projected_features = self.fc1(img_patches) |
| projected_features = self.act_fn1(projected_features) |
| projected_features = self.fc2(projected_features) |
| projected_features = self.act_fn2(projected_features) |
| projected_features = self.fc3(projected_features) |
|
|
| return projected_features |
|
|
|
|
| |
| @dataclass |
| class PrismaticCausalLMOutputWithPast(ModelOutput): |
| """Base class for Prismatic casual (visually-conditioned) language model outputs; also exposes visual features.""" |
|
|
| loss: Optional[torch.FloatTensor] = None |
| logits: torch.FloatTensor = None |
| past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None |
| hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None |
| attentions: Optional[Tuple[torch.FloatTensor]] = None |
|
|
| |
| projector_features: Optional[torch.FloatTensor] = None |
|
|
| |
| labels: Optional[torch.LongTensor] = None |
|
|
| |
| dfm_stats: Optional[Dict[str, torch.FloatTensor]] = None |
| dfm_trace: Optional[Dict[str, torch.FloatTensor]] = None |
|
|
|
|
| class PrismaticPreTrainedModel(PreTrainedModel): |
| config_class: PretrainedConfig = PrismaticConfig |
| base_model_prefix: str = "model" |
| supports_gradient_checkpointing: bool = True |
|
|
| _no_split_modules: ClassVar[List[str]] = ["PrismaticProjector"] |
| _skip_keys_device_placement: str = "past_key_values" |
| _supports_flash_attn_2: bool = True |
|
|
| def _init_weights(self, module: nn.Module) -> None: |
| |
| |
| |
| std = ( |
| self.config.initializer_range |
| if hasattr(self.config, "initializer_range") |
| else self.config.text_config.initializer_range |
| ) |
|
|
| if hasattr(module, "class_embedding"): |
| module.class_embedding.data.normal_(mean=0.0, std=std) |
|
|
| if isinstance(module, (nn.Linear, nn.Conv2d)): |
| module.weight.data.normal_(mean=0.0, std=std) |
| if module.bias is not None: |
| module.bias.data.zero_() |
| elif isinstance(module, nn.Embedding): |
| module.weight.data.normal_(mean=0.0, std=std) |
| if module.padding_idx is not None: |
| module.weight.data[module.padding_idx].zero_() |
|
|
| @property |
| def _supports_sdpa(self) -> bool: |
| """Check LLM supports SDPA Attention""" |
| return self.language_model._supports_sdpa |
|
|
|
|
| class PrismaticForConditionalGeneration(PrismaticPreTrainedModel): |
| def __init__(self, config: PrismaticConfig) -> None: |
| super().__init__(config) |
|
|
| |
| if config.use_fused_vision_backbone is None: |
| raise ValueError("Missing config field `use_fused_vision_backbone`") |
|
|
| if timm.__version__ not in {"0.9.10", "0.9.11", "0.9.12", "0.9.16"}: |
| raise NotImplementedError( |
| "TIMM Version must be >= 0.9.10 and < 1.0.0 (breaking); please raise a GitHub Issue " |
| "if you urgently need support for latest TIMM versions." |
| ) |
|
|
| if (transformers.__version__ != "4.40.1") or (tokenizers.__version__ != "0.19.1"): |
| logger.warning( |
| f"Expected `transformers==4.40.1` and `tokenizers==0.19.1` but got " |
| f"`transformers=={transformers.__version__}` and `tokenizers=={tokenizers.__version__}`; " |
| f"there might be inference-time regressions due to dependency changes. If in doubt, please" |
| f"use the above versions." |
| ) |
|
|
| |
| self.vision_backbone = PrismaticVisionBackbone( |
| config.use_fused_vision_backbone, config.image_sizes, config.timm_model_ids, config.timm_override_act_layers |
| ) |
|
|
| |
| self.projector = PrismaticProjector( |
| config.use_fused_vision_backbone, |
| vision_dim=self.vision_backbone.embed_dim, |
| llm_dim=config.text_config.hidden_size, |
| ) |
|
|
| |
| self.language_model = AutoModelForCausalLM.from_config( |
| config.text_config, attn_implementation=config._attn_implementation |
| ) |
| |
| |
|
|
| self.use_discrete_diffusion = config.use_discrete_diffusion |
| self.use_discrete_flow_matching = getattr(config, "use_discrete_flow_matching", False) |
| self.mask_token_id = config.mask_token_id |
| self.action_vocab_anchor = getattr(config, "action_vocab_anchor", "pad") |
|
|
| if self.use_discrete_diffusion and self.use_discrete_flow_matching: |
| raise ValueError("Cannot enable both discrete diffusion and discrete flow matching.") |
|
|
| self.vocab_size = config.text_config.vocab_size |
| self.pad_token_id = config.pad_token_id |
| self.llm_dim = config.text_config.hidden_size |
|
|
| |
| self._validate_action_vocab() |
|
|
| |
| self.post_init() |
|
|
| |
| def get_input_embeddings(self) -> nn.Module: |
| return self.language_model.get_input_embeddings() |
|
|
| def set_input_embeddings(self, value: nn.Module) -> None: |
| self.language_model.set_input_embeddings(value) |
|
|
| def get_output_embeddings(self) -> nn.Module: |
| return self.language_model.get_output_embeddings() |
|
|
| def set_output_embeddings(self, new_embeddings: nn.Module) -> None: |
| self.language_model.set_output_embeddings(new_embeddings) |
|
|
| def get_decoder(self) -> nn.Module: |
| return self.language_model.get_decoder() |
|
|
| def set_decoder(self, decoder: nn.Module) -> None: |
| self.language_model.set_decoder(decoder) |
|
|
| def tie_weights(self) -> None: |
| self.language_model.tie_weights() |
|
|
| def resize_token_embeddings( |
| self, new_num_tokens: Optional[int] = None, pad_to_multiple_of: Optional[int] = None |
| ) -> nn.Embedding: |
| updated_embeddings = self.language_model.resize_token_embeddings(new_num_tokens, pad_to_multiple_of) |
|
|
| |
| self.config.text_config.vocab_size = updated_embeddings.num_embeddings |
| self.vocab_size = updated_embeddings.num_embeddings |
|
|
| return updated_embeddings |
|
|
| def _replace_input_embeddings(self, input_embeddings, all_actions_mask, noisy_action_features): |
| """ |
| Replace embeddings in input_embeddings at positions where all_actions_mask is True |
| with embeddings from noisy_action_features, using vectorized operations. |
| |
| Args: |
| input_embeddings: Tensor of shape (B, S, D) |
| all_actions_mask: Boolean tensor of shape (B, S) |
| noisy_action_features: Tensor of shape (B, K, D) where K is the number of True values in mask per sample |
| |
| Returns: |
| Modified input_embeddings tensor |
| """ |
| |
| new_input_embeddings = input_embeddings.clone() |
|
|
| |
| repositioned_noisy_action_features = torch.zeros_like(input_embeddings) |
|
|
| |
| batch_indices = torch.arange(input_embeddings.shape[0], device=input_embeddings.device) |
| batch_indices = batch_indices.unsqueeze(1).expand(-1, noisy_action_features.shape[1]) |
|
|
| |
| masked_indices = torch.stack([torch.where(mask)[0] for mask in all_actions_mask]) |
|
|
| |
| repositioned_noisy_action_features[batch_indices, masked_indices] = noisy_action_features |
|
|
| |
| new_input_embeddings = torch.where( |
| all_actions_mask.unsqueeze(-1), repositioned_noisy_action_features, new_input_embeddings |
| ) |
|
|
| return new_input_embeddings |
|
|
| def _action_vocab_range(self) -> Tuple[int, int, int]: |
| """Return (action_begin, action_end, n_bins) for current config.""" |
| n_bins = getattr(self.config, "n_action_bins", None) |
| if n_bins is None and hasattr(self, "bin_centers"): |
| n_bins = int(self.bin_centers.shape[0]) |
| if n_bins is None: |
| raise ValueError("n_action_bins must be set on config or via bin_centers.") |
| n_bins = int(n_bins) |
| if getattr(self.config, "legacy_eval_mode", False) or getattr(self.config, "legacy_train_mode", False): |
| |
| |
| |
| action_begin = int(ACTION_TOKEN_BEGIN_IDX + 1) |
| action_end = int(action_begin + n_bins) |
| return action_begin, action_end, n_bins |
| begin_override = getattr(self.config, "action_token_begin_idx", None) |
| if begin_override is not None: |
| action_begin = int(begin_override) |
| action_end = int(action_begin + n_bins) |
| return action_begin, action_end, n_bins |
|
|
| anchor = getattr(self.config, "action_vocab_anchor", "pad") |
| if anchor == "legacy": |
| action_begin = int(ACTION_TOKEN_BEGIN_IDX) |
| action_end = int(action_begin + n_bins) |
| return action_begin, action_end, n_bins |
| if anchor == "pad": |
| action_end = int(self.pad_token_id) |
| elif anchor == "vocab_size": |
| action_end = int(self.vocab_size) |
| else: |
| raise ValueError(f"Unknown action_vocab_anchor: {anchor}") |
| action_begin = int(action_end - n_bins) |
| return action_begin, action_end, n_bins |
|
|
| def _validate_action_vocab(self) -> None: |
| """Validate that special tokens do not overlap action bins.""" |
| if getattr(self.config, "legacy_eval_mode", False) or getattr(self.config, "legacy_train_mode", False): |
| return |
| if not hasattr(self.config, "n_action_bins") and not hasattr(self, "bin_centers"): |
| return |
| action_begin, action_end, _ = self._action_vocab_range() |
| if action_begin < 0: |
| raise ValueError(f"Action vocab begin ({action_begin}) is negative; check n_action_bins/pad_token_id.") |
| if action_begin <= self.pad_token_id < action_end: |
| raise ValueError( |
| f"pad_token_id ({self.pad_token_id}) overlaps action range [{action_begin}, {action_end})." |
| ) |
| if action_begin <= self.mask_token_id < action_end: |
| raise ValueError( |
| f"mask_token_id ({self.mask_token_id}) overlaps action range [{action_begin}, {action_end})." |
| ) |
|
|
| def _process_action_masks(self, labels): |
| """Helper to get action masks from labels""" |
| if getattr(self.config, "legacy_eval_mode", False) or getattr(self.config, "legacy_train_mode", False): |
| current_action_mask = get_current_action_mask(labels) |
| next_actions_mask = get_next_actions_mask(labels) |
| else: |
| action_begin, action_end, _ = self._action_vocab_range() |
| current_action_mask = get_current_action_mask(labels, action_begin, action_end) |
| next_actions_mask = get_next_actions_mask(labels, action_begin, action_end) |
| all_actions_mask = current_action_mask | next_actions_mask |
| return all_actions_mask |
|
|
| def _compute_language_embeddings(self, input_embeddings, attention_mask, all_actions_mask): |
| """Compute a safe language embedding summary for FiLM conditioning.""" |
| if attention_mask is None: |
| attention_mask = input_embeddings.new_ones(input_embeddings.shape[:2], dtype=torch.bool) |
| else: |
| attention_mask = attention_mask.to(dtype=torch.bool) |
|
|
| language_mask = attention_mask & (~all_actions_mask) |
| if not torch.any(language_mask): |
| return input_embeddings.new_zeros((input_embeddings.shape[0], 1, input_embeddings.shape[2])) |
|
|
| denom = language_mask.sum(dim=1).clamp(min=1).unsqueeze(-1) |
| summed = (input_embeddings * language_mask.unsqueeze(-1)).sum(dim=1) |
| return (summed / denom).unsqueeze(1) |
|
|
| def _process_vision_features(self, pixel_values, language_embeddings=None, use_film=False): |
| """Process vision features with optional FiLM conditioning""" |
| if use_film: |
| |
| patch_features = self.vision_backbone(pixel_values, language_embeddings) |
| else: |
| patch_features = self.vision_backbone(pixel_values) |
|
|
| |
| return self.projector(patch_features) |
|
|
| def _process_proprio_features(self, projected_patch_embeddings, proprio, proprio_projector): |
| """Process proprioceptive features and append to vision features""" |
| if proprio_projector is not None and proprio is not None: |
| |
| |
| proprio = proprio.reshape(projected_patch_embeddings.shape[0], -1) |
| proprio_features = proprio_projector(proprio) |
| proprio_features = proprio_features.unsqueeze(dim=1) |
| |
| return torch.cat((projected_patch_embeddings, proprio_features), dim=1) |
| return projected_patch_embeddings |
|
|
| def _build_multimodal_attention(self, input_embeddings, projected_patch_embeddings, attention_mask): |
| """Build multimodal embeddings and attention mask""" |
| |
| projected_patch_attention_mask = None |
| if attention_mask is not None: |
| projected_patch_attention_mask = torch.full( |
| (projected_patch_embeddings.shape[0], projected_patch_embeddings.shape[1]), |
| fill_value=True, |
| dtype=attention_mask.dtype, |
| device=attention_mask.device, |
| ) |
|
|
| |
| multimodal_embeddings = torch.cat( |
| [input_embeddings[:, :1, :], projected_patch_embeddings, input_embeddings[:, 1:, :]], dim=1 |
| ) |
|
|
| multimodal_attention_mask = None |
| if attention_mask is not None: |
| multimodal_attention_mask = torch.cat( |
| [attention_mask[:, :1], projected_patch_attention_mask, attention_mask[:, 1:]], dim=1 |
| ) |
|
|
| return multimodal_embeddings, multimodal_attention_mask |
|
|
| def _build_multimodal_labels(self, labels, projected_patch_embeddings): |
| """Build multimodal labels with IGNORE_INDEX for patch embeddings""" |
| if labels is not None: |
| projected_patch_labels = torch.full( |
| (projected_patch_embeddings.shape[0], projected_patch_embeddings.shape[1]), |
| fill_value=IGNORE_INDEX, |
| dtype=labels.dtype, |
| device=labels.device, |
| ) |
| return torch.cat([labels[:, :1], projected_patch_labels, labels[:, 1:]], dim=1) |
| return None |
|
|
| @staticmethod |
| def _build_multimodal_token_ids(token_ids, projected_patch_embeddings, patch_fill_id: int = 0): |
| """Build multimodal token IDs with patch fill tokens inserted after BOS.""" |
| if token_ids is None: |
| return None |
| patch_len = projected_patch_embeddings.shape[1] if projected_patch_embeddings is not None else 0 |
| if patch_len > 0: |
| patch_ids = torch.full( |
| (token_ids.shape[0], patch_len), |
| fill_value=patch_fill_id, |
| dtype=token_ids.dtype, |
| device=token_ids.device, |
| ) |
| return torch.cat([token_ids[:, :1], patch_ids, token_ids[:, 1:]], dim=1) |
| return token_ids |
|
|
| @staticmethod |
| def _expand_mask_with_patches(mask, projected_patch_embeddings): |
| """Expand a (B, L) mask to include patch positions after BOS.""" |
| if mask is None: |
| return None |
| patch_len = projected_patch_embeddings.shape[1] if projected_patch_embeddings is not None else 0 |
| if patch_len > 0: |
| patch_mask = torch.zeros( |
| (mask.shape[0], patch_len), device=mask.device, dtype=mask.dtype |
| ) |
| return torch.cat([mask[:, :1], patch_mask, mask[:, 1:]], dim=1) |
| return mask |
|
|
| def _get_eos_pos(self, all_actions_mask): |
| """Prepare loss mask for discrete diffusion""" |
| loss_mask_full = all_actions_mask.clone() |
|
|
| |
| batch_size, seq_len = loss_mask_full.shape |
| |
| last_true = seq_len - 1 - torch.argmax( |
| loss_mask_full.flip(dims=[1]).int(), dim=1 |
| ) |
|
|
| |
| next_pos = last_true + 1 |
| |
| valid = next_pos < seq_len |
| batch_idx = torch.arange(batch_size, device=loss_mask_full.device)[valid] |
|
|
| |
| loss_mask_full[batch_idx, next_pos[valid]] = True |
| return next_pos[valid] |
|
|
| def apply_mask_diffusion( |
| self, |
| input_ids: torch.LongTensor, |
| input_embeddings: torch.Tensor, |
| labels: torch.LongTensor, |
| loss_mask_full: torch.BoolTensor, |
| mask_token_id: int, |
| eos_pos: Optional[torch.LongTensor] = None, |
| no_mask_token_prob: float = 0.0, |
| ): |
| """ |
| 输入: |
| - input_ids: 原始 token id 序列 |
| - input_embeddings: 原始 embedding 序列 |
| - loss_mask_full: 全局可 mask 位掩码(包括动作 token 位置) |
| - mask_token_id: 用于填充的 special mask token id |
| - no_mask_token_prob: 可选概率,把已 mask 掉的位置再随机 unmask |
| 返回: |
| - masked_input_ids: 用 mask_token_id 替代被 mask 掉的位置的 input_ids |
| - labels: 原 input_ids,在非被 mask 位置用 -100 屏蔽(CrossEntropyLoss 忽略) |
| - new_input_embeddings: 对应替换了 action token 的新 embeddings |
| - loss_mask: float mask,用于后续 loss 加权(1 表示预测该位置,0 表示忽略) |
| """ |
| B, L = input_ids.shape |
| device = input_ids.device |
|
|
| |
| |
| total_unknown = loss_mask_full.float().sum(dim=1) |
|
|
| |
| rand_time = torch.rand(B, device=device) |
|
|
| |
| |
| mask_ratios = mask_schedule(rand_time, total_unknown, method="cosine") |
| |
| num_mask = torch.clamp((total_unknown * mask_ratios).round(), min=1).long() |
|
|
| |
| |
| vals = torch.rand(B, L, device=device) |
| |
| large = float('inf') |
| |
| vals = torch.where(loss_mask_full, vals, vals + large) |
|
|
| |
| perm = vals.argsort(dim=1) |
| ranks = perm.argsort(dim=1) |
| |
| masked_mask = ranks < num_mask[:, None] |
|
|
| |
| if no_mask_token_prob > 0: |
| |
| |
| prob = torch.rand(B, L, device=device) |
| |
| unmask = (prob < no_mask_token_prob) & masked_mask |
| masked_mask = masked_mask & (~unmask) |
|
|
| |
| |
| |
| |
| |
| |
|
|
| |
| ignore_labels = torch.full_like(labels, fill_value=IGNORE_INDEX, dtype=labels.dtype, device=device) |
| masked_labels = torch.where(masked_mask, labels, ignore_labels) |
|
|
| masked_input_ids = torch.where(masked_mask, mask_token_id, input_ids) |
| |
| |
| |
| masked_input_embeddings = input_embeddings.clone() |
| |
| mask_emb = self.get_input_embeddings()(torch.tensor([mask_token_id], device=device)) |
| |
| mask_emb = mask_emb.view(1, 1, -1).expand(B, L, -1) |
| |
| masked_input_embeddings = torch.where(masked_mask.unsqueeze(-1), mask_emb, masked_input_embeddings) |
|
|
| |
| |
| loss_mask = masked_mask.float() |
|
|
| return masked_input_ids, masked_input_embeddings, masked_labels, loss_mask |
|
|
| @staticmethod |
| def _sample_mixture_mask(loss_mask_full: torch.BoolTensor, kappa_t: torch.Tensor) -> torch.BoolTensor: |
| """Sample per-coordinate Bernoulli mask for mixture path corruption.""" |
| p_mask = (1.0 - kappa_t).view(-1, 1) |
| rand = torch.rand(loss_mask_full.shape, device=loss_mask_full.device) |
| return (rand < p_mask) & loss_mask_full |
|
|
| def apply_mask_flow_matching( |
| self, |
| input_ids: torch.LongTensor, |
| input_embeddings: torch.Tensor, |
| labels: torch.LongTensor, |
| loss_mask_full: torch.BoolTensor, |
| mask_token_id: int, |
| schedule: str = "cosine", |
| time_eps: float = 1e-3, |
| t_min: float = 0.0, |
| t_max: float = 1.0, |
| t_bias_alpha: float = 1.0, |
| ): |
| """ |
| Apply mask-only corruption following a DFM mixture path. |
| Returns masked inputs/labels plus loss mask and kappa stats. |
| """ |
| B, L = input_ids.shape |
| device = input_ids.device |
|
|
| |
| t_low = max(t_min, time_eps) |
| t_high = min(t_max, 1.0 - time_eps) |
| if t_high <= t_low: |
| raise ValueError("Invalid DFM time range after applying eps clamp.") |
| u = torch.rand(B, device=device) |
| if t_bias_alpha is not None and t_bias_alpha != 1.0: |
| u = torch.pow(u, t_bias_alpha) |
| t = u * (t_high - t_low) + t_low |
| kappa_t = kappa(t, schedule=schedule) |
| kdot_t = kappa_dot(t, schedule=schedule) |
|
|
| masked_mask = self._sample_mixture_mask(loss_mask_full, kappa_t) |
|
|
| ignore_labels = torch.full_like(labels, fill_value=IGNORE_INDEX, dtype=labels.dtype, device=device) |
| masked_labels = torch.where(masked_mask, labels, ignore_labels) |
|
|
| masked_input_ids = torch.where(masked_mask, mask_token_id, input_ids) |
|
|
| masked_input_embeddings = input_embeddings.clone() |
| mask_emb = self.get_input_embeddings()(torch.tensor([mask_token_id], device=device)) |
| mask_emb = mask_emb.view(1, 1, -1).expand(B, L, -1) |
| masked_input_embeddings = torch.where(masked_mask.unsqueeze(-1), mask_emb, masked_input_embeddings) |
|
|
| loss_mask = masked_mask.float() |
|
|
| return masked_input_ids, masked_input_embeddings, masked_labels, loss_mask, kappa_t, kdot_t, t |
|
|
| @staticmethod |
| def _dfm_generalized_kl_loss( |
| *, |
| shift_logits: torch.Tensor, |
| x1: torch.Tensor, |
| xt: torch.Tensor, |
| action_mask: torch.Tensor, |
| kappa_t: torch.Tensor, |
| kdot_t: torch.Tensor, |
| action_begin: int, |
| action_end: int, |
| mask_id: int, |
| weight_clip: float = 20.0, |
| eps: float = 1e-8, |
| ) -> torch.Tensor: |
| """Generalized-KL loss for mixture path on action coordinates.""" |
| B, _, _ = shift_logits.shape |
| am = action_mask.bool() |
|
|
| action_ids = torch.arange(action_begin, action_end, device=shift_logits.device) |
| allowed_ids = torch.cat([action_ids, torch.tensor([mask_id], device=shift_logits.device)], dim=0) |
| K = allowed_ids.numel() |
|
|
| logits = shift_logits.index_select(dim=-1, index=allowed_ids) |
|
|
| x1_safe = torch.where(am, x1, torch.full_like(x1, action_begin)) |
| xt_safe = torch.where(am, xt, torch.full_like(xt, action_begin)) |
|
|
| x1_idx = x1_safe - action_begin |
| xt_idx = torch.where(xt_safe == mask_id, torch.full_like(xt_safe, K - 1), xt_safe - action_begin) |
|
|
| log_p = torch.log_softmax(logits, dim=-1) |
| log_p_x1 = log_p.gather(-1, x1_idx.unsqueeze(-1)).squeeze(-1) |
| log_p_xt = log_p.gather(-1, xt_idx.unsqueeze(-1)).squeeze(-1) |
| p_xt = torch.exp(log_p_xt) |
|
|
| delta = (xt_safe == x1_safe).to(log_p.dtype) |
|
|
| denom = (1.0 - kappa_t).clamp(min=eps) |
| w = (kdot_t / denom).clamp(min=0.0, max=weight_clip).view(B, 1) |
|
|
| loss_pos = -w * (p_xt - delta + (1.0 - delta) * log_p_x1) |
|
|
| return (loss_pos * am).sum() / am.sum().clamp(min=1.0) |
|
|
| |
| def forward( |
| self, |
| input_ids: Optional[torch.LongTensor] = None, |
| attention_mask: Optional[torch.Tensor] = None, |
| pixel_values: Optional[torch.FloatTensor] = None, |
| labels: Optional[torch.LongTensor] = None, |
| inputs_embeds: Optional[torch.FloatTensor] = None, |
| past_key_values: Optional[List[torch.FloatTensor]] = None, |
| use_cache: Optional[bool] = None, |
| output_attentions: Optional[bool] = None, |
| output_hidden_states: Optional[bool] = None, |
| output_projector_features: Optional[bool] = None, |
| return_dict: Optional[bool] = None, |
| proprio=None, |
| proprio_projector=None, |
| noisy_actions=None, |
| noisy_action_projector=None, |
| diffusion_timestep_embeddings=None, |
| use_film: bool = False, |
| dfm_schedule: Optional[str] = None, |
| dfm_time_eps: float = 1e-3, |
| dfm_t_min: float = 0.0, |
| dfm_t_max: float = 1.0, |
| dfm_loss_mode: Optional[str] = None, |
| dfm_weight_clip: float = 20.0, |
| dfm_train_mode: Optional[str] = None, |
| dfm_t_bias_alpha: Optional[float] = None, |
| dfm_log_mask_stats: bool = False, |
| ) -> Union[Tuple, PrismaticCausalLMOutputWithPast]: |
| """Run a forward pass through the VLM, returning a PrismaticCausalLMOutputWithPast instance.""" |
| output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions |
| output_hidden_states = ( |
| output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states |
| ) |
| output_projector_features = output_projector_features if output_projector_features is not None else False |
| return_dict = return_dict if return_dict is not None else self.config.use_return_dict |
|
|
| |
| use_cache = use_cache and not self.training |
|
|
| |
| projected_patch_embeddings = None |
| dfm_loss_mask = None |
| dfm_weight = None |
| multimodal_labels = None |
| dfm_stats = None |
| dfm_trace = None |
| dfm_action_token_count = None |
| dfm_t = None |
| multimodal_x1_labels = None |
| multimodal_xt_ids = None |
| multimodal_actions_mask = None |
|
|
| |
| dfm_schedule = dfm_schedule or getattr(self.config, "dfm_schedule", "cosine") |
| dfm_loss_mode = dfm_loss_mode or getattr(self.config, "dfm_loss_mode", "generalized_kl") |
| dfm_train_mode = dfm_train_mode or "flow" |
| if dfm_t_bias_alpha is None: |
| dfm_t_bias_alpha = getattr(self.config, "dfm_t_bias_alpha", 1.0) |
|
|
| |
| if input_ids.shape[1] == 1: |
| assert input_ids.shape[0] == 1, "Generation is only currently supported for batch size of 1!" |
| assert past_key_values is not None, "You must provide `past_key_values` during cached generation!" |
| assert labels is None, "Unexpected key `labels` provided during cached generation!" |
|
|
| language_model_output = self.language_model( |
| input_ids=input_ids, |
| attention_mask=None, |
| position_ids=None, |
| past_key_values=past_key_values, |
| inputs_embeds=None, |
| labels=None, |
| use_cache=use_cache, |
| output_attentions=output_attentions, |
| output_hidden_states=output_hidden_states, |
| return_dict=return_dict, |
| ) |
|
|
| |
| elif pixel_values is None: |
| assert (input_ids is not None) and (inputs_embeds is None), "Missing `input_ids` in language-only forward!" |
| assert past_key_values is None, "Unexpected key `past_key_values` provided during language-only forward!" |
|
|
| language_model_output = self.language_model( |
| input_ids=input_ids, |
| attention_mask=attention_mask, |
| position_ids=None, |
| past_key_values=None, |
| inputs_embeds=None, |
| labels=labels, |
| use_cache=use_cache, |
| output_attentions=output_attentions, |
| output_hidden_states=output_hidden_states, |
| return_dict=return_dict, |
| ) |
|
|
| |
| elif (input_ids.shape[0] == pixel_values.shape[0]) or (inputs_embeds.shape[0] == pixel_values.shape[0]): |
| assert past_key_values is None, "Unexpected key `past_key_values` provided during multimodal forward!" |
|
|
| |
| input_embeddings = self.get_input_embeddings()(input_ids) |
|
|
| |
| all_actions_mask = self._process_action_masks(labels) |
|
|
| |
| language_embeddings = None |
| if use_film: |
| language_embeddings = self._compute_language_embeddings( |
| input_embeddings, attention_mask, all_actions_mask |
| ) |
|
|
| |
| projected_patch_embeddings = self._process_vision_features(pixel_values, language_embeddings, use_film) |
|
|
| |
| projected_patch_embeddings = self._process_proprio_features( |
| projected_patch_embeddings, proprio, proprio_projector |
| ) |
|
|
| |
| if diffusion_timestep_embeddings is not None: |
| |
| projected_patch_embeddings = torch.cat( |
| (projected_patch_embeddings, diffusion_timestep_embeddings), dim=1 |
| ) |
|
|
| |
| dfm_loss_mask = None |
| dfm_weight = None |
|
|
| if noisy_actions is not None: |
| |
| all_actions_mask = self._process_action_masks(labels) |
|
|
| |
| |
| B = noisy_actions.shape[0] |
| noisy_actions = noisy_actions.reshape(B, -1).unsqueeze(-1) |
|
|
| |
| noisy_action_features = noisy_action_projector(noisy_actions) |
|
|
| |
| input_embeddings = self._replace_input_embeddings( |
| input_embeddings, all_actions_mask, noisy_action_features |
| ) |
| elif self.use_discrete_diffusion: |
| |
| loss_mask_full = all_actions_mask |
| eos_pos = self._get_eos_pos(all_actions_mask) |
| input_ids, input_embeddings, labels, loss_mask = self.apply_mask_diffusion( |
| input_ids=input_ids, |
| input_embeddings=input_embeddings, |
| labels=labels, |
| loss_mask_full=loss_mask_full, |
| mask_token_id=self.mask_token_id, |
| no_mask_token_prob=0.0, |
| ) |
|
|
| |
| labels[torch.arange(labels.shape[0]), eos_pos] = STOP_INDEX |
|
|
| elif self.use_discrete_flow_matching: |
| labels_x1 = labels.clone() if labels is not None else None |
| loss_mask_full = all_actions_mask |
| dfm_action_token_count = loss_mask_full.sum(dim=1) |
| if dfm_train_mode == "diffusion_like": |
| eos_pos = self._get_eos_pos(all_actions_mask) |
| ( |
| input_ids, |
| input_embeddings, |
| labels, |
| dfm_loss_mask, |
| ) = self.apply_mask_diffusion( |
| input_ids=input_ids, |
| input_embeddings=input_embeddings, |
| labels=labels, |
| loss_mask_full=loss_mask_full, |
| mask_token_id=self.mask_token_id, |
| no_mask_token_prob=0.0, |
| ) |
| mask_ratio = dfm_loss_mask.sum(dim=1) / dfm_action_token_count.clamp(min=1.0) |
| kappa_t = (1.0 - mask_ratio).clamp(min=0.0, max=1.0) |
| kdot_t = torch.zeros_like(kappa_t) |
| dfm_weight = torch.ones_like(kappa_t) |
| dfm_t = None |
| dfm_loss_mode = "masked_ce" |
| |
| labels[torch.arange(labels.shape[0]), eos_pos] = STOP_INDEX |
| dfm_loss_mask[torch.arange(dfm_loss_mask.shape[0]), eos_pos] = 1.0 |
| else: |
| ( |
| input_ids, |
| input_embeddings, |
| labels, |
| dfm_loss_mask, |
| kappa_t, |
| kdot_t, |
| dfm_t, |
| ) = self.apply_mask_flow_matching( |
| input_ids=input_ids, |
| input_embeddings=input_embeddings, |
| labels=labels, |
| loss_mask_full=loss_mask_full, |
| mask_token_id=self.mask_token_id, |
| schedule=dfm_schedule, |
| time_eps=dfm_time_eps, |
| t_min=dfm_t_min, |
| t_max=dfm_t_max, |
| t_bias_alpha=dfm_t_bias_alpha, |
| ) |
| denom = (1.0 - kappa_t).clamp(min=1e-8) |
| dfm_weight = (kdot_t / denom).clamp(min=0.0, max=dfm_weight_clip) |
| mask_ratio = (1.0 - kappa_t).clamp(min=0.0, max=1.0) |
| if os.environ.get("VLA_DFM_DEBUG", "0") == "1": |
| self._dfm_debug_step = getattr(self, "_dfm_debug_step", 0) + 1 |
| debug_every = int(os.environ.get("VLA_DFM_DEBUG_EVERY", "200")) |
| if self._dfm_debug_step % debug_every == 0: |
| is_rank0 = (not torch.distributed.is_initialized()) or torch.distributed.get_rank() == 0 |
| if is_rank0: |
| with torch.no_grad(): |
| supervised_mask = labels != IGNORE_INDEX |
| action_mask = loss_mask_full |
| supervised_in_action = supervised_mask & action_mask |
| supervised_outside_action = supervised_mask & (~action_mask) |
| action_count = action_mask.sum().item() |
| supervised_count = supervised_in_action.sum().item() |
| outside_count = supervised_outside_action.sum().item() |
| supervised_frac = supervised_count / max(action_count, 1) |
| per_batch_frac = ( |
| dfm_loss_mask.sum(dim=1) / dfm_action_token_count.clamp(min=1.0) |
| ).mean().item() |
| t_mean = dfm_t.mean().item() if dfm_t is not None else float("nan") |
| t_min = dfm_t.min().item() if dfm_t is not None else float("nan") |
| t_max = dfm_t.max().item() if dfm_t is not None else float("nan") |
| mr_mean = mask_ratio.mean().item() |
| mr_min = mask_ratio.min().item() |
| mr_max = mask_ratio.max().item() |
| w_mean = dfm_weight.mean().item() |
| w_min = dfm_weight.min().item() |
| w_max = dfm_weight.max().item() |
| logger.info( |
| "[DFM DEBUG] supervised_action_tokens=%d/%d (%.4f) " |
| "per_batch_supervised_frac=%.4f supervised_outside_action=%d " |
| "dfm_loss_mask_sum=%d t_mean=%.4f t_min=%.4f t_max=%.4f " |
| "mask_ratio_mean=%.4f mask_ratio_min=%.4f mask_ratio_max=%.4f " |
| "w_mean=%.4f w_min=%.4f w_max=%.4f", |
| supervised_count, |
| action_count, |
| supervised_frac, |
| per_batch_frac, |
| outside_count, |
| dfm_loss_mask.sum().item(), |
| t_mean, |
| t_min, |
| t_max, |
| mr_mean, |
| mr_min, |
| mr_max, |
| w_mean, |
| w_min, |
| w_max, |
| ) |
| if supervised_count == 0: |
| logger.warning( |
| "[DFM DEBUG] No supervised action tokens (all -100). Check masking/labels." |
| ) |
| if outside_count > 0: |
| logger.warning( |
| "[DFM DEBUG] Found supervised tokens outside action mask. Check loss_mask_full/labels." |
| ) |
|
|
| else: |
| |
| |
| all_actions_mask = all_actions_mask.unsqueeze(-1) |
| input_embeddings = input_embeddings * ~all_actions_mask |
|
|
| |
| multimodal_embeddings, multimodal_attention_mask = self._build_multimodal_attention( |
| input_embeddings, projected_patch_embeddings, attention_mask |
| ) |
|
|
| |
| multimodal_labels = self._build_multimodal_labels(labels, projected_patch_embeddings) |
| if self.use_discrete_flow_matching and labels_x1 is not None: |
| multimodal_x1_labels = self._build_multimodal_labels(labels_x1, projected_patch_embeddings) |
| multimodal_xt_ids = self._build_multimodal_token_ids( |
| input_ids, projected_patch_embeddings, patch_fill_id=0 |
| ) |
| multimodal_actions_mask = self._expand_mask_with_patches(all_actions_mask, projected_patch_embeddings) |
| if os.environ.get("VLA_DFM_DEBUG", "0") == "1": |
| if multimodal_x1_labels is not None and multimodal_xt_ids is not None: |
| assert ( |
| multimodal_x1_labels.shape == multimodal_xt_ids.shape == multimodal_actions_mask.shape |
| ), "DFM multimodal tensors are misaligned." |
|
|
| |
| language_model_output = self.language_model( |
| input_ids=None, |
| attention_mask=multimodal_attention_mask, |
| position_ids=None, |
| past_key_values=None, |
| inputs_embeds=multimodal_embeddings, |
| labels=multimodal_labels, |
| use_cache=use_cache, |
| output_attentions=output_attentions, |
| output_hidden_states=output_hidden_states, |
| return_dict=return_dict, |
| ) |
|
|
| |
| elif (input_ids.shape[0] != pixel_values.shape[0]) or (inputs_embeds.shape[0] != pixel_values.shape[0]): |
| raise ValueError("Non-homogenous batch of (text, image) input -- forward() does not support mixed batches!") |
|
|
| else: |
| raise ValueError( |
| "Invalid PrismaticForConditionalGeneration `forward()` call with provided arguments:\n" |
| f"=> `input_ids` = {input_ids is not None}\n" |
| f"=> `attention_mask` = {attention_mask is not None}\n" |
| f"=> `pixel_values` = {pixel_values is not None}\n" |
| f"=> `labels` = {labels is not None}\n" |
| f"=> `input_embeds` = {inputs_embeds is not None}\n" |
| f"=> `past_key_values` = {past_key_values is not None}\n" |
| f"=> `use_cache` = {use_cache}" |
| ) |
|
|
| |
| if not return_dict: |
| if output_projector_features and (projected_patch_embeddings is not None): |
| return *language_model_output, projected_patch_embeddings |
|
|
| return language_model_output |
|
|
| |
| lm_loss = language_model_output.loss |
| if self.use_discrete_flow_matching and (dfm_loss_mask is not None): |
| |
| patch_len = projected_patch_embeddings.shape[1] if projected_patch_embeddings is not None else 0 |
| if patch_len > 0: |
| patch_mask = torch.zeros( |
| (dfm_loss_mask.shape[0], patch_len), device=dfm_loss_mask.device, dtype=dfm_loss_mask.dtype |
| ) |
| multimodal_loss_mask = torch.cat([dfm_loss_mask[:, :1], patch_mask, dfm_loss_mask[:, 1:]], dim=1) |
| else: |
| multimodal_loss_mask = dfm_loss_mask |
|
|
| skip_dfm_loss_override = dfm_train_mode == "diffusion_like" |
| if not skip_dfm_loss_override: |
| logits = language_model_output.logits |
| |
| shift_logits = logits[:, :-1, :] |
| if dfm_loss_mode == "generalized_kl": |
| if multimodal_x1_labels is None or multimodal_xt_ids is None or multimodal_actions_mask is None: |
| raise ValueError("Missing multimodal bookkeeping for generalized_kl loss.") |
| x1 = multimodal_x1_labels[:, 1:] |
| xt = multimodal_xt_ids[:, 1:] |
| am = multimodal_actions_mask[:, 1:] |
| action_begin, action_end, _ = self._action_vocab_range() |
| if action_begin <= self.mask_token_id < action_end: |
| raise ValueError( |
| "mask_token_id overlaps action vocab range; generalized_kl requires distinct mask token." |
| ) |
| lm_loss = self._dfm_generalized_kl_loss( |
| shift_logits=shift_logits, |
| x1=x1, |
| xt=xt, |
| action_mask=am, |
| kappa_t=kappa_t, |
| kdot_t=kdot_t, |
| action_begin=action_begin, |
| action_end=action_end, |
| mask_id=self.mask_token_id, |
| weight_clip=dfm_weight_clip, |
| ) |
| else: |
| vocab_size = logits.shape[-1] |
| shift_labels = multimodal_labels[:, 1:] |
| shift_loss_mask = multimodal_loss_mask[:, 1:] |
|
|
| token_losses = torch.nn.functional.cross_entropy( |
| shift_logits.reshape(-1, vocab_size), |
| shift_labels.reshape(-1), |
| reduction="none", |
| ignore_index=IGNORE_INDEX, |
| ).view(shift_labels.shape) |
|
|
| if dfm_loss_mode == "masked_ce": |
| weight = torch.ones_like(shift_loss_mask) |
| else: |
| weight = dfm_weight.view(-1, 1).expand_as(shift_loss_mask) |
|
|
| masked_loss = token_losses * shift_loss_mask * weight |
| denom = (shift_loss_mask * weight).sum().clamp(min=1.0) |
| lm_loss = masked_loss.sum() / denom |
|
|
| |
| with torch.no_grad(): |
| if dfm_action_token_count is None: |
| mask_frac = multimodal_loss_mask.sum(dim=1) / multimodal_loss_mask.shape[1] |
| else: |
| mask_frac = dfm_loss_mask.sum(dim=1) / dfm_action_token_count.clamp(min=1.0) |
| mask_ratio = (1.0 - kappa_t).clamp(min=0.0, max=1.0) |
| if dfm_log_mask_stats: |
| dfm_t_trace = dfm_t if dfm_t is not None else torch.full_like(kappa_t, float("nan")) |
| dfm_trace = { |
| "t": dfm_t_trace.detach(), |
| "kappa": kappa_t.detach(), |
| "mask_frac": mask_frac.detach(), |
| } |
| dfm_stats = { |
| "kappa_mean": kappa_t.mean().detach(), |
| "t_mean": dfm_t.mean().detach() if dfm_t is not None else torch.tensor(float("nan")), |
| "t_min": dfm_t.min().detach() if dfm_t is not None else torch.tensor(float("nan")), |
| "t_max": dfm_t.max().detach() if dfm_t is not None else torch.tensor(float("nan")), |
| "mask_ratio_mean": mask_ratio.mean().detach(), |
| "mask_ratio_min": mask_ratio.min().detach(), |
| "mask_ratio_max": mask_ratio.max().detach(), |
| "mask_frac_mean": mask_frac.mean().detach(), |
| "w_mean": dfm_weight.mean().detach(), |
| "w_min": dfm_weight.min().detach(), |
| "w_max": dfm_weight.max().detach(), |
| "frac_w_clipped": (dfm_weight >= dfm_weight_clip).float().mean().detach(), |
| "num_supervised_tokens": dfm_loss_mask.sum().detach(), |
| } |
| if dfm_loss_mode == "generalized_kl" and multimodal_x1_labels is not None: |
| shift_logits = language_model_output.logits[:, :-1, :] |
| shift_x1 = multimodal_x1_labels[:, 1:] |
| shift_xt = multimodal_xt_ids[:, 1:] |
| shift_am = multimodal_actions_mask[:, 1:].bool() |
| action_begin, action_end, _ = self._action_vocab_range() |
| action_ids = torch.arange(action_begin, action_end, device=shift_logits.device) |
| allowed_ids = torch.cat( |
| [action_ids, torch.tensor([self.mask_token_id], device=shift_logits.device)], dim=0 |
| ) |
| K = allowed_ids.numel() |
| logits = shift_logits.index_select(dim=-1, index=allowed_ids) |
| log_p = torch.log_softmax(logits, dim=-1) |
| x1_idx = torch.where( |
| shift_am, shift_x1 - action_begin, torch.zeros_like(shift_x1) |
| ) |
| xt_idx = torch.where( |
| shift_am, |
| torch.where( |
| shift_xt == self.mask_token_id, |
| torch.full_like(shift_xt, K - 1), |
| shift_xt - action_begin, |
| ), |
| torch.zeros_like(shift_xt), |
| ) |
| log_p_x1 = log_p.gather(-1, x1_idx.unsqueeze(-1)).squeeze(-1) |
| log_p_xt = log_p.gather(-1, xt_idx.unsqueeze(-1)).squeeze(-1) |
| p_xt = torch.exp(log_p_xt) |
| masked_xt = shift_xt == self.mask_token_id |
| denom = shift_am.sum().clamp(min=1.0) |
| frac_same = ((shift_xt == shift_x1) & shift_am).sum().float() / denom |
| mean_logp_x1_masked = ( |
| log_p_x1[shift_am & masked_xt].mean() if (shift_am & masked_xt).any() else torch.tensor(0.0) |
| ) |
| mean_p_xt_masked = ( |
| p_xt[shift_am & masked_xt].mean() if (shift_am & masked_xt).any() else torch.tensor(0.0) |
| ) |
| jump_coeff = (kdot_t / (1.0 - kappa_t).clamp(min=1e-8)).clamp(max=dfm_weight_clip) |
| dfm_stats.update( |
| { |
| "frac_same": frac_same.detach(), |
| "jump_coeff_mean": jump_coeff.mean().detach(), |
| "jump_coeff_p95": torch.quantile(jump_coeff, 0.95).detach(), |
| "mean_logp_x1_masked": mean_logp_x1_masked.detach(), |
| "mean_p_xt_masked": mean_p_xt_masked.detach(), |
| } |
| ) |
|
|
| |
| if os.environ.get("VLA_DFM_DEBUG", "0") == "1": |
| debug_every = int(os.environ.get("VLA_DFM_DEBUG_EVERY", "200")) |
| debug_step = getattr(self, "_dfm_debug_step", 0) |
| if debug_step % debug_every == 0: |
| is_rank0 = (not torch.distributed.is_initialized()) or torch.distributed.get_rank() == 0 |
| if is_rank0: |
| with torch.no_grad(): |
| hf_loss = language_model_output.loss.detach() |
| if skip_dfm_loss_override: |
| |
| logits = language_model_output.logits |
| vocab_size = logits.shape[-1] |
| shift_logits = logits[:, :-1, :] |
| shift_labels = multimodal_labels[:, 1:] |
| shift_loss_mask = multimodal_loss_mask[:, 1:] |
| token_losses = torch.nn.functional.cross_entropy( |
| shift_logits.reshape(-1, vocab_size), |
| shift_labels.reshape(-1), |
| reduction="none", |
| ignore_index=IGNORE_INDEX, |
| ).view(shift_labels.shape) |
| if dfm_loss_mode == "masked_ce": |
| weight = torch.ones_like(shift_loss_mask) |
| else: |
| weight = dfm_weight.view(-1, 1).expand_as(shift_loss_mask) |
| masked_loss = token_losses * shift_loss_mask * weight |
| denom = (shift_loss_mask * weight).sum().clamp(min=1.0) |
| dfm_loss_check = masked_loss.sum() / denom |
| else: |
| dfm_loss_check = lm_loss.detach() |
| logger.info( |
| "[DFM DEBUG] hf_loss=%.6f dfm_loss_check=%.6f skip_dfm_loss_override=%s", |
| hf_loss.item(), |
| dfm_loss_check.item(), |
| str(skip_dfm_loss_override), |
| ) |
|
|
| return PrismaticCausalLMOutputWithPast( |
| loss=lm_loss, |
| logits=language_model_output.logits, |
| past_key_values=language_model_output.past_key_values, |
| hidden_states=language_model_output.hidden_states, |
| attentions=language_model_output.attentions, |
| projector_features=projected_patch_embeddings, |
| labels=labels if labels is not None else None, |
| dfm_stats=dfm_stats, |
| dfm_trace=dfm_trace, |
| ) |
|
|
| |
| def prepare_inputs_for_generation( |
| self, |
| input_ids: Optional[torch.Tensor] = None, |
| past_key_values: Optional[List[torch.FloatTensor]] = None, |
| inputs_embeds: Optional[torch.FloatTensor] = None, |
| pixel_values: Optional[torch.FloatTensor] = None, |
| attention_mask: Optional[torch.Tensor] = None, |
| **kwargs: str, |
| ) -> Dict[str, torch.Tensor]: |
| """Borrowed from `LlamaForCausalLM` and simplified for batch size = 1; mirrors original PrismaticVLM logic.""" |
| if ((input_ids is not None) and (input_ids.shape[0] > 1)) or ( |
| (inputs_embeds is not None) and (inputs_embeds.shape[0] > 1) |
| ): |
| raise ValueError("Generation with batch size > 1 is not currently supported!") |
|
|
| |
| if past_key_values is not None: |
| input_ids = input_ids[:, -1:] |
|
|
| |
| if inputs_embeds is not None and past_key_values is None: |
| model_inputs = {"input_embeds": inputs_embeds} |
| else: |
| model_inputs = {"input_ids": input_ids} |
|
|
| |
| model_inputs.update( |
| { |
| "attention_mask": attention_mask, |
| "pixel_values": pixel_values, |
| "past_key_values": past_key_values, |
| "use_cache": kwargs.get("use_cache"), |
| } |
| ) |
|
|
| return model_inputs |
|
|
| |
| def _reorder_cache(self, *args, **kwargs) -> Any: |
| return self.language_model._reorder_cache(*args, **kwargs) |
|
|
|
|
| class OpenVLAForActionPrediction(PrismaticForConditionalGeneration): |
| config_class: PretrainedConfig = OpenVLAConfig |
|
|
| def __init__(self, config: OpenVLAConfig) -> None: |
| super().__init__(config) |
| self.norm_stats = config.norm_stats |
|
|
| |
| legacy_bins = getattr(config, "legacy_train_mode", False) or getattr(config, "legacy_eval_mode", False) |
| if legacy_bins: |
| self.bins = np.linspace(-1, 1, config.n_action_bins) |
| else: |
| self.bins = np.linspace(-1, 1, config.n_action_bins + 1) |
| self.bin_centers = (self.bins[:-1] + self.bins[1:]) / 2.0 |
|
|
| |
| if hasattr(config, "topk_filter_thres"): |
| self.topk_filter_thres = config.topk_filter_thres |
| else: |
| self.topk_filter_thres = 0.0 |
|
|
| |
| self.vocab_size = self.config.text_config.vocab_size - self.config.pad_to_multiple_of |
|
|
| def _prepare_input_for_action_prediction(self, input_ids, attention_mask): |
| """Prepares input for action prediction by adding necessary tokens""" |
| |
| placeholder_action_token_ids = ( |
| torch.ones((input_ids.shape[0], ACTION_DIM * NUM_ACTIONS_CHUNK)).to(input_ids.device).to(input_ids.dtype) |
| ) |
| input_ids = torch.cat([input_ids, placeholder_action_token_ids], dim=-1) |
|
|
| |
| stop_token_id = torch.ones((input_ids.shape[0], 1)).to(input_ids.device).to(input_ids.dtype) * STOP_INDEX |
| input_ids = torch.cat([input_ids, stop_token_id], dim=-1) |
|
|
| |
| |
| mask_extension = ( |
| torch.ones((attention_mask.shape[0], input_ids.shape[-1] - attention_mask.shape[-1])) |
| .to(attention_mask.device) |
| .to(attention_mask.dtype) |
| ) |
| attention_mask = torch.cat([attention_mask, mask_extension], dim=-1) |
|
|
| return input_ids, attention_mask |
|
|
| def _prepare_labels_for_action_prediction(self, labels, input_ids): |
| """Creates labels tensor for action prediction if not provided""" |
| |
| if getattr(self.config, "legacy_eval_mode", False) or getattr(self.config, "legacy_train_mode", False): |
| ARBITRARY_ACTION_TOKEN_IDX = ACTION_TOKEN_BEGIN_IDX + 1 |
| else: |
| action_begin, _, _ = self._action_vocab_range() |
| ARBITRARY_ACTION_TOKEN_IDX = action_begin |
| labels_extension = ( |
| torch.ones((labels.shape[0], input_ids.shape[-1] - labels.shape[-1])).to(labels.device).to(labels.dtype) |
| * ARBITRARY_ACTION_TOKEN_IDX |
| ) |
| labels = torch.cat([labels, labels_extension], dim=-1) |
|
|
| |
| labels[:, -1] = STOP_INDEX |
|
|
| return labels |
|
|
| def _unnormalize_actions(self, normalized_actions, unnorm_key=None): |
| """Unnormalize actions using dataset statistics""" |
| action_norm_stats = self.get_action_stats(unnorm_key) |
|
|
| if ACTION_PROPRIO_NORMALIZATION_TYPE == NormalizationType.BOUNDS: |
| mask = action_norm_stats.get("mask", np.ones_like(action_norm_stats["min"], dtype=bool)) |
| action_high, action_low = np.array(action_norm_stats["max"]), np.array(action_norm_stats["min"]) |
| elif ACTION_PROPRIO_NORMALIZATION_TYPE == NormalizationType.BOUNDS_Q99: |
| mask = action_norm_stats.get("mask", np.ones_like(action_norm_stats["q01"], dtype=bool)) |
| action_high, action_low = np.array(action_norm_stats["q99"]), np.array(action_norm_stats["q01"]) |
| else: |
| raise ValueError("Unsupported action/proprio normalization type detected!") |
|
|
| actions = np.where( |
| mask, |
| 0.5 * (normalized_actions + 1) * (action_high - action_low + 1e-8) + action_low, |
| normalized_actions, |
| ) |
|
|
| return actions |
|
|
| def _run_diffusion_prediction( |
| self, |
| input_embeddings, |
| all_actions_mask, |
| noise, |
| action_head, |
| projected_patch_embeddings, |
| labels, |
| attention_mask, |
| NUM_PATCHES, |
| NUM_PROMPT_TOKENS, |
| noisy_action_projector, |
| ): |
| """Run diffusion-based action prediction""" |
| |
| orig_projected_patch_embeddings = projected_patch_embeddings.clone() |
| curr_noisy_actions = noise |
|
|
| |
| for t in action_head.noise_scheduler.timesteps: |
| |
| |
| timesteps = torch.Tensor([t]).to(labels.device) |
| diffusion_timestep_embeddings = ( |
| action_head.time_encoder(timesteps).to(curr_noisy_actions.dtype).to(curr_noisy_actions.device) |
| ) |
| diffusion_timestep_embeddings = diffusion_timestep_embeddings.unsqueeze(1) |
|
|
| |
| |
|
|
| |
| projected_patch_embeddings = torch.cat( |
| (orig_projected_patch_embeddings, diffusion_timestep_embeddings), dim=1 |
| ) |
|
|
| |
| B = curr_noisy_actions.shape[0] |
| orig_curr_noisy_actions_shape = curr_noisy_actions.shape |
| curr_noisy_actions = curr_noisy_actions.reshape(B, -1).unsqueeze(-1) |
| noisy_action_features = noisy_action_projector(curr_noisy_actions) |
| curr_noisy_actions = curr_noisy_actions.reshape(orig_curr_noisy_actions_shape) |
|
|
| |
| input_embeddings = self._replace_input_embeddings( |
| input_embeddings.clone(), all_actions_mask, noisy_action_features |
| ) |
|
|
| |
| multimodal_embeddings, multimodal_attention_mask = self._build_multimodal_attention( |
| input_embeddings, projected_patch_embeddings, attention_mask |
| ) |
|
|
| |
| language_model_output = self.language_model( |
| input_ids=None, |
| attention_mask=multimodal_attention_mask, |
| position_ids=None, |
| past_key_values=None, |
| inputs_embeds=multimodal_embeddings, |
| labels=None, |
| use_cache=None, |
| output_attentions=False, |
| output_hidden_states=True, |
| return_dict=True, |
| ) |
|
|
| |
| last_hidden_states = language_model_output.hidden_states[-1] |
| actions_hidden_states = last_hidden_states[ |
| :, |
| NUM_PATCHES + NUM_PROMPT_TOKENS : NUM_PATCHES + NUM_PROMPT_TOKENS + ACTION_DIM * NUM_ACTIONS_CHUNK, |
| :, |
| ] |
|
|
| |
| noise_pred = action_head.predict_noise(actions_hidden_states) |
| curr_noisy_actions = action_head.noise_scheduler.step(noise_pred, t, curr_noisy_actions).prev_sample |
|
|
| curr_noisy_actions = curr_noisy_actions.reshape(NUM_ACTIONS_CHUNK, ACTION_DIM) |
|
|
| |
| return curr_noisy_actions.float().cpu().detach().numpy(), actions_hidden_states |
|
|
| def _regression_or_discrete_prediction( |
| self, |
| input_embeddings, |
| all_actions_mask, |
| projected_patch_embeddings, |
| attention_mask, |
| labels, |
| NUM_PATCHES, |
| NUM_PROMPT_TOKENS, |
| action_head=None, |
| ): |
| """Run L1 regression-based continuous action prediction or discrete action tokens prediction.""" |
| |
| all_actions_mask = all_actions_mask.unsqueeze(-1) |
| input_embeddings = input_embeddings * ~all_actions_mask |
|
|
| |
| multimodal_embeddings, multimodal_attention_mask = self._build_multimodal_attention( |
| input_embeddings, projected_patch_embeddings, attention_mask |
| ) |
|
|
| |
| language_model_output = self.language_model( |
| input_ids=None, |
| attention_mask=multimodal_attention_mask, |
| position_ids=None, |
| past_key_values=None, |
| inputs_embeds=multimodal_embeddings, |
| labels=None, |
| use_cache=None, |
| output_attentions=False, |
| output_hidden_states=True, |
| return_dict=True, |
| ) |
|
|
| |
| last_hidden_states = language_model_output.hidden_states[-1] |
| actions_hidden_states = last_hidden_states[ |
| :, |
| NUM_PATCHES + NUM_PROMPT_TOKENS : NUM_PATCHES + NUM_PROMPT_TOKENS + ACTION_DIM * NUM_ACTIONS_CHUNK, |
| :, |
| ] |
|
|
| |
| if action_head is not None: |
| |
| normalized_actions = action_head.predict_action(actions_hidden_states) |
| normalized_actions = normalized_actions.reshape(NUM_ACTIONS_CHUNK, ACTION_DIM) |
| normalized_actions = normalized_actions.float().cpu().detach().numpy() |
| else: |
| |
| predicted_action_token_ids = ( |
| language_model_output.logits[ |
| :, |
| NUM_PATCHES + NUM_PROMPT_TOKENS : NUM_PATCHES + NUM_PROMPT_TOKENS + ACTION_DIM * NUM_ACTIONS_CHUNK, |
| ] |
| .argmax(dim=2) |
| .cpu() |
| .numpy() |
| ) |
| action_begin, action_end, _ = self._action_vocab_range() |
| discretized_actions = action_end - predicted_action_token_ids |
| discretized_actions = np.clip(discretized_actions - 1, a_min=0, a_max=self.bin_centers.shape[0] - 1) |
| normalized_actions = self.bin_centers[discretized_actions] |
| normalized_actions = normalized_actions.reshape(NUM_ACTIONS_CHUNK, ACTION_DIM) |
|
|
| return normalized_actions, actions_hidden_states |
|
|
| def _discrete_diffusion_prediction( |
| self, |
| input_embeddings, |
| all_actions_mask, |
| projected_patch_embeddings, |
| attention_mask, |
| labels, |
| NUM_PATCHES, |
| NUM_PROMPT_TOKENS, |
| action_head=None, |
| input_ids=None, |
| ): |
| """Run L1 regression-based continuous action prediction or discrete action tokens prediction.""" |
| assert input_ids is not None, "Input IDs must be provided for discrete diffusion prediction!" |
|
|
| |
| if action_head is not None: |
| pass |
| |
| |
| |
| |
| else: |
| legacy_eval_mode = getattr(self.config, "legacy_eval_mode", False) |
|
|
| def tokens_to_logits(suffix_seq: torch.LongTensor) -> torch.Tensor: |
|
|
| prefix = masked_input_ids[:, :1+NUM_PROMPT_TOKENS] |
| suffix = masked_input_ids[:, 1+NUM_PROMPT_TOKENS + ACTION_DIM * NUM_ACTIONS_CHUNK:] |
| full_seqs = torch.cat([prefix, suffix_seq, suffix], dim=1) |
|
|
| |
| input_embeddings = self.get_input_embeddings()(full_seqs) |
|
|
| |
| multimodal_embeddings, multimodal_attention_mask = self._build_multimodal_attention( |
| input_embeddings, projected_patch_embeddings, attention_mask |
| ) |
|
|
| language_model_output = self.language_model( |
| input_ids=None, |
| attention_mask=multimodal_attention_mask, |
| position_ids=None, |
| past_key_values=None, |
| inputs_embeds=multimodal_embeddings, |
| labels=None, |
| use_cache=None, |
| output_attentions=False, |
| output_hidden_states=True, |
| return_dict=True, |
| ) |
|
|
| logits = language_model_output.logits |
| |
| |
| |
| filtered_logits = logits |
|
|
| full_logits = filtered_logits[ |
| :, |
| NUM_PATCHES + NUM_PROMPT_TOKENS : NUM_PATCHES + NUM_PROMPT_TOKENS + ACTION_DIM * NUM_ACTIONS_CHUNK, |
| :self.vocab_size |
| ] |
| if not legacy_eval_mode: |
| action_begin, action_end, _ = self._action_vocab_range() |
| neg_inf = torch.finfo(full_logits.dtype).min |
| if action_begin > 0: |
| full_logits[..., :action_begin] = neg_inf |
| if action_end < full_logits.shape[-1]: |
| full_logits[..., action_end:] = neg_inf |
|
|
| |
| last_hidden_states = language_model_output.hidden_states[-1] |
| actions_hidden_states = last_hidden_states[ |
| :, |
| NUM_PATCHES + NUM_PROMPT_TOKENS : NUM_PATCHES + NUM_PROMPT_TOKENS + ACTION_DIM * NUM_ACTIONS_CHUNK, |
| :, |
| ] |
|
|
| return full_logits, actions_hidden_states |
|
|
| |
| |
| mask_token_id = self.mask_token_id |
| |
| if not getattr(self, "_dfm_mask_collision_warned", False): |
| if legacy_eval_mode: |
| n_bins = self.bin_centers.shape[0] + 1 |
| action_low = self.vocab_size - n_bins |
| action_high = self.vocab_size - 1 |
| if action_low <= mask_token_id <= action_high: |
| logger.warning( |
| "mask_token_id (%d) overlaps action-token range [%d, %d]. " |
| "DFM decoding will forbid mask-token sampling, but training/tokenizer config should be fixed.", |
| mask_token_id, |
| action_low, |
| action_high, |
| ) |
| else: |
| action_begin, action_end, _ = self._action_vocab_range() |
| if action_begin <= mask_token_id < action_end: |
| logger.warning( |
| "mask_token_id (%d) overlaps action-token range [%d, %d]. " |
| "DFM decoding will forbid mask-token sampling, but training/tokenizer config should be fixed.", |
| mask_token_id, |
| action_begin, |
| action_end - 1, |
| ) |
| self._dfm_mask_collision_warned = True |
| masked_input_ids = torch.where( |
| all_actions_mask, torch.tensor(mask_token_id, device=input_ids.device), input_ids |
| ) |
|
|
| cur_seqs = masked_input_ids[ |
| :, 1+NUM_PROMPT_TOKENS:1+NUM_PROMPT_TOKENS + ACTION_DIM * NUM_ACTIONS_CHUNK |
| ] |
|
|
| if legacy_eval_mode: |
| final_iters, actions_hidden_states = parallel_decode.legacy_decode( |
| init_ids=cur_seqs, |
| tokens_to_logits=tokens_to_logits, |
| mask_token_id=self.mask_token_id, |
| num_iter=12, |
| choice_temperature=1.0, |
| mask_scheduling_method="cosine", |
| use_remask=False, |
| ) |
| predicted_action_token_ids = final_iters[:, -1, :].cpu().numpy() |
| discretized_actions = self.vocab_size - predicted_action_token_ids |
| discretized_actions = np.clip( |
| discretized_actions - 1, a_min=0, a_max=self.bin_centers.shape[0] - 1 |
| ) |
| normalized_actions = self.bin_centers[discretized_actions] |
| normalized_actions = normalized_actions.reshape(NUM_ACTIONS_CHUNK, ACTION_DIM) |
| else: |
| final_iters, actions_hidden_states = parallel_decode.decode( |
| init_ids=cur_seqs, |
| tokens_to_logits=tokens_to_logits, |
| mask_token_id=self.mask_token_id, |
| num_iter=12, |
| choice_temperature=1.0, |
| mask_scheduling_method="cosine", |
| use_remask=False, |
| ) |
| predicted_action_token_ids = final_iters[:, -1, :].cpu().numpy() |
| action_begin, action_end, _ = self._action_vocab_range() |
| discretized_actions = action_end - predicted_action_token_ids |
| discretized_actions = np.clip( |
| discretized_actions - 1, a_min=0, a_max=self.bin_centers.shape[0] - 1 |
| ) |
| normalized_actions = self.bin_centers[discretized_actions] |
| normalized_actions = normalized_actions.reshape(NUM_ACTIONS_CHUNK, ACTION_DIM) |
|
|
| return normalized_actions, actions_hidden_states |
|
|
| def _discrete_flow_matching_prediction( |
| self, |
| input_embeddings, |
| all_actions_mask, |
| projected_patch_embeddings, |
| attention_mask, |
| labels, |
| NUM_PATCHES, |
| NUM_PROMPT_TOKENS, |
| action_head=None, |
| input_ids=None, |
| dfm_num_steps: int = 12, |
| dfm_maskgit_num_steps: int = 12, |
| dfm_maskgit_schedule: str = "cosine", |
| dfm_schedule: str = "cosine", |
| dfm_temperature: float = 1.0, |
| dfm_temperature_anneal: str = "none", |
| dfm_adaptive_step: bool = True, |
| dfm_step_min: float = 1e-4, |
| dfm_step_max: float = 0.2, |
| dfm_time_eps: float = 1e-3, |
| dfm_early_exit: bool = True, |
| dfm_early_exit_frac: float = 0.0, |
| dfm_corrector: bool = False, |
| dfm_corrector_iters: int = 1, |
| dfm_corrector_remask_frac: float = 0.1, |
| dfm_clamp_mask: bool = False, |
| dfm_clamp_values: Optional[torch.LongTensor] = None, |
| return_debug: bool = False, |
| dfm_debug_level: int = 1, |
| dfm_decode_mode: str = "ctmc", |
| dfm_log_mask_stats: bool = False, |
| dfm_divfree_eta: float = 0.0, |
| dfm_divfree_design: str = "general", |
| dfm_divfree_allow_reopen: bool = True, |
| dfm_divfree_debug: bool = False, |
| dfm_divfree_jump_tol: float = 1e-8, |
| ): |
| """CTMC discrete flow matching prediction.""" |
| assert input_ids is not None, "Input IDs must be provided for DFM prediction!" |
|
|
| if action_head is not None: |
| |
| pass |
| else: |
|
|
| def tokens_to_logits(suffix_seq: torch.LongTensor) -> torch.Tensor: |
| prefix = masked_input_ids[:, : 1 + NUM_PROMPT_TOKENS] |
| suffix = masked_input_ids[:, 1 + NUM_PROMPT_TOKENS + ACTION_DIM * NUM_ACTIONS_CHUNK :] |
| full_seqs = torch.cat([prefix, suffix_seq, suffix], dim=1) |
|
|
| input_embeddings = self.get_input_embeddings()(full_seqs) |
|
|
| multimodal_embeddings, multimodal_attention_mask = self._build_multimodal_attention( |
| input_embeddings, projected_patch_embeddings, attention_mask |
| ) |
|
|
| language_model_output = self.language_model( |
| input_ids=None, |
| attention_mask=multimodal_attention_mask, |
| position_ids=None, |
| past_key_values=None, |
| inputs_embeds=multimodal_embeddings, |
| labels=None, |
| use_cache=None, |
| output_attentions=False, |
| output_hidden_states=True, |
| return_dict=True, |
| ) |
|
|
| logits = language_model_output.logits |
|
|
| full_logits = logits[ |
| :, |
| NUM_PATCHES + NUM_PROMPT_TOKENS : NUM_PATCHES + NUM_PROMPT_TOKENS + ACTION_DIM * NUM_ACTIONS_CHUNK, |
| : self.vocab_size, |
| ] |
| action_begin, action_end, _ = self._action_vocab_range() |
| neg_inf = torch.finfo(full_logits.dtype).min |
| if action_begin > 0: |
| full_logits[..., :action_begin] = neg_inf |
| if action_end < full_logits.shape[-1]: |
| full_logits[..., action_end:] = neg_inf |
|
|
| last_hidden_states = language_model_output.hidden_states[-1] |
| actions_hidden_states = last_hidden_states[ |
| :, |
| NUM_PATCHES + NUM_PROMPT_TOKENS : NUM_PATCHES + NUM_PROMPT_TOKENS + ACTION_DIM * NUM_ACTIONS_CHUNK, |
| :, |
| ] |
|
|
| return full_logits, actions_hidden_states |
|
|
| mask_token_id = self.mask_token_id |
| masked_input_ids = torch.where( |
| all_actions_mask, torch.tensor(mask_token_id, device=input_ids.device), input_ids |
| ) |
|
|
| cur_seqs = masked_input_ids[ |
| :, 1 + NUM_PROMPT_TOKENS : 1 + NUM_PROMPT_TOKENS + ACTION_DIM * NUM_ACTIONS_CHUNK |
| ] |
| init_snapshot = None |
| if return_debug and dfm_debug_level >= 2: |
| init_snapshot = cur_seqs[0, :16].detach().cpu().tolist() |
|
|
| clamp_values = None |
| if dfm_clamp_values is not None: |
| clamp_values = dfm_clamp_values.to(cur_seqs.device) |
| if clamp_values.dim() == 1: |
| clamp_values = clamp_values.unsqueeze(0).expand(cur_seqs.size(0), -1) |
| if clamp_values.shape != cur_seqs.shape: |
| raise ValueError( |
| f"dfm_clamp_values must have shape {tuple(cur_seqs.shape)}, got {tuple(clamp_values.shape)}" |
| ) |
|
|
| clamp_mask = None |
| if isinstance(dfm_clamp_mask, torch.Tensor): |
| clamp_mask = dfm_clamp_mask.to(cur_seqs.device) |
| elif clamp_values is not None: |
| clamp_mask = clamp_values != mask_token_id |
| elif dfm_clamp_mask: |
| clamp_mask = cur_seqs != mask_token_id |
|
|
| final_ids, actions_hidden_states, dfm_stats = dfm_decode( |
| init_ids=cur_seqs, |
| tokens_to_logits=tokens_to_logits, |
| mask_token_id=mask_token_id, |
| num_steps=dfm_num_steps, |
| maskgit_num_steps=dfm_maskgit_num_steps, |
| maskgit_schedule=dfm_maskgit_schedule, |
| schedule=dfm_schedule, |
| temperature=dfm_temperature, |
| temperature_anneal=dfm_temperature_anneal, |
| adaptive_step=dfm_adaptive_step, |
| step_min=dfm_step_min, |
| step_max=dfm_step_max, |
| time_eps=dfm_time_eps, |
| early_exit=dfm_early_exit, |
| early_exit_frac=dfm_early_exit_frac, |
| corrector=dfm_corrector, |
| corrector_iters=dfm_corrector_iters, |
| corrector_remask_frac=dfm_corrector_remask_frac, |
| clamp_mask=clamp_mask, |
| clamp_values=clamp_values, |
| debug_level=dfm_debug_level if return_debug else 0, |
| decode_mode=dfm_decode_mode, |
| log_mask_stats=dfm_log_mask_stats, |
| dfm_divfree_eta=dfm_divfree_eta, |
| dfm_divfree_design=dfm_divfree_design, |
| dfm_divfree_allow_reopen=dfm_divfree_allow_reopen, |
| dfm_divfree_debug=dfm_divfree_debug, |
| dfm_divfree_jump_tol=dfm_divfree_jump_tol, |
| ) |
| |
| action_begin, action_end, n_bins = self._action_vocab_range() |
| in_action = (final_ids >= action_begin) & (final_ids < action_end) |
| dfm_stats["dfm_in_action_frac_final"] = in_action.float().mean().item() |
| self.last_dfm_stats = dfm_stats |
| if return_debug: |
| self.last_dfm_debug = None |
|
|
| predicted_action_token_ids = final_ids.cpu().numpy() |
| discretized_actions = action_end - predicted_action_token_ids |
| discretized_actions = np.clip(discretized_actions - 1, a_min=0, a_max=self.bin_centers.shape[0] - 1) |
| normalized_actions = self.bin_centers[discretized_actions] |
| normalized_actions = normalized_actions.reshape(NUM_ACTIONS_CHUNK, ACTION_DIM) |
|
|
| debug = None |
| if return_debug: |
| action_start = 1 + NUM_PROMPT_TOKENS |
| action_span_end = action_start + ACTION_DIM * NUM_ACTIONS_CHUNK |
| prefix = input_ids[:, :action_start] |
| suffix = input_ids[:, action_span_end:] |
| full_seq = torch.cat([prefix, final_ids, suffix], dim=1) |
| action_span_mask = torch.zeros_like(full_seq, dtype=torch.bool) |
| action_span_mask[:, action_start:action_span_end] = True |
| changed_off_action = (full_seq != input_ids) & (~action_span_mask) |
| changed_off_action_count = int(changed_off_action.sum().item()) |
| stop_token_corrupted = bool((full_seq[:, -1] != input_ids[:, -1]).any().item()) |
|
|
| |
| actions_tensor = torch.as_tensor(normalized_actions) |
| nan_frac = torch.isnan(actions_tensor).float().mean().item() |
| actions_safe = torch.nan_to_num(actions_tensor, nan=0.0) |
| action_min = actions_safe.min().item() |
| action_max = actions_safe.max().item() |
| action_mean = actions_safe.mean().item() |
| action_std = actions_safe.std().item() |
| clip_frac = ((actions_tensor <= -1.0) | (actions_tensor >= 1.0)).float().mean().item() |
| debug = { |
| "mask_token_id": int(mask_token_id), |
| "mask_frac_final": dfm_stats.get("dfm_mask_frac_final"), |
| "valid_action_frac_final": dfm_stats.get("dfm_in_action_frac_final"), |
| "num_unique_action_tokens": int(final_ids.unique().numel()), |
| "action_stats": { |
| "min": action_min, |
| "max": action_max, |
| "mean": action_mean, |
| "std": action_std, |
| "nan_frac": nan_frac, |
| "clip_frac": clip_frac, |
| }, |
| "dfm_stats": dfm_stats, |
| "changed_off_action_count": changed_off_action_count, |
| "stop_token_corrupted": stop_token_corrupted, |
| "action_vocab_range": { |
| "low": int(action_begin), |
| "high": int(action_end), |
| "n_bins": int(n_bins), |
| }, |
| } |
| if dfm_debug_level >= 2: |
| debug["token_snapshot"] = { |
| "init": init_snapshot, |
| "final": final_ids[0, :16].detach().cpu().tolist(), |
| } |
| self.last_dfm_debug = debug |
|
|
| if return_debug: |
| return normalized_actions, actions_hidden_states, debug |
| return normalized_actions, actions_hidden_states |
|
|
| def predict_action( |
| self, |
| input_ids: Optional[torch.LongTensor] = None, |
| unnorm_key: Optional[str] = None, |
| proprio=None, |
| proprio_projector=None, |
| action_head=None, |
| noisy_action_projector=None, |
| use_film: bool = False, |
| use_discrete_diffusion: bool = False, |
| use_discrete_flow_matching: bool = False, |
| dfm_num_steps: int = 12, |
| dfm_maskgit_num_steps: int = 12, |
| dfm_maskgit_schedule: str = "cosine", |
| dfm_schedule: str = "cosine", |
| dfm_temperature: float = 1.0, |
| dfm_temperature_anneal: str = "none", |
| dfm_adaptive_step: bool = True, |
| dfm_step_min: float = 1e-4, |
| dfm_step_max: float = 0.2, |
| dfm_time_eps: float = 1e-3, |
| dfm_early_exit: bool = True, |
| dfm_early_exit_frac: float = 0.0, |
| dfm_corrector: bool = False, |
| dfm_corrector_iters: int = 1, |
| dfm_corrector_remask_frac: float = 0.1, |
| dfm_clamp_mask: bool = False, |
| dfm_clamp_values: Optional[torch.LongTensor] = None, |
| return_debug: bool = False, |
| dfm_debug_level: int = 1, |
| dfm_decode_mode: str = "ctmc", |
| dfm_log_mask_stats: bool = False, |
| dfm_divfree_eta: float = 0.0, |
| dfm_divfree_design: str = "general", |
| dfm_divfree_allow_reopen: bool = True, |
| dfm_divfree_debug: bool = False, |
| dfm_divfree_jump_tol: float = 1e-8, |
| **kwargs: str, |
| ) -> np.ndarray: |
| """Predict actions from input sequence, with options for different prediction methods. |
| |
| Args: |
| input_ids: Input token ids |
| unnorm_key: Key for unnormalization statistics |
| proprio: Proprioceptive features |
| proprio_projector: Projector for proprioceptive features |
| action_head: Optional head for L1 regression or diffusion-based prediction |
| noisy_action_projector: Projector for noisy actions in diffusion-based prediction |
| use_film: Whether to use FiLM conditioning |
| use_discrete_diffusion: Whether to use discrete diffusion for action prediction |
| **kwargs: Additional arguments including pixel_values and attention_mask |
| |
| Returns: |
| Tuple of (unnormalized_actions, action_hidden_states) |
| """ |
| |
| |
| if not torch.all(input_ids[:, -1] == 29871): |
| input_ids = torch.cat( |
| (input_ids, torch.unsqueeze(torch.Tensor([29871]).long(), dim=0).to(input_ids.device)), dim=1 |
| ) |
|
|
| pixel_values = kwargs["pixel_values"] |
| attention_mask = kwargs["attention_mask"] |
|
|
| |
| labels = input_ids.clone() |
| labels[:] = IGNORE_INDEX |
|
|
| if use_discrete_diffusion and use_discrete_flow_matching: |
| raise ValueError("Cannot enable both discrete diffusion and discrete flow matching.") |
|
|
| |
| NUM_PROMPT_TOKENS = input_ids.shape[-1] - 1 |
|
|
| |
| input_ids, attention_mask = self._prepare_input_for_action_prediction(input_ids, attention_mask) |
|
|
| |
| labels = self._prepare_labels_for_action_prediction(labels, input_ids) |
|
|
| |
| input_embeddings = self.get_input_embeddings()(input_ids) |
| all_actions_mask = self._process_action_masks(labels) |
|
|
| |
| language_embeddings = None |
| if use_film: |
| language_embeddings = self._compute_language_embeddings( |
| input_embeddings, attention_mask, all_actions_mask |
| ) |
|
|
| |
| projected_patch_embeddings = self._process_vision_features(pixel_values, language_embeddings, use_film) |
|
|
| |
| use_proprio = proprio_projector is not None and proprio is not None |
| if use_proprio: |
| proprio = torch.Tensor(proprio).to(projected_patch_embeddings.device, dtype=projected_patch_embeddings.dtype) |
| projected_patch_embeddings = self._process_proprio_features( |
| projected_patch_embeddings, proprio, proprio_projector |
| ) |
|
|
| |
| use_diffusion = noisy_action_projector is not None and hasattr(action_head, "noise_scheduler") |
|
|
| |
| NUM_PATCHES = self.vision_backbone.get_num_patches() * self.vision_backbone.get_num_images_in_input() |
| if use_proprio: |
| NUM_PATCHES += 1 |
| if use_diffusion: |
| NUM_PATCHES += 1 |
|
|
| debug = None |
| if use_diffusion: |
| assert use_discrete_diffusion is False, "Discrete diffusion has not been supported in this method!" |
| assert use_discrete_flow_matching is False, "DFM is not supported with diffusion action head!" |
| |
| noise = torch.randn( |
| size=(1, NUM_ACTIONS_CHUNK, ACTION_DIM), device=input_embeddings.device, dtype=input_embeddings.dtype |
| ) |
|
|
| |
| normalized_actions, actions_hidden_states = self._run_diffusion_prediction( |
| input_embeddings, |
| all_actions_mask, |
| noise, |
| action_head, |
| projected_patch_embeddings, |
| labels, |
| attention_mask, |
| NUM_PATCHES, |
| NUM_PROMPT_TOKENS, |
| noisy_action_projector, |
| ) |
| else: |
| if use_discrete_flow_matching: |
| dfm_result = self._discrete_flow_matching_prediction( |
| input_embeddings, |
| all_actions_mask, |
| projected_patch_embeddings, |
| attention_mask, |
| labels, |
| NUM_PATCHES, |
| NUM_PROMPT_TOKENS, |
| action_head, |
| input_ids=input_ids, |
| dfm_num_steps=dfm_num_steps, |
| dfm_maskgit_num_steps=dfm_maskgit_num_steps, |
| dfm_maskgit_schedule=dfm_maskgit_schedule, |
| dfm_schedule=dfm_schedule, |
| dfm_temperature=dfm_temperature, |
| dfm_temperature_anneal=dfm_temperature_anneal, |
| dfm_adaptive_step=dfm_adaptive_step, |
| dfm_step_min=dfm_step_min, |
| dfm_step_max=dfm_step_max, |
| dfm_time_eps=dfm_time_eps, |
| dfm_early_exit=dfm_early_exit, |
| dfm_early_exit_frac=dfm_early_exit_frac, |
| dfm_corrector=dfm_corrector, |
| dfm_corrector_iters=dfm_corrector_iters, |
| dfm_corrector_remask_frac=dfm_corrector_remask_frac, |
| dfm_clamp_mask=dfm_clamp_mask, |
| dfm_clamp_values=dfm_clamp_values, |
| return_debug=return_debug, |
| dfm_debug_level=dfm_debug_level, |
| dfm_decode_mode=dfm_decode_mode, |
| dfm_log_mask_stats=dfm_log_mask_stats, |
| dfm_divfree_eta=dfm_divfree_eta, |
| dfm_divfree_design=dfm_divfree_design, |
| dfm_divfree_allow_reopen=dfm_divfree_allow_reopen, |
| dfm_divfree_debug=dfm_divfree_debug, |
| dfm_divfree_jump_tol=dfm_divfree_jump_tol, |
| ) |
| if return_debug: |
| normalized_actions, actions_hidden_states, debug = dfm_result |
| else: |
| normalized_actions, actions_hidden_states = dfm_result |
| elif use_discrete_diffusion: |
| normalized_actions, actions_hidden_states = self._discrete_diffusion_prediction( |
| input_embeddings, |
| all_actions_mask, |
| projected_patch_embeddings, |
| attention_mask, |
| labels, |
| NUM_PATCHES, |
| NUM_PROMPT_TOKENS, |
| action_head, |
| input_ids=input_ids, |
| ) |
| else: |
| |
| normalized_actions, actions_hidden_states = self._regression_or_discrete_prediction( |
| input_embeddings, |
| all_actions_mask, |
| projected_patch_embeddings, |
| attention_mask, |
| labels, |
| NUM_PATCHES, |
| NUM_PROMPT_TOKENS, |
| action_head, |
| ) |
|
|
| |
| actions = self._unnormalize_actions(normalized_actions, unnorm_key) |
|
|
| if return_debug: |
| return actions, actions_hidden_states, debug |
| return actions, actions_hidden_states |
|
|
| @staticmethod |
| def _check_unnorm_key(norm_stats: Dict[str, Dict[str, Any]], unnorm_key: Optional[str]) -> str: |
| """Validate and resolve the unnormalization key for action statistics""" |
| if unnorm_key is None: |
| assert len(norm_stats) == 1, ( |
| f"Your model was trained on more than one dataset, " |
| f"please pass a `unnorm_key` from the following options to choose the statistics " |
| f"used for un-normalizing actions: {norm_stats.keys()}" |
| ) |
| unnorm_key = next(iter(norm_stats.keys())) |
|
|
| assert unnorm_key in norm_stats, ( |
| f"The `unnorm_key` you chose is not in the set of available dataset statistics, " |
| f"please choose from: {norm_stats.keys()}" |
| ) |
| return unnorm_key |
|
|
| def get_action_dim(self, unnorm_key: Optional[str] = None) -> int: |
| """Get the dimensionality of the policy's action space.""" |
| unnorm_key = self._check_unnorm_key(self.norm_stats, unnorm_key) |
| return len(self.norm_stats[unnorm_key]["action"]["min"]) |
|
|
| def get_action_stats(self, unnorm_key: Optional[str] = None) -> Dict[str, Any]: |
| """Get all the logged statistics for the given dataset.""" |
| unnorm_key = self._check_unnorm_key(self.norm_stats, unnorm_key) |
| return self.norm_stats[unnorm_key]["action"] |
|
|
|
|
| class DiscreteDiffusionForActionPrediction(PrismaticForConditionalGeneration): |
| config_class: PretrainedConfig = OpenVLAConfig |
|
|
| def __init__(self, config: OpenVLAConfig) -> None: |
| super().__init__(config) |
| self.norm_stats = config.norm_stats |
|
|
| |
| legacy_bins = getattr(config, "legacy_train_mode", False) or getattr(config, "legacy_eval_mode", False) |
| if legacy_bins: |
| self.bins = np.linspace(-1, 1, config.n_action_bins) |
| else: |
| self.bins = np.linspace(-1, 1, config.n_action_bins + 1) |
| self.bin_centers = (self.bins[:-1] + self.bins[1:]) / 2.0 |
|
|
| |
| self.vocab_size = self.config.text_config.vocab_size - self.config.pad_to_multiple_of |
|
|
| def _prepare_input_for_action_prediction(self, input_ids, attention_mask): |
| """Prepares input for action prediction by adding necessary tokens""" |
| |
| placeholder_action_token_ids = ( |
| torch.ones((input_ids.shape[0], ACTION_DIM * NUM_ACTIONS_CHUNK)).to(input_ids.device).to(input_ids.dtype) |
| ) |
| input_ids = torch.cat([input_ids, placeholder_action_token_ids], dim=-1) |
|
|
| |
| stop_token_id = torch.ones((input_ids.shape[0], 1)).to(input_ids.device).to(input_ids.dtype) * STOP_INDEX |
| input_ids = torch.cat([input_ids, stop_token_id], dim=-1) |
|
|
| |
| |
| mask_extension = ( |
| torch.ones((attention_mask.shape[0], input_ids.shape[-1] - attention_mask.shape[-1])) |
| .to(attention_mask.device) |
| .to(attention_mask.dtype) |
| ) |
| attention_mask = torch.cat([attention_mask, mask_extension], dim=-1) |
|
|
| return input_ids, attention_mask |
|
|
| def _prepare_labels_for_action_prediction(self, labels, input_ids): |
| """Creates labels tensor for action prediction if not provided""" |
| |
| if getattr(self.config, "legacy_eval_mode", False) or getattr(self.config, "legacy_train_mode", False): |
| ARBITRARY_ACTION_TOKEN_IDX = ACTION_TOKEN_BEGIN_IDX + 1 |
| else: |
| action_begin, _, _ = self._action_vocab_range() |
| ARBITRARY_ACTION_TOKEN_IDX = action_begin |
| labels_extension = ( |
| torch.ones((labels.shape[0], input_ids.shape[-1] - labels.shape[-1])).to(labels.device).to(labels.dtype) |
| * ARBITRARY_ACTION_TOKEN_IDX |
| ) |
| labels = torch.cat([labels, labels_extension], dim=-1) |
|
|
| |
| labels[:, -1] = STOP_INDEX |
|
|
| return labels |
|
|
| def _unnormalize_actions(self, normalized_actions, unnorm_key=None): |
| """Unnormalize actions using dataset statistics""" |
| action_norm_stats = self.get_action_stats(unnorm_key) |
|
|
| if ACTION_PROPRIO_NORMALIZATION_TYPE == NormalizationType.BOUNDS: |
| mask = action_norm_stats.get("mask", np.ones_like(action_norm_stats["min"], dtype=bool)) |
| action_high, action_low = np.array(action_norm_stats["max"]), np.array(action_norm_stats["min"]) |
| elif ACTION_PROPRIO_NORMALIZATION_TYPE == NormalizationType.BOUNDS_Q99: |
| mask = action_norm_stats.get("mask", np.ones_like(action_norm_stats["q01"], dtype=bool)) |
| action_high, action_low = np.array(action_norm_stats["q99"]), np.array(action_norm_stats["q01"]) |
| else: |
| raise ValueError("Unsupported action/proprio normalization type detected!") |
|
|
| actions = np.where( |
| mask, |
| 0.5 * (normalized_actions + 1) * (action_high - action_low + 1e-8) + action_low, |
| normalized_actions, |
| ) |
|
|
| return actions |
|
|
| def _run_diffusion_prediction( |
| self, |
| input_embeddings, |
| all_actions_mask, |
| noise, |
| action_head, |
| projected_patch_embeddings, |
| labels, |
| attention_mask, |
| NUM_PATCHES, |
| NUM_PROMPT_TOKENS, |
| noisy_action_projector, |
| ): |
| """Run diffusion-based action prediction""" |
| |
| orig_projected_patch_embeddings = projected_patch_embeddings.clone() |
| curr_noisy_actions = noise |
|
|
| |
| for t in action_head.noise_scheduler.timesteps: |
| |
| |
| timesteps = torch.Tensor([t]).to(labels.device) |
| diffusion_timestep_embeddings = ( |
| action_head.time_encoder(timesteps).to(curr_noisy_actions.dtype).to(curr_noisy_actions.device) |
| ) |
| diffusion_timestep_embeddings = diffusion_timestep_embeddings.unsqueeze(1) |
|
|
| |
| |
|
|
| |
| projected_patch_embeddings = torch.cat( |
| (orig_projected_patch_embeddings, diffusion_timestep_embeddings), dim=1 |
| ) |
|
|
| |
| B = curr_noisy_actions.shape[0] |
| orig_curr_noisy_actions_shape = curr_noisy_actions.shape |
| curr_noisy_actions = curr_noisy_actions.reshape(B, -1).unsqueeze(-1) |
| noisy_action_features = noisy_action_projector(curr_noisy_actions) |
| curr_noisy_actions = curr_noisy_actions.reshape(orig_curr_noisy_actions_shape) |
|
|
| |
| input_embeddings = self._replace_input_embeddings( |
| input_embeddings.clone(), all_actions_mask, noisy_action_features |
| ) |
|
|
| |
| multimodal_embeddings, multimodal_attention_mask = self._build_multimodal_attention( |
| input_embeddings, projected_patch_embeddings, attention_mask |
| ) |
|
|
| |
| language_model_output = self.language_model( |
| input_ids=None, |
| attention_mask=multimodal_attention_mask, |
| position_ids=None, |
| past_key_values=None, |
| inputs_embeds=multimodal_embeddings, |
| labels=None, |
| use_cache=None, |
| output_attentions=False, |
| output_hidden_states=True, |
| return_dict=True, |
| ) |
|
|
| |
| last_hidden_states = language_model_output.hidden_states[-1] |
| actions_hidden_states = last_hidden_states[ |
| :, |
| NUM_PATCHES + NUM_PROMPT_TOKENS : NUM_PATCHES + NUM_PROMPT_TOKENS + ACTION_DIM * NUM_ACTIONS_CHUNK, |
| :, |
| ] |
|
|
| |
| noise_pred = action_head.predict_noise(actions_hidden_states) |
| curr_noisy_actions = action_head.noise_scheduler.step(noise_pred, t, curr_noisy_actions).prev_sample |
|
|
| curr_noisy_actions = curr_noisy_actions.reshape(NUM_ACTIONS_CHUNK, ACTION_DIM) |
|
|
| |
| return curr_noisy_actions.float().cpu().detach().numpy(), actions_hidden_states |
|
|
| def _regression_or_discrete_prediction( |
| self, |
| input_embeddings, |
| all_actions_mask, |
| projected_patch_embeddings, |
| attention_mask, |
| labels, |
| NUM_PATCHES, |
| NUM_PROMPT_TOKENS, |
| action_head=None, |
| ): |
| """Run L1 regression-based continuous action prediction or discrete action tokens prediction.""" |
| |
| all_actions_mask = all_actions_mask.unsqueeze(-1) |
| input_embeddings = input_embeddings * ~all_actions_mask |
|
|
| |
| multimodal_embeddings, multimodal_attention_mask = self._build_multimodal_attention( |
| input_embeddings, projected_patch_embeddings, attention_mask |
| ) |
|
|
| |
| language_model_output = self.language_model( |
| input_ids=None, |
| attention_mask=multimodal_attention_mask, |
| position_ids=None, |
| past_key_values=None, |
| inputs_embeds=multimodal_embeddings, |
| labels=None, |
| use_cache=None, |
| output_attentions=False, |
| output_hidden_states=True, |
| return_dict=True, |
| ) |
|
|
| |
| last_hidden_states = language_model_output.hidden_states[-1] |
| actions_hidden_states = last_hidden_states[ |
| :, |
| NUM_PATCHES + NUM_PROMPT_TOKENS : NUM_PATCHES + NUM_PROMPT_TOKENS + ACTION_DIM * NUM_ACTIONS_CHUNK, |
| :, |
| ] |
|
|
| |
| if action_head is not None: |
| |
| normalized_actions = action_head.predict_action(actions_hidden_states) |
| normalized_actions = normalized_actions.reshape(NUM_ACTIONS_CHUNK, ACTION_DIM) |
| normalized_actions = normalized_actions.float().cpu().detach().numpy() |
| else: |
| |
| predicted_action_token_ids = ( |
| language_model_output.logits[ |
| :, |
| NUM_PATCHES + NUM_PROMPT_TOKENS : NUM_PATCHES + NUM_PROMPT_TOKENS + ACTION_DIM * NUM_ACTIONS_CHUNK, |
| ] |
| .argmax(dim=2) |
| .cpu() |
| .numpy() |
| ) |
| action_begin, action_end, _ = self._action_vocab_range() |
| discretized_actions = action_end - predicted_action_token_ids |
| discretized_actions = np.clip(discretized_actions - 1, a_min=0, a_max=self.bin_centers.shape[0] - 1) |
| normalized_actions = self.bin_centers[discretized_actions] |
| normalized_actions = normalized_actions.reshape(NUM_ACTIONS_CHUNK, ACTION_DIM) |
|
|
| return normalized_actions, actions_hidden_states |
|
|
| def predict_action( |
| self, |
| input_ids: Optional[torch.LongTensor] = None, |
| unnorm_key: Optional[str] = None, |
| proprio=None, |
| proprio_projector=None, |
| action_head=None, |
| noisy_action_projector=None, |
| use_film: bool = False, |
| **kwargs: str, |
| ) -> np.ndarray: |
| """Predict actions from input sequence, with options for different prediction methods. |
| |
| Args: |
| input_ids: Input token ids |
| unnorm_key: Key for unnormalization statistics |
| proprio: Proprioceptive features |
| proprio_projector: Projector for proprioceptive features |
| action_head: Optional head for L1 regression or diffusion-based prediction |
| noisy_action_projector: Projector for noisy actions in diffusion-based prediction |
| use_film: Whether to use FiLM conditioning |
| **kwargs: Additional arguments including pixel_values and attention_mask |
| |
| Returns: |
| Tuple of (unnormalized_actions, action_hidden_states) |
| """ |
| |
| |
| if not torch.all(input_ids[:, -1] == 29871): |
| input_ids = torch.cat( |
| (input_ids, torch.unsqueeze(torch.Tensor([29871]).long(), dim=0).to(input_ids.device)), dim=1 |
| ) |
|
|
| pixel_values = kwargs["pixel_values"] |
| attention_mask = kwargs["attention_mask"] |
|
|
| |
| labels = input_ids.clone() |
| labels[:] = IGNORE_INDEX |
|
|
| |
| NUM_PROMPT_TOKENS = input_ids.shape[-1] - 1 |
|
|
| |
| input_ids, attention_mask = self._prepare_input_for_action_prediction(input_ids, attention_mask) |
|
|
| |
| labels = self._prepare_labels_for_action_prediction(labels, input_ids) |
|
|
| |
| input_embeddings = self.get_input_embeddings()(input_ids) |
| all_actions_mask = self._process_action_masks(labels) |
|
|
| |
| language_embeddings = None |
| if use_film: |
| language_embeddings = self._compute_language_embeddings( |
| input_embeddings, attention_mask, all_actions_mask |
| ) |
|
|
| |
| projected_patch_embeddings = self._process_vision_features(pixel_values, language_embeddings, use_film) |
|
|
| |
| use_proprio = proprio_projector is not None and proprio is not None |
| if use_proprio: |
| proprio = torch.Tensor(proprio).to(projected_patch_embeddings.device, dtype=projected_patch_embeddings.dtype) |
| projected_patch_embeddings = self._process_proprio_features( |
| projected_patch_embeddings, proprio, proprio_projector |
| ) |
|
|
| |
| use_diffusion = noisy_action_projector is not None and hasattr(action_head, "noise_scheduler") |
|
|
| |
| NUM_PATCHES = self.vision_backbone.get_num_patches() * self.vision_backbone.get_num_images_in_input() |
| if use_proprio: |
| NUM_PATCHES += 1 |
| if use_diffusion: |
| NUM_PATCHES += 1 |
|
|
| if use_diffusion: |
| |
| noise = torch.randn( |
| size=(1, NUM_ACTIONS_CHUNK, ACTION_DIM), device=input_embeddings.device, dtype=input_embeddings.dtype |
| ) |
|
|
| |
| normalized_actions, actions_hidden_states = self._run_diffusion_prediction( |
| input_embeddings, |
| all_actions_mask, |
| noise, |
| action_head, |
| projected_patch_embeddings, |
| labels, |
| attention_mask, |
| NUM_PATCHES, |
| NUM_PROMPT_TOKENS, |
| noisy_action_projector, |
| ) |
| else: |
| |
| normalized_actions, actions_hidden_states = self._regression_or_discrete_prediction( |
| input_embeddings, |
| all_actions_mask, |
| projected_patch_embeddings, |
| attention_mask, |
| labels, |
| NUM_PATCHES, |
| NUM_PROMPT_TOKENS, |
| action_head, |
| ) |
|
|
| |
| actions = self._unnormalize_actions(normalized_actions, unnorm_key) |
|
|
| return actions, actions_hidden_states |
|
|
| @staticmethod |
| def _check_unnorm_key(norm_stats: Dict[str, Dict[str, Any]], unnorm_key: Optional[str]) -> str: |
| """Validate and resolve the unnormalization key for action statistics""" |
| if unnorm_key is None: |
| assert len(norm_stats) == 1, ( |
| f"Your model was trained on more than one dataset, " |
| f"please pass a `unnorm_key` from the following options to choose the statistics " |
| f"used for un-normalizing actions: {norm_stats.keys()}" |
| ) |
| unnorm_key = next(iter(norm_stats.keys())) |
|
|
| assert unnorm_key in norm_stats, ( |
| f"The `unnorm_key` you chose is not in the set of available dataset statistics, " |
| f"please choose from: {norm_stats.keys()}" |
| ) |
| return unnorm_key |
|
|
| def get_action_dim(self, unnorm_key: Optional[str] = None) -> int: |
| """Get the dimensionality of the policy's action space.""" |
| unnorm_key = self._check_unnorm_key(self.norm_stats, unnorm_key) |
| return len(self.norm_stats[unnorm_key]["action"]["min"]) |
|
|
| def get_action_stats(self, unnorm_key: Optional[str] = None) -> Dict[str, Any]: |
| """Get all the logged statistics for the given dataset.""" |
| unnorm_key = self._check_unnorm_key(self.norm_stats, unnorm_key) |
| return self.norm_stats[unnorm_key]["action"] |
|
|