Upload 5 files
Browse files- README.md +99 -0
- configuration_prismatic.py +140 -0
- modeling_prismatic.py +562 -0
- processing_prismatic.py +257 -0
- processor_config.json +6 -0
README.md
ADDED
|
@@ -0,0 +1,99 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
library_name: transformers
|
| 3 |
+
tags:
|
| 4 |
+
- robotics
|
| 5 |
+
- vla
|
| 6 |
+
- image-text-to-text
|
| 7 |
+
- multimodal
|
| 8 |
+
- pretraining
|
| 9 |
+
license: mit
|
| 10 |
+
language:
|
| 11 |
+
- en
|
| 12 |
+
pipeline_tag: image-text-to-text
|
| 13 |
+
---
|
| 14 |
+
|
| 15 |
+
# OpenVLA 7B
|
| 16 |
+
|
| 17 |
+
OpenVLA 7B (`openvla-7b`) is an open vision-language-action model trained on 970K robot manipulation episodes from the [Open X-Embodiment](https://robotics-transformer-x.github.io/) dataset.
|
| 18 |
+
The model takes language instructions and camera images as input and generates robot actions. It supports controlling multiple robots out-of-the-box, and can be quickly adapted for new robot domains via (parameter-efficient) fine-tuning.
|
| 19 |
+
|
| 20 |
+
All OpenVLA checkpoints, as well as our [training codebase](https://github.com/openvla/openvla) are released under an MIT License.
|
| 21 |
+
|
| 22 |
+
For full details, please read [our paper](https://arxiv.org/abs/2406.09246) and see [our project page](https://openvla.github.io/).
|
| 23 |
+
|
| 24 |
+
## Model Summary
|
| 25 |
+
|
| 26 |
+
- **Developed by:** The OpenVLA team consisting of researchers from Stanford, UC Berkeley, Google Deepmind, and the Toyota Research Institute.
|
| 27 |
+
- **Model type:** Vision-language-action (language, image => robot actions)
|
| 28 |
+
- **Language(s) (NLP):** en
|
| 29 |
+
- **License:** MIT
|
| 30 |
+
- **Finetuned from:** [`prism-dinosiglip-224px`](https://github.com/TRI-ML/prismatic-vlms), a VLM trained from:
|
| 31 |
+
+ **Vision Backbone**: DINOv2 ViT-L/14 and SigLIP ViT-So400M/14
|
| 32 |
+
+ **Language Model**: Llama-2
|
| 33 |
+
- **Pretraining Dataset:** [Open X-Embodiment](https://robotics-transformer-x.github.io/) -- specific component datasets can be found [here](https://github.com/openvla/openvla).
|
| 34 |
+
- **Repository:** [https://github.com/openvla/openvla](https://github.com/openvla/openvla)
|
| 35 |
+
- **Paper:** [OpenVLA: An Open-Source Vision-Language-Action Model](https://arxiv.org/abs/2406.09246)
|
| 36 |
+
- **Project Page & Videos:** [https://openvla.github.io/](https://openvla.github.io/)
|
| 37 |
+
|
| 38 |
+
## Uses
|
| 39 |
+
|
| 40 |
+
OpenVLA models take a language instruction and a camera image of a robot workspace as input, and predict (normalized) robot actions consisting of 7-DoF end-effector deltas
|
| 41 |
+
of the form (x, y, z, roll, pitch, yaw, gripper). To execute on an actual robot platform, actions need to be *un-normalized* subject to statistics computed on a per-robot,
|
| 42 |
+
per-dataset basis. See [our repository](https://github.com/openvla/openvla) for more information.
|
| 43 |
+
|
| 44 |
+
OpenVLA models can be used zero-shot to control robots for specific combinations of embodiments and domains seen in the Open-X pretraining mixture (e.g., for
|
| 45 |
+
[BridgeV2 environments with a Widow-X robot](https://rail-berkeley.github.io/bridgedata/)). They can also be efficiently *fine-tuned* for new tasks and robot setups
|
| 46 |
+
given minimal demonstration data; [see here](https://github.com/openvla/openvla/blob/main/scripts/finetune.py).
|
| 47 |
+
|
| 48 |
+
**Out-of-Scope:** OpenVLA models do not zero-shot generalize to new (unseen) robot embodiments, or setups that are not represented in the pretraining mix; in these cases,
|
| 49 |
+
we suggest collecting a dataset of demonstrations on the desired setup, and fine-tuning OpenVLA models instead.
|
| 50 |
+
|
| 51 |
+
## Getting Started
|
| 52 |
+
|
| 53 |
+
OpenVLA 7B can be used to control multiple robots for domains represented in the pretraining mixture out-of-the-box. For example,
|
| 54 |
+
here is an example for loading `openvla-7b` for zero-shot instruction following in the [BridgeV2 environments] with a Widow-X robot:
|
| 55 |
+
|
| 56 |
+
```python
|
| 57 |
+
# Install minimal dependencies (`torch`, `transformers`, `timm`, `tokenizers`, ...)
|
| 58 |
+
# > pip install -r https://raw.githubusercontent.com/openvla/openvla/main/requirements-min.txt
|
| 59 |
+
from transformers import AutoModelForVision2Seq, AutoProcessor
|
| 60 |
+
from PIL import Image
|
| 61 |
+
|
| 62 |
+
import torch
|
| 63 |
+
|
| 64 |
+
# Load Processor & VLA
|
| 65 |
+
processor = AutoProcessor.from_pretrained("openvla/openvla-7b", trust_remote_code=True)
|
| 66 |
+
vla = AutoModelForVision2Seq.from_pretrained(
|
| 67 |
+
"openvla/openvla-7b",
|
| 68 |
+
attn_implementation="flash_attention_2", # [Optional] Requires `flash_attn`
|
| 69 |
+
torch_dtype=torch.bfloat16,
|
| 70 |
+
low_cpu_mem_usage=True,
|
| 71 |
+
trust_remote_code=True
|
| 72 |
+
).to("cuda:0")
|
| 73 |
+
|
| 74 |
+
# Grab image input & format prompt
|
| 75 |
+
image: Image.Image = get_from_camera(...)
|
| 76 |
+
prompt = "In: What action should the robot take to {<INSTRUCTION>}?\nOut:"
|
| 77 |
+
|
| 78 |
+
# Predict Action (7-DoF; un-normalize for BridgeV2)
|
| 79 |
+
inputs = processor(prompt, image).to("cuda:0", dtype=torch.bfloat16)
|
| 80 |
+
action = vla.predict_action(**inputs, unnorm_key="bridge_orig", do_sample=False)
|
| 81 |
+
|
| 82 |
+
# Execute...
|
| 83 |
+
robot.act(action, ...)
|
| 84 |
+
```
|
| 85 |
+
|
| 86 |
+
For more examples, including scripts for fine-tuning OpenVLA models on your own robot demonstration datasets, see [our training repository](https://github.com/openvla/openvla).
|
| 87 |
+
|
| 88 |
+
## Citation
|
| 89 |
+
|
| 90 |
+
**BibTeX:**
|
| 91 |
+
|
| 92 |
+
```bibtex
|
| 93 |
+
@article{kim24openvla,
|
| 94 |
+
title={OpenVLA: An Open-Source Vision-Language-Action Model},
|
| 95 |
+
author={{Moo Jin} Kim and Karl Pertsch and Siddharth Karamcheti and Ted Xiao and Ashwin Balakrishna and Suraj Nair and Rafael Rafailov and Ethan Foster and Grace Lam and Pannag Sanketi and Quan Vuong and Thomas Kollar and Benjamin Burchfiel and Russ Tedrake and Dorsa Sadigh and Sergey Levine and Percy Liang and Chelsea Finn},
|
| 96 |
+
journal = {arXiv preprint arXiv:2406.09246},
|
| 97 |
+
year={2024}
|
| 98 |
+
}
|
| 99 |
+
```
|
configuration_prismatic.py
ADDED
|
@@ -0,0 +1,140 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
configuration_prismatic.py
|
| 3 |
+
|
| 4 |
+
HuggingFace-style configuration definition for Prismatic VLMs, inheriting from `transformers.PretrainedConfig`.
|
| 5 |
+
Default configuration specifies `siglip-224px+7b`.
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
from typing import Any, Dict, List, Optional
|
| 9 |
+
|
| 10 |
+
from transformers import PretrainedConfig
|
| 11 |
+
from transformers.models.auto import CONFIG_MAPPING
|
| 12 |
+
|
| 13 |
+
# === Utilities for Mapping Prismatic names to HF names ===
|
| 14 |
+
# fmt: off
|
| 15 |
+
VISION_BACKBONE_TO_RESOLUTION: Dict[str, List[int]] = {
|
| 16 |
+
"clip-vit-l": [224], "siglip-vit-so400m": [224], "dinov2-vit-l": [224], "in1k-vit-l": [224],
|
| 17 |
+
|
| 18 |
+
"clip-vit-l-336px": [336],
|
| 19 |
+
"siglip-vit-so400m-384px": [384],
|
| 20 |
+
|
| 21 |
+
"dinoclip-vit-l-336px": [336, 336],
|
| 22 |
+
"dinosiglip-vit-so-224px": [224, 224],
|
| 23 |
+
"dinosiglip-vit-so-384px": [384, 384],
|
| 24 |
+
}
|
| 25 |
+
VISION_BACKBONE_TO_TIMM_ID: Dict[str, List[str]] = {
|
| 26 |
+
"clip-vit-l": ["vit_large_patch14_clip_224.openai"],
|
| 27 |
+
"clip-vit-l-336px": ["vit_large_patch14_clip_336.openai"],
|
| 28 |
+
|
| 29 |
+
"dinov2-vit-l": ["vit_large_patch14_reg4_dinov2.lvd142m"],
|
| 30 |
+
"in1k-vit-l": ["vit_large_patch16_224.augreg_in21k_ft_in1k"],
|
| 31 |
+
|
| 32 |
+
"siglip-vit-so400m": ["vit_so400m_patch14_siglip_224"],
|
| 33 |
+
"siglip-vit-so400m-384px": ["vit_so400m_patch14_siglip_384"],
|
| 34 |
+
|
| 35 |
+
"dinoclip-vit-l-336px": ["vit_large_patch14_reg4_dinov2.lvd142m", "vit_large_patch14_clip_336.openai"],
|
| 36 |
+
"dinosiglip-vit-so-224px": ["vit_large_patch14_reg4_dinov2.lvd142m", "vit_so400m_patch14_siglip_224"],
|
| 37 |
+
"dinosiglip-vit-so-384px": ["vit_large_patch14_reg4_dinov2.lvd142m", "vit_so400m_patch14_siglip_384"],
|
| 38 |
+
}
|
| 39 |
+
TIMM_OVERRIDE_ACT_LAYER: Dict[str, List[Optional[str]]] = {
|
| 40 |
+
"clip-vit-l": ["quick_gelu"], "clip-vit-l-336px": ["quick_gelu"],
|
| 41 |
+
"dinov2-vit-l": [None], "in1k-vit-l": [None],
|
| 42 |
+
"siglip-vit-so400m": [None], "siglip-vit-so400m-384px": [None],
|
| 43 |
+
"dinoclip-vit-l-336px": [None, "quick_gelu"],
|
| 44 |
+
"dinosiglip-vit-so-224px": [None, None], "dinosiglip-vit-so-384px": [None, None]
|
| 45 |
+
}
|
| 46 |
+
|
| 47 |
+
LLM_BACKBONE_TO_HF_PATH = {
|
| 48 |
+
"llama2-7b-pure": "meta-llama/Llama-2-7b-hf", "llama2-13b-pure": "meta-llama/Llama-2-13b-hf",
|
| 49 |
+
"llama2-7b-chat": "meta-llama/Llama-2-7b-chat-hf", "llama2-13b-chat": "meta-llama/Llama-2-13b-chat-hf",
|
| 50 |
+
|
| 51 |
+
"vicuna-v15-7b": "lmsys/vicuna-7b-v1.5", "vicuna-v15-13b": "lmsys/vicuna-13b-v1.5",
|
| 52 |
+
|
| 53 |
+
"mistral-v0.1-7b-pure": "mistralai/Mistral-7B-v0.1",
|
| 54 |
+
"mistral-v0.1-7b-instruct": "mistralai/Mistral-7B-Instruct-v0.1",
|
| 55 |
+
|
| 56 |
+
"phi-2-3b": "microsoft/phi-2",
|
| 57 |
+
}
|
| 58 |
+
LLM_BACKBONE_TO_HF_METACLASS = {
|
| 59 |
+
"llama2-7b-pure": "llama", "llama2-13b-pure": "llama", "llama2-7b-chat": "llama", "llama2-13b-chat": "llama",
|
| 60 |
+
"vicuna-v15-7b": "llama", "vicuna-v15-13b": "llama",
|
| 61 |
+
|
| 62 |
+
"mistral-v0.1-7b-pure": "mistral", "mistral-v0.1-7b-instruct": "mistral",
|
| 63 |
+
|
| 64 |
+
"phi-2-3b": "phi",
|
| 65 |
+
}
|
| 66 |
+
|
| 67 |
+
VALID_VISION_BACKBONES = set(VISION_BACKBONE_TO_RESOLUTION.keys())
|
| 68 |
+
VALID_LLM_BACKBONES = set(LLM_BACKBONE_TO_HF_PATH)
|
| 69 |
+
# fmt: on
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
class PrismaticConfig(PretrainedConfig):
|
| 73 |
+
model_type: str = "prismatic"
|
| 74 |
+
is_composition: bool = False
|
| 75 |
+
|
| 76 |
+
def __init__(
|
| 77 |
+
self,
|
| 78 |
+
vision_backbone_id: str = "siglip-vit-so400m",
|
| 79 |
+
llm_backbone_id: str = "vicuna-v15-7b",
|
| 80 |
+
arch_specifier: str = "no-align+gelu-mlp",
|
| 81 |
+
use_fused_vision_backbone: Optional[bool] = None,
|
| 82 |
+
image_resize_strategy: str = "letterbox",
|
| 83 |
+
text_config: Optional[Dict[str, Any]] = None,
|
| 84 |
+
llm_max_length: int = 2048,
|
| 85 |
+
pad_token_id: int = 32000,
|
| 86 |
+
pad_to_multiple_of: int = 64,
|
| 87 |
+
output_projector_states: bool = False,
|
| 88 |
+
**kwargs: str,
|
| 89 |
+
) -> None:
|
| 90 |
+
if vision_backbone_id not in VALID_VISION_BACKBONES:
|
| 91 |
+
raise ValueError(f"Vision backbone `{vision_backbone_id}` not in {VALID_VISION_BACKBONES = }")
|
| 92 |
+
|
| 93 |
+
if llm_backbone_id not in VALID_LLM_BACKBONES:
|
| 94 |
+
raise ValueError(f"LLM backbone `{llm_backbone_id}` not in {VALID_LLM_BACKBONES = }")
|
| 95 |
+
|
| 96 |
+
# Set Prismatic Configuration Fields
|
| 97 |
+
self.vision_backbone_id = vision_backbone_id
|
| 98 |
+
self.llm_backbone_id = llm_backbone_id
|
| 99 |
+
self.arch_specifier = arch_specifier
|
| 100 |
+
self.output_projector_states = output_projector_states
|
| 101 |
+
|
| 102 |
+
# [Contract] All vision backbone parameters are lists =>> supports fused backbones with different preprocessing
|
| 103 |
+
self.use_fused_vision_backbone = (
|
| 104 |
+
use_fused_vision_backbone
|
| 105 |
+
if use_fused_vision_backbone is not None
|
| 106 |
+
else any(self.vision_backbone_id.startswith(v) for v in ["dinoclip", "dinosiglip"])
|
| 107 |
+
)
|
| 108 |
+
|
| 109 |
+
self.timm_model_ids = VISION_BACKBONE_TO_TIMM_ID[self.vision_backbone_id]
|
| 110 |
+
self.timm_override_act_layers = TIMM_OVERRIDE_ACT_LAYER[self.vision_backbone_id]
|
| 111 |
+
self.image_sizes = VISION_BACKBONE_TO_RESOLUTION[self.vision_backbone_id]
|
| 112 |
+
self.image_resize_strategy = image_resize_strategy
|
| 113 |
+
|
| 114 |
+
self.hf_llm_id = LLM_BACKBONE_TO_HF_PATH[self.llm_backbone_id]
|
| 115 |
+
self.llm_max_length = llm_max_length
|
| 116 |
+
self.pad_token_id, self.pad_to_multiple_of = pad_token_id, pad_to_multiple_of
|
| 117 |
+
|
| 118 |
+
# [IMPORTANT] HF Utilities actually look for a `text_config` field... we need to use that specific naming!
|
| 119 |
+
self.text_config = (
|
| 120 |
+
CONFIG_MAPPING[LLM_BACKBONE_TO_HF_METACLASS[self.llm_backbone_id]](**text_config)
|
| 121 |
+
if text_config is not None
|
| 122 |
+
else CONFIG_MAPPING[LLM_BACKBONE_TO_HF_METACLASS[self.llm_backbone_id]]()
|
| 123 |
+
)
|
| 124 |
+
|
| 125 |
+
# Dispatch **kwargs to super() =>> note that `pad_token_id` collides, so we pass it in here as well...
|
| 126 |
+
super().__init__(pad_token_id=pad_token_id, **kwargs)
|
| 127 |
+
|
| 128 |
+
|
| 129 |
+
class OpenVLAConfig(PrismaticConfig):
|
| 130 |
+
model_type: str = "openvla"
|
| 131 |
+
|
| 132 |
+
def __init__(
|
| 133 |
+
self,
|
| 134 |
+
norm_stats: Optional[Dict[str, Dict[str, Dict[str, Dict[str, List[float]]]]]] = None,
|
| 135 |
+
n_action_bins: int = 256,
|
| 136 |
+
**kwargs: str,
|
| 137 |
+
) -> None:
|
| 138 |
+
self.norm_stats, self.n_action_bins = norm_stats, n_action_bins
|
| 139 |
+
|
| 140 |
+
super().__init__(**kwargs)
|
modeling_prismatic.py
ADDED
|
@@ -0,0 +1,562 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
modeling_prismatic.py
|
| 3 |
+
|
| 4 |
+
Core HuggingFace-style PrismaticPreTrainedModel and PrismaticForConditionalGeneration class definitions, inheriting
|
| 5 |
+
from the default `transformers.PretrainedModel`. Meant to be standalone and self-contained, but exactly replicate the
|
| 6 |
+
logic in `prismatic.models.vlms.prismatic.py`.
|
| 7 |
+
|
| 8 |
+
Note =>> for the time being, not adding the custom HF "docstring" formatting.
|
| 9 |
+
|
| 10 |
+
References [LLaVa, IDEFICS-2]:
|
| 11 |
+
=> https://github.com/huggingface/transformers/blob/main/src/transformers/models/llava/modeling_llava.py
|
| 12 |
+
=> https://github.com/huggingface/transformers/blob/main/src/transformers/models/idefics2/modeling_idefics2.py
|
| 13 |
+
"""
|
| 14 |
+
|
| 15 |
+
import logging
|
| 16 |
+
from dataclasses import dataclass
|
| 17 |
+
from functools import partial
|
| 18 |
+
from typing import Any, Callable, ClassVar, Dict, List, Optional, Tuple, Union
|
| 19 |
+
|
| 20 |
+
import numpy as np
|
| 21 |
+
import timm
|
| 22 |
+
import tokenizers
|
| 23 |
+
import torch
|
| 24 |
+
import torch.nn as nn
|
| 25 |
+
import transformers
|
| 26 |
+
from timm.models.vision_transformer import LayerScale
|
| 27 |
+
from transformers import AutoModelForCausalLM, PretrainedConfig, PreTrainedModel
|
| 28 |
+
from transformers.modeling_outputs import ModelOutput
|
| 29 |
+
|
| 30 |
+
from .configuration_prismatic import OpenVLAConfig, PrismaticConfig
|
| 31 |
+
|
| 32 |
+
# Get Logger
|
| 33 |
+
logger = logging.getLogger(__name__)
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
# === PyTorch/HuggingFace Default IGNORE_INDEX (for CrossEntropyLoss labels)
|
| 37 |
+
IGNORE_INDEX = -100
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
# === Utility Functions for Monkey-Patching ===
|
| 41 |
+
def unpack_tuple(fn: Callable[[Any], Tuple[Any]]) -> Callable[[Any], Any]:
|
| 42 |
+
def wrapper(*args: Any, **kwargs: Any) -> Any:
|
| 43 |
+
result = fn(*args, **kwargs)
|
| 44 |
+
return result[0] if isinstance(result, tuple) else result
|
| 45 |
+
|
| 46 |
+
return wrapper
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
# HF Transformers overwrites parameters with names containing `gamma`; we're going to patch VisionBackbone.LayerScale.
|
| 50 |
+
# =>> TIMM :: https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/vision_transformer.py#L109
|
| 51 |
+
# =>> Transformers :: https://github.com/huggingface/transformers/blob/main/src/transformers/modeling_utils.py#L3960
|
| 52 |
+
def _ls_new_forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 53 |
+
return x.mul_(self.scale_factor) if self.inplace else x * self.scale_factor
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
def ls_apply_patch(ls_module: LayerScale):
|
| 57 |
+
ls_module.scale_factor = nn.Parameter(ls_module.gamma.clone())
|
| 58 |
+
ls_module.forward = _ls_new_forward.__get__(ls_module, LayerScale)
|
| 59 |
+
del ls_module.gamma
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
# === Prismatic Vision Backbone (nn.Module) Definitions (w/ Fused Backbone Support) ===
|
| 63 |
+
class PrismaticVisionBackbone(nn.Module):
|
| 64 |
+
def __init__(
|
| 65 |
+
self,
|
| 66 |
+
use_fused_vision_backbone: bool,
|
| 67 |
+
image_sizes: List[int],
|
| 68 |
+
timm_model_ids: List[str],
|
| 69 |
+
timm_override_act_layers: List[Optional[str]],
|
| 70 |
+
) -> None:
|
| 71 |
+
super().__init__()
|
| 72 |
+
self.use_fused_vision_backbone = use_fused_vision_backbone
|
| 73 |
+
|
| 74 |
+
# [Contract] Validate number of (fused) vision backbones, create "alpha" featurizer and Instantiate
|
| 75 |
+
# =>> Note :: Monkey-Patch the `forward()` function of the backbone to ensure FSDP-compatibility
|
| 76 |
+
# Hardcodes `get_intermediate_layers` to return the **SECOND-TO-LAST** layer patches!
|
| 77 |
+
assert len(timm_model_ids) <= 2, "Prismatic models only support up to 2 (fused) vision backbones!"
|
| 78 |
+
self.featurizer = timm.create_model(
|
| 79 |
+
timm_model_ids[0],
|
| 80 |
+
pretrained=False,
|
| 81 |
+
num_classes=0,
|
| 82 |
+
img_size=image_sizes[0],
|
| 83 |
+
act_layer=timm_override_act_layers[0],
|
| 84 |
+
)
|
| 85 |
+
self.featurizer.forward = unpack_tuple(
|
| 86 |
+
partial(self.featurizer.get_intermediate_layers, n={len(self.featurizer.blocks) - 2})
|
| 87 |
+
)
|
| 88 |
+
self.embed_dim = self.featurizer.embed_dim
|
| 89 |
+
|
| 90 |
+
# If `use_fused_vision_backbone` =>> create "beta" featurizer
|
| 91 |
+
if self.use_fused_vision_backbone:
|
| 92 |
+
self.fused_featurizer = timm.create_model(
|
| 93 |
+
timm_model_ids[1],
|
| 94 |
+
pretrained=False,
|
| 95 |
+
num_classes=0,
|
| 96 |
+
img_size=image_sizes[1],
|
| 97 |
+
act_layer=timm_override_act_layers[1],
|
| 98 |
+
)
|
| 99 |
+
self.fused_featurizer.forward = unpack_tuple(
|
| 100 |
+
partial(self.fused_featurizer.get_intermediate_layers, n={len(self.fused_featurizer.blocks) - 2})
|
| 101 |
+
)
|
| 102 |
+
self.embed_dim += self.fused_featurizer.embed_dim
|
| 103 |
+
|
| 104 |
+
# Patch `vision_backbone.featurizer` and `vision_backbone.fused_featurizer` with HF-Compatible LayerScale
|
| 105 |
+
for module in self.featurizer.modules():
|
| 106 |
+
if isinstance(module, LayerScale):
|
| 107 |
+
ls_apply_patch(module)
|
| 108 |
+
|
| 109 |
+
if self.use_fused_vision_backbone:
|
| 110 |
+
for module in self.fused_featurizer.modules():
|
| 111 |
+
if isinstance(module, LayerScale):
|
| 112 |
+
ls_apply_patch(module)
|
| 113 |
+
|
| 114 |
+
def forward(self, pixel_values: torch.Tensor) -> torch.Tensor:
|
| 115 |
+
"""Run image (`pixel_values`) through featurizer; if channel-stacked, then dispatch and sequence stack."""
|
| 116 |
+
if not self.use_fused_vision_backbone:
|
| 117 |
+
return self.featurizer(pixel_values)
|
| 118 |
+
|
| 119 |
+
# Split `pixel_values :: [bsz, 2 * 3, resolution, resolution]` =>> featurize =>> channel stack
|
| 120 |
+
img, img_fused = torch.split(pixel_values, [3, 3], dim=1)
|
| 121 |
+
patches, patches_fused = self.featurizer(img), self.fused_featurizer(img_fused)
|
| 122 |
+
|
| 123 |
+
return torch.cat([patches, patches_fused], dim=2)
|
| 124 |
+
|
| 125 |
+
|
| 126 |
+
# === Prismatic Projector (nn.Module) Definitions ===
|
| 127 |
+
class PrismaticProjector(nn.Module):
|
| 128 |
+
def __init__(self, use_fused_vision_backbone: bool, vision_dim: int, llm_dim: int) -> None:
|
| 129 |
+
super().__init__()
|
| 130 |
+
self.use_fused_vision_backbone = use_fused_vision_backbone
|
| 131 |
+
self.vision_dim, self.llm_dim = vision_dim, llm_dim
|
| 132 |
+
|
| 133 |
+
# Switch on `use_fused_vision_backbone` =>> use slightly different MLPs and projection factors!
|
| 134 |
+
if not self.use_fused_vision_backbone:
|
| 135 |
+
self.fc1 = nn.Linear(self.vision_dim, self.llm_dim, bias=True)
|
| 136 |
+
self.fc2 = nn.Linear(self.llm_dim, self.llm_dim, bias=True)
|
| 137 |
+
self.act_fn1 = nn.GELU()
|
| 138 |
+
else:
|
| 139 |
+
initial_projection_dim = 4 * vision_dim
|
| 140 |
+
self.fc1 = nn.Linear(self.vision_dim, initial_projection_dim, bias=True)
|
| 141 |
+
self.fc2 = nn.Linear(initial_projection_dim, self.llm_dim, bias=True)
|
| 142 |
+
self.fc3 = nn.Linear(self.llm_dim, self.llm_dim, bias=True)
|
| 143 |
+
self.act_fn1 = nn.GELU()
|
| 144 |
+
self.act_fn2 = nn.GELU()
|
| 145 |
+
|
| 146 |
+
def forward(self, img_patches: torch.Tensor) -> torch.Tensor:
|
| 147 |
+
if not self.use_fused_vision_backbone:
|
| 148 |
+
projected_features = self.fc1(img_patches)
|
| 149 |
+
projected_features = self.act_fn1(projected_features)
|
| 150 |
+
projected_features = self.fc2(projected_features)
|
| 151 |
+
else:
|
| 152 |
+
projected_features = self.fc1(img_patches)
|
| 153 |
+
projected_features = self.act_fn1(projected_features)
|
| 154 |
+
projected_features = self.fc2(projected_features)
|
| 155 |
+
projected_features = self.act_fn2(projected_features)
|
| 156 |
+
projected_features = self.fc3(projected_features)
|
| 157 |
+
|
| 158 |
+
return projected_features
|
| 159 |
+
|
| 160 |
+
|
| 161 |
+
# === Main HF Class Definitions ===
|
| 162 |
+
@dataclass
|
| 163 |
+
class PrismaticCausalLMOutputWithPast(ModelOutput):
|
| 164 |
+
"""Base class for Prismatic casual (visually-conditioned) language model outputs; also exposes visual features."""
|
| 165 |
+
|
| 166 |
+
loss: Optional[torch.FloatTensor] = None
|
| 167 |
+
logits: torch.FloatTensor = None
|
| 168 |
+
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
|
| 169 |
+
hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
|
| 170 |
+
attentions: Optional[Tuple[torch.FloatTensor]] = None
|
| 171 |
+
|
| 172 |
+
# Additions for VLMs
|
| 173 |
+
projector_features: Optional[torch.FloatTensor] = None
|
| 174 |
+
|
| 175 |
+
|
| 176 |
+
class PrismaticPreTrainedModel(PreTrainedModel):
|
| 177 |
+
config_class: PretrainedConfig = PrismaticConfig
|
| 178 |
+
base_model_prefix: str = "model"
|
| 179 |
+
supports_gradient_checkpointing: bool = True
|
| 180 |
+
|
| 181 |
+
_no_split_modules: ClassVar[List[str]] = ["PrismaticProjector"]
|
| 182 |
+
_skip_keys_device_placement: str = "past_key_values"
|
| 183 |
+
_supports_flash_attn_2: bool = True
|
| 184 |
+
|
| 185 |
+
def _init_weights(self, module: nn.Module) -> None:
|
| 186 |
+
# Important :: this HF ported version is *not* meant for training from scratch; only inference and fine-tuning!
|
| 187 |
+
# => As such, this init_weights code is not correct; if training VLMs from scratch, use the main codebase at
|
| 188 |
+
# https://github.com/TRI-ML/prismatic-vlms
|
| 189 |
+
std = (
|
| 190 |
+
self.config.initializer_range
|
| 191 |
+
if hasattr(self.config, "initializer_range")
|
| 192 |
+
else self.config.text_config.initializer_range
|
| 193 |
+
)
|
| 194 |
+
|
| 195 |
+
if hasattr(module, "class_embedding"):
|
| 196 |
+
module.class_embedding.data.normal_(mean=0.0, std=std)
|
| 197 |
+
|
| 198 |
+
if isinstance(module, (nn.Linear, nn.Conv2d)):
|
| 199 |
+
module.weight.data.normal_(mean=0.0, std=std)
|
| 200 |
+
if module.bias is not None:
|
| 201 |
+
module.bias.data.zero_()
|
| 202 |
+
elif isinstance(module, nn.Embedding):
|
| 203 |
+
module.weight.data.normal_(mean=0.0, std=std)
|
| 204 |
+
if module.padding_idx is not None:
|
| 205 |
+
module.weight.data[module.padding_idx].zero_()
|
| 206 |
+
|
| 207 |
+
@property
|
| 208 |
+
def _supports_sdpa(self) -> bool:
|
| 209 |
+
"""Check LLM supports SDPA Attention"""
|
| 210 |
+
return self.language_model._supports_sdpa
|
| 211 |
+
|
| 212 |
+
|
| 213 |
+
class PrismaticForConditionalGeneration(PrismaticPreTrainedModel):
|
| 214 |
+
def __init__(self, config: PrismaticConfig) -> None:
|
| 215 |
+
super().__init__(config)
|
| 216 |
+
|
| 217 |
+
# [Validation] Lightweight Validate on `config` Fields + Dependency Versions
|
| 218 |
+
if config.use_fused_vision_backbone is None:
|
| 219 |
+
raise ValueError("Missing config field `use_fused_vision_backbone`")
|
| 220 |
+
|
| 221 |
+
if timm.__version__ not in {"0.9.10", "0.9.11", "0.9.12", "0.9.16"}:
|
| 222 |
+
raise NotImplementedError(
|
| 223 |
+
"TIMM Version must be >= 0.9.10 and < 1.0.0 (breaking); please raise a GitHub Issue "
|
| 224 |
+
"if you urgently need support for latest TIMM versions."
|
| 225 |
+
)
|
| 226 |
+
|
| 227 |
+
if (transformers.__version__ != "4.40.1") or (tokenizers.__version__ != "0.19.1"):
|
| 228 |
+
logger.warning(
|
| 229 |
+
f"Expected `transformers==4.40.1` and `tokenizers==0.19.1` but got "
|
| 230 |
+
f"`transformers=={transformers.__version__}` and `tokenizers=={tokenizers.__version__}`; "
|
| 231 |
+
f"there might be inference-time regressions due to dependency changes. If in doubt, please"
|
| 232 |
+
f"use the above versions."
|
| 233 |
+
)
|
| 234 |
+
|
| 235 |
+
# Instantiate PrismaticVisionBackbone (w/ Potential Fused Backbone)
|
| 236 |
+
self.vision_backbone = PrismaticVisionBackbone(
|
| 237 |
+
config.use_fused_vision_backbone, config.image_sizes, config.timm_model_ids, config.timm_override_act_layers
|
| 238 |
+
)
|
| 239 |
+
|
| 240 |
+
# Create Multimodal Projector
|
| 241 |
+
self.projector = PrismaticProjector(
|
| 242 |
+
config.use_fused_vision_backbone,
|
| 243 |
+
vision_dim=self.vision_backbone.embed_dim,
|
| 244 |
+
llm_dim=config.text_config.hidden_size,
|
| 245 |
+
)
|
| 246 |
+
|
| 247 |
+
# Instantiate LLM Backbone
|
| 248 |
+
self.language_model = AutoModelForCausalLM.from_config(
|
| 249 |
+
config.text_config, attn_implementation=config._attn_implementation
|
| 250 |
+
)
|
| 251 |
+
self.vocab_size = config.text_config.vocab_size
|
| 252 |
+
self.pad_token_id = config.pad_token_id
|
| 253 |
+
|
| 254 |
+
# HF Boilerplate =>> initializes weights via `_init_weights()` and sets gradient checkpointing
|
| 255 |
+
self.post_init()
|
| 256 |
+
|
| 257 |
+
# === `PreTrainedModel` Boilerplate ===
|
| 258 |
+
def get_input_embeddings(self) -> nn.Module:
|
| 259 |
+
return self.language_model.get_input_embeddings()
|
| 260 |
+
|
| 261 |
+
def set_input_embeddings(self, value: nn.Module) -> None:
|
| 262 |
+
self.language_model.set_input_embeddings(value)
|
| 263 |
+
|
| 264 |
+
def get_output_embeddings(self) -> nn.Module:
|
| 265 |
+
return self.language_model.get_output_embeddings()
|
| 266 |
+
|
| 267 |
+
def set_output_embeddings(self, new_embeddings: nn.Module) -> None:
|
| 268 |
+
self.language_model.set_output_embeddings(new_embeddings)
|
| 269 |
+
|
| 270 |
+
def get_decoder(self) -> nn.Module:
|
| 271 |
+
return self.language_model.get_decoder()
|
| 272 |
+
|
| 273 |
+
def set_decoder(self, decoder: nn.Module) -> None:
|
| 274 |
+
self.language_model.set_decoder(decoder)
|
| 275 |
+
|
| 276 |
+
def tie_weights(self) -> None:
|
| 277 |
+
self.language_model.tie_weights() # Note: `Llama-2` and `Mistral` don't tie weights (no-op)
|
| 278 |
+
|
| 279 |
+
def resize_token_embeddings(
|
| 280 |
+
self, new_num_tokens: Optional[int] = None, pad_to_multiple_of: Optional[int] = None
|
| 281 |
+
) -> nn.Embedding:
|
| 282 |
+
updated_embeddings = self.language_model.resize_token_embeddings(new_num_tokens, pad_to_multiple_of)
|
| 283 |
+
|
| 284 |
+
# Update config/instance variables
|
| 285 |
+
self.config.text_config.vocab_size = updated_embeddings.num_embeddings
|
| 286 |
+
self.vocab_size = updated_embeddings.num_embeddings
|
| 287 |
+
|
| 288 |
+
return updated_embeddings
|
| 289 |
+
|
| 290 |
+
# === Core Prismatic VLM `forward()` Logic ===
|
| 291 |
+
def forward(
|
| 292 |
+
self,
|
| 293 |
+
input_ids: Optional[torch.LongTensor] = None,
|
| 294 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 295 |
+
pixel_values: Optional[torch.FloatTensor] = None,
|
| 296 |
+
labels: Optional[torch.LongTensor] = None,
|
| 297 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
| 298 |
+
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
| 299 |
+
use_cache: Optional[bool] = None,
|
| 300 |
+
output_attentions: Optional[bool] = None,
|
| 301 |
+
output_hidden_states: Optional[bool] = None,
|
| 302 |
+
output_projector_features: Optional[bool] = None,
|
| 303 |
+
return_dict: Optional[bool] = None,
|
| 304 |
+
) -> Union[Tuple, PrismaticCausalLMOutputWithPast]:
|
| 305 |
+
"""Run a forward pass through the VLM, returning a PrismaticCausalLMOutputWithPast instance."""
|
| 306 |
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
| 307 |
+
output_hidden_states = (
|
| 308 |
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
| 309 |
+
)
|
| 310 |
+
output_projector_features = output_projector_features if output_projector_features is not None else False
|
| 311 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
| 312 |
+
|
| 313 |
+
# Respect `use_cache` only if not training (even if `gradient_checkpointing` is off)
|
| 314 |
+
use_cache = use_cache and not self.training
|
| 315 |
+
|
| 316 |
+
# Instantiate Placeholder for Projector Features
|
| 317 |
+
projected_patch_embeddings = None
|
| 318 |
+
|
| 319 |
+
# Note :: We only support forward passes with the following cases:
|
| 320 |
+
# => Cached Generation :: (input_ids.shape[1] == 1) and (past_key_values is not None)
|
| 321 |
+
# => Unimodal Forward :: (pixel_values is None)
|
| 322 |
+
# => Multimodal Forward :: (pixel_values is not None) and (input_ids/embeds.shape[0] == pixel_values.shape[0])
|
| 323 |
+
|
| 324 |
+
# === Handle Generation with Cache (`input_ids.shape[1] == 1`) =>> requires `past_keys_values` ===
|
| 325 |
+
if input_ids.shape[1] == 1:
|
| 326 |
+
assert input_ids.shape[0] == 1, "Generation is only currently supported for batch size of 1!"
|
| 327 |
+
assert past_key_values is not None, "You must provide `past_key_values` during cached generation!"
|
| 328 |
+
assert labels is None, "Unexpected key `labels` provided during cached generation!"
|
| 329 |
+
|
| 330 |
+
language_model_output = self.language_model(
|
| 331 |
+
input_ids=input_ids,
|
| 332 |
+
attention_mask=None,
|
| 333 |
+
position_ids=None,
|
| 334 |
+
past_key_values=past_key_values,
|
| 335 |
+
inputs_embeds=None,
|
| 336 |
+
labels=None,
|
| 337 |
+
use_cache=use_cache,
|
| 338 |
+
output_attentions=output_attentions,
|
| 339 |
+
output_hidden_states=output_hidden_states,
|
| 340 |
+
return_dict=return_dict,
|
| 341 |
+
)
|
| 342 |
+
|
| 343 |
+
# === Handle Unimodal Forward ===
|
| 344 |
+
elif pixel_values is None:
|
| 345 |
+
assert (input_ids is not None) and (inputs_embeds is None), "Missing `input_ids` in language-only forward!"
|
| 346 |
+
assert past_key_values is None, "Unexpected key `past_key_values` provided during language-only forward!"
|
| 347 |
+
|
| 348 |
+
language_model_output = self.language_model(
|
| 349 |
+
input_ids=input_ids,
|
| 350 |
+
attention_mask=attention_mask,
|
| 351 |
+
position_ids=None,
|
| 352 |
+
past_key_values=None,
|
| 353 |
+
inputs_embeds=None,
|
| 354 |
+
labels=labels,
|
| 355 |
+
use_cache=use_cache,
|
| 356 |
+
output_attentions=output_attentions,
|
| 357 |
+
output_hidden_states=output_hidden_states,
|
| 358 |
+
return_dict=return_dict,
|
| 359 |
+
)
|
| 360 |
+
|
| 361 |
+
# === Handle Multimodal Forward ===
|
| 362 |
+
elif (input_ids.shape[0] == pixel_values.shape[0]) or (inputs_embeds.shape[0] == pixel_values.shape[0]):
|
| 363 |
+
assert past_key_values is None, "Unexpected key `past_key_values` provided during language-only forward!"
|
| 364 |
+
|
| 365 |
+
# Visual Feature Extraction
|
| 366 |
+
patch_features = self.vision_backbone(pixel_values)
|
| 367 |
+
|
| 368 |
+
# Projection Logic =>> Update Attention Mask
|
| 369 |
+
projected_patch_embeddings = self.projector(patch_features)
|
| 370 |
+
projected_patch_attention_mask = None
|
| 371 |
+
if attention_mask is not None:
|
| 372 |
+
projected_patch_attention_mask = torch.full(
|
| 373 |
+
(projected_patch_embeddings.shape[0], projected_patch_embeddings.shape[1]),
|
| 374 |
+
fill_value=True,
|
| 375 |
+
dtype=attention_mask.dtype,
|
| 376 |
+
device=attention_mask.device,
|
| 377 |
+
)
|
| 378 |
+
|
| 379 |
+
# Get Input Embeddings (from Language Model Embeddings)
|
| 380 |
+
input_embeddings = self.get_input_embeddings()(input_ids)
|
| 381 |
+
|
| 382 |
+
# Build Multimodal Embeddings & Attention Mask =>> Prismatic defaults to inserting after <BOS> token (1:)
|
| 383 |
+
multimodal_embeddings = torch.cat(
|
| 384 |
+
[input_embeddings[:, :1, :], projected_patch_embeddings, input_embeddings[:, 1:, :]], dim=1
|
| 385 |
+
)
|
| 386 |
+
multimodal_attention_mask = None
|
| 387 |
+
if attention_mask is not None:
|
| 388 |
+
multimodal_attention_mask = torch.cat(
|
| 389 |
+
[attention_mask[:, :1], projected_patch_attention_mask, attention_mask[:, 1:]], dim=1
|
| 390 |
+
)
|
| 391 |
+
|
| 392 |
+
# Build Labels (if specified) =>> Ignore Labels for Patch Embeddings
|
| 393 |
+
multimodal_labels = None
|
| 394 |
+
if labels is not None:
|
| 395 |
+
projected_patch_labels = torch.full(
|
| 396 |
+
(projected_patch_embeddings.shape[0], projected_patch_embeddings.shape[1]),
|
| 397 |
+
fill_value=IGNORE_INDEX,
|
| 398 |
+
dtype=labels.dtype,
|
| 399 |
+
device=labels.device,
|
| 400 |
+
)
|
| 401 |
+
multimodal_labels = torch.cat([labels[:, :1], projected_patch_labels, labels[:, 1:]], dim=1)
|
| 402 |
+
|
| 403 |
+
# Dispatch to Language Model
|
| 404 |
+
language_model_output = self.language_model(
|
| 405 |
+
input_ids=None,
|
| 406 |
+
attention_mask=multimodal_attention_mask,
|
| 407 |
+
position_ids=None,
|
| 408 |
+
past_key_values=None,
|
| 409 |
+
inputs_embeds=multimodal_embeddings,
|
| 410 |
+
labels=multimodal_labels,
|
| 411 |
+
use_cache=use_cache,
|
| 412 |
+
output_attentions=output_attentions,
|
| 413 |
+
output_hidden_states=output_hidden_states,
|
| 414 |
+
return_dict=return_dict,
|
| 415 |
+
)
|
| 416 |
+
|
| 417 |
+
# === Otherwise =>> Assume Invalid! ===
|
| 418 |
+
elif (input_ids.shape[0] != pixel_values.shape[0]) or (inputs_embeds.shape[0] != pixel_values.shape[0]):
|
| 419 |
+
raise ValueError("Non-homogenous batch of (text, image) input -- forward() does not support mixed batches!")
|
| 420 |
+
|
| 421 |
+
else:
|
| 422 |
+
raise ValueError(
|
| 423 |
+
"Invalid PrismaticForConditionalGeneration `forward()` call with provided arguments:\n"
|
| 424 |
+
f"=> `input_ids` = {input_ids is not None}\n"
|
| 425 |
+
f"=> `attention_mask` = {attention_mask is not None}\n"
|
| 426 |
+
f"=> `pixel_values` = {pixel_values is not None}\n"
|
| 427 |
+
f"=> `labels` = {labels is not None}\n"
|
| 428 |
+
f"=> `input_embeds` = {inputs_embeds is not None}\n"
|
| 429 |
+
f"=> `past_key_values` = {past_key_values is not None}\n"
|
| 430 |
+
f"=> `use_cache` = {use_cache}"
|
| 431 |
+
)
|
| 432 |
+
|
| 433 |
+
# Unpack `language_model_output` and return PrismaticCausalLMOutputWithPast (or tuple if not `return_dict`)
|
| 434 |
+
if not return_dict:
|
| 435 |
+
if output_projector_features and (projected_patch_embeddings is not None):
|
| 436 |
+
return *language_model_output, projected_patch_embeddings
|
| 437 |
+
|
| 438 |
+
return language_model_output
|
| 439 |
+
|
| 440 |
+
return PrismaticCausalLMOutputWithPast(
|
| 441 |
+
loss=language_model_output.loss,
|
| 442 |
+
logits=language_model_output.logits,
|
| 443 |
+
past_key_values=language_model_output.past_key_values,
|
| 444 |
+
hidden_states=language_model_output.hidden_states,
|
| 445 |
+
attentions=language_model_output.attentions,
|
| 446 |
+
projector_features=projected_patch_embeddings,
|
| 447 |
+
)
|
| 448 |
+
|
| 449 |
+
# === GenerationMixin Methods ===
|
| 450 |
+
def prepare_inputs_for_generation(
|
| 451 |
+
self,
|
| 452 |
+
input_ids: Optional[torch.Tensor] = None,
|
| 453 |
+
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
| 454 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
| 455 |
+
pixel_values: Optional[torch.FloatTensor] = None,
|
| 456 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 457 |
+
**kwargs: str,
|
| 458 |
+
) -> Dict[str, torch.Tensor]:
|
| 459 |
+
"""Borrowed from `LlamaForCausalLM` and simplified for batch size = 1; mirrors original PrismaticVLM logic."""
|
| 460 |
+
if ((input_ids is not None) and (input_ids.shape[0] > 1)) or (
|
| 461 |
+
(inputs_embeds is not None) and (inputs_embeds.shape[0] > 1)
|
| 462 |
+
):
|
| 463 |
+
raise ValueError("Generation with batch size > 1 is not currently supported!")
|
| 464 |
+
|
| 465 |
+
# Handle `past_key_values` (cache) =>> assume `input_ids` just has unprocessed tokens
|
| 466 |
+
if past_key_values is not None:
|
| 467 |
+
input_ids = input_ids[:, -1:]
|
| 468 |
+
|
| 469 |
+
# If `input_embeds` are passed, we only want to use them in the 1st generation step
|
| 470 |
+
if inputs_embeds is not None and past_key_values is None:
|
| 471 |
+
model_inputs = {"input_embeds": inputs_embeds}
|
| 472 |
+
else:
|
| 473 |
+
model_inputs = {"input_ids": input_ids}
|
| 474 |
+
|
| 475 |
+
# Make sure `pixel_values` are preserved in `model_inputs`
|
| 476 |
+
model_inputs.update(
|
| 477 |
+
{
|
| 478 |
+
"attention_mask": attention_mask,
|
| 479 |
+
"pixel_values": pixel_values,
|
| 480 |
+
"past_key_values": past_key_values,
|
| 481 |
+
"use_cache": kwargs.get("use_cache"),
|
| 482 |
+
}
|
| 483 |
+
)
|
| 484 |
+
|
| 485 |
+
return model_inputs
|
| 486 |
+
|
| 487 |
+
# Defer to Language Model (all handle this differently, with different return types)
|
| 488 |
+
def _reorder_cache(self, *args, **kwargs) -> Any:
|
| 489 |
+
return self.language_model._reorder_cache(*args, **kwargs)
|
| 490 |
+
|
| 491 |
+
|
| 492 |
+
class OpenVLAForActionPrediction(PrismaticForConditionalGeneration):
|
| 493 |
+
config_class: PretrainedConfig = OpenVLAConfig
|
| 494 |
+
|
| 495 |
+
def __init__(self, config: OpenVLAConfig) -> None:
|
| 496 |
+
super().__init__(config)
|
| 497 |
+
self.norm_stats = config.norm_stats
|
| 498 |
+
|
| 499 |
+
# Compute action bins
|
| 500 |
+
self.bins = np.linspace(-1, 1, config.n_action_bins)
|
| 501 |
+
self.bin_centers = (self.bins[:-1] + self.bins[1:]) / 2.0
|
| 502 |
+
|
| 503 |
+
# Compute vocab size for de-tokenization -- revert added "multiple of"
|
| 504 |
+
self.vocab_size = self.config.text_config.vocab_size - self.config.pad_to_multiple_of
|
| 505 |
+
|
| 506 |
+
def predict_action(
|
| 507 |
+
self, input_ids: Optional[torch.LongTensor] = None, unnorm_key: Optional[str] = None, **kwargs: str
|
| 508 |
+
) -> np.ndarray:
|
| 509 |
+
"""Thin wrapper around .generate() that decodes predicted actions and unnormalizes them."""
|
| 510 |
+
# If the special empty token ('') does not already appear after the colon (':') token in the prompt
|
| 511 |
+
# (after "OUT:" or "ASSISTANT:"), insert it to match the inputs seen at training time
|
| 512 |
+
if not torch.all(input_ids[:, -1] == 29871):
|
| 513 |
+
input_ids = torch.cat(
|
| 514 |
+
(input_ids, torch.unsqueeze(torch.Tensor([29871]).long(), dim=0).to(input_ids.device)), dim=1
|
| 515 |
+
)
|
| 516 |
+
|
| 517 |
+
# Run VLA inference
|
| 518 |
+
generated_ids = self.generate(input_ids, max_new_tokens=self.get_action_dim(unnorm_key), **kwargs)
|
| 519 |
+
|
| 520 |
+
# Extract predicted action tokens and translate into (normalized) continuous actions
|
| 521 |
+
predicted_action_token_ids = generated_ids[0, -self.get_action_dim(unnorm_key) :].cpu().numpy()
|
| 522 |
+
discretized_actions = self.vocab_size - predicted_action_token_ids
|
| 523 |
+
discretized_actions = np.clip(discretized_actions - 1, a_min=0, a_max=self.bin_centers.shape[0] - 1)
|
| 524 |
+
normalized_actions = self.bin_centers[discretized_actions]
|
| 525 |
+
|
| 526 |
+
# Unnormalize actions
|
| 527 |
+
action_norm_stats = self.get_action_stats(unnorm_key)
|
| 528 |
+
mask = action_norm_stats.get("mask", np.ones_like(action_norm_stats["q01"], dtype=bool))
|
| 529 |
+
action_high, action_low = np.array(action_norm_stats["q99"]), np.array(action_norm_stats["q01"])
|
| 530 |
+
actions = np.where(
|
| 531 |
+
mask,
|
| 532 |
+
0.5 * (normalized_actions + 1) * (action_high - action_low) + action_low,
|
| 533 |
+
normalized_actions,
|
| 534 |
+
)
|
| 535 |
+
|
| 536 |
+
return actions
|
| 537 |
+
|
| 538 |
+
@staticmethod
|
| 539 |
+
def _check_unnorm_key(norm_stats: Dict[str, Dict[str, Any]], unnorm_key: Optional[str]) -> str:
|
| 540 |
+
if unnorm_key is None:
|
| 541 |
+
assert len(norm_stats) == 1, (
|
| 542 |
+
f"Your model was trained on more than one dataset, "
|
| 543 |
+
f"please pass a `unnorm_key` from the following options to choose the statistics "
|
| 544 |
+
f"used for un-normalizing actions: {norm_stats.keys()}"
|
| 545 |
+
)
|
| 546 |
+
unnorm_key = next(iter(norm_stats.keys()))
|
| 547 |
+
|
| 548 |
+
assert unnorm_key in norm_stats, (
|
| 549 |
+
f"The `unnorm_key` you chose is not in the set of available dataset statistics, "
|
| 550 |
+
f"please choose from: {norm_stats.keys()}"
|
| 551 |
+
)
|
| 552 |
+
return unnorm_key
|
| 553 |
+
|
| 554 |
+
def get_action_dim(self, unnorm_key: Optional[str] = None) -> int:
|
| 555 |
+
"""Get the dimensionality of the policy's action space."""
|
| 556 |
+
unnorm_key = self._check_unnorm_key(self.norm_stats, unnorm_key)
|
| 557 |
+
return len(self.norm_stats[unnorm_key]["action"]["q01"])
|
| 558 |
+
|
| 559 |
+
def get_action_stats(self, unnorm_key: Optional[str] = None) -> Dict[str, Any]:
|
| 560 |
+
"""Get all the logged statistics for the given dataset."""
|
| 561 |
+
unnorm_key = self._check_unnorm_key(self.norm_stats, unnorm_key)
|
| 562 |
+
return self.norm_stats[unnorm_key]["action"]
|
processing_prismatic.py
ADDED
|
@@ -0,0 +1,257 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
processing_prismatic.py
|
| 3 |
+
|
| 4 |
+
HuggingFace-style preprocessor definitions for Prismatic VLMs, inheriting from `ProcessorMixin`. Default configuration
|
| 5 |
+
specifies `siglip-224px+7b`.
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
from typing import Any, ClassVar, List, Optional, Tuple, Union
|
| 9 |
+
|
| 10 |
+
import timm.data
|
| 11 |
+
import torch
|
| 12 |
+
import torchvision.transforms.functional as TVF
|
| 13 |
+
from PIL import Image
|
| 14 |
+
from torchvision.transforms import CenterCrop, Compose, Normalize, Resize, ToTensor
|
| 15 |
+
from transformers import PreTrainedTokenizerBase
|
| 16 |
+
from transformers.image_processing_utils import BatchFeature, ImageProcessingMixin
|
| 17 |
+
from transformers.processing_utils import ProcessorMixin
|
| 18 |
+
from transformers.tokenization_utils import PaddingStrategy, PreTokenizedInput, TextInput, TruncationStrategy
|
| 19 |
+
from transformers.utils import TensorType
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
# === Image Processing ===
|
| 23 |
+
def letterbox_pad_transform(image: Image.Image, padding_fill_value: Tuple[int, int, int]) -> Image.Image:
|
| 24 |
+
"""Given a PIL.Image, pad to square by adding a symmetric border around the height/width."""
|
| 25 |
+
(w, h), max_wh = image.size, max(image.size)
|
| 26 |
+
horizontal_pad, vertical_pad = int((max_wh - w) / 2), int((max_wh - h) / 2)
|
| 27 |
+
padding = (horizontal_pad, vertical_pad, horizontal_pad, vertical_pad)
|
| 28 |
+
|
| 29 |
+
return TVF.pad(image, padding, fill=padding_fill_value, padding_mode="constant")
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
class PrismaticImageProcessor(ImageProcessingMixin):
|
| 33 |
+
model_input_names: ClassVar[List[str]] = ["pixel_values"]
|
| 34 |
+
|
| 35 |
+
def __init__(
|
| 36 |
+
self,
|
| 37 |
+
use_fused_vision_backbone: bool = False,
|
| 38 |
+
image_resize_strategy: str = "letterbox",
|
| 39 |
+
input_sizes: Optional[List[Tuple[int, int, int]]] = None,
|
| 40 |
+
interpolations: Optional[List[str]] = None,
|
| 41 |
+
means: Optional[List[Tuple[float, float, float]]] = None,
|
| 42 |
+
stds: Optional[List[Tuple[float, float, float]]] = None,
|
| 43 |
+
**kwargs: str,
|
| 44 |
+
) -> None:
|
| 45 |
+
"""
|
| 46 |
+
Initialize a PrismaticImageProcessor as a wrapper around a torchvision transform; this transform will be
|
| 47 |
+
created by TIMM, and edited to follow our custom `image_resize_strategy` logic.
|
| 48 |
+
|
| 49 |
+
@param use_fused_vision_backbone: Boolean indicating single or fused (dual) vision backbone
|
| 50 |
+
@param image_resize_strategy: Prismatic image resize strategy in < resize-naive | resize-crop | letterbox >
|
| 51 |
+
@param input_size: [TIMM :: `data_cfg`] Input image size as tuple (channels, width, height)
|
| 52 |
+
@param interpolation: [TIMM :: `data_cfg`] Interpolation as string (default: "bicubic")
|
| 53 |
+
@param mean: [TIMM :: `data_cfg`] Normalization mean as float tuple (or two-tuple if `fused_backbone`)
|
| 54 |
+
@param std: [TIMM :: `data_cfg`] Normalization std as float tuple (or two-tuple if `fused_backbone`)
|
| 55 |
+
"""
|
| 56 |
+
self.use_fused_vision_backbone = use_fused_vision_backbone
|
| 57 |
+
self.image_resize_strategy = image_resize_strategy
|
| 58 |
+
|
| 59 |
+
# Handle `None` default values
|
| 60 |
+
input_sizes = [(3, 224, 224)] if input_sizes is None else input_sizes
|
| 61 |
+
means = [(0.5, 0.5, 0.5)] if means is None else means
|
| 62 |
+
stds = [(0.5, 0.5, 0.5)] if stds is None else stds
|
| 63 |
+
|
| 64 |
+
# TIMM `data_cfg` Parameters
|
| 65 |
+
self.input_sizes, self.interpolations, self.means, self.stds = input_sizes, interpolations, means, stds
|
| 66 |
+
|
| 67 |
+
# Grab torchvision transforms via TIMM =>> need to parse for specific "functional" transform values!
|
| 68 |
+
self.tvf_resize_params, self.tvf_crop_params, self.tvf_normalize_params = [], [], []
|
| 69 |
+
self.tvf_do_letterbox, self.tvf_letterbox_fill = False, None
|
| 70 |
+
|
| 71 |
+
for idx in range(len(input_sizes)):
|
| 72 |
+
transform = timm.data.create_transform(
|
| 73 |
+
input_size=self.input_sizes[idx],
|
| 74 |
+
interpolation=self.interpolations[idx],
|
| 75 |
+
mean=self.means[idx],
|
| 76 |
+
std=self.stds[idx],
|
| 77 |
+
crop_pct=1.0, # Set to 1.0 to ignore cropping (initial Resize sets `input_size`)
|
| 78 |
+
crop_mode="center", # Default crop mode -- no-op when `crop_pct == 1.0`
|
| 79 |
+
is_training=False, # No image augmentations when loading the transform!
|
| 80 |
+
)
|
| 81 |
+
|
| 82 |
+
# [Validation] Ensure appropriate transform structure, expected sizes
|
| 83 |
+
if not (
|
| 84 |
+
isinstance(transform, Compose)
|
| 85 |
+
and (len(transform.transforms) == 4)
|
| 86 |
+
and isinstance(transform.transforms[0], Resize)
|
| 87 |
+
and isinstance(transform.transforms[1], CenterCrop)
|
| 88 |
+
and isinstance(transform.transforms[2], ToTensor)
|
| 89 |
+
and isinstance(transform.transforms[3], Normalize)
|
| 90 |
+
and (transform.transforms[0].size == self.input_sizes[idx][-1])
|
| 91 |
+
and (transform.transforms[1].size == self.input_sizes[idx][-2:])
|
| 92 |
+
):
|
| 93 |
+
raise ValueError(f"Unexpected TIMM image transformation structure/sizes: `{transform}`")
|
| 94 |
+
|
| 95 |
+
# HF Image Processors *must* be JSON-serializable; as such, cannot have torchvision. as an attribute.
|
| 96 |
+
# => Instead, we're going to parse the transform and call "torchvision.transforms.functional" (`tvf`)
|
| 97 |
+
resize_t, crop_t, norm_t = transform.transforms[0], transform.transforms[1], transform.transforms[3]
|
| 98 |
+
self.tvf_resize_params.append(
|
| 99 |
+
{
|
| 100 |
+
"size": resize_t.size,
|
| 101 |
+
"interpolation": TVF.pil_modes_mapping[resize_t.interpolation],
|
| 102 |
+
"max_size": None,
|
| 103 |
+
"antialias": True,
|
| 104 |
+
}
|
| 105 |
+
)
|
| 106 |
+
self.tvf_crop_params.append({"output_size": crop_t.size})
|
| 107 |
+
self.tvf_normalize_params.append(
|
| 108 |
+
{
|
| 109 |
+
"mean": norm_t.mean.float().numpy().tolist(),
|
| 110 |
+
"std": norm_t.std.float().numpy().tolist(),
|
| 111 |
+
"inplace": False,
|
| 112 |
+
}
|
| 113 |
+
)
|
| 114 |
+
self.tvf_do_letterbox, self.tvf_letterbox_fill = False, None
|
| 115 |
+
|
| 116 |
+
# Handle Prismatic `image_resize_strategy`
|
| 117 |
+
if self.image_resize_strategy == "resize-naive":
|
| 118 |
+
self.tvf_resize_params[idx]["size"] = (resize_t.size, resize_t.size)
|
| 119 |
+
elif self.image_resize_strategy == "letterbox":
|
| 120 |
+
self.tvf_do_letterbox, self.tvf_letterbox_fill = True, tuple([int(x * 255) for x in self.means[idx]])
|
| 121 |
+
elif self.image_resize_strategy == "resize-crop":
|
| 122 |
+
pass
|
| 123 |
+
else:
|
| 124 |
+
raise ValueError(f"Image resize strategy `{self.image_resize_strategy}` is not supported!")
|
| 125 |
+
|
| 126 |
+
# Dispatch **kwargs to super()
|
| 127 |
+
super().__init__(**kwargs)
|
| 128 |
+
|
| 129 |
+
def apply_transform(self, img: Image.Image) -> torch.Tensor:
|
| 130 |
+
"""Apply `functional` variant of TIMM's Transform = Compose([Resize -> CenterCrop -> ToTensor -> Normalize])"""
|
| 131 |
+
if self.tvf_do_letterbox:
|
| 132 |
+
img = letterbox_pad_transform(img, self.tvf_letterbox_fill)
|
| 133 |
+
|
| 134 |
+
# [Contract] Fused Backbones expect "channel-stacked" inputs; we'll unpack on the model side!
|
| 135 |
+
imgs_t = []
|
| 136 |
+
for idx in range(len(self.input_sizes)):
|
| 137 |
+
img_idx = TVF.resize(img, **self.tvf_resize_params[idx])
|
| 138 |
+
img_idx = TVF.center_crop(img_idx, **self.tvf_crop_params[idx])
|
| 139 |
+
img_idx_t = TVF.to_tensor(img_idx)
|
| 140 |
+
img_idx_t = TVF.normalize(img_idx_t, **self.tvf_normalize_params[idx])
|
| 141 |
+
imgs_t.append(img_idx_t)
|
| 142 |
+
|
| 143 |
+
# [Contract] `imgs_t` is a list of Tensors of shape [3, input_size, input_size]; stack along dim = 0
|
| 144 |
+
img_t = torch.vstack(imgs_t)
|
| 145 |
+
|
| 146 |
+
return img_t
|
| 147 |
+
|
| 148 |
+
def preprocess(
|
| 149 |
+
self,
|
| 150 |
+
images: Union[Image.Image, List[Image.Image]],
|
| 151 |
+
return_tensors: Optional[Union[str, TensorType]] = None,
|
| 152 |
+
**_: str,
|
| 153 |
+
) -> BatchFeature:
|
| 154 |
+
"""
|
| 155 |
+
Preprocess an image (or batch of images); note that unlike the `transformers :: BaseImageProcessor` we
|
| 156 |
+
explicitly only handle PIL.Image.Image instances for simplicity.
|
| 157 |
+
|
| 158 |
+
@param images: A (batch of) PIL.Image.Image instance(s) to preprocess.
|
| 159 |
+
@param return_tensors: BatchFeature default Tensor format (e.g., "pt" for torch); if None, returns np.ndarray
|
| 160 |
+
|
| 161 |
+
@return: Instance of `transformers :: BatchFeature` with a single key "pixel_values"
|
| 162 |
+
"""
|
| 163 |
+
if not isinstance(images, list):
|
| 164 |
+
images = [images]
|
| 165 |
+
|
| 166 |
+
# Apply `self.img_transform` to each image (will return list of torch.Tensors); stack into "batched" Tensor
|
| 167 |
+
pixel_values = torch.stack([self.apply_transform(img.convert("RGB")) for img in images])
|
| 168 |
+
|
| 169 |
+
# Return BatchFeature =>> note that for compatibility, constructor expects Dict[str, np.ndarray], so we convert
|
| 170 |
+
return BatchFeature(data={"pixel_values": pixel_values.float().numpy()}, tensor_type=return_tensors)
|
| 171 |
+
|
| 172 |
+
def __call__(self, images: Union[Image.Image, List[Image.Image]], **kwargs) -> BatchFeature:
|
| 173 |
+
return self.preprocess(images, **kwargs)
|
| 174 |
+
|
| 175 |
+
|
| 176 |
+
# === PrismaticProcessor =>> Wraps both ImageProcessor and Tokenizer ===
|
| 177 |
+
# =>> https://github.com/huggingface/transformers/blob/main/src/transformers/models/llava/processing_llava.py
|
| 178 |
+
class PrismaticProcessor(ProcessorMixin):
|
| 179 |
+
attributes: ClassVar[List[str]] = ["image_processor", "tokenizer"]
|
| 180 |
+
image_processor_class: str = "AutoImageProcessor"
|
| 181 |
+
tokenizer_class: str = "AutoTokenizer"
|
| 182 |
+
|
| 183 |
+
def __init__(
|
| 184 |
+
self,
|
| 185 |
+
image_processor: Optional[ImageProcessingMixin] = None,
|
| 186 |
+
tokenizer: Optional[PreTrainedTokenizerBase] = None,
|
| 187 |
+
) -> None:
|
| 188 |
+
super().__init__(image_processor, tokenizer)
|
| 189 |
+
|
| 190 |
+
def __call__(
|
| 191 |
+
self,
|
| 192 |
+
text: Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]],
|
| 193 |
+
images: Union[Image.Image, List[Image.Image]],
|
| 194 |
+
padding: Union[bool, str, PaddingStrategy] = False,
|
| 195 |
+
truncation: Optional[Union[bool, str, TruncationStrategy]] = None,
|
| 196 |
+
max_length: Optional[int] = None,
|
| 197 |
+
return_tensors: Optional[Union[str, TensorType]] = TensorType.PYTORCH,
|
| 198 |
+
) -> BatchFeature:
|
| 199 |
+
"""
|
| 200 |
+
Preprocess a given (batch) of text/images for a Prismatic VLM; forwards text to the underlying LLM's tokenizer,
|
| 201 |
+
forwards images to PrismaticImageProcessor.
|
| 202 |
+
|
| 203 |
+
@param text: The (batch) of text to encode; must be a string or list of strings.
|
| 204 |
+
@param images: A (batch of) PIL.Image.Image instance(s) to preprocess.
|
| 205 |
+
@param padding: Sequence padding strategy (if multiple specified) in < True = "longest" | "max_length" | False >
|
| 206 |
+
@param truncation: Truncation strategy for the output sequences; requires `max_length` to be specified
|
| 207 |
+
@param max_length: Maximum length (in tokens) to truncate
|
| 208 |
+
@param return_tensors: Type of return tensors (usually "pt" or TensorType.PYTORCH)
|
| 209 |
+
|
| 210 |
+
@return: BatchFeature with keys for `input_ids`, `attention_mask` and `pixel_values`.
|
| 211 |
+
"""
|
| 212 |
+
pixel_values = self.image_processor(images, return_tensors=return_tensors)["pixel_values"]
|
| 213 |
+
text_inputs = self.tokenizer(
|
| 214 |
+
text, return_tensors=return_tensors, padding=padding, truncation=truncation, max_length=max_length
|
| 215 |
+
)
|
| 216 |
+
|
| 217 |
+
# [Validate] Need same number of images and text inputs!
|
| 218 |
+
if pixel_values.shape[0] != text_inputs.input_ids.shape[0]:
|
| 219 |
+
raise ValueError("Batch is malformed; expected same number of images and text inputs!")
|
| 220 |
+
|
| 221 |
+
return BatchFeature(data={**text_inputs, "pixel_values": pixel_values})
|
| 222 |
+
|
| 223 |
+
# === Tokenizer Dispatch Utilities =>> check `PreTrainedTokenizerBase` for documentation ===
|
| 224 |
+
def batch_decode(
|
| 225 |
+
self,
|
| 226 |
+
sequences: Union[List[int], List[List[int]], torch.Tensor, Any], # `Any` = np.ndarray | tf.Tensor
|
| 227 |
+
skip_special_tokens: bool = False,
|
| 228 |
+
clean_up_tokenization_spaces: Optional[bool] = None,
|
| 229 |
+
**kwargs: str,
|
| 230 |
+
) -> List[str]:
|
| 231 |
+
return self.tokenizer.batch_decode(
|
| 232 |
+
sequences=sequences,
|
| 233 |
+
skip_special_tokens=skip_special_tokens,
|
| 234 |
+
clean_up_tokenization_spaces=clean_up_tokenization_spaces,
|
| 235 |
+
**kwargs,
|
| 236 |
+
)
|
| 237 |
+
|
| 238 |
+
def decode(
|
| 239 |
+
self,
|
| 240 |
+
token_ids: Union[int, List[int], torch.Tensor, Any], # `Any` = np.ndarray | tf.Tensor
|
| 241 |
+
skip_special_tokens: bool = False,
|
| 242 |
+
clean_up_tokenization_spaces: Optional[bool] = None,
|
| 243 |
+
**kwargs: str,
|
| 244 |
+
) -> str:
|
| 245 |
+
return self.tokenizer.decode(
|
| 246 |
+
token_ids=token_ids,
|
| 247 |
+
skip_special_tokens=skip_special_tokens,
|
| 248 |
+
clean_up_tokenization_spaces=clean_up_tokenization_spaces,
|
| 249 |
+
**kwargs,
|
| 250 |
+
)
|
| 251 |
+
|
| 252 |
+
@property
|
| 253 |
+
def model_input_names(self) -> List[str]:
|
| 254 |
+
tokenizer_input_names = self.tokenizer.model_input_names
|
| 255 |
+
image_processor_input_names = self.image_processor.model_input_names
|
| 256 |
+
|
| 257 |
+
return list(dict.fromkeys(tokenizer_input_names + image_processor_input_names))
|
processor_config.json
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"auto_map": {
|
| 3 |
+
"AutoProcessor": "processing_prismatic.PrismaticProcessor"
|
| 4 |
+
},
|
| 5 |
+
"processor_class": "PrismaticProcessor"
|
| 6 |
+
}
|