| """ |
| modeling_prismatic.py |
| |
| Core HuggingFace-style PrismaticPreTrainedModel and PrismaticForConditionalGeneration class definitions. |
| Inherits from the default `transformers.PretrainedModel`. Meant to be standalone and self-contained, |
| but exactly replicate the logic in `prismatic.models.vlms.prismatic.py`. |
| """ |
|
|
| import os |
| import logging |
| import math |
| from dataclasses import dataclass |
| from functools import partial |
| from typing import Any, Callable, ClassVar, Dict, List, Optional, Tuple, Union |
|
|
| import numpy as np |
| import timm |
| import tokenizers |
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| import transformers |
| from timm.models.vision_transformer import LayerScale |
| from transformers import PretrainedConfig, PreTrainedModel |
| from transformers.cache_utils import Cache |
| from transformers.modeling_outputs import ModelOutput, BaseModelOutputWithPast, CausalLMOutputWithPast |
| from transformers.models.llama import LlamaForCausalLM, LlamaModel, LlamaPreTrainedModel |
|
|
| from prismatic.training.train_utils import ( |
| get_current_action_mask, |
| get_next_actions_mask, |
| ) |
| from prismatic.vla.constants import ( |
| ACTION_DIM, |
| ACTION_PROPRIO_NORMALIZATION_TYPE, |
| ACTION_TOKEN_BEGIN_IDX, |
| IGNORE_INDEX, |
| NUM_ACTIONS_CHUNK, |
| STOP_INDEX, |
| NormalizationType, |
| ) |
|
|
| from .configuration_prismatic import OpenVLAConfig, PrismaticConfig |
|
|
| |
| logger = logging.getLogger(__name__) |
|
|
|
|
| class TokenPruner(nn.Module): |
| def __init__( |
| self, |
| config, |
| num_patches, |
| ): |
| super().__init__() |
| self.num_patches = num_patches |
| self.noise_scale = None |
| self.scale_factor = 1 / math.sqrt(config.hidden_size) |
| |
| self.disabled: bool = getattr(config, "prune_disabled", False) |
|
|
| |
| self.selection_strategy = getattr(config, "prune_selection_strategy", "coverage") |
| self.coverage_temperature = getattr(config, "prune_temperature", 0.1) |
| self.coverage_target = getattr(config, "prune_target_coverage", 0.9) |
| self.min_keep = getattr(config, "prune_min_keep", 64) |
| self.max_keep = getattr(config, "prune_max_keep", None) |
| keep_bins = getattr(config, "prune_keep_bins", (64, 96, 128, 160, 192)) |
| self.keep_bins = tuple(keep_bins) if keep_bins is not None else None |
| self.top_k = getattr(config, "prune_top_k", None) |
| self.debug = getattr(config, "prune_debug", False) |
| self.debug_max_logs = getattr(config, "prune_debug_max_logs", 20) |
| self._debug_counter = 0 |
| self._last_keep_counts: Optional[torch.Tensor] = None |
|
|
| |
| self._coverage_eps = 1e-6 |
|
|
| |
| |
| |
| |
| self.prompt_aggregation = getattr(config, "prune_prompt_aggregation", "logsumexp") |
| self.logsumexp_temperature = getattr(config, "prune_logsumexp_temperature", 1.0) |
|
|
| |
| |
| |
| |
| self.soft_rescale_mean_preserve = getattr( |
| config, "prune_soft_rescale_mean_preserve", False |
| ) |
| self.soft_rescale_clip = getattr(config, "prune_soft_rescale_clip", None) |
|
|
| |
| |
| |
| self.train_use_st_topk: bool = getattr(config, "prune_train_use_st_topk", False) |
| self.train_gumbel_tau: float = getattr(config, "prune_train_gumbel_tau", 1.0) |
| self.train_gumbel_tau_min: Optional[float] = getattr(config, "prune_train_gumbel_tau_min", None) |
|
|
| def set_noise_scale(self, noise_scale): |
| self.noise_scale = noise_scale |
| |
| def set_coverage_target(self, coverage_target): |
| """Dynamically set the coverage target for token pruning""" |
| self.coverage_target = coverage_target |
|
|
| def set_disabled(self, disabled: bool): |
| """Enable/disable pruning and gating entirely (keep all tokens as-is).""" |
| self.disabled = bool(disabled) |
|
|
| def set_train_use_st_topk(self, use_st_topk: bool): |
| self.train_use_st_topk = bool(use_st_topk) |
|
|
| def set_train_gumbel_tau(self, tau: float): |
| self.train_gumbel_tau = float(tau) |
|
|
| def rms_norm(self, hidden_states, eps=1e-6): |
| input_dtype = hidden_states.dtype |
| hidden_states = hidden_states.to(torch.float32) |
| variance = hidden_states.pow(2).mean(-1, keepdim=True) |
| hidden_states = hidden_states * torch.rsqrt(variance + eps) |
| return hidden_states.to(input_dtype) |
|
|
| def get_score( |
| self, |
| patches, |
| prompts, |
| q_proj_weight, |
| q_proj_bias, |
| k_proj_weight, |
| k_proj_bias, |
| num_heads, |
| prompt_mask=None, |
| ): |
| patches = self.rms_norm(patches) |
| prompts = self.rms_norm(prompts) |
|
|
| |
| queries = F.linear(patches, q_proj_weight, q_proj_bias) |
| keys = F.linear(prompts, k_proj_weight, k_proj_bias) |
|
|
| bsz, num_patches, _ = queries.shape |
| _, num_tokens, _ = keys.shape |
|
|
| head_dim = queries.shape[-1] // num_heads |
| queries = queries.view(bsz, num_patches, num_heads, head_dim).permute(0, 2, 1, 3) |
| keys = keys.view(bsz, num_tokens, num_heads, head_dim).permute(0, 2, 1, 3) |
|
|
| attn_logits = torch.matmul(queries, keys.transpose(-1, -2)) / math.sqrt(head_dim) |
|
|
| if prompt_mask is not None: |
| prompt_mask = prompt_mask.bool() |
| expanded_mask = prompt_mask.unsqueeze(1).unsqueeze(2) |
| attn_logits = attn_logits.masked_fill(~expanded_mask, float("-inf")) |
|
|
| |
| if self.prompt_aggregation == "logsumexp": |
| |
| |
| |
| token_scores = torch.logsumexp( |
| attn_logits / self.logsumexp_temperature, dim=-1 |
| ) * self.logsumexp_temperature |
| else: |
| |
| token_scores = attn_logits.max(dim=-1).values |
| |
| |
| score = token_scores.mean(dim=1) |
|
|
| score = torch.where(torch.isfinite(score), score, torch.zeros_like(score)) |
|
|
| return score |
|
|
| def _budgeted_keep_counts(self, score): |
| device = score.device |
| bsz, num_patches = score.shape |
|
|
| |
| if self.selection_strategy == "topk" and self.top_k is not None: |
| k = min(self.top_k, num_patches) |
| keep_counts = torch.full((bsz,), k, device=device, dtype=torch.int64) |
| sorted_indices = score.argsort(dim=-1, descending=True) |
| return keep_counts, sorted_indices |
|
|
| temperature = max(float(self.coverage_temperature), self._coverage_eps) |
| probs = torch.softmax(score / temperature, dim=-1) |
| sorted_probs, sorted_indices = probs.sort(dim=-1, descending=True) |
| cumulative = torch.cumsum(sorted_probs, dim=-1) |
|
|
| target = float(self.coverage_target) |
| keep_counts = (cumulative < target).sum(dim=-1) + 1 |
|
|
| if self.min_keep is not None: |
| keep_counts = torch.maximum( |
| keep_counts, |
| torch.full_like(keep_counts, min(self.min_keep, num_patches)), |
| ) |
|
|
| if self.max_keep is not None: |
| keep_counts = torch.minimum( |
| keep_counts, |
| torch.full_like(keep_counts, min(self.max_keep, num_patches)), |
| ) |
|
|
| keep_counts = torch.clamp(keep_counts, min=1, max=num_patches) |
|
|
| if self.keep_bins: |
| valid_bins = [min(num_patches, int(bin_val)) for bin_val in self.keep_bins if bin_val > 0] |
| if valid_bins: |
| bins = torch.tensor(sorted(set(valid_bins)), device=device, dtype=torch.int64) |
| search_idx = torch.searchsorted(bins, keep_counts, right=False) |
| search_idx = torch.clamp(search_idx, max=bins.numel() - 1) |
| keep_counts = bins[search_idx] |
|
|
| if self.debug: |
| self._last_keep_counts = keep_counts.detach().to("cpu") |
|
|
| return keep_counts, sorted_indices |
|
|
| def score_to_mask(self, score): |
| bsz, num_patches = score.shape |
| mask = torch.zeros(bsz, num_patches, dtype=torch.bool, device=score.device) |
|
|
| keep_counts, sorted_indices = self._budgeted_keep_counts(score) |
|
|
| for batch_idx in range(bsz): |
| k = int(keep_counts[batch_idx].item()) |
| topk_indices = sorted_indices[batch_idx, :k] |
| mask[batch_idx, topk_indices] = True |
|
|
| return mask |
|
|
| def score_to_indices(self, score, patches): |
| if self.noise_scale is not None: |
| score = score + torch.rand_like(score) * self.noise_scale |
| hard_score = F.one_hot(score.argmax(dim=-1), num_classes=self.num_patches) |
| soft_score = torch.softmax(score, dim=-1) |
| score = hard_score + soft_score - soft_score.detach() |
| return score.argmax(dim=-1), score @ patches |
|
|
| def forward( |
| self, |
| tokens, |
| position_ids, |
| attention_mask, |
| q_proj_weight, |
| q_proj_bias, |
| k_proj_weight, |
| k_proj_bias, |
| num_heads, |
| ): |
| |
| if self.disabled: |
| |
| bsz = tokens.shape[0] |
| self._last_keep_counts = ( |
| torch.full((bsz,), self.num_patches, dtype=torch.int64, device=tokens.device) |
| .detach() |
| .to("cpu") |
| ) |
| return tokens, position_ids, attention_mask |
|
|
| bsz, seq_len, dim = tokens.shape |
| cls_token, patches, task = torch.split(tokens, [1, self.num_patches, seq_len-self.num_patches-1], dim=1) |
| cls_token_id, patches_id, task_id = torch.split(position_ids, [1, self.num_patches, seq_len-self.num_patches-1], dim=1) |
| if attention_mask is not None: |
| cls_token_mask, patches_mask, task_mask = torch.split(attention_mask, [1, self.num_patches, seq_len-self.num_patches-1], dim=1) |
|
|
| |
| score = self.get_score( |
| patches, |
| task, |
| q_proj_weight, |
| q_proj_bias, |
| k_proj_weight, |
| k_proj_bias, |
| num_heads, |
| prompt_mask=task_mask if attention_mask is not None else None, |
| ) |
|
|
| if self.training: |
| if self.train_use_st_topk: |
| |
| |
| keep_counts, sorted_indices = self._budgeted_keep_counts(score) |
|
|
| |
| |
| gumbel = -torch.log(-torch.log(torch.rand_like(score).clamp(min=1e-6, max=1.0 - 1e-6))) |
| tau = float(self.train_gumbel_tau if self.train_gumbel_tau is not None else 1.0) |
| tau = max(tau, self._coverage_eps) |
| logits = (score + gumbel) / tau |
| probs = torch.softmax(logits, dim=-1) |
|
|
| |
| bsz, num_patches = score.shape |
| m_hard = torch.zeros_like(score) |
| for b in range(bsz): |
| k = int(keep_counts[b].item()) |
| topk_idx = sorted_indices[b, :k] |
| m_hard[b, topk_idx] = 1.0 |
|
|
| |
| m = m_hard + probs - probs.detach() |
|
|
| |
| if self.soft_rescale_mean_preserve: |
| |
| scale = (torch.ones_like(keep_counts, dtype=patches.dtype) * num_patches) / keep_counts.clamp(min=1).to(patches.dtype) |
| scale = scale.view(-1, 1) |
| if self.soft_rescale_clip is not None: |
| scale = torch.clamp(scale, max=float(self.soft_rescale_clip)) |
| patches = patches * (m.unsqueeze(-1)) * scale.unsqueeze(-1) |
| else: |
| patches = patches * (m.unsqueeze(-1)) |
|
|
| |
| if self.debug: |
| self._last_keep_counts = keep_counts.detach().to("cpu") |
|
|
| tokens = torch.cat([cls_token, patches, task], dim=1) |
| position_ids = torch.cat([cls_token_id, patches_id, task_id], dim=1) |
| if attention_mask is not None: |
| attention_mask = torch.cat([cls_token_mask, patches_mask, task_mask], dim=1) |
| else: |
| |
| weights = torch.softmax( |
| score / max(self.coverage_temperature, self._coverage_eps), dim=-1 |
| ) |
| if self.soft_rescale_mean_preserve: |
| |
| weights = weights * patches.shape[1] |
| if self.soft_rescale_clip is not None: |
| |
| weights = torch.clamp(weights, max=float(self.soft_rescale_clip)) |
| patches = patches * weights.unsqueeze(-1) |
|
|
| |
| if self.debug: |
| keep_counts, _ = self._budgeted_keep_counts(score) |
| self._last_keep_counts = keep_counts.detach().to("cpu") |
| tokens = torch.cat([cls_token, patches, task], dim=1) |
| position_ids = torch.cat([cls_token_id, patches_id, task_id], dim=1) |
| if attention_mask is not None: |
| attention_mask = torch.cat([cls_token_mask, patches_mask, task_mask], dim=1) |
|
|
| else: |
| mask = self.score_to_mask(score) |
|
|
| patches = patches[mask].view(bsz, -1, dim) |
| tokens = torch.cat([cls_token, patches, task], dim=1) |
|
|
| patches_id = patches_id[mask].view(bsz, -1) |
| position_ids = torch.cat([cls_token_id, patches_id, task_id], dim=1) |
|
|
| if attention_mask is not None: |
| patches_mask = patches_mask[mask].view(bsz, -1) |
| attention_mask = torch.cat([cls_token_mask, patches_mask, task_mask], dim=1) |
|
|
| if self.debug and self._debug_counter < self.debug_max_logs: |
| keep_counts = self._last_keep_counts |
| if keep_counts is not None: |
| keep_counts = keep_counts.to(torch.float32) |
| logger.info( |
| "TokenPruner debug | keep_counts min=%.0f max=%.0f mean=%.2f | target=%.2f | temp=%.3f | bins=%s", |
| keep_counts.min().item(), |
| keep_counts.max().item(), |
| keep_counts.mean().item(), |
| float(self.coverage_target), |
| float(self.coverage_temperature), |
| self.keep_bins, |
| ) |
| self._debug_counter += 1 |
|
|
| return tokens, position_ids, attention_mask |
|
|
|
|
| class PrunedLlamaModel(LlamaModel): |
| def __init__(self, config, num_patches): |
| super().__init__(config) |
|
|
| self.pruner = TokenPruner( |
| config, |
| num_patches, |
| ) |
|
|
| def forward( |
| self, |
| input_ids: torch.LongTensor = None, |
| attention_mask: Optional[torch.Tensor] = None, |
| position_ids: Optional[torch.LongTensor] = None, |
| past_key_values: Optional[List[torch.FloatTensor]] = None, |
| inputs_embeds: Optional[torch.FloatTensor] = None, |
| use_cache: Optional[bool] = None, |
| output_attentions: Optional[bool] = None, |
| output_hidden_states: Optional[bool] = None, |
| return_dict: Optional[bool] = None, |
| cache_position: Optional[torch.LongTensor] = None, |
| ) -> Union[Tuple, BaseModelOutputWithPast]: |
| output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions |
| output_hidden_states = ( |
| output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states |
| ) |
| use_cache = use_cache if use_cache is not None else self.config.use_cache |
| return_dict = return_dict if return_dict is not None else self.config.use_return_dict |
|
|
| if (input_ids is None) ^ (inputs_embeds is not None): |
| raise ValueError( |
| "You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one" |
| ) |
|
|
| if self.gradient_checkpointing and self.training and use_cache: |
| logger.warning_once( |
| "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`." |
| ) |
| use_cache = False |
|
|
| if inputs_embeds is None: |
| inputs_embeds = self.embed_tokens(input_ids) |
|
|
| |
| hidden_states = inputs_embeds |
|
|
| |
| all_hidden_states = () if output_hidden_states else None |
| all_self_attns = () if output_attentions else None |
| next_decoder_cache = None |
|
|
| past_seen_tokens = 0 |
|
|
| cache_position = torch.arange( |
| past_seen_tokens, past_seen_tokens + hidden_states.shape[1], device=inputs_embeds.device |
| ) |
|
|
| position_ids = cache_position.unsqueeze(0).expand(hidden_states.shape[0], -1) |
|
|
| first_layer_attn = self.layers[0].self_attn |
| hidden_states, position_ids, attention_mask = self.pruner( |
| hidden_states, |
| position_ids, |
| attention_mask, |
| first_layer_attn.q_proj.weight, |
| first_layer_attn.q_proj.bias, |
| first_layer_attn.k_proj.weight, |
| first_layer_attn.k_proj.bias, |
| first_layer_attn.num_heads, |
| ) |
|
|
| past_seen_tokens = 0 |
|
|
| cache_position = torch.arange( |
| past_seen_tokens, past_seen_tokens + hidden_states.shape[1], device=hidden_states.device |
| ) |
|
|
| causal_mask = self._update_causal_mask(attention_mask, hidden_states, cache_position, past_seen_tokens) |
|
|
| for decoder_layer in self.layers: |
|
|
| if output_hidden_states: |
| all_hidden_states += (hidden_states,) |
|
|
| if self.gradient_checkpointing and self.training: |
| layer_outputs = self._gradient_checkpointing_func( |
| decoder_layer.__call__, |
| hidden_states, |
| causal_mask, |
| position_ids, |
| past_key_values, |
| output_attentions, |
| use_cache, |
| cache_position, |
| ) |
| else: |
| layer_outputs = decoder_layer( |
| hidden_states, |
| attention_mask=causal_mask, |
| position_ids=position_ids, |
| past_key_value=past_key_values, |
| output_attentions=output_attentions, |
| use_cache=use_cache, |
| cache_position=cache_position, |
| ) |
|
|
| hidden_states = layer_outputs[0] |
|
|
| if use_cache: |
| next_decoder_cache = layer_outputs[2 if output_attentions else 1] |
|
|
| if output_attentions: |
| all_self_attns += (layer_outputs[1],) |
|
|
| hidden_states = self.norm(hidden_states) |
|
|
| |
| if output_hidden_states: |
| all_hidden_states += (hidden_states,) |
|
|
| next_cache = None |
| if use_cache: |
| next_cache = ( |
| next_decoder_cache.to_legacy_cache() if isinstance(next_decoder_cache, Cache) else next_decoder_cache |
| ) |
| if not return_dict: |
| return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) |
| return BaseModelOutputWithPast( |
| last_hidden_state=hidden_states, |
| past_key_values=next_cache, |
| hidden_states=all_hidden_states, |
| attentions=all_self_attns, |
| ) |
|
|
|
|
| class PrunedLlamaForCausalLM(LlamaForCausalLM): |
| def __init__(self, config, num_patches): |
| super(LlamaPreTrainedModel, self).__init__(config) |
| self.model = PrunedLlamaModel(config, num_patches) |
| self.vocab_size = config.vocab_size |
| self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) |
|
|
| |
| self.post_init() |
|
|
| def forward( |
| self, |
| input_ids: torch.LongTensor = None, |
| attention_mask: Optional[torch.Tensor] = None, |
| position_ids: Optional[torch.LongTensor] = None, |
| past_key_values: Optional[List[torch.FloatTensor]] = None, |
| inputs_embeds: Optional[torch.FloatTensor] = None, |
| labels: Optional[torch.LongTensor] = None, |
| use_cache: Optional[bool] = None, |
| output_attentions: Optional[bool] = None, |
| output_hidden_states: Optional[bool] = None, |
| return_dict: Optional[bool] = None, |
| cache_position: Optional[torch.LongTensor] = None, |
| ) -> Union[Tuple, CausalLMOutputWithPast]: |
| r""" |
| Args: |
| labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): |
| Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., |
| config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored |
| (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. |
| |
| Returns: |
| |
| Example: |
| |
| ```python |
| >>> from transformers import AutoTokenizer, LlamaForCausalLM |
| |
| >>> model = LlamaForCausalLM.from_pretrained("meta-llama/Llama-2-7b-hf") |
| >>> tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf") |
| |
| >>> prompt = "Hey, are you conscious? Can you talk to me?" |
| >>> inputs = tokenizer(prompt, return_tensors="pt") |
| |
| >>> # Generate |
| >>> generate_ids = model.generate(inputs.input_ids, max_length=30) |
| >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] |
| "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you." |
| ```""" |
| output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions |
| output_hidden_states = ( |
| output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states |
| ) |
| return_dict = return_dict if return_dict is not None else self.config.use_return_dict |
|
|
| |
| outputs = self.model( |
| input_ids=input_ids, |
| attention_mask=attention_mask, |
| position_ids=position_ids, |
| past_key_values=past_key_values, |
| inputs_embeds=inputs_embeds, |
| use_cache=use_cache, |
| output_attentions=output_attentions, |
| output_hidden_states=output_hidden_states, |
| return_dict=return_dict, |
| cache_position=cache_position, |
| ) |
|
|
| hidden_states = outputs[0] |
| if self.config.pretraining_tp > 1: |
| lm_head_slices = self.lm_head.weight.split(self.vocab_size // self.config.pretraining_tp, dim=0) |
| logits = [F.linear(hidden_states, lm_head_slices[i]) for i in range(self.config.pretraining_tp)] |
| logits = torch.cat(logits, dim=-1) |
| else: |
| logits = self.lm_head(hidden_states) |
| logits = logits.float() |
|
|
| loss = None |
| if labels is not None: |
| |
| shift_logits = logits[..., :-1, :].contiguous() |
| shift_labels = labels[..., -shift_logits.shape[-2]:].contiguous() |
| |
| shift_logits = shift_logits.view(-1, self.config.vocab_size) |
| shift_labels = shift_labels.view(-1) |
| |
| shift_labels = shift_labels.to(shift_logits.device) |
| loss = F.cross_entropy(shift_logits, shift_labels) |
|
|
| if not return_dict: |
| output = (logits,) + outputs[1:] |
| return (loss,) + output if loss is not None else output |
|
|
| return CausalLMOutputWithPast( |
| loss=loss, |
| logits=logits, |
| past_key_values=outputs.past_key_values, |
| hidden_states=outputs.hidden_states, |
| attentions=outputs.attentions, |
| ) |
|
|
|
|
| |
| def unpack_tuple(fn: Callable[[Any], Tuple[Any]]) -> Callable[[Any], Any]: |
| def wrapper(*args: Any, **kwargs: Any) -> Any: |
| result = fn(*args, **kwargs) |
| return result[0] if isinstance(result, tuple) else result |
|
|
| return wrapper |
|
|
|
|
| |
| |
| |
| def _ls_new_forward(self, x: torch.Tensor) -> torch.Tensor: |
| return x.mul_(self.scale_factor) if self.inplace else x * self.scale_factor |
|
|
|
|
| def ls_apply_patch(ls_module: LayerScale): |
| ls_module.scale_factor = nn.Parameter(ls_module.gamma.clone()) |
| ls_module.forward = _ls_new_forward.__get__(ls_module, LayerScale) |
| del ls_module.gamma |
|
|
|
|
| |
| class PrismaticVisionBackbone(nn.Module): |
| """ |
| Vision backbone for Prismatic models that handles image feature extraction. |
| |
| Supports both single backbone (e.g., SigLIP) and fused backbone (e.g., SigLIP + DINOv2) configurations. |
| For fused backbones, features from both models are concatenated along the feature dimension. |
| """ |
|
|
| def __init__( |
| self, |
| use_fused_vision_backbone: bool, |
| image_sizes: List[int], |
| timm_model_ids: List[str], |
| timm_override_act_layers: List[Optional[str]], |
| ) -> None: |
| """ |
| Initialize the vision backbone. |
| |
| Args: |
| use_fused_vision_backbone: Whether to use two backbones and fuse their features |
| image_sizes: List of image sizes for each backbone |
| timm_model_ids: List of TIMM model IDs to use for each backbone |
| timm_override_act_layers: List of activation layer overrides for each backbone |
| """ |
| super().__init__() |
| self.use_fused_vision_backbone = use_fused_vision_backbone |
| self.num_images_in_input = 2 |
|
|
| |
| if len(timm_model_ids) > 2: |
| raise ValueError("Prismatic models only support up to 2 (fused) vision backbones!") |
|
|
| |
| self.featurizer = self._create_featurizer( |
| model_id=timm_model_ids[0], img_size=image_sizes[0], act_layer=timm_override_act_layers[0] |
| ) |
| self.embed_dim = self.featurizer.embed_dim |
|
|
| |
| if self.use_fused_vision_backbone: |
| self.fused_featurizer = self._create_featurizer( |
| model_id=timm_model_ids[1], img_size=image_sizes[1], act_layer=timm_override_act_layers[1] |
| ) |
| self.embed_dim += self.fused_featurizer.embed_dim |
|
|
| |
| self._patch_layer_scales() |
|
|
| def _create_featurizer(self, model_id: str, img_size: int, act_layer: Optional[str]) -> nn.Module: |
| """ |
| Create a TIMM-based featurizer model with appropriate configurations. |
| |
| Args: |
| model_id: The TIMM model ID to load |
| img_size: Input image size for the model |
| act_layer: Override for the activation layer type |
| |
| Returns: |
| A configured featurizer model |
| """ |
| featurizer = timm.create_model( |
| model_id, |
| pretrained=False, |
| num_classes=0, |
| img_size=img_size, |
| act_layer=act_layer, |
| ) |
|
|
| |
| num_blocks = len(featurizer.blocks) |
| featurizer.forward = unpack_tuple(partial(featurizer.get_intermediate_layers, n={num_blocks - 2})) |
|
|
| return featurizer |
|
|
| def _patch_layer_scales(self) -> None: |
| """ |
| Patch all LayerScale modules to be compatible with HF's parameter naming. |
| |
| HF Transformers overwrites parameters with names containing 'gamma', |
| so we need to rename and modify the forward method. |
| """ |
| |
| for module in self.featurizer.modules(): |
| if isinstance(module, LayerScale): |
| ls_apply_patch(module) |
|
|
| |
| if self.use_fused_vision_backbone: |
| for module in self.fused_featurizer.modules(): |
| if isinstance(module, LayerScale): |
| ls_apply_patch(module) |
|
|
| def get_num_patches(self) -> int: |
| """ |
| Returns the number of vision patches output by the vision backbone. |
| |
| Returns: |
| Number of patches per image |
| """ |
| return self.featurizer.patch_embed.num_patches |
|
|
| def get_num_images_in_input(self) -> int: |
| """ |
| Returns the number of input images for the vision backbone. |
| |
| Returns: |
| Number of images expected in the input |
| """ |
| return self.num_images_in_input |
|
|
| def set_num_images_in_input(self, num_images_in_input: int) -> None: |
| """ |
| Sets the number of input images for the vision backbone. |
| |
| Args: |
| num_images_in_input: Number of images to expect in the input |
| """ |
| self.num_images_in_input = num_images_in_input |
|
|
| def forward(self, pixel_values: torch.Tensor) -> torch.Tensor: |
| """ |
| Implements the forward pass for the vision backbone. |
| |
| If `self.use_fused_vision_backbone == True`, uses both SigLIP and DINOv2 transformers to extract visual features |
| (otherwise uses SigLIP only). Allows multi-image inputs (but only for fused vision backbone). |
| |
| Args: |
| pixel_values (torch.Tensor): Pixels for input image(s), (B, C, H, W). |
| """ |
| if self.num_images_in_input == 1: |
| if not self.use_fused_vision_backbone: |
| return self.featurizer(pixel_values) |
|
|
| |
| img, img_fused = torch.split(pixel_values, [3, 3], dim=1) |
| patches, patches_fused = self.featurizer(img), self.fused_featurizer(img_fused) |
|
|
| return torch.cat([patches, patches_fused], dim=2) |
|
|
| else: |
| assert self.use_fused_vision_backbone, "Multi-image inputs require using fused backbone!" |
|
|
| |
| images = torch.split(pixel_values, [6] * self.num_images_in_input, dim=1) |
|
|
| |
| all_patches = [] |
| for img in images: |
| |
| img_regular, img_fused = torch.split(img, [3, 3], dim=1) |
|
|
| |
| patches = self.featurizer(img_regular) |
| patches_fused = self.fused_featurizer(img_fused) |
|
|
| |
| combined_patches = torch.cat([patches, patches_fused], dim=2) |
| all_patches.append(combined_patches) |
|
|
| |
| return torch.cat(all_patches, dim=1) |
|
|
|
|
| |
| class PrismaticProjector(nn.Module): |
| def __init__(self, use_fused_vision_backbone: bool, vision_dim: int, llm_dim: int) -> None: |
| super().__init__() |
| self.use_fused_vision_backbone = use_fused_vision_backbone |
| self.vision_dim, self.llm_dim = vision_dim, llm_dim |
|
|
| |
| if not self.use_fused_vision_backbone: |
| self.fc1 = nn.Linear(self.vision_dim, self.llm_dim, bias=True) |
| self.fc2 = nn.Linear(self.llm_dim, self.llm_dim, bias=True) |
| self.act_fn1 = nn.GELU() |
| else: |
| initial_projection_dim = 4 * vision_dim |
| self.fc1 = nn.Linear(self.vision_dim, initial_projection_dim, bias=True) |
| self.fc2 = nn.Linear(initial_projection_dim, self.llm_dim, bias=True) |
| self.fc3 = nn.Linear(self.llm_dim, self.llm_dim, bias=True) |
| self.act_fn1 = nn.GELU() |
| self.act_fn2 = nn.GELU() |
|
|
| def forward(self, img_patches: torch.Tensor) -> torch.Tensor: |
| if not self.use_fused_vision_backbone: |
| projected_features = self.fc1(img_patches) |
| projected_features = self.act_fn1(projected_features) |
| projected_features = self.fc2(projected_features) |
| else: |
| projected_features = self.fc1(img_patches) |
| projected_features = self.act_fn1(projected_features) |
| projected_features = self.fc2(projected_features) |
| projected_features = self.act_fn2(projected_features) |
| projected_features = self.fc3(projected_features) |
|
|
| return projected_features |
|
|
|
|
| |
| @dataclass |
| class PrismaticCausalLMOutputWithPast(ModelOutput): |
| """Base class for Prismatic casual (visually-conditioned) language model outputs; also exposes visual features.""" |
|
|
| loss: Optional[torch.FloatTensor] = None |
| logits: torch.FloatTensor = None |
| past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None |
| hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None |
| attentions: Optional[Tuple[torch.FloatTensor]] = None |
|
|
| |
| projector_features: Optional[torch.FloatTensor] = None |
|
|
|
|
| class PrismaticPreTrainedModel(PreTrainedModel): |
| config_class: PretrainedConfig = PrismaticConfig |
| base_model_prefix: str = "model" |
| supports_gradient_checkpointing: bool = True |
|
|
| _no_split_modules: ClassVar[List[str]] = ["PrismaticProjector"] |
| _skip_keys_device_placement: str = "past_key_values" |
| _supports_flash_attn_2: bool = False |
|
|
| def _init_weights(self, module: nn.Module) -> None: |
| |
| |
| |
| std = ( |
| self.config.initializer_range |
| if hasattr(self.config, "initializer_range") |
| else self.config.text_config.initializer_range |
| ) |
|
|
| if hasattr(module, "class_embedding"): |
| module.class_embedding.data.normal_(mean=0.0, std=std) |
|
|
| if isinstance(module, (nn.Linear, nn.Conv2d)): |
| module.weight.data.normal_(mean=0.0, std=std) |
| if module.bias is not None: |
| module.bias.data.zero_() |
| elif isinstance(module, nn.Embedding): |
| module.weight.data.normal_(mean=0.0, std=std) |
| if module.padding_idx is not None: |
| module.weight.data[module.padding_idx].zero_() |
|
|
| @property |
| def _supports_sdpa(self) -> bool: |
| """Check LLM supports SDPA Attention""" |
| return self.language_model._supports_sdpa |
|
|
|
|
| class PrismaticForConditionalGeneration(PrismaticPreTrainedModel): |
| def __init__(self, config: PrismaticConfig) -> None: |
| super().__init__(config) |
|
|
| |
| if config.use_fused_vision_backbone is None: |
| raise ValueError("Missing config field `use_fused_vision_backbone`") |
|
|
| if timm.__version__ not in {"0.9.10", "0.9.11", "0.9.12", "0.9.16"}: |
| raise NotImplementedError( |
| "TIMM Version must be >= 0.9.10 and < 1.0.0 (breaking); please raise a GitHub Issue " |
| "if you urgently need support for latest TIMM versions." |
| ) |
|
|
| if (transformers.__version__ != "4.40.1") or (tokenizers.__version__ != "0.19.1"): |
| logger.warning( |
| f"Expected `transformers==4.40.1` and `tokenizers==0.19.1` but got " |
| f"`transformers=={transformers.__version__}` and `tokenizers=={tokenizers.__version__}`; " |
| f"there might be inference-time regressions due to dependency changes. If in doubt, please" |
| f"use the above versions." |
| ) |
|
|
| |
| self.vision_backbone = PrismaticVisionBackbone( |
| config.use_fused_vision_backbone, config.image_sizes, config.timm_model_ids, config.timm_override_act_layers |
| ) |
|
|
| |
| self.projector = PrismaticProjector( |
| config.use_fused_vision_backbone, |
| vision_dim=self.vision_backbone.embed_dim, |
| llm_dim=config.text_config.hidden_size, |
| ) |
|
|
| |
| num_patches = self.vision_backbone.get_num_patches() * self.vision_backbone.get_num_images_in_input() |
| |
| for key in ( |
| "prune_selection_strategy", |
| "prune_temperature", |
| "prune_target_coverage", |
| "prune_min_keep", |
| "prune_max_keep", |
| "prune_keep_bins", |
| "prune_top_k", |
| "prune_debug", |
| "prune_debug_max_logs", |
| "prune_prompt_aggregation", |
| "prune_logsumexp_temperature", |
| "prune_soft_rescale_mean_preserve", |
| "prune_soft_rescale_clip", |
| "prune_disabled", |
| ): |
| if hasattr(config, key) and not hasattr(config.text_config, key): |
| setattr(config.text_config, key, getattr(config, key)) |
| self.language_model = PrunedLlamaForCausalLM(config.text_config, num_patches) |
| self.vocab_size = config.text_config.vocab_size |
| self.pad_token_id = config.pad_token_id |
| self.llm_dim = config.text_config.hidden_size |
|
|
| |
| self.post_init() |
|
|
| |
| def get_input_embeddings(self) -> nn.Module: |
| return self.language_model.get_input_embeddings() |
|
|
| def set_input_embeddings(self, value: nn.Module) -> None: |
| self.language_model.set_input_embeddings(value) |
|
|
| def get_output_embeddings(self) -> nn.Module: |
| return self.language_model.get_output_embeddings() |
|
|
| def set_output_embeddings(self, new_embeddings: nn.Module) -> None: |
| self.language_model.set_output_embeddings(new_embeddings) |
|
|
| def get_decoder(self) -> nn.Module: |
| return self.language_model.get_decoder() |
|
|
| def set_decoder(self, decoder: nn.Module) -> None: |
| self.language_model.set_decoder(decoder) |
|
|
| def tie_weights(self) -> None: |
| self.language_model.tie_weights() |
|
|
| def resize_token_embeddings( |
| self, new_num_tokens: Optional[int] = None, pad_to_multiple_of: Optional[int] = None |
| ) -> nn.Embedding: |
| updated_embeddings = self.language_model.resize_token_embeddings(new_num_tokens, pad_to_multiple_of) |
|
|
| |
| self.config.text_config.vocab_size = updated_embeddings.num_embeddings |
| self.vocab_size = updated_embeddings.num_embeddings |
|
|
| return updated_embeddings |
|
|
| def _replace_input_embeddings(self, input_embeddings, all_actions_mask, noisy_action_features): |
| """ |
| Replace embeddings in input_embeddings at positions where all_actions_mask is True |
| with embeddings from noisy_action_features, using vectorized operations. |
| |
| Args: |
| input_embeddings: Tensor of shape (B, S, D) |
| all_actions_mask: Boolean tensor of shape (B, S) |
| noisy_action_features: Tensor of shape (B, K, D) where K is the number of True values in mask per sample |
| |
| Returns: |
| Modified input_embeddings tensor |
| """ |
| |
| new_input_embeddings = input_embeddings.clone() |
|
|
| |
| repositioned_noisy_action_features = torch.zeros_like(input_embeddings) |
|
|
| |
| batch_indices = torch.arange(input_embeddings.shape[0], device=input_embeddings.device) |
| batch_indices = batch_indices.unsqueeze(1).expand(-1, noisy_action_features.shape[1]) |
|
|
| |
| masked_indices = torch.stack([torch.where(mask)[0] for mask in all_actions_mask]) |
|
|
| |
| repositioned_noisy_action_features[batch_indices, masked_indices] = noisy_action_features |
|
|
| |
| new_input_embeddings = torch.where( |
| all_actions_mask.unsqueeze(-1), repositioned_noisy_action_features, new_input_embeddings |
| ) |
|
|
| return new_input_embeddings |
|
|
| def _process_action_masks(self, labels): |
| """Helper to get action masks from labels""" |
| current_action_mask = get_current_action_mask(labels) |
| next_actions_mask = get_next_actions_mask(labels) |
| all_actions_mask = current_action_mask | next_actions_mask |
| return all_actions_mask |
|
|
| def _process_vision_features(self, pixel_values, language_embeddings=None, use_film=False): |
| """Process vision features with optional FiLM conditioning""" |
| if use_film: |
| |
| patch_features = self.vision_backbone(pixel_values, language_embeddings) |
| else: |
| patch_features = self.vision_backbone(pixel_values) |
|
|
| |
| return self.projector(patch_features) |
|
|
| def _process_proprio_features(self, projected_patch_embeddings, proprio, proprio_projector): |
| """Process proprioceptive features and append to vision features""" |
| if proprio_projector is not None and proprio is not None: |
| |
| |
| proprio = proprio.reshape(projected_patch_embeddings.shape[0], -1) |
| proprio_features = proprio_projector(proprio) |
| proprio_features = proprio_features.unsqueeze(dim=1) |
| |
| return torch.cat((projected_patch_embeddings, proprio_features), dim=1) |
| return projected_patch_embeddings |
|
|
| def _build_multimodal_attention(self, input_embeddings, projected_patch_embeddings, attention_mask): |
| """Build multimodal embeddings and attention mask""" |
| |
| projected_patch_attention_mask = None |
| if attention_mask is not None: |
| projected_patch_attention_mask = torch.full( |
| (projected_patch_embeddings.shape[0], projected_patch_embeddings.shape[1]), |
| fill_value=True, |
| dtype=attention_mask.dtype, |
| device=attention_mask.device, |
| ) |
|
|
| |
| multimodal_embeddings = torch.cat( |
| [input_embeddings[:, :1, :], projected_patch_embeddings, input_embeddings[:, 1:, :]], dim=1 |
| ) |
|
|
| multimodal_attention_mask = None |
| if attention_mask is not None: |
| multimodal_attention_mask = torch.cat( |
| [attention_mask[:, :1], projected_patch_attention_mask, attention_mask[:, 1:]], dim=1 |
| ) |
|
|
| return multimodal_embeddings, multimodal_attention_mask |
|
|
| def _build_multimodal_labels(self, labels, projected_patch_embeddings): |
| """Build multimodal labels with IGNORE_INDEX for patch embeddings""" |
| if labels is not None: |
| projected_patch_labels = torch.full( |
| (projected_patch_embeddings.shape[0], projected_patch_embeddings.shape[1]), |
| fill_value=IGNORE_INDEX, |
| dtype=labels.dtype, |
| device=labels.device, |
| ) |
| return torch.cat([labels[:, :1], projected_patch_labels, labels[:, 1:]], dim=1) |
| return None |
|
|
| |
| def forward( |
| self, |
| input_ids: Optional[torch.LongTensor] = None, |
| attention_mask: Optional[torch.Tensor] = None, |
| pixel_values: Optional[torch.FloatTensor] = None, |
| labels: Optional[torch.LongTensor] = None, |
| inputs_embeds: Optional[torch.FloatTensor] = None, |
| past_key_values: Optional[List[torch.FloatTensor]] = None, |
| use_cache: Optional[bool] = None, |
| output_attentions: Optional[bool] = None, |
| output_hidden_states: Optional[bool] = None, |
| output_projector_features: Optional[bool] = None, |
| return_dict: Optional[bool] = None, |
| proprio=None, |
| proprio_projector=None, |
| noisy_actions=None, |
| noisy_action_projector=None, |
| diffusion_timestep_embeddings=None, |
| use_film: bool = False, |
| ) -> Union[Tuple, PrismaticCausalLMOutputWithPast]: |
| """Run a forward pass through the VLM, returning a PrismaticCausalLMOutputWithPast instance.""" |
| output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions |
| output_hidden_states = ( |
| output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states |
| ) |
| output_projector_features = output_projector_features if output_projector_features is not None else False |
| return_dict = return_dict if return_dict is not None else self.config.use_return_dict |
|
|
| |
| use_cache = use_cache and not self.training |
|
|
| |
| projected_patch_embeddings = None |
|
|
| |
| if input_ids.shape[1] == 1: |
| assert input_ids.shape[0] == 1, "Generation is only currently supported for batch size of 1!" |
| assert past_key_values is not None, "You must provide `past_key_values` during cached generation!" |
| assert labels is None, "Unexpected key `labels` provided during cached generation!" |
|
|
| language_model_output = self.language_model( |
| input_ids=input_ids, |
| attention_mask=None, |
| position_ids=None, |
| past_key_values=past_key_values, |
| inputs_embeds=None, |
| labels=None, |
| use_cache=use_cache, |
| output_attentions=output_attentions, |
| output_hidden_states=output_hidden_states, |
| return_dict=return_dict, |
| ) |
|
|
| |
| elif pixel_values is None: |
| assert (input_ids is not None) and (inputs_embeds is None), "Missing `input_ids` in language-only forward!" |
| assert past_key_values is None, "Unexpected key `past_key_values` provided during language-only forward!" |
|
|
| language_model_output = self.language_model( |
| input_ids=input_ids, |
| attention_mask=attention_mask, |
| position_ids=None, |
| past_key_values=None, |
| inputs_embeds=None, |
| labels=labels, |
| use_cache=use_cache, |
| output_attentions=output_attentions, |
| output_hidden_states=output_hidden_states, |
| return_dict=return_dict, |
| ) |
|
|
| |
| elif (input_ids.shape[0] == pixel_values.shape[0]) or (inputs_embeds.shape[0] == pixel_values.shape[0]): |
| assert past_key_values is None, "Unexpected key `past_key_values` provided during multimodal forward!" |
|
|
| |
| input_embeddings = self.get_input_embeddings()(input_ids) |
|
|
| |
| all_actions_mask = self._process_action_masks(labels) |
|
|
| |
| language_embeddings = input_embeddings[~all_actions_mask].reshape( |
| input_embeddings.shape[0], -1, input_embeddings.shape[2] |
| ) |
|
|
| |
| projected_patch_embeddings = self._process_vision_features(pixel_values, language_embeddings, use_film) |
|
|
| |
| projected_patch_embeddings = self._process_proprio_features( |
| projected_patch_embeddings, proprio, proprio_projector |
| ) |
|
|
| |
| if diffusion_timestep_embeddings is not None: |
| |
| projected_patch_embeddings = torch.cat( |
| (projected_patch_embeddings, diffusion_timestep_embeddings), dim=1 |
| ) |
|
|
| |
| if noisy_actions is not None: |
| |
| all_actions_mask = self._process_action_masks(labels) |
|
|
| |
| |
| B = noisy_actions.shape[0] |
| noisy_actions = noisy_actions.reshape(B, -1).unsqueeze(-1) |
|
|
| |
| noisy_action_features = noisy_action_projector(noisy_actions) |
|
|
| |
| input_embeddings = self._replace_input_embeddings( |
| input_embeddings, all_actions_mask, noisy_action_features |
| ) |
| else: |
| |
| |
| all_actions_mask = all_actions_mask.unsqueeze(-1) |
| input_embeddings = input_embeddings * ~all_actions_mask |
|
|
| |
| multimodal_embeddings, multimodal_attention_mask = self._build_multimodal_attention( |
| input_embeddings, projected_patch_embeddings, attention_mask |
| ) |
|
|
| |
| multimodal_labels = self._build_multimodal_labels(labels, projected_patch_embeddings) |
|
|
| |
| language_model_output = self.language_model( |
| input_ids=None, |
| attention_mask=multimodal_attention_mask, |
| position_ids=None, |
| past_key_values=None, |
| inputs_embeds=multimodal_embeddings, |
| labels=multimodal_labels, |
| use_cache=use_cache, |
| output_attentions=output_attentions, |
| output_hidden_states=output_hidden_states, |
| return_dict=return_dict, |
| ) |
|
|
| |
| elif (input_ids.shape[0] != pixel_values.shape[0]) or (inputs_embeds.shape[0] != pixel_values.shape[0]): |
| raise ValueError("Non-homogenous batch of (text, image) input -- forward() does not support mixed batches!") |
|
|
| else: |
| raise ValueError( |
| "Invalid PrismaticForConditionalGeneration `forward()` call with provided arguments:\n" |
| f"=> `input_ids` = {input_ids is not None}\n" |
| f"=> `attention_mask` = {attention_mask is not None}\n" |
| f"=> `pixel_values` = {pixel_values is not None}\n" |
| f"=> `labels` = {labels is not None}\n" |
| f"=> `input_embeds` = {inputs_embeds is not None}\n" |
| f"=> `past_key_values` = {past_key_values is not None}\n" |
| f"=> `use_cache` = {use_cache}" |
| ) |
|
|
| |
| if not return_dict: |
| if output_projector_features and (projected_patch_embeddings is not None): |
| return *language_model_output, projected_patch_embeddings |
|
|
| return language_model_output |
|
|
| return PrismaticCausalLMOutputWithPast( |
| loss=language_model_output.loss, |
| logits=language_model_output.logits, |
| past_key_values=language_model_output.past_key_values, |
| hidden_states=language_model_output.hidden_states, |
| attentions=language_model_output.attentions, |
| projector_features=projected_patch_embeddings, |
| ) |
|
|
| |
| def prepare_inputs_for_generation( |
| self, |
| input_ids: Optional[torch.Tensor] = None, |
| past_key_values: Optional[List[torch.FloatTensor]] = None, |
| inputs_embeds: Optional[torch.FloatTensor] = None, |
| pixel_values: Optional[torch.FloatTensor] = None, |
| attention_mask: Optional[torch.Tensor] = None, |
| **kwargs: str, |
| ) -> Dict[str, torch.Tensor]: |
| """Borrowed from `LlamaForCausalLM` and simplified for batch size = 1; mirrors original PrismaticVLM logic.""" |
| if ((input_ids is not None) and (input_ids.shape[0] > 1)) or ( |
| (inputs_embeds is not None) and (inputs_embeds.shape[0] > 1) |
| ): |
| raise ValueError("Generation with batch size > 1 is not currently supported!") |
|
|
| |
| if past_key_values is not None: |
| input_ids = input_ids[:, -1:] |
|
|
| |
| if inputs_embeds is not None and past_key_values is None: |
| model_inputs = {"input_embeds": inputs_embeds} |
| else: |
| model_inputs = {"input_ids": input_ids} |
|
|
| |
| model_inputs.update( |
| { |
| "attention_mask": attention_mask, |
| "pixel_values": pixel_values, |
| "past_key_values": past_key_values, |
| "use_cache": kwargs.get("use_cache"), |
| } |
| ) |
|
|
| return model_inputs |
|
|
| |
| def _reorder_cache(self, *args, **kwargs) -> Any: |
| return self.language_model._reorder_cache(*args, **kwargs) |
|
|
|
|
| class OpenVLAForActionPrediction(PrismaticForConditionalGeneration): |
| config_class: PretrainedConfig = OpenVLAConfig |
|
|
| def __init__(self, config: OpenVLAConfig) -> None: |
| super().__init__(config) |
| self.norm_stats = config.norm_stats |
|
|
| |
| self.bins = np.linspace(-1, 1, config.n_action_bins) |
| self.bin_centers = (self.bins[:-1] + self.bins[1:]) / 2.0 |
|
|
| |
| self.vocab_size = self.config.text_config.vocab_size - self.config.pad_to_multiple_of |
| |
| def set_num_images_in_input(self, num_images_in_input): |
| self.vision_backbone.set_num_images_in_input(num_images_in_input) |
| self.language_model.model.pruner.num_patches = self.vision_backbone.get_num_patches() * self.vision_backbone.get_num_images_in_input() |
|
|
| def get_num_patches(self): |
| return self.vision_backbone.get_num_patches() * self.vision_backbone.get_num_images_in_input() |
|
|
| def _prepare_input_for_action_prediction(self, input_ids, attention_mask): |
| """Prepares input for action prediction by adding necessary tokens""" |
| |
| placeholder_action_token_ids = ( |
| torch.ones((input_ids.shape[0], ACTION_DIM * NUM_ACTIONS_CHUNK)).to(input_ids.device).to(input_ids.dtype) |
| ) |
| input_ids = torch.cat([input_ids, placeholder_action_token_ids], dim=-1) |
|
|
| |
| stop_token_id = torch.ones((input_ids.shape[0], 1)).to(input_ids.device).to(input_ids.dtype) * STOP_INDEX |
| input_ids = torch.cat([input_ids, stop_token_id], dim=-1) |
|
|
| |
| |
| mask_extension = ( |
| torch.ones((attention_mask.shape[0], input_ids.shape[-1] - attention_mask.shape[-1])) |
| .to(attention_mask.device) |
| .to(attention_mask.dtype) |
| ) |
| attention_mask = torch.cat([attention_mask, mask_extension], dim=-1) |
|
|
| return input_ids, attention_mask |
|
|
| def _prepare_labels_for_action_prediction(self, labels, input_ids): |
| """Creates labels tensor for action prediction if not provided""" |
| |
| ARBITRARY_ACTION_TOKEN_IDX = ACTION_TOKEN_BEGIN_IDX + 1 |
| labels_extension = ( |
| torch.ones((labels.shape[0], input_ids.shape[-1] - labels.shape[-1])).to(labels.device).to(labels.dtype) |
| * ARBITRARY_ACTION_TOKEN_IDX |
| ) |
| labels = torch.cat([labels, labels_extension], dim=-1) |
|
|
| |
| labels[:, -1] = STOP_INDEX |
|
|
| return labels |
|
|
| def _unnormalize_actions(self, normalized_actions, unnorm_key=None): |
| """Unnormalize actions using dataset statistics""" |
| action_norm_stats = self.get_action_stats(unnorm_key) |
|
|
| if ACTION_PROPRIO_NORMALIZATION_TYPE == NormalizationType.BOUNDS: |
| mask = action_norm_stats.get("mask", np.ones_like(action_norm_stats["min"], dtype=bool)) |
| action_high, action_low = np.array(action_norm_stats["max"]), np.array(action_norm_stats["min"]) |
| elif ACTION_PROPRIO_NORMALIZATION_TYPE == NormalizationType.BOUNDS_Q99: |
| mask = action_norm_stats.get("mask", np.ones_like(action_norm_stats["q01"], dtype=bool)) |
| action_high, action_low = np.array(action_norm_stats["q99"]), np.array(action_norm_stats["q01"]) |
| else: |
| raise ValueError("Unsupported action/proprio normalization type detected!") |
|
|
| actions = np.where( |
| mask, |
| 0.5 * (normalized_actions + 1) * (action_high - action_low + 1e-8) + action_low, |
| normalized_actions, |
| ) |
|
|
| return actions |
|
|
| def _run_diffusion_prediction( |
| self, |
| input_embeddings, |
| all_actions_mask, |
| noise, |
| action_head, |
| projected_patch_embeddings, |
| labels, |
| attention_mask, |
| NUM_PROMPT_TOKENS, |
| noisy_action_projector, |
| ): |
| """Run diffusion-based action prediction""" |
| |
| orig_projected_patch_embeddings = projected_patch_embeddings.clone() |
| curr_noisy_actions = noise |
|
|
| |
| for t in action_head.noise_scheduler.timesteps: |
| |
| |
| timesteps = torch.Tensor([t]).to(labels.device) |
| diffusion_timestep_embeddings = ( |
| action_head.time_encoder(timesteps).to(curr_noisy_actions.dtype).to(curr_noisy_actions.device) |
| ) |
| diffusion_timestep_embeddings = diffusion_timestep_embeddings.unsqueeze(1) |
|
|
| |
| |
|
|
| |
| projected_patch_embeddings = torch.cat( |
| (orig_projected_patch_embeddings, diffusion_timestep_embeddings), dim=1 |
| ) |
|
|
| |
| B = curr_noisy_actions.shape[0] |
| orig_curr_noisy_actions_shape = curr_noisy_actions.shape |
| curr_noisy_actions = curr_noisy_actions.reshape(B, -1).unsqueeze(-1) |
| noisy_action_features = noisy_action_projector(curr_noisy_actions) |
| curr_noisy_actions = curr_noisy_actions.reshape(orig_curr_noisy_actions_shape) |
|
|
| |
| input_embeddings = self._replace_input_embeddings( |
| input_embeddings.clone(), all_actions_mask, noisy_action_features |
| ) |
|
|
| |
| multimodal_embeddings, multimodal_attention_mask = self._build_multimodal_attention( |
| input_embeddings, projected_patch_embeddings, attention_mask |
| ) |
|
|
| |
| language_model_output = self.language_model( |
| input_ids=None, |
| attention_mask=multimodal_attention_mask, |
| position_ids=None, |
| past_key_values=None, |
| inputs_embeds=multimodal_embeddings, |
| labels=None, |
| use_cache=None, |
| output_attentions=False, |
| output_hidden_states=True, |
| return_dict=True, |
| ) |
|
|
| |
| last_hidden_states = language_model_output.hidden_states[-1] |
| actions_hidden_states = last_hidden_states[ |
| :, |
| -ACTION_DIM * NUM_ACTIONS_CHUNK:, |
| :, |
| ] |
|
|
| |
| noise_pred = action_head.predict_noise(actions_hidden_states) |
| curr_noisy_actions = action_head.noise_scheduler.step(noise_pred, t, curr_noisy_actions).prev_sample |
|
|
| curr_noisy_actions = curr_noisy_actions.reshape(NUM_ACTIONS_CHUNK, ACTION_DIM) |
|
|
| |
| return curr_noisy_actions.float().cpu().detach().numpy(), actions_hidden_states |
|
|
| def _regression_or_discrete_prediction( |
| self, |
| input_embeddings, |
| all_actions_mask, |
| projected_patch_embeddings, |
| attention_mask, |
| labels, |
| NUM_PROMPT_TOKENS, |
| action_head=None, |
| ): |
| """Run L1 regression-based continuous action prediction or discrete action tokens prediction.""" |
| |
| all_actions_mask = all_actions_mask.unsqueeze(-1) |
| input_embeddings = input_embeddings * ~all_actions_mask |
|
|
| |
| multimodal_embeddings, multimodal_attention_mask = self._build_multimodal_attention( |
| input_embeddings, projected_patch_embeddings, attention_mask |
| ) |
|
|
| |
| language_model_output = self.language_model( |
| input_ids=None, |
| attention_mask=multimodal_attention_mask, |
| position_ids=None, |
| past_key_values=None, |
| inputs_embeds=multimodal_embeddings, |
| labels=None, |
| use_cache=None, |
| output_attentions=False, |
| output_hidden_states=True, |
| return_dict=True, |
| ) |
|
|
| |
| last_hidden_states = language_model_output.hidden_states[-1] |
| actions_hidden_states = last_hidden_states[ |
| :, |
| -ACTION_DIM * NUM_ACTIONS_CHUNK:, |
| :, |
| ] |
|
|
| |
| if action_head is not None: |
| |
| normalized_actions = action_head.predict_action(actions_hidden_states) |
| normalized_actions = normalized_actions.reshape(NUM_ACTIONS_CHUNK, ACTION_DIM) |
| normalized_actions = normalized_actions.float().cpu().detach().numpy() |
| else: |
| |
| predicted_action_token_ids = ( |
| language_model_output.logits[ |
| :, |
| -ACTION_DIM * NUM_ACTIONS_CHUNK:, |
| ] |
| .argmax(dim=2) |
| .cpu() |
| .numpy() |
| ) |
| discretized_actions = self.vocab_size - predicted_action_token_ids |
| discretized_actions = np.clip(discretized_actions - 1, a_min=0, a_max=self.bin_centers.shape[0] - 1) |
| normalized_actions = self.bin_centers[discretized_actions] |
| normalized_actions = normalized_actions.reshape(NUM_ACTIONS_CHUNK, ACTION_DIM) |
|
|
| return normalized_actions, actions_hidden_states |
|
|
| def predict_action( |
| self, |
| input_ids: Optional[torch.LongTensor] = None, |
| unnorm_key: Optional[str] = None, |
| proprio=None, |
| proprio_projector=None, |
| action_head=None, |
| noisy_action_projector=None, |
| use_film: bool = False, |
| **kwargs: str, |
| ) -> np.ndarray: |
| """Predict actions from input sequence, with options for different prediction methods. |
| |
| Args: |
| input_ids: Input token ids |
| unnorm_key: Key for unnormalization statistics |
| proprio: Proprioceptive features |
| proprio_projector: Projector for proprioceptive features |
| action_head: Optional head for L1 regression or diffusion-based prediction |
| noisy_action_projector: Projector for noisy actions in diffusion-based prediction |
| use_film: Whether to use FiLM conditioning |
| **kwargs: Additional arguments including pixel_values and attention_mask |
| |
| Returns: |
| Tuple of (unnormalized_actions, action_hidden_states) |
| """ |
| |
| |
| if not torch.all(input_ids[:, -1] == 29871): |
| input_ids = torch.cat( |
| (input_ids, torch.unsqueeze(torch.Tensor([29871]).long(), dim=0).to(input_ids.device)), dim=1 |
| ) |
|
|
| pixel_values = kwargs["pixel_values"] |
| attention_mask = kwargs["attention_mask"] |
|
|
| |
| labels = input_ids.clone() |
| labels[:] = IGNORE_INDEX |
|
|
| |
| NUM_PROMPT_TOKENS = input_ids.shape[-1] - 1 |
|
|
| |
| input_ids, attention_mask = self._prepare_input_for_action_prediction(input_ids, attention_mask) |
|
|
| |
| labels = self._prepare_labels_for_action_prediction(labels, input_ids) |
|
|
| |
| input_embeddings = self.get_input_embeddings()(input_ids) |
| all_actions_mask = self._process_action_masks(labels) |
|
|
| |
| language_embeddings = input_embeddings[~all_actions_mask].reshape( |
| input_embeddings.shape[0], -1, input_embeddings.shape[2] |
| ) |
|
|
| |
| projected_patch_embeddings = self._process_vision_features(pixel_values, language_embeddings, use_film) |
|
|
| |
| use_proprio = proprio_projector is not None and proprio is not None |
| if use_proprio: |
| proprio = torch.Tensor(proprio).to(projected_patch_embeddings.device, dtype=projected_patch_embeddings.dtype) |
| projected_patch_embeddings = self._process_proprio_features( |
| projected_patch_embeddings, proprio, proprio_projector |
| ) |
|
|
| |
| use_diffusion = noisy_action_projector is not None and hasattr(action_head, "noise_scheduler") |
|
|
| if use_diffusion: |
| |
| noise = torch.randn( |
| size=(1, NUM_ACTIONS_CHUNK, ACTION_DIM), device=input_embeddings.device, dtype=input_embeddings.dtype |
| ) |
|
|
| |
| normalized_actions, actions_hidden_states = self._run_diffusion_prediction( |
| input_embeddings, |
| all_actions_mask, |
| noise, |
| action_head, |
| projected_patch_embeddings, |
| labels, |
| attention_mask, |
| NUM_PROMPT_TOKENS, |
| noisy_action_projector, |
| ) |
| else: |
| |
| normalized_actions, actions_hidden_states = self._regression_or_discrete_prediction( |
| input_embeddings, |
| all_actions_mask, |
| projected_patch_embeddings, |
| attention_mask, |
| labels, |
| NUM_PROMPT_TOKENS, |
| action_head, |
| ) |
|
|
| |
| actions = self._unnormalize_actions(normalized_actions, unnorm_key) |
|
|
| return actions, actions_hidden_states |
|
|
| @staticmethod |
| def _check_unnorm_key(norm_stats: Dict[str, Dict[str, Any]], unnorm_key: Optional[str]) -> str: |
| """Validate and resolve the unnormalization key for action statistics""" |
| if unnorm_key is None: |
| assert len(norm_stats) == 1, ( |
| f"Your model was trained on more than one dataset, " |
| f"please pass a `unnorm_key` from the following options to choose the statistics " |
| f"used for un-normalizing actions: {norm_stats.keys()}" |
| ) |
| unnorm_key = next(iter(norm_stats.keys())) |
|
|
| assert unnorm_key in norm_stats, ( |
| f"The `unnorm_key` you chose is not in the set of available dataset statistics, " |
| f"please choose from: {norm_stats.keys()}" |
| ) |
| return unnorm_key |
|
|
| def get_action_dim(self, unnorm_key: Optional[str] = None) -> int: |
| """Get the dimensionality of the policy's action space.""" |
| unnorm_key = self._check_unnorm_key(self.norm_stats, unnorm_key) |
| return len(self.norm_stats[unnorm_key]["action"]["min"]) |
|
|
| def get_action_stats(self, unnorm_key: Optional[str] = None) -> Dict[str, Any]: |
| """Get all the logged statistics for the given dataset.""" |
| unnorm_key = self._check_unnorm_key(self.norm_stats, unnorm_key) |
| return self.norm_stats[unnorm_key]["action"] |
|
|