diff --git a/ALOHA.md b/ALOHA.md new file mode 100644 index 0000000000000000000000000000000000000000..c8e7853a8bdfefe16285827422d99258f42da4e2 --- /dev/null +++ b/ALOHA.md @@ -0,0 +1,157 @@ +# OpenVLA-OFT+ in Real-World ALOHA Robot Tasks + +## Relevant Files + +Evaluation +* `experiments/robot/aloha/`: ALOHA training and eval files + * `run_aloha_eval.py`: ALOHA eval script (CLIENT SIDE; see "SERVER SIDE" below) + * `aloha_utils.py`: ALOHA eval utils + * Other ALOHA robot environment files copied from the original [ALOHA GitHub repo](https://github.com/tonyzhaozh/aloha): + * `constants.py` + * `real_env.py` + * `robot_utils.py` +* `experiments/robot/`: General eval utils files + * `openvla_utils.py`: OpenVLA-specific eval utils + * `robot_utils.py`: Other eval utils +* `vla-scripts/deploy.py`: VLA server deploy script (SERVER SIDE) + +Note: Unlike the LIBERO evaluation setup, we use a server-client interface here. This is particularly useful if the user's machine which commands the robot does not have access to a local GPU with sufficient specs to run the fine-tuned VLA policies. + +Training +* `experiments/robot/aloha/`: ALOHA training and eval files + * `preprocess_split_aloha_data.py`: ALOHA data preprocessing script +* `vla-scripts/finetune.py`: VLA fine-tuning script + +## Setup + +Set up a conda environment for training policies and deploying them on the VLA server (see instructions in [SETUP.md](SETUP.md)). + +## Fine-Tuning on ALOHA Robot Data + +We assume that you have collected a set of expert demonstrations on the ALOHA robot already. + +First, use our `preprocess_split_aloha_data.py` script to preprocess the raw ALOHA dataset: downsize images from 480x640 to 256x256 and split into training and validation sets. Below are examples for the `put X into pot` task in our paper (which has 3 possible target objects, 1 per episode): + +```bash +python experiments/robot/aloha/preprocess_split_aloha_data.py \ + --dataset_path /scr/moojink/data/aloha1_raw/put_green_pepper_into_pot/ \ + --out_base_dir /scr/moojink/data/aloha1_preprocessed/ \ + --percent_val 0.05 +python experiments/robot/aloha/preprocess_split_aloha_data.py \ + --dataset_path /scr/moojink/data/aloha1_raw/put_red_pepper_into_pot/ \ + --out_base_dir /scr/moojink/data/aloha1_preprocessed/ \ + --percent_val 0.05 +python experiments/robot/aloha/preprocess_split_aloha_data.py \ + --dataset_path /scr/moojink/data/aloha1_raw/put_yellow_corn_into_pot/ \ + --out_base_dir /scr/moojink/data/aloha1_preprocessed/ \ + --percent_val 0.05 +``` + +Then, convert the preprocessed ALOHA datasets into a single RLDS dataset that is compatible with OpenVLA fine-tuning. This process is the same as in the original OpenVLA repo. See instructions for converting to RLDS [here](https://github.com/moojink/rlds_dataset_builder) (a sample ALOHA preprocessed-to-RLDS conversion script is available [here](https://github.com/moojink/rlds_dataset_builder/blob/main/aloha1_put_X_into_pot_300_demos/aloha1_put_X_into_pot_300_demos_dataset_builder.py); this script converts the three preprocessed datasets above into one unified RLDS dataset, with train/val splits). + +After converting to RLDS, register the dataset (which, for the example task above, would be called `aloha1_put_X_into_pot_300_demos`) with our dataloader by adding an entry for it in `configs.py` ([here](prismatic/vla/datasets/rlds/oxe/configs.py#L680)), `transforms.py` ([here](prismatic/vla/datasets/rlds/oxe/transforms.py#L928)), and `mixtures.py` ([here](prismatic/vla/datasets/rlds/oxe/mixtures.py#L216)). For reference, in each of these files, there are sample entries for the ALOHA datasets that we used in our paper. + +Before fine-tuning, set the desired ALOHA action chunk size in [`prismatic/vla/constants.py`](prismatic/vla/constants.py) (see `NUM_ACTIONS_CHUNK` in `ALOHA_CONSTANTS`). We set it to 25 by default because we used a control frequency of 25 Hz in our ALOHA setup to reduce storage costs and training time (while still maintaining smoothness in the robot's motions). If you use 50 Hz, we recommend setting `NUM_ACTIONS_CHUNK` to `50`. In general, 1 second-long action chunks are a good default. Do NOT modify `ACTION_PROPRIO_NORMALIZATION_TYPE`: Since the ALOHA robot action space is absolute joint angles, we do not want to use a normalization scheme that clips outlier values (like the Q1-Q99 normalization we used with the relative end-effector pose actions for LIBERO), since that would prevent the model from outputting certain robot joint angles that are crucial for solving the task. + +Now begin fine-tuning! Below is a sample command to fine-tune OpenVLA using our OFT+ recipe on the `put X into pot` task above ("+" in "OFT+" means FiLM is included for enhanced language grounding). Replace `X` in the first line with the number of GPUs available to you. + +```bash +torchrun --standalone --nnodes 1 --nproc-per-node X vla-scripts/finetune.py \ + --vla_path openvla/openvla-7b \ + --data_root_dir /PATH/TO/RLDS/DATASETS/DIR/ \ + --dataset_name aloha1_put_X_into_pot_300_demos \ + --run_root_dir /YOUR/CHECKPOINTS/AND/LOG/DIR/ \ + --use_l1_regression True \ + --use_diffusion False \ + --use_film True \ + --num_images_in_input 3 \ + --use_proprio True \ + --batch_size 4 \ + --learning_rate 5e-4 \ + --num_steps_before_decay 50000 \ + --max_steps 100005 \ + --use_val_set True \ + --val_freq 10000 \ + --save_freq 10000 \ + --save_latest_checkpoint_only False \ + --image_aug True \ + --lora_rank 32 \ + --wandb_entity "YOUR_WANDB_ENTITY" \ + --wandb_project "YOUR_WANDB_PROJECT" \ + --run_id_note parallel_dec--25_acts_chunk--continuous_acts--L1_regression--3rd_person_img--left_right_wrist_imgs--proprio_state--film +``` + +The above training command should reproduce our OpenVLA-OFT+ results on the `put X into pot` task if `X = 8` and the 100K step checkpoint is evaluated. It will fine-tune OpenVLA using 3 input images (1 third-person image + 2 wrist camera images). Note that we use learning rate decay after a certain point (50K steps in the command above) since doing so speeds up training convergence (train L1 loss spikes down from our experience). + +Best practices for fine-tuning: +* In general, we recommend fine-tuning until training L1 loss goes below 0.01 and starts to plateau. + * One way to achieve this is to fine-tune using our default learning rate of `5e-4` until the loss starts to decrease very slowly, and then decay the learning rate by 10x to `5e-5` (which should make the loss spike down) and train until the training L1 loss finally plateaus. +* Depending on your dataset size, you may need to adjust some hyperparameters. For example, if you use a large dataset with over 300 demos, you may need to decay the learning rate later and train for longer for best performance. Decaying too earlier can lead to a suboptimal policy. +* If your task does not require good langauge grounding (e.g., if there is only one language instruction), FiLM is not necessary; consider setting `--use_film False` to train fewer model parameters. +* Please be sure to test your policy with the same device/GPU used to train it! Otherwise, performance may drop substantially. You may be able to avoid the performance drop if you merge the LoRA weights into the base model on the downstream device used for testing (e.g., if you train on H100 and then merge on A100 before testing on A100). You can see our script [vla-scripts/merge_lora_weights_and_save.py](vla-scripts/merge_lora_weights_and_save.py) for merging the LoRA adapter into the base model offline. It's okay if you already merged LoRA weights into the base OpenVLA model during fine-tuning; you can always redownload the base model and merge again as long as you still have the LoRA adapter (`merge_lora_weights_and_save.py` will handle this for you). + +If you run into any issues, please open a new GitHub issue. + +## Launching ALOHA Robot Evaluations + +In the primary conda environment (`openvla-oft`) which you will use to launch the VLA server, install a few packages for the server-client interface: + +```bash +conda activate openvla-oft +pip install uvicorn fastapi json-numpy +``` + +On the machine that you will use to command the robot, set up a second conda environment that will be used to run the robot environment, query the VLA server, and execute actions in the environment: + +```bash +# Create and activate client conda environment +conda create -n openvla-oft-aloha python=3.10 -y +conda activate openvla-oft-aloha + +# Install PyTorch +# Use a command specific to your machine: https://pytorch.org/get-started/locally/ +pip3 install torch torchvision torchaudio + +# Clone openvla-oft repo and pip install to download dependencies +git clone https://github.com/moojink/openvla-oft.git +cd openvla-oft +pip install -e . + +# Install packages needed for the ALOHA robot environment +pip install -r experiments/robot/aloha/requirements_aloha.txt +``` + +Launch the VLA server on the machine that has the GPU you will use to run model inference (using the `openvla-oft` conda environment). Below is a sample command for this (change as needed): + +```bash +python vla-scripts/deploy.py \ + --pretrained_checkpoint /PATH/TO/FINETUNED/MODEL/CHECKPOINT/DIR/ \ + --use_l1_regression True \ + --use_film True \ + --num_images_in_input 3 \ + --use_proprio True \ + --center_crop True \ + --unnorm_key aloha1_put_X_into_pot_300_demos +``` + +Then, run the ALOHA evaluation script. Specify the VLA server URL or IP address in the `vla_server_url` argument. Below is a sample command: + +```bash +python experiments/robot/aloha/run_aloha_eval.py \ + --center_crop True \ + --num_open_loop_steps 25 \ + --use_vla_server True \ + --vla_server_url \ + --num_rollouts_planned \ + --max_steps +``` + +If you run into any issues, please open a new GitHub issue. + +## Troubleshooting Tips + +* Tip #1: If you run into a ROS error such as `ImportError: /lib/x86_64-linux-gnu/libp11-kit.so.0: undefined symbol: ffi_type_pointer, version LIBFFI_BASE_7.0`, try running the following command in your client conda environment (`openvla-oft-aloha`): + + ``` + conda install -c conda-forge libffi + ``` diff --git a/LIBERO.md b/LIBERO.md new file mode 100644 index 0000000000000000000000000000000000000000..893a1215f4b3ba809cf9da377f7bd59a2f779a87 --- /dev/null +++ b/LIBERO.md @@ -0,0 +1,130 @@ +# OpenVLA-OFT in the LIBERO Simulation Benchmark + +## Relevant Files + +Evaluation +* `experiments/robot/libero/`: LIBERO eval files + * `run_libero_eval.py`: LIBERO eval script + * `libero_utils.py`: LIBERO eval utils +* `experiments/robot/`: General eval utils files + * `openvla_utils.py`: OpenVLA-specific eval utils + * `robot_utils.py`: Other eval utils + +Training +* `vla-scripts/finetune.py`: VLA fine-tuning script + + +## Setup + +Set up a conda environment (see instructions in [SETUP.md](SETUP.md)). + +Clone and install the [LIBERO repo](https://github.com/Lifelong-Robot-Learning/LIBERO) and required packages: + +```bash +git clone https://github.com/Lifelong-Robot-Learning/LIBERO.git +pip install -e LIBERO +pip install -r experiments/robot/libero/libero_requirements.txt # From openvla-oft base dir +``` + +(Optional, if you plan to launch training) To download the [LIBERO datasets](https://huggingface.co/datasets/openvla/modified_libero_rlds) that we used in our fine-tuning +experiments, run the command below. This will download the LIBERO-Spatial, LIBERO-Object, LIBERO-Goal, +and LIBERO-10 datasets in RLDS data format (~10 GB total). You can use these to fine-tune OpenVLA or +train other methods. This step is optional since we provide pretrained OpenVLA-OFT checkpoints below. +Note that these are the same datasets used in the original OpenVLA project. If needed, see details on how to download the original non-RLDS datasets [here](https://github.com/openvla/openvla?tab=readme-ov-file#libero-setup). +```bash +git clone git@hf.co:datasets/openvla/modified_libero_rlds +``` + +## Launching LIBERO Evaluations + +We fine-tuned OpenVLA via LoRA (r=32) with our OFT recipe on four LIBERO task suites: LIBERO-Spatial, LIBERO-Object, LIBERO-Goal, and LIBERO-10 (also called LIBERO-Long). +In the initial version of our paper, we trained one checkpoint for each LIBERO task suite independently. In an updated version of the paper, we conducted an additional experiment in which we trained a single policy on all four task suites combined (results for this are available in the Additional Experiments section in the Appendix). Overall, the results for the task-specific policies and the combined policy are comparable: 97.1% vs. 96.8% average success rate across the four suites, respectively. + +Below are the four independently trained OpenVLA-OFT checkpoints for LIBERO: +* [moojink/openvla-7b-oft-finetuned-libero-spatial](https://huggingface.co/moojink/openvla-7b-oft-finetuned-libero-spatial) +* [moojink/openvla-7b-oft-finetuned-libero-object](https://huggingface.co/moojink/openvla-7b-oft-finetuned-libero-object) +* [moojink/openvla-7b-oft-finetuned-libero-goal](https://huggingface.co/moojink/openvla-7b-oft-finetuned-libero-goal) +* [moojink/openvla-7b-oft-finetuned-libero-10](https://huggingface.co/moojink/openvla-7b-oft-finetuned-libero-10) + +Below is the OpenVLA-OFT checkpoint trained on all four task suites combined: +* [moojink/openvla-7b-oft-finetuned-libero-spatial-object-goal-10](https://huggingface.co/moojink/openvla-7b-oft-finetuned-libero-spatial-object-goal-10) + +To start evaluations with one of the independently trained checkpoints, run one of the commands below. Each will automatically download the appropriate checkpoint listed above. You can set the `TRANSFORMERS_CACHE` and `HF_HOME` environment variable to change where the checkpoint files get cached. + +```bash +# Launch LIBERO-Spatial evals +python experiments/robot/libero/run_libero_eval.py \ + --pretrained_checkpoint moojink/openvla-7b-oft-finetuned-libero-spatial \ + --task_suite_name libero_spatial + +# Launch LIBERO-Object evals +python experiments/robot/libero/run_libero_eval.py \ + --pretrained_checkpoint moojink/openvla-7b-oft-finetuned-libero-object \ + --task_suite_name libero_object + +# Launch LIBERO-Goal evals +python experiments/robot/libero/run_libero_eval.py \ + --pretrained_checkpoint moojink/openvla-7b-oft-finetuned-libero-goal \ + --task_suite_name libero_goal + +# Launch LIBERO-10 (LIBERO-Long) evals +python experiments/robot/libero/run_libero_eval.py \ + --pretrained_checkpoint moojink/openvla-7b-oft-finetuned-libero-10 \ + --task_suite_name libero_10 +``` + +To evaluate the policy trained on all four task suites together, simply swap out the `--pretrained_checkpoint` in the commands above with `moojink/openvla-7b-oft-finetuned-libero-spatial-object-goal-10`. + +Notes: +* The evaluation script will run 500 trials by default (10 tasks x 50 episodes each). You can modify the number of + trials per task by setting `--num_trials_per_task`. You can also change the random seed via `--seed`. There are + other arguments in the script; we set them to the default values that work with the OpenVLA-OFT checkpoints above. +* **NOTE: Setting `--center_crop True` is important** because we fine-tuned OpenVLA with random crop augmentations + (we took a random crop with 90% area in every training sample, so at test time we simply take the center 90% crop). +* The evaluation script logs results locally. You can also log results in Weights & Biases + by setting `--use_wandb True` and specifying `--wandb_project ` and `--wandb_entity `. +* The results reported in our paper were obtained using **Python 3.10.14, PyTorch 2.2.0, and our + [custom transformers v4.40.1 fork](https://github.com/moojink/transformers-openvla-oft.git)** + on an **NVIDIA A100 GPU**, averaged over three random seeds. Please stick to these package versions if possible. + Note that results may vary slightly if you use a different GPU than the A100. If the discrepancy is large, + please post a GitHub issue, and we will look into it. + +## Fine-Tuning on LIBERO Datasets + +First, download the LIBERO datasets as mentioned above in the Setup section above: `libero_spatial_no_noops`, `libero_object_no_noops`, `libero_goal_no_noops`, `libero_10_no_noops`. (`"_no_noops"` stands for no no-op actions, i.e., training samples with near-zero actions are filtered out). + +Then, launch the fine-tuning script with the OFT configuration below, replacing `X` in the first line with the number of GPUs. The command below launches fine-tuning on LIBERO-Spatial with the hyperparameters that we used in our paper. Here, batch size 8 per GPU will require ~62 GB VRAM, and batch size 1 per GPU will require ~25 GB VRAM. + +```bash +torchrun --standalone --nnodes 1 --nproc-per-node X vla-scripts/finetune.py \ + --vla_path openvla/openvla-7b \ + --data_root_dir /PATH/TO/RLDS/DATASETS/DIR/ \ + --dataset_name libero_spatial_no_noops \ + --run_root_dir /YOUR/CHECKPOINTS/AND/LOG/DIR/ \ + --use_l1_regression True \ + --use_diffusion False \ + --use_film False \ + --num_images_in_input 2 \ + --use_proprio True \ + --batch_size 8 \ + --learning_rate 5e-4 \ + --num_steps_before_decay 100000 \ + --max_steps 150005 \ + --save_freq 10000 \ + --save_latest_checkpoint_only False \ + --image_aug True \ + --lora_rank 32 \ + --wandb_entity "YOUR_WANDB_ENTITY" \ + --wandb_project "YOUR_WANDB_PROJECT" \ + --run_id_note parallel_dec--8_acts_chunk--continuous_acts--L1_regression--3rd_person_img--wrist_img--proprio_state +``` + +The above training command should reproduce our OpenVLA-OFT results if `X = 8` and the 150K step checkpoint is evaluated. + +You can replace `libero_spatial_no_noops` with `libero_object_no_noops`, `libero_goal_no_noops`, or `libero_10_no_noops`. You can also modify other args — e.g., if you want to train with just one input image from the third-person camera and disable proprio state input, you can set `--num_images_in_input 1` and `--use_proprio False`. + +In general, we recommend fine-tuning until training L1 loss goes below 0.01 and starts to plateau (with the above configuration, it should reach ~0.006 L1 loss on LIBERO-Spatial after 150K gradient steps with 10x LR decay after 100K steps). However, for LIBERO-Goal only, we found that the 50K checkpoint (which was at ~0.02 L1 loss) performed best for unknown reasons. For all other task suites though, we found that the 150K checkpoint performed best. + +Please be sure to test your policy with the same device/GPU used to train it! Otherwise, performance may drop substantially. You may be able to avoid the performance drop if you merge the LoRA weights into the base model on the downstream device used for testing (e.g., if you train on H100 and then merge on A100 before testing on A100). You can see our script [vla-scripts/merge_lora_weights_and_save.py](vla-scripts/merge_lora_weights_and_save.py) for merging the LoRA adapter into the base model offline. It's okay if you already merged LoRA weights into the base OpenVLA model during fine-tuning; you can always redownload the base model and merge again as long as you still have the LoRA adapter (`merge_lora_weights_and_save.py` will handle this for you). + +If you run into any issues, please open a new GitHub issue. If you do not receive a response within 2 business days, please email Moo Jin Kim (moojink@cs.stanford.edu) to bring the issue to his attention. diff --git a/README.md b/README.md new file mode 100644 index 0000000000000000000000000000000000000000..2ee12646a4d85fafa9645b0445124d6fc345914f --- /dev/null +++ b/README.md @@ -0,0 +1,97 @@ +# Fine-Tuning Vision-Language-Action Models: Optimizing Speed and Success + +**Project website: https://openvla-oft.github.io/** + +**Paper: https://arxiv.org/abs/2502.19645** + +**Summary video: https://youtu.be/T3Zkkr_NTSA** + +## System Requirements + +Inference: +* 1 GPU with ~16 GB VRAM for LIBERO sim benchmark tasks +* 1 GPU with ~18 GB VRAM for ALOHA robot tasks + +Training: +* Between 1-8 GPUs with 27-80 GB, depending on the desired training setup (with default bfloat16 data type). See [this FAQ on our project website](https://openvla-oft.github.io/#train-compute) for details. + +## Quick Start + +First, set up a conda environment (see instructions in [SETUP.md](SETUP.md)). + +Then, run the Python script below to download a pretrained OpenVLA-OFT checkpoint and run inference to generate an action chunk: + +```python +import pickle +from experiments.robot.libero.run_libero_eval import GenerateConfig +from experiments.robot.openvla_utils import get_action_head, get_processor, get_proprio_projector, get_vla, get_vla_action +from prismatic.vla.constants import NUM_ACTIONS_CHUNK, PROPRIO_DIM + +# Instantiate config (see class GenerateConfig in experiments/robot/libero/run_libero_eval.py for definitions) +cfg = GenerateConfig( + pretrained_checkpoint = "moojink/openvla-7b-oft-finetuned-libero-spatial", + use_l1_regression = True, + use_diffusion = False, + use_film = False, + num_images_in_input = 2, + use_proprio = True, + load_in_8bit = False, + load_in_4bit = False, + center_crop = True, + num_open_loop_steps = NUM_ACTIONS_CHUNK, + unnorm_key = "libero_spatial_no_noops", +) + +# Load OpenVLA-OFT policy and inputs processor +vla = get_vla(cfg) +processor = get_processor(cfg) + +# Load MLP action head to generate continuous actions (via L1 regression) +action_head = get_action_head(cfg, llm_dim=vla.llm_dim) + +# Load proprio projector to map proprio to language embedding space +proprio_projector = get_proprio_projector(cfg, llm_dim=vla.llm_dim, proprio_dim=PROPRIO_DIM) + +# Load sample observation: +# observation (dict): { +# "full_image": primary third-person image, +# "wrist_image": wrist-mounted camera image, +# "state": robot proprioceptive state, +# "task_description": task description, +# } +with open("experiments/robot/libero/sample_libero_spatial_observation.pkl", "rb") as file: + observation = pickle.load(file) + +# Generate robot action chunk (sequence of future actions) +actions = get_vla_action(cfg, vla, processor, observation, observation["task_description"], action_head, proprio_projector) +print("Generated action chunk:") +for act in actions: + print(act) +``` + +## Installation + +See [SETUP.md](SETUP.md) for instructions on setting up the conda environment. + +## Training and Evaluation + +See [LIBERO.md](LIBERO.md) for fine-tuning/evaluating on LIBERO simulation benchmark task suites. + +See [ALOHA.md](ALOHA.md) for fine-tuning/evaluating on real-world ALOHA robot tasks. + +## Support + +If you run into any issues, please open a new GitHub issue. If you do not receive a response within 2 business days, please email Moo Jin Kim (moojink@cs.stanford.edu) to bring the issue to his attention. + +## Citation + +If you use our code in your work, please cite [our paper](https://arxiv.org/abs/2502.19645): + +```bibtex +@article{kim2025fine, + title={Fine-Tuning Vision-Language-Action Models: Optimizing Speed and Success}, + author={Kim, Moo Jin and Finn, Chelsea and Liang, Percy}, + journal={arXiv preprint arXiv:2502.19645}, + year={2025} +} +``` diff --git a/SETUP.md b/SETUP.md new file mode 100644 index 0000000000000000000000000000000000000000..d4d7c72c79295d0d292c90328c343538c385b73b --- /dev/null +++ b/SETUP.md @@ -0,0 +1,24 @@ +# Setup Instructions + +## Set Up Conda Environment + +```bash +# Create and activate conda environment +conda create -n openvla-oft python=3.10 -y +conda activate openvla-oft + +# Install PyTorch +# Use a command specific to your machine: https://pytorch.org/get-started/locally/ +pip3 install torch torchvision torchaudio + +# Clone openvla-oft repo and pip install to download dependencies +git clone https://github.com/moojink/openvla-oft.git +cd openvla-oft +pip install -e . + +# Install Flash Attention 2 for training (https://github.com/Dao-AILab/flash-attention) +# =>> If you run into difficulty, try `pip cache remove flash_attn` first +pip install packaging ninja +ninja --version; echo $? # Verify Ninja --> should return exit code "0" +pip install "flash-attn==2.5.5" --no-build-isolation +``` \ No newline at end of file diff --git a/experiments/robot/aloha/aloha_utils.py b/experiments/robot/aloha/aloha_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..b3dc5b8aae766a07c4032acd7371be0feff5bbeb --- /dev/null +++ b/experiments/robot/aloha/aloha_utils.py @@ -0,0 +1,85 @@ +"""Utils for evaluating policies in real-world ALOHA environments.""" + +import os + +import imageio +import numpy as np +from PIL import Image + +from experiments.robot.aloha.real_env import make_real_env +from experiments.robot.robot_utils import ( + DATE, + DATE_TIME, +) + + +def get_next_task_label(task_label): + """Prompt the user to input the next task.""" + if task_label == "": + user_input = "" + while user_input == "": + user_input = input("Enter the task name: ") + task_label = user_input + else: + user_input = input("Enter the task name (or leave blank to repeat the previous task): ") + if user_input == "": + pass # Do nothing -> Let task_label be the same + else: + task_label = user_input + print(f"Task: {task_label}") + return task_label + + +def get_aloha_env(): + """Initializes and returns the ALOHA environment.""" + env = make_real_env(init_node=True) + return env + + +def resize_image_for_preprocessing(img): + """ + Takes numpy array corresponding to a single image and resizes to 256x256, exactly as done + in the ALOHA data preprocessing script, which is used before converting the dataset to RLDS. + """ + ALOHA_PREPROCESS_SIZE = 256 + img = np.array( + Image.fromarray(img).resize((ALOHA_PREPROCESS_SIZE, ALOHA_PREPROCESS_SIZE), resample=Image.BICUBIC) + ) # BICUBIC is default; specify explicitly to make it clear + return img + + +def get_aloha_image(obs): + """Extracts third-person image from observations and preprocesses it.""" + # obs: dm_env._environment.TimeStep + img = obs.observation["images"]["cam_high"] + img = resize_image_for_preprocessing(img) + return img + + +def get_aloha_wrist_images(obs): + """Extracts both wrist camera images from observations and preprocesses them.""" + # obs: dm_env._environment.TimeStep + left_wrist_img = obs.observation["images"]["cam_left_wrist"] + right_wrist_img = obs.observation["images"]["cam_right_wrist"] + left_wrist_img = resize_image_for_preprocessing(left_wrist_img) + right_wrist_img = resize_image_for_preprocessing(right_wrist_img) + return left_wrist_img, right_wrist_img + + +def save_rollout_video(rollout_images, idx, success, task_description, log_file=None, notes=None): + """Saves an MP4 replay of an episode.""" + rollout_dir = f"./rollouts/{DATE}" + os.makedirs(rollout_dir, exist_ok=True) + processed_task_description = task_description.lower().replace(" ", "_").replace("\n", "_").replace(".", "_")[:50] + filetag = f"{rollout_dir}/{DATE_TIME}--openvla_oft--episode={idx}--success={success}--task={processed_task_description}" + if notes is not None: + filetag += f"--{notes}" + mp4_path = f"{filetag}.mp4" + video_writer = imageio.get_writer(mp4_path, fps=25) + for img in rollout_images: + video_writer.append_data(img) + video_writer.close() + print(f"Saved rollout MP4 at path {mp4_path}") + if log_file is not None: + log_file.write(f"Saved rollout MP4 at path {mp4_path}\n") + return mp4_path diff --git a/experiments/robot/aloha/constants.py b/experiments/robot/aloha/constants.py new file mode 100644 index 0000000000000000000000000000000000000000..5599e3590be470b6f57d01fe4c3aa0e7739e1353 --- /dev/null +++ b/experiments/robot/aloha/constants.py @@ -0,0 +1,100 @@ +### Task parameters + +DATA_DIR = '/scr2/moojink/data/aloha1/' +TASK_CONFIGS = { + # fold shorts + 'fold_shorts':{ + 'dataset_dir': DATA_DIR + '/fold_shorts', + 'num_episodes': 20, + 'episode_len': 1000, + 'camera_names': ['cam_high', 'cam_left_wrist', 'cam_right_wrist'] + }, + # fold shirt + 'fold_shirt':{ + 'dataset_dir': DATA_DIR + '/fold_shirt', + 'num_episodes': 30, + 'episode_len': 1250, + 'camera_names': ['cam_high', 'cam_left_wrist', 'cam_right_wrist'] + }, + # scoop X into bowl + 'scoop_raisins_into_bowl':{ + 'dataset_dir': DATA_DIR + '/scoop_raisins_into_bowl', + 'num_episodes': 15, + 'episode_len': 900, + 'camera_names': ['cam_high', 'cam_left_wrist', 'cam_right_wrist'] + }, + 'scoop_almonds_and_green_M&Ms_into_bowl':{ + 'dataset_dir': DATA_DIR + '/scoop_almonds_and_green_M&Ms_into_bowl', + 'num_episodes': 15, + 'episode_len': 900, + 'camera_names': ['cam_high', 'cam_left_wrist', 'cam_right_wrist'] + }, + 'scoop_pretzels_into_bowl':{ + 'dataset_dir': DATA_DIR + '/scoop_pretzels_into_bowl', + 'num_episodes': 15, + 'episode_len': 900, + 'camera_names': ['cam_high', 'cam_left_wrist', 'cam_right_wrist'] + }, + # put X into pot + 'put_red_pepper_into_pot':{ + 'dataset_dir': DATA_DIR + '/put_red_pepper_into_pot', + 'num_episodes': 100, + 'episode_len': 400, + 'camera_names': ['cam_high', 'cam_low', 'cam_left_wrist', 'cam_right_wrist'] + }, + 'put_yellow_corn_into_pot':{ + 'dataset_dir': DATA_DIR + '/put_yellow_corn_into_pot', + 'num_episodes': 100, + 'episode_len': 400, + 'camera_names': ['cam_high', 'cam_low', 'cam_left_wrist', 'cam_right_wrist'] + }, + 'put_green_pepper_into_pot':{ + 'dataset_dir': DATA_DIR + '/put_green_pepper_into_pot', + 'num_episodes': 100, + 'episode_len': 400, + 'camera_names': ['cam_high', 'cam_low', 'cam_left_wrist', 'cam_right_wrist'] + }, +} + +### ALOHA fixed constants +DT = 0.04 # 1 / 0.04 -> 25 Hz +JOINT_NAMES = ["waist", "shoulder", "elbow", "forearm_roll", "wrist_angle", "wrist_rotate"] +START_ARM_POSE = [0, -0.96, 1.16, 0, -0.3, 0, 0.02239, -0.02239, 0, -0.96, 1.16, 0, -0.3, 0, 0.02239, -0.02239] + +# Left finger position limits (qpos[7]), right_finger = -1 * left_finger +MASTER_GRIPPER_POSITION_OPEN = 0.02417 +MASTER_GRIPPER_POSITION_CLOSE = 0.01244 +PUPPET_GRIPPER_POSITION_OPEN = 0.05800 +PUPPET_GRIPPER_POSITION_CLOSE = 0.01844 + +# Gripper joint limits (qpos[6]) +MASTER_GRIPPER_JOINT_OPEN = 0.3083 # For ALOHA 1 +MASTER_GRIPPER_JOINT_CLOSE = -0.6842 # For ALOHA 1 +# MASTER_GRIPPER_JOINT_OPEN = -0.8 # For ALOHA 2 +# MASTER_GRIPPER_JOINT_CLOSE = -1.65 # For ALOHA 2 +PUPPET_GRIPPER_JOINT_OPEN = 1.4910 +PUPPET_GRIPPER_JOINT_CLOSE = -0.6213 + +############################ Helper functions ############################ + +MASTER_GRIPPER_POSITION_NORMALIZE_FN = lambda x: (x - MASTER_GRIPPER_POSITION_CLOSE) / (MASTER_GRIPPER_POSITION_OPEN - MASTER_GRIPPER_POSITION_CLOSE) +PUPPET_GRIPPER_POSITION_NORMALIZE_FN = lambda x: (x - PUPPET_GRIPPER_POSITION_CLOSE) / (PUPPET_GRIPPER_POSITION_OPEN - PUPPET_GRIPPER_POSITION_CLOSE) +MASTER_GRIPPER_POSITION_UNNORMALIZE_FN = lambda x: x * (MASTER_GRIPPER_POSITION_OPEN - MASTER_GRIPPER_POSITION_CLOSE) + MASTER_GRIPPER_POSITION_CLOSE +PUPPET_GRIPPER_POSITION_UNNORMALIZE_FN = lambda x: x * (PUPPET_GRIPPER_POSITION_OPEN - PUPPET_GRIPPER_POSITION_CLOSE) + PUPPET_GRIPPER_POSITION_CLOSE +MASTER2PUPPET_POSITION_FN = lambda x: PUPPET_GRIPPER_POSITION_UNNORMALIZE_FN(MASTER_GRIPPER_POSITION_NORMALIZE_FN(x)) + +MASTER_GRIPPER_JOINT_NORMALIZE_FN = lambda x: (x - MASTER_GRIPPER_JOINT_CLOSE) / (MASTER_GRIPPER_JOINT_OPEN - MASTER_GRIPPER_JOINT_CLOSE) +PUPPET_GRIPPER_JOINT_NORMALIZE_FN = lambda x: (x - PUPPET_GRIPPER_JOINT_CLOSE) / (PUPPET_GRIPPER_JOINT_OPEN - PUPPET_GRIPPER_JOINT_CLOSE) +MASTER_GRIPPER_JOINT_UNNORMALIZE_FN = lambda x: x * (MASTER_GRIPPER_JOINT_OPEN - MASTER_GRIPPER_JOINT_CLOSE) + MASTER_GRIPPER_JOINT_CLOSE +PUPPET_GRIPPER_JOINT_UNNORMALIZE_FN = lambda x: x * (PUPPET_GRIPPER_JOINT_OPEN - PUPPET_GRIPPER_JOINT_CLOSE) + PUPPET_GRIPPER_JOINT_CLOSE +MASTER2PUPPET_JOINT_FN = lambda x: PUPPET_GRIPPER_JOINT_UNNORMALIZE_FN(MASTER_GRIPPER_JOINT_NORMALIZE_FN(x)) + +MASTER_GRIPPER_VELOCITY_NORMALIZE_FN = lambda x: x / (MASTER_GRIPPER_POSITION_OPEN - MASTER_GRIPPER_POSITION_CLOSE) +PUPPET_GRIPPER_VELOCITY_NORMALIZE_FN = lambda x: x / (PUPPET_GRIPPER_POSITION_OPEN - PUPPET_GRIPPER_POSITION_CLOSE) + +MASTER_POS2JOINT = lambda x: MASTER_GRIPPER_POSITION_NORMALIZE_FN(x) * (MASTER_GRIPPER_JOINT_OPEN - MASTER_GRIPPER_JOINT_CLOSE) + MASTER_GRIPPER_JOINT_CLOSE +MASTER_JOINT2POS = lambda x: MASTER_GRIPPER_POSITION_UNNORMALIZE_FN((x - MASTER_GRIPPER_JOINT_CLOSE) / (MASTER_GRIPPER_JOINT_OPEN - MASTER_GRIPPER_JOINT_CLOSE)) +PUPPET_POS2JOINT = lambda x: PUPPET_GRIPPER_POSITION_NORMALIZE_FN(x) * (PUPPET_GRIPPER_JOINT_OPEN - PUPPET_GRIPPER_JOINT_CLOSE) + PUPPET_GRIPPER_JOINT_CLOSE +PUPPET_JOINT2POS = lambda x: PUPPET_GRIPPER_POSITION_UNNORMALIZE_FN((x - PUPPET_GRIPPER_JOINT_CLOSE) / (PUPPET_GRIPPER_JOINT_OPEN - PUPPET_GRIPPER_JOINT_CLOSE)) + +MASTER_GRIPPER_JOINT_MID = (MASTER_GRIPPER_JOINT_OPEN + MASTER_GRIPPER_JOINT_CLOSE)/2 \ No newline at end of file diff --git a/experiments/robot/aloha/preprocess_split_aloha_data.py b/experiments/robot/aloha/preprocess_split_aloha_data.py new file mode 100644 index 0000000000000000000000000000000000000000..8de07f232a48d6713e2758f78fde7fe75e28bd37 --- /dev/null +++ b/experiments/robot/aloha/preprocess_split_aloha_data.py @@ -0,0 +1,260 @@ +""" +Preprocesses ALOHA dataset(s) and splits them into train/val sets. + +Preprocessing includes downsizing images from 480x640 to 256x256. +Splits happen at the episode level (not step level), which means that +an episode is treated as an atomic unit that entirely goes to either +the train set or val set. + +Original ALOHA data layout: + /PATH/TO/DATASET/dataset_name/ + - episode_0.hdf5 + - episode_1.hdf5 + - ... + - episode_N.hdf5 + +Preprocessed data layout (after running this script): + /PATH/TO/PREPROCESSED_DATASETS/dataset_name/ + - train/ + - episode_0.hdf5 + - episode_1.hdf5 + - ... + - episode_M.hdf5 + - val/ + - episode_0.hdf5 + - episode_1.hdf5 + - ... + - episode_K.hdf5 + + where N > M > K + +Example usage: + # "put X into pot" task + python experiments/robot/aloha/preprocess_split_aloha_data.py \ + --dataset_path /scr/moojink/data/aloha1_raw/put_green_pepper_into_pot/ \ + --out_base_dir /scr/moojink/data/aloha1_preprocessed/ \ + --percent_val 0.05 && \ + python experiments/robot/aloha/preprocess_split_aloha_data.py \ + --dataset_path /scr/moojink/data/aloha1_raw/put_red_pepper_into_pot/ \ + --out_base_dir /scr/moojink/data/aloha1_preprocessed/ \ + --percent_val 0.05 && \ + python experiments/robot/aloha/preprocess_split_aloha_data.py \ + --dataset_path /scr/moojink/data/aloha1_raw/put_yellow_corn_into_pot/ \ + --out_base_dir /scr/moojink/data/aloha1_preprocessed/ \ + --percent_val 0.05 +""" + +import argparse +import glob +import os +import random + +import h5py +import numpy as np +from PIL import Image +from tqdm import tqdm + + +def load_hdf5(demo_path): + """Loads single episode.""" + if not os.path.isfile(demo_path): + print(f"Dataset does not exist at \n{demo_path}\n") + exit() + + print(f"Loading {demo_path}...") + with h5py.File(demo_path, "r") as root: + is_sim = root.attrs["sim"] + qpos = root["/observations/qpos"][()] + qvel = root["/observations/qvel"][()] + effort = root["/observations/effort"][()] + action = root["/action"][()] + image_dict = dict() + for cam_name in root["/observations/images/"].keys(): + image_dict[cam_name] = root[f"/observations/images/{cam_name}"][()] + print(f"Loading episode complete: {demo_path}") + + return qpos, qvel, effort, action, image_dict, is_sim + + +def load_and_preprocess_all_episodes(demo_paths, out_dataset_dir): + """ + Loads and preprocesses all episodes. + Resizes all images in one episode before loading the next, to reduce memory usage. + """ + cam_names = ["cam_high", "cam_left_wrist", "cam_right_wrist"] + idx = 0 + for demo in tqdm(demo_paths): + qpos, qvel, effort, action, image_dict, is_sim = load_hdf5(demo) + # Save non-image info + episode_len = image_dict["cam_high"].shape[0] + # Resize all images + print("Resizing images in episode...") + for k in cam_names: + resized_images = [] + for i in range(episode_len): + resized_images.append( + np.array( + Image.fromarray(image_dict[k][i]).resize( + (args.img_resize_size, args.img_resize_size), resample=Image.BICUBIC + ) + ) # BICUBIC is default; specify explicitly to make it clear + ) + image_dict[k] = np.stack(resized_images) + print("Resizing images in episode complete!") + # Save preprocessed episode + data_dict = dict( + qpos=qpos, + qvel=qvel, + effort=effort, + action=action, + image_dict=image_dict, + is_sim=is_sim, + ) + save_new_hdf5(out_dataset_dir, data_dict, idx) + idx += 1 + + +def randomly_split(full_qpos, full_qvel, full_effort, full_action, full_image_dict, percent_val): + """Randomly splits dataset into train and validation sets.""" + # Create a list of episode indices + num_episodes_total = len(full_qpos) + indices = list(range(num_episodes_total)) + # Shuffle the episode indices + random.shuffle(indices) + # Create new lists using the shuffled indices + shuffled_qpos = [full_qpos[idx] for idx in indices] + shuffled_qvel = [full_qvel[idx] for idx in indices] + shuffled_effort = [full_effort[idx] for idx in indices] + shuffled_action = [full_action[idx] for idx in indices] + shuffled_image_dict = { + "cam_high": [], + "cam_left_wrist": [], + "cam_right_wrist": [], + } + for k in full_image_dict.keys(): + shuffled_image_dict[k] = [full_image_dict[k][idx] for idx in indices] + # Split into train and val sets + num_episodes_val = int(num_episodes_total * percent_val) + print(f"Total # steps: {num_episodes_total}; using {num_episodes_val} ({percent_val:.2f}%) for val set") + num_episodes_train = num_episodes_total - num_episodes_val + train_dict = dict( + qpos=shuffled_qpos[:num_episodes_train], + qvel=shuffled_qvel[:num_episodes_train], + effort=shuffled_effort[:num_episodes_train], + action=shuffled_action[:num_episodes_train], + image_dict=dict( + cam_high=shuffled_image_dict["cam_high"][:num_episodes_train], + cam_left_wrist=shuffled_image_dict["cam_left_wrist"][:num_episodes_train], + cam_right_wrist=shuffled_image_dict["cam_right_wrist"][:num_episodes_train], + ), + ) + val_dict = dict( + qpos=shuffled_qpos[num_episodes_train:], + qvel=shuffled_qvel[num_episodes_train:], + effort=shuffled_effort[num_episodes_train:], + action=shuffled_action[num_episodes_train:], + image_dict=dict( + cam_high=shuffled_image_dict["cam_high"][num_episodes_train:], + cam_left_wrist=shuffled_image_dict["cam_left_wrist"][num_episodes_train:], + cam_right_wrist=shuffled_image_dict["cam_right_wrist"][num_episodes_train:], + ), + ) + return train_dict, val_dict + + +def save_new_hdf5(out_dataset_dir, data_dict, episode_idx): + """Saves an HDF5 file for a new episode.""" + camera_names = data_dict["image_dict"].keys() + H, W, C = data_dict["image_dict"]["cam_high"][0].shape + out_path = os.path.join(out_dataset_dir, f"episode_{episode_idx}.hdf5") + # Save HDF5 with same structure as original demos (except that now we combine all episodes into one HDF5 file) + with h5py.File( + out_path, "w", rdcc_nbytes=1024**2 * 2 + ) as root: # Magic constant for rdcc_nbytes comes from ALOHA codebase + episode_len = data_dict["qpos"].shape[0] + root.attrs["sim"] = data_dict["is_sim"] + obs = root.create_group("observations") + _ = obs.create_dataset("qpos", (episode_len, 14)) + _ = obs.create_dataset("qvel", (episode_len, 14)) + _ = obs.create_dataset("effort", (episode_len, 14)) + root["/observations/qpos"][...] = data_dict["qpos"] + root["/observations/qvel"][...] = data_dict["qvel"] + root["/observations/effort"][...] = data_dict["effort"] + image = obs.create_group("images") + for cam_name in camera_names: + _ = image.create_dataset( + cam_name, + (episode_len, H, W, C), + dtype="uint8", + chunks=(1, H, W, C), + ) + root[f"/observations/images/{cam_name}"][...] = data_dict["image_dict"][cam_name] + _ = root.create_dataset("action", (episode_len, 14)) + root["/action"][...] = data_dict["action"] + # Compute and save *relative* actions as well + actions = data_dict["action"] + relative_actions = np.zeros_like(actions) + relative_actions[:-1] = actions[1:] - actions[:-1] # Relative actions are the changes in joint pos + relative_actions[-1] = relative_actions[-2] # Just copy the second-to-last action for the last action + _ = root.create_dataset("relative_action", (episode_len, 14)) + root["/relative_action"][...] = relative_actions + print(f"Saved dataset: {out_path}") + + +def main(args): + # Create directory to save preprocessed dataset (if it doesn't exist already) + os.makedirs(args.out_base_dir, exist_ok=True) + out_dataset_dir = os.path.join(args.out_base_dir, os.path.basename(args.dataset_path.rstrip("/"))) + os.makedirs(out_dataset_dir, exist_ok=True) + # Get list of filepaths of all episodes + all_demo_paths = glob.glob(os.path.join(args.dataset_path, "*.hdf5")) # List of HDF5 filepaths + all_demo_paths.sort() + # Create a list of episode indices + num_episodes_total = len(all_demo_paths) + indices = list(range(num_episodes_total)) + # Shuffle the episode indices + random.shuffle(indices) + # Split into train and val sets + num_episodes_val = int(num_episodes_total * args.percent_val) + print(f"Total # episodes: {num_episodes_total}; using {num_episodes_val} ({args.percent_val:.2f}%) for val set") + num_episodes_train = num_episodes_total - num_episodes_val + train_indices = indices[:num_episodes_train] + val_indices = indices[num_episodes_train:] + train_demo_paths = [all_demo_paths[i] for i in train_indices] + val_demo_paths = [all_demo_paths[i] for i in val_indices] + # Preprocess all episodes and save the result + out_dataset_dir_train = os.path.join(out_dataset_dir, "train") + out_dataset_dir_val = os.path.join(out_dataset_dir, "val") + os.makedirs(out_dataset_dir_train, exist_ok=True) + os.makedirs(out_dataset_dir_val, exist_ok=True) + load_and_preprocess_all_episodes(train_demo_paths, out_dataset_dir_train) + load_and_preprocess_all_episodes(val_demo_paths, out_dataset_dir_val) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument( + "--dataset_path", + required=True, + help="Path to raw ALOHA dataset directory. Example: /PATH/TO/USER/data/aloha_raw/put_green_pepper_into_pot/", + ) + parser.add_argument( + "--out_base_dir", + required=True, + help="Path to directory in which to save preprocessed dataset. Example: /PATH/TO/USER/data/aloha_preprocessed/", + ) + parser.add_argument( + "--percent_val", + type=float, + help="Percent of dataset to use as validation set (measured in episodes, not steps).", + default=0.05, + ) + parser.add_argument( + "--img_resize_size", + type=int, + help="Size to resize images to. Final images will be square (img_resize_size x img_resize_size pixels).", + default=256, + ) + args = parser.parse_args() + + main(args) diff --git a/experiments/robot/aloha/real_env.py b/experiments/robot/aloha/real_env.py new file mode 100644 index 0000000000000000000000000000000000000000..f3f6c8f54b9f84c49a992345973b44eda45aaa42 --- /dev/null +++ b/experiments/robot/aloha/real_env.py @@ -0,0 +1,213 @@ +import time +import numpy as np +import collections +import matplotlib.pyplot as plt +import dm_env + +from experiments.robot.aloha.constants import DT, START_ARM_POSE, MASTER_GRIPPER_JOINT_NORMALIZE_FN, PUPPET_GRIPPER_JOINT_UNNORMALIZE_FN +from experiments.robot.aloha.constants import PUPPET_GRIPPER_POSITION_NORMALIZE_FN, PUPPET_GRIPPER_VELOCITY_NORMALIZE_FN +from experiments.robot.aloha.constants import PUPPET_GRIPPER_JOINT_OPEN, PUPPET_GRIPPER_JOINT_CLOSE +from experiments.robot.aloha.robot_utils import Recorder, ImageRecorder +from experiments.robot.aloha.robot_utils import setup_master_bot, setup_puppet_bot, move_arms, move_grippers +from interbotix_xs_modules.arm import InterbotixManipulatorXS +from interbotix_xs_msgs.msg import JointSingleCommand + +import IPython +e = IPython.embed + +class RealEnv: + """ + Environment for real robot bi-manual manipulation + Action space: [left_arm_qpos (6), # absolute joint position + left_gripper_positions (1), # normalized gripper position (0: close, 1: open) + right_arm_qpos (6), # absolute joint position + right_gripper_positions (1),] # normalized gripper position (0: close, 1: open) + + Observation space: {"qpos": Concat[ left_arm_qpos (6), # absolute joint position + left_gripper_position (1), # normalized gripper position (0: close, 1: open) + right_arm_qpos (6), # absolute joint position + right_gripper_qpos (1)] # normalized gripper position (0: close, 1: open) + "qvel": Concat[ left_arm_qvel (6), # absolute joint velocity (rad) + left_gripper_velocity (1), # normalized gripper velocity (pos: opening, neg: closing) + right_arm_qvel (6), # absolute joint velocity (rad) + right_gripper_qvel (1)] # normalized gripper velocity (pos: opening, neg: closing) + "images": {"cam_high": (480x640x3), # h, w, c, dtype='uint8' + "cam_low": (480x640x3), # h, w, c, dtype='uint8' + "cam_left_wrist": (480x640x3), # h, w, c, dtype='uint8' + "cam_right_wrist": (480x640x3)} # h, w, c, dtype='uint8' + """ + + def __init__(self, init_node, setup_robots=True): + self.puppet_bot_left = InterbotixManipulatorXS(robot_model="vx300s", group_name="arm", gripper_name="gripper", + robot_name=f'puppet_left', init_node=init_node) + self.puppet_bot_right = InterbotixManipulatorXS(robot_model="vx300s", group_name="arm", gripper_name="gripper", + robot_name=f'puppet_right', init_node=False) + if setup_robots: + self.setup_robots() + + self.recorder_left = Recorder('left', init_node=False) + self.recorder_right = Recorder('right', init_node=False) + self.image_recorder = ImageRecorder(init_node=False) + self.gripper_command = JointSingleCommand(name="gripper") + + def setup_robots(self): + setup_puppet_bot(self.puppet_bot_left) + setup_puppet_bot(self.puppet_bot_right) + + def get_qpos(self): + left_qpos_raw = self.recorder_left.qpos + right_qpos_raw = self.recorder_right.qpos + left_arm_qpos = left_qpos_raw[:6] + right_arm_qpos = right_qpos_raw[:6] + left_gripper_qpos = [PUPPET_GRIPPER_POSITION_NORMALIZE_FN(left_qpos_raw[7])] # this is position not joint + right_gripper_qpos = [PUPPET_GRIPPER_POSITION_NORMALIZE_FN(right_qpos_raw[7])] # this is position not joint + return np.concatenate([left_arm_qpos, left_gripper_qpos, right_arm_qpos, right_gripper_qpos]) + + def get_qvel(self): + left_qvel_raw = self.recorder_left.qvel + right_qvel_raw = self.recorder_right.qvel + left_arm_qvel = left_qvel_raw[:6] + right_arm_qvel = right_qvel_raw[:6] + left_gripper_qvel = [PUPPET_GRIPPER_VELOCITY_NORMALIZE_FN(left_qvel_raw[7])] + right_gripper_qvel = [PUPPET_GRIPPER_VELOCITY_NORMALIZE_FN(right_qvel_raw[7])] + return np.concatenate([left_arm_qvel, left_gripper_qvel, right_arm_qvel, right_gripper_qvel]) + + def get_effort(self): + left_effort_raw = self.recorder_left.effort + right_effort_raw = self.recorder_right.effort + left_robot_effort = left_effort_raw[:7] + right_robot_effort = right_effort_raw[:7] + return np.concatenate([left_robot_effort, right_robot_effort]) + + def get_images(self): + return self.image_recorder.get_images() + + def set_gripper_pose(self, left_gripper_desired_pos_normalized, right_gripper_desired_pos_normalized): + left_gripper_desired_joint = PUPPET_GRIPPER_JOINT_UNNORMALIZE_FN(left_gripper_desired_pos_normalized) + self.gripper_command.cmd = left_gripper_desired_joint + self.puppet_bot_left.gripper.core.pub_single.publish(self.gripper_command) + + right_gripper_desired_joint = PUPPET_GRIPPER_JOINT_UNNORMALIZE_FN(right_gripper_desired_pos_normalized) + self.gripper_command.cmd = right_gripper_desired_joint + self.puppet_bot_right.gripper.core.pub_single.publish(self.gripper_command) + + def _reset_joints(self): + reset_position = START_ARM_POSE[:6] + move_arms([self.puppet_bot_left, self.puppet_bot_right], [reset_position, reset_position], move_time=1) + + def _reset_gripper(self): + """Set to position mode and do position resets: first open then close. Then change back to PWM mode""" + move_grippers([self.puppet_bot_left, self.puppet_bot_right], [PUPPET_GRIPPER_JOINT_OPEN] * 2, move_time=0.5) + move_grippers([self.puppet_bot_left, self.puppet_bot_right], [PUPPET_GRIPPER_JOINT_CLOSE] * 2, move_time=1) + + def _get_obs(self): + obs = collections.OrderedDict() + obs['qpos'] = self.get_qpos() + obs['qvel'] = self.get_qvel() + obs['effort'] = self.get_effort() + obs['images'] = self.get_images() + return obs + + def get_observation(self, t=0): + step_type = dm_env.StepType.FIRST if t == 0 else dm_env.StepType.MID + return dm_env.TimeStep( + step_type=step_type, + reward=self.get_reward(), + discount=None, + observation=self._get_obs() + ) + + def get_reward(self): + return 0 + + def reset(self, fake=False): + if not fake: + # Reboot puppet robot gripper motors + self.puppet_bot_left.dxl.robot_reboot_motors("single", "gripper", True) + self.puppet_bot_right.dxl.robot_reboot_motors("single", "gripper", True) + self._reset_joints() + self._reset_gripper() + return dm_env.TimeStep( + step_type=dm_env.StepType.FIRST, + reward=self.get_reward(), + discount=None, + observation=self._get_obs()) + + def step(self, action): + state_len = int(len(action) / 2) + left_action = action[:state_len] + right_action = action[state_len:] + self.puppet_bot_left.arm.set_joint_positions(left_action[:6], blocking=False) + self.puppet_bot_right.arm.set_joint_positions(right_action[:6], blocking=False) + self.set_gripper_pose(left_action[-1], right_action[-1]) + time.sleep(DT) + return dm_env.TimeStep( + step_type=dm_env.StepType.MID, + reward=self.get_reward(), + discount=None, + observation=self._get_obs()) + + +def get_action(master_bot_left, master_bot_right): + action = np.zeros(14) # 6 joint + 1 gripper, for two arms + # Arm actions + action[:6] = master_bot_left.dxl.joint_states.position[:6] + action[7:7+6] = master_bot_right.dxl.joint_states.position[:6] + # Gripper actions + action[6] = MASTER_GRIPPER_JOINT_NORMALIZE_FN(master_bot_left.dxl.joint_states.position[6]) + action[7+6] = MASTER_GRIPPER_JOINT_NORMALIZE_FN(master_bot_right.dxl.joint_states.position[6]) + + return action + + +def make_real_env(init_node, setup_robots=True): + env = RealEnv(init_node, setup_robots) + return env + + +def test_real_teleop(): + """ + Test bimanual teleoperation and show image observations onscreen. + It first reads joint poses from both master arms. + Then use it as actions to step the environment. + The environment returns full observations including images. + + An alternative approach is to have separate scripts for teleoperation and observation recording. + This script will result in higher fidelity (obs, action) pairs + """ + + onscreen_render = True + render_cam = 'cam_left_wrist' + + # source of data + master_bot_left = InterbotixManipulatorXS(robot_model="wx250s", group_name="arm", gripper_name="gripper", + robot_name=f'master_left', init_node=True) + master_bot_right = InterbotixManipulatorXS(robot_model="wx250s", group_name="arm", gripper_name="gripper", + robot_name=f'master_right', init_node=False) + setup_master_bot(master_bot_left) + setup_master_bot(master_bot_right) + + # setup the environment + env = make_real_env(init_node=False) + ts = env.reset(fake=True) + episode = [ts] + # setup visualization + if onscreen_render: + ax = plt.subplot() + plt_img = ax.imshow(ts.observation['images'][render_cam]) + plt.ion() + + for t in range(1000): + action = get_action(master_bot_left, master_bot_right) + ts = env.step(action) + episode.append(ts) + + if onscreen_render: + plt_img.set_data(ts.observation['images'][render_cam]) + plt.pause(DT) + else: + time.sleep(DT) + + +if __name__ == '__main__': + test_real_teleop() diff --git a/experiments/robot/aloha/requirements_aloha.txt b/experiments/robot/aloha/requirements_aloha.txt new file mode 100644 index 0000000000000000000000000000000000000000..c84c6d08c5b343b91ca2eed61bd1f029a7d61037 --- /dev/null +++ b/experiments/robot/aloha/requirements_aloha.txt @@ -0,0 +1,26 @@ +numpy<2 +draccus +torchvision +torch +pyquaternion +pyyaml +rospkg +pexpect +mujoco==2.3.7 +dm_control==1.0.14 +opencv-python +matplotlib +einops +packaging +h5py +traitlets +ipdb +IPython +modern_robotics +Pillow +termcolor +imageio[ffmpeg] +uvicorn +fastapi +requests +json_numpy diff --git a/experiments/robot/aloha/robot_utils.py b/experiments/robot/aloha/robot_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..82f080cc9e6ebe92b6b3447bfd9cc5d5e60b78a8 --- /dev/null +++ b/experiments/robot/aloha/robot_utils.py @@ -0,0 +1,187 @@ +import numpy as np +import time +from experiments.robot.aloha.constants import DT +from interbotix_xs_msgs.msg import JointSingleCommand + +import IPython +e = IPython.embed + +class ImageRecorder: + def __init__(self, init_node=True, is_debug=False): + from collections import deque + import rospy + from cv_bridge import CvBridge + from sensor_msgs.msg import Image + self.is_debug = is_debug + self.bridge = CvBridge() + self.camera_names = ['cam_high', 'cam_low', 'cam_left_wrist', 'cam_right_wrist'] + if init_node: + rospy.init_node('image_recorder', anonymous=True) + for cam_name in self.camera_names: + setattr(self, f'{cam_name}_image', None) + setattr(self, f'{cam_name}_secs', None) + setattr(self, f'{cam_name}_nsecs', None) + if cam_name == 'cam_high': + callback_func = self.image_cb_cam_high + elif cam_name == 'cam_low': + callback_func = self.image_cb_cam_low + elif cam_name == 'cam_left_wrist': + callback_func = self.image_cb_cam_left_wrist + elif cam_name == 'cam_right_wrist': + callback_func = self.image_cb_cam_right_wrist + else: + raise NotImplementedError + rospy.Subscriber(f"/usb_{cam_name}/image_raw", Image, callback_func) + if self.is_debug: + setattr(self, f'{cam_name}_timestamps', deque(maxlen=50)) + time.sleep(0.5) + + def image_cb(self, cam_name, data): + setattr(self, f'{cam_name}_image', self.bridge.imgmsg_to_cv2(data, desired_encoding='passthrough')) + setattr(self, f'{cam_name}_secs', data.header.stamp.secs) + setattr(self, f'{cam_name}_nsecs', data.header.stamp.nsecs) + # cv2.imwrite('/home/tonyzhao/Desktop/sample.jpg', cv_image) + if self.is_debug: + getattr(self, f'{cam_name}_timestamps').append(data.header.stamp.secs + data.header.stamp.secs * 1e-9) + + def image_cb_cam_high(self, data): + cam_name = 'cam_high' + return self.image_cb(cam_name, data) + + def image_cb_cam_low(self, data): + cam_name = 'cam_low' + return self.image_cb(cam_name, data) + + def image_cb_cam_left_wrist(self, data): + cam_name = 'cam_left_wrist' + return self.image_cb(cam_name, data) + + def image_cb_cam_right_wrist(self, data): + cam_name = 'cam_right_wrist' + return self.image_cb(cam_name, data) + + def get_images(self): + image_dict = dict() + for cam_name in self.camera_names: + image_dict[cam_name] = getattr(self, f'{cam_name}_image') + return image_dict + + def print_diagnostics(self): + def dt_helper(l): + l = np.array(l) + diff = l[1:] - l[:-1] + return np.mean(diff) + for cam_name in self.camera_names: + image_freq = 1 / dt_helper(getattr(self, f'{cam_name}_timestamps')) + print(f'{cam_name} {image_freq=:.2f}') + print() + +class Recorder: + def __init__(self, side, init_node=True, is_debug=False): + from collections import deque + import rospy + from sensor_msgs.msg import JointState + from interbotix_xs_msgs.msg import JointGroupCommand, JointSingleCommand + + self.secs = None + self.nsecs = None + self.qpos = None + self.effort = None + self.arm_command = None + self.gripper_command = None + self.is_debug = is_debug + + if init_node: + rospy.init_node('recorder', anonymous=True) + rospy.Subscriber(f"/puppet_{side}/joint_states", JointState, self.puppet_state_cb) + rospy.Subscriber(f"/puppet_{side}/commands/joint_group", JointGroupCommand, self.puppet_arm_commands_cb) + rospy.Subscriber(f"/puppet_{side}/commands/joint_single", JointSingleCommand, self.puppet_gripper_commands_cb) + if self.is_debug: + self.joint_timestamps = deque(maxlen=50) + self.arm_command_timestamps = deque(maxlen=50) + self.gripper_command_timestamps = deque(maxlen=50) + time.sleep(0.1) + + def puppet_state_cb(self, data): + self.qpos = data.position + self.qvel = data.velocity + self.effort = data.effort + self.data = data + if self.is_debug: + self.joint_timestamps.append(time.time()) + + def puppet_arm_commands_cb(self, data): + self.arm_command = data.cmd + if self.is_debug: + self.arm_command_timestamps.append(time.time()) + + def puppet_gripper_commands_cb(self, data): + self.gripper_command = data.cmd + if self.is_debug: + self.gripper_command_timestamps.append(time.time()) + + def print_diagnostics(self): + def dt_helper(l): + l = np.array(l) + diff = l[1:] - l[:-1] + return np.mean(diff) + + joint_freq = 1 / dt_helper(self.joint_timestamps) + arm_command_freq = 1 / dt_helper(self.arm_command_timestamps) + gripper_command_freq = 1 / dt_helper(self.gripper_command_timestamps) + + print(f'{joint_freq=:.2f}\n{arm_command_freq=:.2f}\n{gripper_command_freq=:.2f}\n') + +def get_arm_joint_positions(bot): + return bot.arm.core.joint_states.position[:6] + +def get_arm_gripper_positions(bot): + joint_position = bot.gripper.core.joint_states.position[6] + return joint_position + +def move_arms(bot_list, target_pose_list, move_time=1): + num_steps = int(move_time / DT) + curr_pose_list = [get_arm_joint_positions(bot) for bot in bot_list] + traj_list = [np.linspace(curr_pose, target_pose, num_steps) for curr_pose, target_pose in zip(curr_pose_list, target_pose_list)] + for t in range(num_steps): + for bot_id, bot in enumerate(bot_list): + bot.arm.set_joint_positions(traj_list[bot_id][t], blocking=False) + time.sleep(DT) + +def move_grippers(bot_list, target_pose_list, move_time): + gripper_command = JointSingleCommand(name="gripper") + num_steps = int(move_time / DT) + curr_pose_list = [get_arm_gripper_positions(bot) for bot in bot_list] + traj_list = [np.linspace(curr_pose, target_pose, num_steps) for curr_pose, target_pose in zip(curr_pose_list, target_pose_list)] + for t in range(num_steps): + for bot_id, bot in enumerate(bot_list): + gripper_command.cmd = traj_list[bot_id][t] + bot.gripper.core.pub_single.publish(gripper_command) + time.sleep(DT) + +def setup_puppet_bot(bot): + bot.dxl.robot_reboot_motors("single", "gripper", True) + bot.dxl.robot_set_operating_modes("group", "arm", "position") + bot.dxl.robot_set_operating_modes("single", "gripper", "current_based_position") + torque_on(bot) + +def setup_master_bot(bot): + bot.dxl.robot_set_operating_modes("group", "arm", "pwm") + bot.dxl.robot_set_operating_modes("single", "gripper", "current_based_position") + torque_off(bot) + +def set_standard_pid_gains(bot): + bot.dxl.robot_set_motor_registers("group", "arm", 'Position_P_Gain', 800) + bot.dxl.robot_set_motor_registers("group", "arm", 'Position_I_Gain', 0) + +def set_low_pid_gains(bot): + bot.dxl.robot_set_motor_registers("group", "arm", 'Position_P_Gain', 100) + bot.dxl.robot_set_motor_registers("group", "arm", 'Position_I_Gain', 0) + +def torque_off(bot): + bot.dxl.robot_torque_enable("group", "arm", False) + bot.dxl.robot_torque_enable("single", "gripper", False) + +def torque_on(bot): + bot.dxl.robot_torque_enable("group", "arm", True) + bot.dxl.robot_torque_enable("single", "gripper", True) \ No newline at end of file diff --git a/experiments/robot/aloha/run_aloha_eval.py b/experiments/robot/aloha/run_aloha_eval.py new file mode 100644 index 0000000000000000000000000000000000000000..520f5af9a284cead73d8cd027f821b5490354a45 --- /dev/null +++ b/experiments/robot/aloha/run_aloha_eval.py @@ -0,0 +1,385 @@ +""" +run_aloha_eval.py + +Evaluates a model in a real-world ALOHA environment. +""" + +import logging +import os +import socket +import sys +import time +from collections import deque +from dataclasses import dataclass +from pathlib import Path +from typing import Optional, Union + +import draccus +import tqdm + +# Append current directory so that interpreter can find experiments.robot +sys.path.append(".") +from experiments.robot.aloha.aloha_utils import ( + get_aloha_env, + get_aloha_image, + get_aloha_wrist_images, + get_next_task_label, + save_rollout_video, +) +from experiments.robot.openvla_utils import ( + get_action_from_server, + resize_image_for_policy, +) +from experiments.robot.robot_utils import ( + DATE_TIME, + get_image_resize_size, + set_seed_everywhere, +) + +# Set up logging +logging.basicConfig( + level=logging.INFO, + format="%(asctime)s [%(levelname)s] %(message)s", + handlers=[logging.StreamHandler()], +) +logger = logging.getLogger(__name__) + + +@dataclass +class GenerateConfig: + # fmt: off + + ################################################################################################################# + # Model-specific parameters + ################################################################################################################# + model_family: str = "openvla" # Model family + + center_crop: bool = True # Center crop? (if trained w/ random crop image aug) + num_open_loop_steps: int = 25 # Number of actions to execute open-loop before requerying policy + + use_vla_server: bool = True # Whether to query remote VLA server for actions + vla_server_url: Union[str, Path] = "" # Remote VLA server URL (set to 127.0.0.1 if on same machine) + + ################################################################################################################# + # ALOHA environment-specific parameters + ################################################################################################################# + num_rollouts_planned: int = 50 # Number of test rollouts + max_steps: int = 1500 # Max number of steps per rollout + use_relative_actions: bool = False # Whether to use relative actions (delta joint angles) + + ################################################################################################################# + # Utils + ################################################################################################################# + run_id_note: Optional[str] = None # Extra note to add to end of run ID for logging + local_log_dir: str = "./experiments/logs" # Local directory for eval logs + + seed: int = 7 # Random Seed (for reproducibility) + + # fmt: on + + +def validate_config(cfg: GenerateConfig) -> None: + """Validate configuration parameters.""" + assert cfg.use_vla_server, ( + "Must use VLA server (server-client interface) to query model and get actions! Please set --use_vla_server=True" + ) + + +def setup_logging(cfg: GenerateConfig): + """Set up logging to file.""" + # Create run ID + run_id = f"EVAL-{cfg.model_family}-{DATE_TIME}" + if cfg.run_id_note is not None: + run_id += f"--{cfg.run_id_note}" + + # Set up local logging + os.makedirs(cfg.local_log_dir, exist_ok=True) + local_log_filepath = os.path.join(cfg.local_log_dir, run_id + ".txt") + log_file = open(local_log_filepath, "w") + logger.info(f"Logging to local log file: {local_log_filepath}") + + return log_file, local_log_filepath, run_id + + +def log_message(message: str, log_file=None): + """Log a message to console and optionally to a log file.""" + print(message) + logger.info(message) + if log_file: + log_file.write(message + "\n") + log_file.flush() + + +def get_server_endpoint(cfg: GenerateConfig): + """Get the server endpoint for remote inference.""" + ip_address = socket.gethostbyname(cfg.vla_server_url) + return f"http://{ip_address}:8777/act" + + +def prepare_observation(obs, resize_size): + """Prepare observation for policy input.""" + # Get preprocessed images + img = get_aloha_image(obs) + left_wrist_img, right_wrist_img = get_aloha_wrist_images(obs) + + # Resize images to size expected by model + img_resized = resize_image_for_policy(img, resize_size) + left_wrist_img_resized = resize_image_for_policy(left_wrist_img, resize_size) + right_wrist_img_resized = resize_image_for_policy(right_wrist_img, resize_size) + + # Prepare observations dict + observation = { + "full_image": img_resized, + "left_wrist_image": left_wrist_img_resized, + "right_wrist_image": right_wrist_img_resized, + "state": obs.observation["qpos"], + } + + return observation, img_resized, left_wrist_img_resized, right_wrist_img_resized + + +def run_episode( + cfg: GenerateConfig, + env, + task_description: str, + server_endpoint: str, + resize_size, + log_file=None, +): + """Run a single episode in the ALOHA environment.""" + # Define control frequency + STEP_DURATION_IN_SEC = 1.0 / 25.0 + + # Reset environment + obs = env.reset() + + # Initialize action queue + action_queue = deque(maxlen=cfg.num_open_loop_steps) + + # Setup + t = 0 + curr_state = None + replay_images = [] + replay_images_resized = [] + replay_images_left_wrist_resized = [] + replay_images_right_wrist_resized = [] + + log_message("Prepare the scene, and then press Enter to begin...", log_file) + input() + + # Reset environment again to fetch first timestep observation + obs = env.reset() + + # Fetch initial robot state (but sleep first so that robot stops moving) + time.sleep(2) + curr_state = env.get_qpos() + + episode_start_time = time.time() + total_model_query_time = 0.0 + + try: + while t < cfg.max_steps: + # Get step start time (used to compute how much to sleep between steps) + step_start_time = time.time() + + # Get observation + obs = env.get_observation(t=t) + + # Save raw high camera image for replay video + replay_images.append(obs.observation["images"]["cam_high"]) + + # If action queue is empty, requery model + if len(action_queue) == 0: + # Prepare observation + observation, img_resized, left_wrist_resized, right_wrist_resized = prepare_observation(obs, resize_size) + observation["instruction"] = task_description + + # Save processed images for replay + replay_images_resized.append(img_resized) + replay_images_left_wrist_resized.append(left_wrist_resized) + replay_images_right_wrist_resized.append(right_wrist_resized) + + # Query model to get action + log_message("Requerying model...", log_file) + model_query_start_time = time.time() + actions = get_action_from_server(observation, server_endpoint) + actions = actions[: cfg.num_open_loop_steps] + total_model_query_time += time.time() - model_query_start_time + action_queue.extend(actions) + + # Get action from queue + action = action_queue.popleft() + log_message("-----------------------------------------------------", log_file) + log_message(f"t: {t}", log_file) + log_message(f"action: {action}", log_file) + + # Execute action in environment + if cfg.use_relative_actions: + # Get absolute joint angles from relative action + rel_action = action + target_state = curr_state + rel_action + obs = env.step(target_state.tolist()) + # Update current state (assume it is the commanded target state) + curr_state = target_state + else: + obs = env.step(action.tolist()) + t += 1 + + # Sleep until next timestep + step_elapsed_time = time.time() - step_start_time + if step_elapsed_time < STEP_DURATION_IN_SEC: + time_to_sleep = STEP_DURATION_IN_SEC - step_elapsed_time + log_message(f"Sleeping {time_to_sleep} sec...", log_file) + time.sleep(time_to_sleep) + + except (KeyboardInterrupt, Exception) as e: + if isinstance(e, KeyboardInterrupt): + log_message("\nCaught KeyboardInterrupt: Terminating episode early.", log_file) + else: + log_message(f"\nCaught exception: {e}", log_file) + + episode_end_time = time.time() + + # Get success feedback from user + user_input = input("Success? Enter 'y' or 'n': ") + success = True if user_input.lower() == "y" else False + + # Calculate episode statistics + episode_stats = { + "success": success, + "total_steps": t, + "model_query_time": total_model_query_time, + "episode_duration": episode_end_time - episode_start_time, + } + + return ( + episode_stats, + replay_images, + replay_images_resized, + replay_images_left_wrist_resized, + replay_images_right_wrist_resized, + ) + + +def save_episode_videos( + replay_images, + replay_images_resized, + replay_images_left_wrist, + replay_images_right_wrist, + episode_idx, + success, + task_description, + log_file=None, +): + """Save videos of the episode from different camera angles.""" + # Save main replay video + save_rollout_video(replay_images, episode_idx, success=success, task_description=task_description, log_file=log_file) + + # Save processed view videos + save_rollout_video( + replay_images_resized, + episode_idx, + success=success, + task_description=task_description, + log_file=log_file, + notes="resized", + ) + save_rollout_video( + replay_images_left_wrist, + episode_idx, + success=success, + task_description=task_description, + log_file=log_file, + notes="left_wrist_resized", + ) + save_rollout_video( + replay_images_right_wrist, + episode_idx, + success=success, + task_description=task_description, + log_file=log_file, + notes="right_wrist_resized", + ) + + +@draccus.wrap() +def eval_aloha(cfg: GenerateConfig) -> None: + """Main function to evaluate a trained policy in a real-world ALOHA environment.""" + # Validate configuration + validate_config(cfg) + + # Set random seed + set_seed_everywhere(cfg.seed) + + # Setup logging + log_file, local_log_filepath, run_id = setup_logging(cfg) + + # Get expected image dimensions + resize_size = get_image_resize_size(cfg) + + # Get ALOHA environment + env = get_aloha_env() + + # Get server endpoint for remote inference + server_endpoint = get_server_endpoint(cfg) + + # Initialize task description + task_description = "" + + # Start evaluation + num_rollouts_completed, total_successes = 0, 0 + + for episode_idx in tqdm.tqdm(range(cfg.num_rollouts_planned)): + # Get task description from user + task_description = get_next_task_label(task_description) + log_message(f"\nTask: {task_description}", log_file) + + log_message(f"Starting episode {num_rollouts_completed + 1}...", log_file) + + # Run episode + episode_stats, replay_images, replay_images_resized, replay_images_left_wrist, replay_images_right_wrist = ( + run_episode(cfg, env, task_description, server_endpoint, resize_size, log_file) + ) + + # Update counters + num_rollouts_completed += 1 + if episode_stats["success"]: + total_successes += 1 + + # Save videos + save_episode_videos( + replay_images, + replay_images_resized, + replay_images_left_wrist, + replay_images_right_wrist, + num_rollouts_completed, + episode_stats["success"], + task_description, + log_file, + ) + + # Log results + log_message(f"Success: {episode_stats['success']}", log_file) + log_message(f"# episodes completed so far: {num_rollouts_completed}", log_file) + log_message(f"# successes: {total_successes} ({total_successes / num_rollouts_completed * 100:.1f}%)", log_file) + log_message(f"Total model query time: {episode_stats['model_query_time']:.2f} sec", log_file) + log_message(f"Total episode elapsed time: {episode_stats['episode_duration']:.2f} sec", log_file) + + # Calculate final success rate + final_success_rate = float(total_successes) / float(num_rollouts_completed) if num_rollouts_completed > 0 else 0 + + # Log final results + log_message("\nFinal results:", log_file) + log_message(f"Total episodes: {num_rollouts_completed}", log_file) + log_message(f"Total successes: {total_successes}", log_file) + log_message(f"Overall success rate: {final_success_rate:.4f} ({final_success_rate * 100:.1f}%)", log_file) + + # Close log file + if log_file: + log_file.close() + + return final_success_rate + + +if __name__ == "__main__": + eval_aloha() diff --git a/experiments/robot/libero/libero_requirements.txt b/experiments/robot/libero/libero_requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..ca0c0467efa33b91ad76d66976da480c1e5f572f --- /dev/null +++ b/experiments/robot/libero/libero_requirements.txt @@ -0,0 +1,6 @@ +imageio[ffmpeg] +robosuite==1.4.1 +bddl +easydict +cloudpickle +gym diff --git a/experiments/robot/libero/libero_utils.py b/experiments/robot/libero/libero_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..8e2c23458cbf48a9d78bf61e7cf0cbd5e196d635 --- /dev/null +++ b/experiments/robot/libero/libero_utils.py @@ -0,0 +1,87 @@ +"""Utils for evaluating policies in LIBERO simulation environments.""" + +import math +import os + +import imageio +import numpy as np +import tensorflow as tf +from libero.libero import get_libero_path +from libero.libero.envs import OffScreenRenderEnv + +from experiments.robot.robot_utils import ( + DATE, + DATE_TIME, +) + + +def get_libero_env(task, model_family, resolution=256): + """Initializes and returns the LIBERO environment, along with the task description.""" + task_description = task.language + task_bddl_file = os.path.join(get_libero_path("bddl_files"), task.problem_folder, task.bddl_file) + env_args = {"bddl_file_name": task_bddl_file, "camera_heights": resolution, "camera_widths": resolution} + env = OffScreenRenderEnv(**env_args) + env.seed(0) # IMPORTANT: seed seems to affect object positions even when using fixed initial state + return env, task_description + + +def get_libero_dummy_action(model_family: str): + """Get dummy/no-op action, used to roll out the simulation while the robot does nothing.""" + return [0, 0, 0, 0, 0, 0, -1] + + +def get_libero_image(obs): + """Extracts third-person image from observations and preprocesses it.""" + img = obs["agentview_image"] + img = img[::-1, ::-1] # IMPORTANT: rotate 180 degrees to match train preprocessing + return img + + +def get_libero_wrist_image(obs): + """Extracts wrist camera image from observations and preprocesses it.""" + img = obs["robot0_eye_in_hand_image"] + img = img[::-1, ::-1] # IMPORTANT: rotate 180 degrees to match train preprocessing + return img + + +def save_rollout_video(rollout_images, idx, success, task_description, log_file=None): + """Saves an MP4 replay of an episode.""" + rollout_dir = f"./rollouts/{DATE}" + os.makedirs(rollout_dir, exist_ok=True) + processed_task_description = task_description.lower().replace(" ", "_").replace("\n", "_").replace(".", "_")[:50] + mp4_path = f"{rollout_dir}/{DATE_TIME}--openvla_oft--episode={idx}--success={success}--task={processed_task_description}.mp4" + video_writer = imageio.get_writer(mp4_path, fps=30) + for img in rollout_images: + video_writer.append_data(img) + video_writer.close() + print(f"Saved rollout MP4 at path {mp4_path}") + if log_file is not None: + log_file.write(f"Saved rollout MP4 at path {mp4_path}\n") + return mp4_path + + +def quat2axisangle(quat): + """ + Copied from robosuite: https://github.com/ARISE-Initiative/robosuite/blob/eafb81f54ffc104f905ee48a16bb15f059176ad3/robosuite/utils/transform_utils.py#L490C1-L512C55 + + Converts quaternion to axis-angle format. + Returns a unit vector direction scaled by its angle in radians. + + Args: + quat (np.array): (x,y,z,w) vec4 float angles + + Returns: + np.array: (ax,ay,az) axis-angle exponential coordinates + """ + # clip quaternion + if quat[3] > 1.0: + quat[3] = 1.0 + elif quat[3] < -1.0: + quat[3] = -1.0 + + den = np.sqrt(1.0 - quat[3] * quat[3]) + if math.isclose(den, 0.0): + # This is (close to) a zero degree rotation, immediately return + return np.zeros(3) + + return (quat[:3] * 2.0 * math.acos(quat[3])) / den diff --git a/experiments/robot/libero/regenerate_libero_dataset.py b/experiments/robot/libero/regenerate_libero_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..0966413ee197c6f6a2da4d03310943a84584dbb7 --- /dev/null +++ b/experiments/robot/libero/regenerate_libero_dataset.py @@ -0,0 +1,249 @@ +""" +Regenerates a LIBERO dataset (HDF5 files) by replaying demonstrations in the environments. + +Notes: + - We save image observations at 256x256px resolution (instead of 128x128). + - We filter out transitions with "no-op" (zero) actions that do not change the robot's state. + - We filter out unsuccessful demonstrations. + - In the LIBERO HDF5 data -> RLDS data conversion (not shown here), we rotate the images by + 180 degrees because we observe that the environments return images that are upside down + on our platform. + +Usage: + python experiments/robot/libero/regenerate_libero_dataset.py \ + --libero_task_suite [ libero_spatial | libero_object | libero_goal | libero_10 ] \ + --libero_raw_data_dir \ + --libero_target_dir + + Example (LIBERO-Spatial): + python experiments/robot/libero/regenerate_libero_dataset.py \ + --libero_task_suite libero_spatial \ + --libero_raw_data_dir ./LIBERO/libero/datasets/libero_spatial \ + --libero_target_dir ./LIBERO/libero/datasets/libero_spatial_no_noops + +""" + +import argparse +import json +import os +import time + +import h5py +import numpy as np +import robosuite.utils.transform_utils as T +import tqdm +from libero.libero import benchmark + +from experiments.robot.libero.libero_utils import ( + get_libero_dummy_action, + get_libero_env, +) + + +IMAGE_RESOLUTION = 256 + + +def is_noop(action, prev_action=None, threshold=1e-4): + """ + Returns whether an action is a no-op action. + + A no-op action satisfies two criteria: + (1) All action dimensions, except for the last one (gripper action), are near zero. + (2) The gripper action is equal to the previous timestep's gripper action. + + Explanation of (2): + Naively filtering out actions with just criterion (1) is not good because you will + remove actions where the robot is staying still but opening/closing its gripper. + So you also need to consider the current state (by checking the previous timestep's + gripper action as a proxy) to determine whether the action really is a no-op. + """ + # Special case: Previous action is None if this is the first action in the episode + # Then we only care about criterion (1) + if prev_action is None: + return np.linalg.norm(action[:-1]) < threshold + + # Normal case: Check both criteria (1) and (2) + gripper_action = action[-1] + prev_gripper_action = prev_action[-1] + return np.linalg.norm(action[:-1]) < threshold and gripper_action == prev_gripper_action + + +def main(args): + print(f"Regenerating {args.libero_task_suite} dataset!") + + # Create target directory + if os.path.isdir(args.libero_target_dir): + user_input = input(f"Target directory already exists at path: {args.libero_target_dir}\nEnter 'y' to overwrite the directory, or anything else to exit: ") + if user_input != 'y': + exit() + os.makedirs(args.libero_target_dir, exist_ok=True) + + # Prepare JSON file to record success/false and initial states per episode + metainfo_json_dict = {} + metainfo_json_out_path = f"./experiments/robot/libero/{args.libero_task_suite}_metainfo.json" + with open(metainfo_json_out_path, "w") as f: + # Just test that we can write to this file (we overwrite it later) + json.dump(metainfo_json_dict, f) + + # Get task suite + benchmark_dict = benchmark.get_benchmark_dict() + task_suite = benchmark_dict[args.libero_task_suite]() + num_tasks_in_suite = task_suite.n_tasks + + # Setup + num_replays = 0 + num_success = 0 + num_noops = 0 + + for task_id in tqdm.tqdm(range(num_tasks_in_suite)): + # Get task in suite + task = task_suite.get_task(task_id) + env, task_description = get_libero_env(task, "llava", resolution=IMAGE_RESOLUTION) + + # Get dataset for task + orig_data_path = os.path.join(args.libero_raw_data_dir, f"{task.name}_demo.hdf5") + assert os.path.exists(orig_data_path), f"Cannot find raw data file {orig_data_path}." + orig_data_file = h5py.File(orig_data_path, "r") + orig_data = orig_data_file["data"] + + # Create new HDF5 file for regenerated demos + new_data_path = os.path.join(args.libero_target_dir, f"{task.name}_demo.hdf5") + new_data_file = h5py.File(new_data_path, "w") + grp = new_data_file.create_group("data") + + for i in range(len(orig_data.keys())): + # Get demo data + demo_data = orig_data[f"demo_{i}"] + orig_actions = demo_data["actions"][()] + orig_states = demo_data["states"][()] + + # Reset environment, set initial state, and wait a few steps for environment to settle + env.reset() + env.set_init_state(orig_states[0]) + for _ in range(10): + obs, reward, done, info = env.step(get_libero_dummy_action("llava")) + + # Set up new data lists + states = [] + actions = [] + ee_states = [] + gripper_states = [] + joint_states = [] + robot_states = [] + agentview_images = [] + eye_in_hand_images = [] + + # Replay original demo actions in environment and record observations + for _, action in enumerate(orig_actions): + # Skip transitions with no-op actions + prev_action = actions[-1] if len(actions) > 0 else None + if is_noop(action, prev_action): + print(f"\tSkipping no-op action: {action}") + num_noops += 1 + continue + + if states == []: + # In the first timestep, since we're using the original initial state to initialize the environment, + # copy the initial state (first state in episode) over from the original HDF5 to the new one + states.append(orig_states[0]) + robot_states.append(demo_data["robot_states"][0]) + else: + # For all other timesteps, get state from environment and record it + states.append(env.sim.get_state().flatten()) + robot_states.append( + np.concatenate([obs["robot0_gripper_qpos"], obs["robot0_eef_pos"], obs["robot0_eef_quat"]]) + ) + + # Record original action (from demo) + actions.append(action) + + # Record data returned by environment + if "robot0_gripper_qpos" in obs: + gripper_states.append(obs["robot0_gripper_qpos"]) + joint_states.append(obs["robot0_joint_pos"]) + ee_states.append( + np.hstack( + ( + obs["robot0_eef_pos"], + T.quat2axisangle(obs["robot0_eef_quat"]), + ) + ) + ) + agentview_images.append(obs["agentview_image"]) + eye_in_hand_images.append(obs["robot0_eye_in_hand_image"]) + + # Execute demo action in environment + obs, reward, done, info = env.step(action.tolist()) + + # At end of episode, save replayed trajectories to new HDF5 files (only keep successes) + if done: + dones = np.zeros(len(actions)).astype(np.uint8) + dones[-1] = 1 + rewards = np.zeros(len(actions)).astype(np.uint8) + rewards[-1] = 1 + assert len(actions) == len(agentview_images) + + ep_data_grp = grp.create_group(f"demo_{i}") + obs_grp = ep_data_grp.create_group("obs") + obs_grp.create_dataset("gripper_states", data=np.stack(gripper_states, axis=0)) + obs_grp.create_dataset("joint_states", data=np.stack(joint_states, axis=0)) + obs_grp.create_dataset("ee_states", data=np.stack(ee_states, axis=0)) + obs_grp.create_dataset("ee_pos", data=np.stack(ee_states, axis=0)[:, :3]) + obs_grp.create_dataset("ee_ori", data=np.stack(ee_states, axis=0)[:, 3:]) + obs_grp.create_dataset("agentview_rgb", data=np.stack(agentview_images, axis=0)) + obs_grp.create_dataset("eye_in_hand_rgb", data=np.stack(eye_in_hand_images, axis=0)) + ep_data_grp.create_dataset("actions", data=actions) + ep_data_grp.create_dataset("states", data=np.stack(states)) + ep_data_grp.create_dataset("robot_states", data=np.stack(robot_states, axis=0)) + ep_data_grp.create_dataset("rewards", data=rewards) + ep_data_grp.create_dataset("dones", data=dones) + + num_success += 1 + + num_replays += 1 + + # Record success/false and initial environment state in metainfo dict + task_key = task_description.replace(" ", "_") + episode_key = f"demo_{i}" + if task_key not in metainfo_json_dict: + metainfo_json_dict[task_key] = {} + if episode_key not in metainfo_json_dict[task_key]: + metainfo_json_dict[task_key][episode_key] = {} + metainfo_json_dict[task_key][episode_key]["success"] = bool(done) + metainfo_json_dict[task_key][episode_key]["initial_state"] = orig_states[0].tolist() + + # Write metainfo dict to JSON file + # (We repeatedly overwrite, rather than doing this once at the end, just in case the script crashes midway) + with open(metainfo_json_out_path, "w") as f: + json.dump(metainfo_json_dict, f, indent=2) + + # Count total number of successful replays so far + print( + f"Total # episodes replayed: {num_replays}, Total # successes: {num_success} ({num_success / num_replays * 100:.1f} %)" + ) + + # Report total number of no-op actions filtered out so far + print(f" Total # no-op actions filtered out: {num_noops}") + + # Close HDF5 files + orig_data_file.close() + new_data_file.close() + print(f"Saved regenerated demos for task '{task_description}' at: {new_data_path}") + + print(f"Dataset regeneration complete! Saved new dataset at: {args.libero_target_dir}") + print(f"Saved metainfo JSON at: {metainfo_json_out_path}") + + +if __name__ == "__main__": + # Parse command-line arguments + parser = argparse.ArgumentParser() + parser.add_argument("--libero_task_suite", type=str, choices=["libero_spatial", "libero_object", "libero_goal", "libero_10", "libero_90"], + help="LIBERO task suite. Example: libero_spatial", required=True) + parser.add_argument("--libero_raw_data_dir", type=str, + help="Path to directory containing raw HDF5 dataset. Example: ./LIBERO/libero/datasets/libero_spatial", required=True) + parser.add_argument("--libero_target_dir", type=str, + help="Path to regenerated dataset directory. Example: ./LIBERO/libero/datasets/libero_spatial_no_noops", required=True) + args = parser.parse_args() + + # Start data regeneration + main(args) diff --git a/experiments/robot/libero/run_libero_eval.py b/experiments/robot/libero/run_libero_eval.py new file mode 100644 index 0000000000000000000000000000000000000000..cdf27f56c5464808481d56af0898475541e10817 --- /dev/null +++ b/experiments/robot/libero/run_libero_eval.py @@ -0,0 +1,531 @@ +""" +run_libero_eval.py + +Evaluates a trained policy in a LIBERO simulation benchmark task suite. +""" + +import json +import logging +import os +import sys +from collections import deque +from dataclasses import dataclass +from enum import Enum +from pathlib import Path +from typing import Optional, Union + +import draccus +import numpy as np +import tqdm +from libero.libero import benchmark + +import wandb + +# Append current directory so that interpreter can find experiments.robot +sys.path.append("../..") +from experiments.robot.libero.libero_utils import ( + get_libero_dummy_action, + get_libero_env, + get_libero_image, + get_libero_wrist_image, + quat2axisangle, + save_rollout_video, +) +from experiments.robot.openvla_utils import ( + get_action_head, + get_noisy_action_projector, + get_processor, + get_proprio_projector, + resize_image_for_policy, +) +from experiments.robot.robot_utils import ( + DATE_TIME, + get_action, + get_image_resize_size, + get_model, + invert_gripper_action, + normalize_gripper_action, + set_seed_everywhere, +) +from prismatic.vla.constants import NUM_ACTIONS_CHUNK + + +# Define task suite constants +class TaskSuite(str, Enum): + LIBERO_SPATIAL = "libero_spatial" + LIBERO_OBJECT = "libero_object" + LIBERO_GOAL = "libero_goal" + LIBERO_10 = "libero_10" + LIBERO_90 = "libero_90" + + +# Define max steps for each task suite +TASK_MAX_STEPS = { + TaskSuite.LIBERO_SPATIAL: 220, # longest training demo has 193 steps + TaskSuite.LIBERO_OBJECT: 280, # longest training demo has 254 steps + TaskSuite.LIBERO_GOAL: 300, # longest training demo has 270 steps + TaskSuite.LIBERO_10: 520, # longest training demo has 505 steps + TaskSuite.LIBERO_90: 400, # longest training demo has 373 steps +} + + +# Set up logging +logging.basicConfig( + level=logging.INFO, + format="%(asctime)s [%(levelname)s] %(message)s", + handlers=[logging.StreamHandler()], +) +logger = logging.getLogger(__name__) + + +@dataclass +class GenerateConfig: + # fmt: off + + ################################################################################################################# + # Model-specific parameters + ################################################################################################################# + model_family: str = "openvla" # Model family + pretrained_checkpoint: Union[str, Path] = "" # Pretrained checkpoint path + + use_l1_regression: bool = True # If True, uses continuous action head with L1 regression objective + use_diffusion: bool = False # If True, uses continuous action head with diffusion modeling objective (DDIM) + num_diffusion_steps_train: int = 50 # (When `diffusion==True`) Number of diffusion steps used for training + num_diffusion_steps_inference: int = 50 # (When `diffusion==True`) Number of diffusion steps used for inference + use_film: bool = False # If True, uses FiLM to infuse language inputs into visual features + num_images_in_input: int = 2 # Number of images in the VLA input (default: 1) + use_proprio: bool = True # Whether to include proprio state in input + + center_crop: bool = True # Center crop? (if trained w/ random crop image aug) + num_open_loop_steps: int = 8 # Number of actions to execute open-loop before requerying policy + + lora_rank: int = 32 # Rank of LoRA weight matrix (MAKE SURE THIS MATCHES TRAINING!) + + unnorm_key: Union[str, Path] = "" # Action un-normalization key + + load_in_8bit: bool = False # (For OpenVLA only) Load with 8-bit quantization + load_in_4bit: bool = False # (For OpenVLA only) Load with 4-bit quantization + + ################################################################################################################# + # LIBERO environment-specific parameters + ################################################################################################################# + task_suite_name: str = TaskSuite.LIBERO_SPATIAL # Task suite + num_steps_wait: int = 10 # Number of steps to wait for objects to stabilize in sim + num_trials_per_task: int = 50 # Number of rollouts per task + initial_states_path: str = "DEFAULT" # "DEFAULT", or path to initial states JSON file + env_img_res: int = 256 # Resolution for environment images (not policy input resolution) + + ################################################################################################################# + # Utils + ################################################################################################################# + run_id_note: Optional[str] = None # Extra note to add to end of run ID for logging + local_log_dir: str = "./experiments/logs" # Local directory for eval logs + + use_wandb: bool = False # Whether to also log results in Weights & Biases + wandb_entity: str = "your-wandb-entity" # Name of WandB entity + wandb_project: str = "your-wandb-project" # Name of WandB project + + seed: int = 7 # Random Seed (for reproducibility) + + # fmt: on + + +def validate_config(cfg: GenerateConfig) -> None: + """Validate configuration parameters.""" + assert cfg.pretrained_checkpoint is not None, "pretrained_checkpoint must not be None!" + + if "image_aug" in str(cfg.pretrained_checkpoint): + assert cfg.center_crop, "Expecting `center_crop==True` because model was trained with image augmentations!" + + assert not (cfg.load_in_8bit and cfg.load_in_4bit), "Cannot use both 8-bit and 4-bit quantization!" + + # Validate task suite + assert cfg.task_suite_name in [suite.value for suite in TaskSuite], f"Invalid task suite: {cfg.task_suite_name}" + + +def initialize_model(cfg: GenerateConfig): + """Initialize model and associated components.""" + # Load model + model = get_model(cfg) + + # Load proprio projector if needed + proprio_projector = None + if cfg.use_proprio: + proprio_projector = get_proprio_projector( + cfg, + model.llm_dim, + proprio_dim=8, # 8-dimensional proprio for LIBERO + ) + + # Load action head if needed + action_head = None + if cfg.use_l1_regression or cfg.use_diffusion: + action_head = get_action_head(cfg, model.llm_dim) + + # Load noisy action projector if using diffusion + noisy_action_projector = None + if cfg.use_diffusion: + noisy_action_projector = get_noisy_action_projector(cfg, model.llm_dim) + + # Get OpenVLA processor if needed + processor = None + if cfg.model_family == "openvla": + processor = get_processor(cfg) + check_unnorm_key(cfg, model) + + return model, action_head, proprio_projector, noisy_action_projector, processor + + +def check_unnorm_key(cfg: GenerateConfig, model) -> None: + """Check that the model contains the action un-normalization key.""" + # Initialize unnorm_key + unnorm_key = cfg.task_suite_name + + # In some cases, the key must be manually modified (e.g. after training on a modified version of the dataset + # with the suffix "_no_noops" in the dataset name) + if unnorm_key not in model.norm_stats and f"{unnorm_key}_no_noops" in model.norm_stats: + unnorm_key = f"{unnorm_key}_no_noops" + + assert unnorm_key in model.norm_stats, f"Action un-norm key {unnorm_key} not found in VLA `norm_stats`!" + + # Set the unnorm_key in cfg + cfg.unnorm_key = unnorm_key + + +def setup_logging(cfg: GenerateConfig): + """Set up logging to file and optionally to wandb.""" + # Create run ID + run_id = f"EVAL-{cfg.task_suite_name}-{cfg.model_family}-{DATE_TIME}" + if cfg.run_id_note is not None: + run_id += f"--{cfg.run_id_note}" + + # Set up local logging + os.makedirs(cfg.local_log_dir, exist_ok=True) + local_log_filepath = os.path.join(cfg.local_log_dir, run_id + ".txt") + log_file = open(local_log_filepath, "w") + logger.info(f"Logging to local log file: {local_log_filepath}") + + # Initialize Weights & Biases logging if enabled + if cfg.use_wandb: + wandb.init( + entity=cfg.wandb_entity, + project=cfg.wandb_project, + name=run_id, + ) + + return log_file, local_log_filepath, run_id + + +def log_message(message: str, log_file=None): + """Log a message to console and optionally to a log file.""" + logger.info(message) + if log_file: + log_file.write(message + "\n") + log_file.flush() + + +def load_initial_states(cfg: GenerateConfig, task_suite, task_id: int, log_file=None): + """Load initial states for the given task.""" + # Get default initial states + initial_states = task_suite.get_task_init_states(task_id) + + # If using custom initial states, load them from file + if cfg.initial_states_path != "DEFAULT": + with open(cfg.initial_states_path, "r") as f: + all_initial_states = json.load(f) + log_message(f"Using initial states from {cfg.initial_states_path}", log_file) + return initial_states, all_initial_states + else: + log_message("Using default initial states", log_file) + return initial_states, None + + +def prepare_observation(obs, resize_size): + """Prepare observation for policy input.""" + # Get preprocessed images + img = get_libero_image(obs) + wrist_img = get_libero_wrist_image(obs) + + # Resize images to size expected by model + img_resized = resize_image_for_policy(img, resize_size) + wrist_img_resized = resize_image_for_policy(wrist_img, resize_size) + + # Prepare observations dict + observation = { + "full_image": img_resized, + "wrist_image": wrist_img_resized, + "state": np.concatenate( + (obs["robot0_eef_pos"], quat2axisangle(obs["robot0_eef_quat"]), obs["robot0_gripper_qpos"]) + ), + } + + return observation, img # Return both processed observation and original image for replay + + +def process_action(action, model_family): + """Process action before sending to environment.""" + # Normalize gripper action [0,1] -> [-1,+1] because the environment expects the latter + action = normalize_gripper_action(action, binarize=True) + + # [OpenVLA] The dataloader flips the sign of the gripper action to align with other datasets + # (0 = close, 1 = open), so flip it back (-1 = open, +1 = close) before executing the action + if model_family == "openvla": + action = invert_gripper_action(action) + + return action + + +def run_episode( + cfg: GenerateConfig, + env, + task_description: str, + model, + resize_size, + processor=None, + action_head=None, + proprio_projector=None, + noisy_action_projector=None, + initial_state=None, + log_file=None, +): + """Run a single episode in the environment.""" + # Reset environment + env.reset() + + # Set initial state if provided + if initial_state is not None: + obs = env.set_init_state(initial_state) + else: + obs = env.get_observation() + + # Initialize action queue + if cfg.num_open_loop_steps != NUM_ACTIONS_CHUNK: + print(f"WARNING: cfg.num_open_loop_steps ({cfg.num_open_loop_steps}) does not match the NUM_ACTIONS_CHUNK " + f"({NUM_ACTIONS_CHUNK}) constant defined in prismatic.vla.constants! For best performance (in terms of " + "both speed and success rate), we recommend executing the full action chunk.") + action_queue = deque(maxlen=cfg.num_open_loop_steps) + + # Setup + t = 0 + replay_images = [] + max_steps = TASK_MAX_STEPS[cfg.task_suite_name] + + # Run episode + success = False + try: + while t < max_steps + cfg.num_steps_wait: + # Do nothing for the first few timesteps to let objects stabilize + if t < cfg.num_steps_wait: + obs, reward, done, info = env.step(get_libero_dummy_action(cfg.model_family)) + t += 1 + continue + + # Prepare observation + observation, img = prepare_observation(obs, resize_size) + replay_images.append(img) + + # If action queue is empty, requery model + if len(action_queue) == 0: + # Query model to get action + actions = get_action( + cfg, + model, + observation, + task_description, + processor=processor, + action_head=action_head, + proprio_projector=proprio_projector, + noisy_action_projector=noisy_action_projector, + use_film=cfg.use_film, + ) + action_queue.extend(actions) + + # Get action from queue + action = action_queue.popleft() + + # Process action + action = process_action(action, cfg.model_family) + + # Execute action in environment + obs, reward, done, info = env.step(action.tolist()) + if done: + success = True + break + t += 1 + + except Exception as e: + log_message(f"Episode error: {e}", log_file) + + return success, replay_images + + +def run_task( + cfg: GenerateConfig, + task_suite, + task_id: int, + model, + resize_size, + processor=None, + action_head=None, + proprio_projector=None, + noisy_action_projector=None, + total_episodes=0, + total_successes=0, + log_file=None, +): + """Run evaluation for a single task.""" + # Get task + task = task_suite.get_task(task_id) + + # Get initial states + initial_states, all_initial_states = load_initial_states(cfg, task_suite, task_id, log_file) + + # Initialize environment and get task description + env, task_description = get_libero_env(task, cfg.model_family, resolution=cfg.env_img_res) + + # Start episodes + task_episodes, task_successes = 0, 0 + for episode_idx in tqdm.tqdm(range(cfg.num_trials_per_task)): + log_message(f"\nTask: {task_description}", log_file) + + # Handle initial state + if cfg.initial_states_path == "DEFAULT": + # Use default initial state + initial_state = initial_states[episode_idx] + else: + # Get keys for fetching initial episode state from JSON + initial_states_task_key = task_description.replace(" ", "_") + episode_key = f"demo_{episode_idx}" + + # Skip episode if expert demonstration failed to complete the task + if not all_initial_states[initial_states_task_key][episode_key]["success"]: + log_message(f"Skipping task {task_id} episode {episode_idx} due to failed expert demo!", log_file) + continue + + # Get initial state + initial_state = np.array(all_initial_states[initial_states_task_key][episode_key]["initial_state"]) + + log_message(f"Starting episode {task_episodes + 1}...", log_file) + + # Run episode + success, replay_images = run_episode( + cfg, + env, + task_description, + model, + resize_size, + processor, + action_head, + proprio_projector, + noisy_action_projector, + initial_state, + log_file, + ) + + # Update counters + task_episodes += 1 + total_episodes += 1 + if success: + task_successes += 1 + total_successes += 1 + + # Save replay video + save_rollout_video( + replay_images, total_episodes, success=success, task_description=task_description, log_file=log_file + ) + + # Log results + log_message(f"Success: {success}", log_file) + log_message(f"# episodes completed so far: {total_episodes}", log_file) + log_message(f"# successes: {total_successes} ({total_successes / total_episodes * 100:.1f}%)", log_file) + + # Log task results + task_success_rate = float(task_successes) / float(task_episodes) if task_episodes > 0 else 0 + total_success_rate = float(total_successes) / float(total_episodes) if total_episodes > 0 else 0 + + log_message(f"Current task success rate: {task_success_rate}", log_file) + log_message(f"Current total success rate: {total_success_rate}", log_file) + + # Log to wandb if enabled + if cfg.use_wandb: + wandb.log( + { + f"success_rate/{task_description}": task_success_rate, + f"num_episodes/{task_description}": task_episodes, + } + ) + + return total_episodes, total_successes + + +@draccus.wrap() +def eval_libero(cfg: GenerateConfig) -> float: + """Main function to evaluate a trained policy on LIBERO benchmark tasks.""" + # Validate configuration + validate_config(cfg) + + # Set random seed + set_seed_everywhere(cfg.seed) + + # Initialize model and components + model, action_head, proprio_projector, noisy_action_projector, processor = initialize_model(cfg) + + # Get expected image dimensions + resize_size = get_image_resize_size(cfg) + + # Setup logging + log_file, local_log_filepath, run_id = setup_logging(cfg) + + # Initialize LIBERO task suite + benchmark_dict = benchmark.get_benchmark_dict() + task_suite = benchmark_dict[cfg.task_suite_name]() + num_tasks = task_suite.n_tasks + + log_message(f"Task suite: {cfg.task_suite_name}", log_file) + + # Start evaluation + total_episodes, total_successes = 0, 0 + for task_id in tqdm.tqdm(range(num_tasks)): + total_episodes, total_successes = run_task( + cfg, + task_suite, + task_id, + model, + resize_size, + processor, + action_head, + proprio_projector, + noisy_action_projector, + total_episodes, + total_successes, + log_file, + ) + + # Calculate final success rate + final_success_rate = float(total_successes) / float(total_episodes) if total_episodes > 0 else 0 + + # Log final results + log_message("Final results:", log_file) + log_message(f"Total episodes: {total_episodes}", log_file) + log_message(f"Total successes: {total_successes}", log_file) + log_message(f"Overall success rate: {final_success_rate:.4f} ({final_success_rate * 100:.1f}%)", log_file) + + # Log to wandb if enabled + if cfg.use_wandb: + wandb.log( + { + "success_rate/total": final_success_rate, + "num_episodes/total": total_episodes, + } + ) + wandb.save(local_log_filepath) + + # Close log file + if log_file: + log_file.close() + + return final_success_rate + + +if __name__ == "__main__": + eval_libero() diff --git a/experiments/robot/openvla_utils.py b/experiments/robot/openvla_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..c14c9dcfd3d10bbf80191dd64120e8746e72bb6a --- /dev/null +++ b/experiments/robot/openvla_utils.py @@ -0,0 +1,818 @@ +"""Utils for evaluating OpenVLA or fine-tuned OpenVLA policies.""" + +import filecmp +import json +import os +import shutil +import time +from datetime import datetime +from pathlib import Path +from typing import Any, Dict, List, Optional, Tuple, Union + +import json_numpy +import numpy as np +import requests +import tensorflow as tf +import torch +from huggingface_hub import HfApi, hf_hub_download +from PIL import Image +from transformers import AutoConfig, AutoImageProcessor, AutoModelForVision2Seq, AutoProcessor + +# Apply JSON numpy patch for serialization +json_numpy.patch() + +from prismatic.extern.hf.configuration_prismatic import OpenVLAConfig +from prismatic.extern.hf.modeling_prismatic import OpenVLAForActionPrediction +from prismatic.extern.hf.processing_prismatic import PrismaticImageProcessor, PrismaticProcessor +from prismatic.models.action_heads import DiffusionActionHead, L1RegressionActionHead +from prismatic.models.film_vit_wrapper import FiLMedPrismaticVisionBackbone +from prismatic.models.projectors import NoisyActionProjector, ProprioProjector +from prismatic.vla.constants import ( + ACTION_DIM, + ACTION_PROPRIO_NORMALIZATION_TYPE, +) +from prismatic.vla.datasets.rlds.utils.data_utils import NormalizationType + +# Initialize important constants +DATE = time.strftime("%Y_%m_%d") +DATE_TIME = time.strftime("%Y_%m_%d-%H_%M_%S") +DEVICE = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu") +OPENVLA_IMAGE_SIZE = 224 # Standard image size expected by OpenVLA + +# Configure NumPy print settings +np.set_printoptions(formatter={"float": lambda x: "{0:0.3f}".format(x)}) + + +def model_is_on_hf_hub(model_path: str) -> bool: + """Checks whether a model path points to a model on Hugging Face Hub.""" + # If the API call below runs without error, the model is on the hub + try: + HfApi().model_info(model_path) + return True + except Exception: + return False + + +def update_auto_map(pretrained_checkpoint: str) -> None: + """ + Update the AutoMap configuration in the checkpoint config.json file. + + This loads the config.json file inside the checkpoint directory and overwrites + the AutoConfig and AutoModelForVision2Seq fields to use OpenVLA-specific classes. + + Args: + pretrained_checkpoint: Path to the checkpoint directory + """ + if not os.path.isdir(pretrained_checkpoint): + return + + config_path = os.path.join(pretrained_checkpoint, "config.json") + if not os.path.exists(config_path): + print(f"Warning: No config.json found at {config_path}") + return + + # Create timestamped backup + timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") + backup_path = os.path.join(pretrained_checkpoint, f"config.json.back.{timestamp}") + shutil.copy2(config_path, backup_path) + print(f"Created backup of original config at: {os.path.abspath(backup_path)}") + + # Read and update the config + with open(config_path, "r") as f: + config = json.load(f) + + config["auto_map"] = { + "AutoConfig": "configuration_prismatic.OpenVLAConfig", + "AutoModelForVision2Seq": "modeling_prismatic.OpenVLAForActionPrediction", + } + + # Write back the updated config + with open(config_path, "w") as f: + json.dump(config, f, indent=2) + + print(f"Updated config.json at: {os.path.abspath(config_path)}") + print("Changes made:") + print(' - Set AutoConfig to "configuration_prismatic.OpenVLAConfig"') + print(' - Set AutoModelForVision2Seq to "modeling_prismatic.OpenVLAForActionPrediction"') + + +def check_identical_files(path1: Union[str, Path], path2: Union[str, Path]) -> bool: + """ + Check if two files are identical in content. + + Args: + path1: Path to the first file + path2: Path to the second file + + Returns: + bool: True if files are identical, False otherwise + """ + path1, path2 = Path(path1), Path(path2) + + # First check if file sizes match + if path1.stat().st_size != path2.stat().st_size: + return False + + # Check if contents match + return filecmp.cmp(path1, path2, shallow=False) + + +def _handle_file_sync(curr_filepath: str, checkpoint_filepath: str, file_type: str) -> None: + """ + Handle syncing of files between current directory and checkpoint. + + Creates backups if files exist but differ, and copies current versions to checkpoint. + + Args: + curr_filepath: Path to the current file version + checkpoint_filepath: Path where the file should be in the checkpoint + file_type: Description of the file type for logging + """ + if os.path.exists(checkpoint_filepath): + # Check if existing files are identical + match = check_identical_files(curr_filepath, checkpoint_filepath) + + if not match: + print( + "\n------------------------------------------------------------------------------------------------\n" + f"Found mismatch between:\n" + f"Current: {curr_filepath}\n" + f"Checkpoint: {checkpoint_filepath}\n" + ) + + # Create timestamped backup + timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") + backup_path = f"{checkpoint_filepath}.back.{timestamp}" + shutil.copy2(checkpoint_filepath, backup_path) + print(f"Created backup of original checkpoint file at: {os.path.abspath(backup_path)}") + + # Copy current version to checkpoint directory + shutil.copy2(curr_filepath, checkpoint_filepath) + print(f"Copied current version to checkpoint at: {os.path.abspath(checkpoint_filepath)}") + print( + f"Changes complete. The checkpoint will now use the current version of {file_type}" + "\n------------------------------------------------------------------------------------------------\n" + ) + else: + # If file doesn't exist in checkpoint directory, copy it + shutil.copy2(curr_filepath, checkpoint_filepath) + print( + "\n------------------------------------------------------------------------------------------------\n" + f"No {file_type} found in checkpoint directory.\n" + f"Copied current version from: {curr_filepath}\n" + f"To checkpoint location: {os.path.abspath(checkpoint_filepath)}" + "\n------------------------------------------------------------------------------------------------\n" + ) + + +def check_model_logic_mismatch(pretrained_checkpoint: str) -> None: + """ + Check and sync model logic files between current code and checkpoint. + + Handles the relationship between current and checkpoint versions of both + modeling_prismatic.py and configuration_prismatic.py: + - If checkpoint file exists and differs: creates backup and copies current version + - If checkpoint file doesn't exist: copies current version + + Args: + pretrained_checkpoint: Path to the checkpoint directory + """ + if not os.path.isdir(pretrained_checkpoint): + return + + # Find current files + curr_files = {"modeling_prismatic.py": None, "configuration_prismatic.py": None} + + for root, _, files in os.walk("./prismatic/"): + for filename in curr_files.keys(): + if filename in files and curr_files[filename] is None: + curr_files[filename] = os.path.join(root, filename) + + # Check and handle each file + for filename, curr_filepath in curr_files.items(): + if curr_filepath is None: + print(f"WARNING: `{filename}` is not found anywhere in the current directory.") + continue + + checkpoint_filepath = os.path.join(pretrained_checkpoint, filename) + _handle_file_sync(curr_filepath, checkpoint_filepath, filename) + + +def find_checkpoint_file(pretrained_checkpoint: str, file_pattern: str) -> str: + """ + Find a specific checkpoint file matching a pattern. + + Args: + pretrained_checkpoint: Path to the checkpoint directory + file_pattern: String pattern to match in filenames + + Returns: + str: Path to the matching checkpoint file + + Raises: + AssertionError: If no files or multiple files match the pattern + """ + assert os.path.isdir(pretrained_checkpoint), f"Checkpoint path must be a directory: {pretrained_checkpoint}" + + checkpoint_files = [] + for filename in os.listdir(pretrained_checkpoint): + if file_pattern in filename and "checkpoint" in filename: + full_path = os.path.join(pretrained_checkpoint, filename) + checkpoint_files.append(full_path) + + assert len(checkpoint_files) == 1, ( + f"Expected exactly 1 {file_pattern} checkpoint but found {len(checkpoint_files)} in directory: {pretrained_checkpoint}" + ) + + return checkpoint_files[0] + + +def load_component_state_dict(checkpoint_path: str) -> Dict[str, torch.Tensor]: + """ + Load a component's state dict from checkpoint and handle DDP prefix if present. + + Args: + checkpoint_path: Path to the checkpoint file + + Returns: + Dict: The processed state dictionary for loading + """ + state_dict = torch.load(checkpoint_path, weights_only=True) + + # If the component was trained with DDP, elements in the state dict have prefix "module." which we must remove + new_state_dict = {} + for k, v in state_dict.items(): + if k.startswith("module."): + new_state_dict[k[7:]] = v + else: + new_state_dict[k] = v + + return new_state_dict + + +def get_vla(cfg: Any) -> torch.nn.Module: + """ + Load and initialize the VLA model from checkpoint. + + Args: + cfg: Configuration object + + Returns: + torch.nn.Module: The initialized VLA model + """ + print("Instantiating pretrained VLA policy...") + + # If loading a locally stored pretrained checkpoint, check whether config or model files + # need to be synced so that any changes the user makes to the VLA modeling code will + # actually go into effect + # If loading a pretrained checkpoint from Hugging Face Hub, we just assume that the policy + # will be used as is, with its original modeling logic + if not model_is_on_hf_hub(cfg.pretrained_checkpoint): + # Register OpenVLA model to HF Auto Classes (not needed if the model is on HF Hub) + AutoConfig.register("openvla", OpenVLAConfig) + AutoImageProcessor.register(OpenVLAConfig, PrismaticImageProcessor) + AutoProcessor.register(OpenVLAConfig, PrismaticProcessor) + AutoModelForVision2Seq.register(OpenVLAConfig, OpenVLAForActionPrediction) + + # Update config.json and sync model files + update_auto_map(cfg.pretrained_checkpoint) + check_model_logic_mismatch(cfg.pretrained_checkpoint) + + # Load the model + vla = AutoModelForVision2Seq.from_pretrained( + cfg.pretrained_checkpoint, + # attn_implementation="flash_attention_2", + torch_dtype=torch.bfloat16, + load_in_8bit=cfg.load_in_8bit, + load_in_4bit=cfg.load_in_4bit, + low_cpu_mem_usage=True, + trust_remote_code=True, + ) + + # If using FiLM, wrap the vision backbone to allow for infusion of language inputs + if cfg.use_film: + vla = _apply_film_to_vla(vla, cfg) + + # Set number of images in model input + vla.vision_backbone.set_num_images_in_input(cfg.num_images_in_input) + + vla.eval() + + # Move model to device if not using quantization + if not cfg.load_in_8bit and not cfg.load_in_4bit: + vla = vla.to(DEVICE) + + # Load dataset stats for action normalization + _load_dataset_stats(vla, cfg.pretrained_checkpoint) + + return vla + + +def _apply_film_to_vla(vla: torch.nn.Module, cfg: Any) -> torch.nn.Module: + """ + Apply FiLM (Feature-wise Linear Modulation) to the VLA vision backbone. + + Args: + vla: The VLA model + cfg: Configuration object with model parameters + + Returns: + torch.nn.Module: VLA model with FiLM applied + """ + from peft import LoraConfig, get_peft_model + + # Apply LoRA configuration + lora_config = LoraConfig( + r=cfg.lora_rank, + lora_alpha=min(cfg.lora_rank, 16), + lora_dropout=0.0, + target_modules="all-linear", + init_lora_weights="gaussian", + ) + vla = get_peft_model(vla, lora_config) + + # Create and apply FiLMed vision backbone + new_vision_backbone = FiLMedPrismaticVisionBackbone( + vision_backbone=vla.vision_backbone, llm_dim=vla.llm_dim, + ) + vla.model.vision_backbone = new_vision_backbone + + # Load vision backbone checkpoint + checkpoint_path = find_checkpoint_file(cfg.pretrained_checkpoint, "vision_backbone") + state_dict = torch.load(checkpoint_path, weights_only=True) + vla.model.vision_backbone.load_state_dict(state_dict) + + # Use the model component instead of wrapper and convert to bfloat16 + vla = vla.model + vla.vision_backbone = vla.vision_backbone.to(torch.bfloat16) + + return vla + + +def _load_dataset_stats(vla: torch.nn.Module, checkpoint_path: str) -> None: + """ + Load dataset statistics used during training for action normalization. + + Args: + vla: The VLA model + checkpoint_path: Path to the checkpoint directory + """ + if model_is_on_hf_hub(checkpoint_path): + # Download dataset stats directly from HF Hub + dataset_statistics_path = hf_hub_download( + repo_id=checkpoint_path, + filename="dataset_statistics.json", + ) + else: + dataset_statistics_path = os.path.join(checkpoint_path, "dataset_statistics.json") + if os.path.isfile(dataset_statistics_path): + with open(dataset_statistics_path, "r") as f: + norm_stats = json.load(f) + vla.norm_stats = norm_stats + else: + print( + "WARNING: No local dataset_statistics.json file found for current checkpoint.\n" + "You can ignore this if you are loading the base VLA (i.e. not fine-tuned) checkpoint." + "Otherwise, you may run into errors when trying to call `predict_action()` due to an absent `unnorm_key`." + ) + + +def get_processor(cfg: Any) -> AutoProcessor: + """ + Get the VLA model's Hugging Face processor. + + Args: + cfg: Configuration object with model parameters + + Returns: + AutoProcessor: The model's processor + """ + return AutoProcessor.from_pretrained(cfg.pretrained_checkpoint, trust_remote_code=True) + + +def get_proprio_projector(cfg: Any, llm_dim: int, proprio_dim: int) -> ProprioProjector: + """ + Get proprioception projector for the VLA model. + + Args: + cfg: Configuration object with model parameters + llm_dim: Dimension of the language model + proprio_dim: Dimension of proprioception data + + Returns: + ProprioProjector: The initialized proprio projector + """ + # Initialize projector and move to device + proprio_projector = ProprioProjector( + llm_dim=llm_dim, + proprio_dim=proprio_dim, + ).to(DEVICE) + proprio_projector = proprio_projector.to(torch.bfloat16).to(DEVICE) + proprio_projector.eval() + + # Find and load checkpoint (may be on Hugging Face Hub or stored locally) + if model_is_on_hf_hub(cfg.pretrained_checkpoint): + model_path_to_proprio_projector_name = { + "moojink/openvla-7b-oft-finetuned-libero-spatial": "proprio_projector--150000_checkpoint.pt", + "moojink/openvla-7b-oft-finetuned-libero-object": "proprio_projector--150000_checkpoint.pt", + "moojink/openvla-7b-oft-finetuned-libero-goal": "proprio_projector--50000_checkpoint.pt", + "moojink/openvla-7b-oft-finetuned-libero-10": "proprio_projector--150000_checkpoint.pt", + "moojink/openvla-7b-oft-finetuned-libero-spatial-object-goal-10": "proprio_projector--300000_checkpoint.pt", + } + if cfg.pretrained_checkpoint not in model_path_to_proprio_projector_name.keys(): + raise ValueError("Unsupported HF Hub pretrained checkpoint found!") + # Download proprio projector directly from HF Hub + proprio_projector_path = hf_hub_download( + repo_id=cfg.pretrained_checkpoint, filename=model_path_to_proprio_projector_name[cfg.pretrained_checkpoint] + ) + state_dict = load_component_state_dict(proprio_projector_path) + proprio_projector.load_state_dict(state_dict) + else: + checkpoint_path = find_checkpoint_file(cfg.pretrained_checkpoint, "proprio_projector") + state_dict = load_component_state_dict(checkpoint_path) + proprio_projector.load_state_dict(state_dict) + + return proprio_projector + + +def get_noisy_action_projector(cfg: Any, llm_dim: int) -> NoisyActionProjector: + """ + Get noisy action projector for diffusion-based action prediction. + + Args: + cfg: Configuration object with model parameters + llm_dim: Dimension of the language model + + Returns: + NoisyActionProjector: The initialized noisy action projector + """ + # Initialize projector and move to device + noisy_action_projector = NoisyActionProjector( + llm_dim=llm_dim, + ).to(DEVICE) + noisy_action_projector = noisy_action_projector.to(torch.bfloat16).to(DEVICE) + noisy_action_projector.eval() + + # Find and load checkpoint + checkpoint_path = find_checkpoint_file(cfg.pretrained_checkpoint, "noisy_action_projector") + state_dict = load_component_state_dict(checkpoint_path) + noisy_action_projector.load_state_dict(state_dict) + + return noisy_action_projector + + +def get_action_head(cfg: Any, llm_dim: int) -> Union[L1RegressionActionHead, DiffusionActionHead]: + """ + Get action head for continuous value prediction. + + Args: + cfg: Configuration object with model parameters + llm_dim: Dimension of the language model + + Returns: + Union[L1RegressionActionHead, DiffusionActionHead]: The initialized action head + + Raises: + AssertionError: If both L1 regression and diffusion are specified + """ + assert not (cfg.use_l1_regression and cfg.use_diffusion), "Cannot use both L1 regression and diffusion action head!" + + # Initialize appropriate action head based on configuration + if cfg.use_l1_regression: + action_head = L1RegressionActionHead(input_dim=llm_dim, hidden_dim=llm_dim, action_dim=ACTION_DIM) + elif cfg.use_diffusion: + action_head = DiffusionActionHead( + input_dim=llm_dim, hidden_dim=llm_dim, action_dim=ACTION_DIM, num_diffusion_steps_train=cfg.num_diffusion_steps_train + ) + # Set number of diffusion steps for inference + action_head.noise_scheduler.set_timesteps(cfg.num_diffusion_steps_inference) + else: + raise ValueError("Either use_l1_regression or use_diffusion must be True") + + action_head = action_head.to(torch.bfloat16).to(DEVICE) + action_head.eval() + + # Find and load checkpoint (may be on Hugging Face Hub or stored locally) + if model_is_on_hf_hub(cfg.pretrained_checkpoint): + model_path_to_action_head_name = { + "moojink/openvla-7b-oft-finetuned-libero-spatial": "action_head--150000_checkpoint.pt", + "moojink/openvla-7b-oft-finetuned-libero-object": "action_head--150000_checkpoint.pt", + "moojink/openvla-7b-oft-finetuned-libero-goal": "action_head--50000_checkpoint.pt", + "moojink/openvla-7b-oft-finetuned-libero-10": "action_head--150000_checkpoint.pt", + "moojink/openvla-7b-oft-finetuned-libero-spatial-object-goal-10": "action_head--300000_checkpoint.pt", + } + if cfg.pretrained_checkpoint not in model_path_to_action_head_name.keys(): + raise ValueError("Unsupported HF Hub pretrained checkpoint found!") + # Download proprio projector directly from HF Hub + action_head_path = hf_hub_download( + repo_id=cfg.pretrained_checkpoint, filename=model_path_to_action_head_name[cfg.pretrained_checkpoint] + ) + state_dict = load_component_state_dict(action_head_path) + action_head.load_state_dict(state_dict) + else: + checkpoint_path = find_checkpoint_file(cfg.pretrained_checkpoint, "action_head") + state_dict = load_component_state_dict(checkpoint_path) + action_head.load_state_dict(state_dict) + + return action_head + + +def resize_image_for_policy(img: np.ndarray, resize_size: Union[int, Tuple[int, int]]) -> np.ndarray: + """ + Resize an image to match the policy's expected input size. + + Uses the same resizing scheme as in the training data pipeline for distribution matching. + + Args: + img: Numpy array containing the image + resize_size: Target size as int (square) or (height, width) tuple + + Returns: + np.ndarray: The resized image + """ + assert isinstance(resize_size, int) or isinstance(resize_size, tuple) + if isinstance(resize_size, int): + resize_size = (resize_size, resize_size) + + # Resize using the same pipeline as in RLDS dataset builder + img = tf.image.encode_jpeg(img) # Encode as JPEG + img = tf.io.decode_image(img, expand_animations=False, dtype=tf.uint8) # Decode back + img = tf.image.resize(img, resize_size, method="lanczos3", antialias=True) + img = tf.cast(tf.clip_by_value(tf.round(img), 0, 255), tf.uint8) + + return img.numpy() + + +def crop_and_resize(image: tf.Tensor, crop_scale: float, batch_size: int) -> tf.Tensor: + """ + Center-crop an image and resize it back to original dimensions. + + Uses the same logic as in the training data pipeline for distribution matching. + + Args: + image: TF Tensor of shape (batch_size, H, W, C) or (H, W, C) with values in [0,1] + crop_scale: Area of center crop relative to original image + batch_size: Batch size + + Returns: + tf.Tensor: The cropped and resized image + """ + # Handle 3D inputs by adding batch dimension if needed + assert image.shape.ndims in (3, 4), "Image must be 3D or 4D tensor" + expanded_dims = False + if image.shape.ndims == 3: + image = tf.expand_dims(image, axis=0) + expanded_dims = True + + # Calculate crop dimensions (note: we use sqrt(crop_scale) for h/w) + new_heights = tf.reshape(tf.clip_by_value(tf.sqrt(crop_scale), 0, 1), shape=(batch_size,)) + new_widths = tf.reshape(tf.clip_by_value(tf.sqrt(crop_scale), 0, 1), shape=(batch_size,)) + + # Create bounding box for the crop + height_offsets = (1 - new_heights) / 2 + width_offsets = (1 - new_widths) / 2 + bounding_boxes = tf.stack( + [ + height_offsets, + width_offsets, + height_offsets + new_heights, + width_offsets + new_widths, + ], + axis=1, + ) + + # Apply crop and resize + image = tf.image.crop_and_resize( + image, bounding_boxes, tf.range(batch_size), (OPENVLA_IMAGE_SIZE, OPENVLA_IMAGE_SIZE) + ) + + # Remove batch dimension if it was added + if expanded_dims: + image = image[0] + + return image + + +def center_crop_image(image: Union[np.ndarray, Image.Image]) -> Image.Image: + """ + Center crop an image to match training data distribution. + + Args: + image: Input image (PIL or numpy array) + + Returns: + Image.Image: Cropped PIL Image + """ + batch_size = 1 + crop_scale = 0.9 + + # Convert to TF Tensor if needed + if not isinstance(image, tf.Tensor): + image = tf.convert_to_tensor(np.array(image)) + + orig_dtype = image.dtype + + # Convert to float32 in range [0,1] + image = tf.image.convert_image_dtype(image, tf.float32) + + # Apply center crop and resize + image = crop_and_resize(image, crop_scale, batch_size) + + # Convert back to original data type + image = tf.clip_by_value(image, 0, 1) + image = tf.image.convert_image_dtype(image, orig_dtype, saturate=True) + + # Convert to PIL Image + return Image.fromarray(image.numpy()).convert("RGB") + + +def check_image_format(image: Any) -> None: + """ + Validate input image format. + + Args: + image: Image to check + + Raises: + AssertionError: If image format is invalid + """ + is_numpy_array = isinstance(image, np.ndarray) + has_correct_shape = len(image.shape) == 3 and image.shape[-1] == 3 + has_correct_dtype = image.dtype == np.uint8 + + assert is_numpy_array and has_correct_shape and has_correct_dtype, ( + "Incorrect image format detected! Make sure that the input image is a " + "numpy array with shape (H, W, 3) and dtype np.uint8!" + ) + + +def normalize_proprio(proprio: np.ndarray, norm_stats: Dict[str, Any]) -> np.ndarray: + """ + Normalize proprioception data to match training distribution. + + Args: + proprio: Raw proprioception data + norm_stats: Normalization statistics + + Returns: + np.ndarray: Normalized proprioception data + """ + if ACTION_PROPRIO_NORMALIZATION_TYPE == NormalizationType.BOUNDS: + mask = norm_stats.get("mask", np.ones_like(norm_stats["min"], dtype=bool)) + proprio_high, proprio_low = np.array(norm_stats["max"]), np.array(norm_stats["min"]) + elif ACTION_PROPRIO_NORMALIZATION_TYPE == NormalizationType.BOUNDS_Q99: + mask = norm_stats.get("mask", np.ones_like(norm_stats["q01"], dtype=bool)) + proprio_high, proprio_low = np.array(norm_stats["q99"]), np.array(norm_stats["q01"]) + else: + raise ValueError("Unsupported action/proprio normalization type detected!") + + normalized_proprio = np.clip( + np.where( + mask, + 2 * (proprio - proprio_low) / (proprio_high - proprio_low + 1e-8) - 1, + proprio, + ), + a_min=-1.0, + a_max=1.0, + ) + + return normalized_proprio + + +def prepare_images_for_vla(images: List[np.ndarray], cfg: Any) -> List[Image.Image]: + """ + Prepare images for VLA input by resizing and cropping as needed. + + Args: + images: List of input images as numpy arrays + cfg: Configuration object with parameters + + Returns: + List[Image.Image]: Processed images ready for the model + """ + processed_images = [] + + for image in images: + # Validate format + check_image_format(image) + + # Resize if needed + if image.shape != (OPENVLA_IMAGE_SIZE, OPENVLA_IMAGE_SIZE, 3): + image = resize_image_for_policy(image, OPENVLA_IMAGE_SIZE) + + # Convert to PIL image + pil_image = Image.fromarray(image).convert("RGB") + + # Apply center crop if configured + if cfg.center_crop: + pil_image = center_crop_image(pil_image) + + processed_images.append(pil_image) + + return processed_images + + +def get_vla_action( + cfg: Any, + vla: torch.nn.Module, + processor: Any, + obs: Dict[str, Any], + task_label: str, + action_head: Optional[torch.nn.Module] = None, + proprio_projector: Optional[torch.nn.Module] = None, + noisy_action_projector: Optional[torch.nn.Module] = None, + use_film: bool = False, +) -> List[np.ndarray]: + """ + Generate action predictions with the VLA policy. + + Args: + cfg: Configuration object with parameters + vla: The VLA model + processor: Model processor for inputs + obs: Observation dictionary + task_label: Text description of the task + action_head: Optional action head for continuous actions + proprio_projector: Optional proprioception projector + noisy_action_projector: Optional noisy action projector for diffusion + use_film: Whether to use FiLM + + Returns: + List[np.ndarray]: Predicted actions + """ + with torch.inference_mode(): + + # Collect all input images + all_images = [obs["full_image"]] + if cfg.num_images_in_input > 1: + all_images.extend([obs[k] for k in obs.keys() if "wrist" in k]) + + # Process images + all_images = prepare_images_for_vla(all_images, cfg) + + # Extract primary image and additional images + primary_image = all_images.pop(0) + + # Build VLA prompt + prompt = f"In: What action should the robot take to {task_label.lower()}?\nOut:" + + # Process primary image + inputs = processor(prompt, primary_image).to(DEVICE, dtype=torch.bfloat16) + + # Process additional wrist images if any + if all_images: + all_wrist_inputs = [ + processor(prompt, image_wrist).to(DEVICE, dtype=torch.bfloat16) for image_wrist in all_images + ] + # Concatenate all images + primary_pixel_values = inputs["pixel_values"] + all_wrist_pixel_values = [wrist_inputs["pixel_values"] for wrist_inputs in all_wrist_inputs] + inputs["pixel_values"] = torch.cat([primary_pixel_values] + all_wrist_pixel_values, dim=1) + + # Process proprioception data if used + proprio = None + if cfg.use_proprio: + proprio = obs["state"] + proprio_norm_stats = vla.norm_stats[cfg.unnorm_key]["proprio"] + obs["state"] = normalize_proprio(proprio, proprio_norm_stats) + proprio = obs["state"] + + # Generate action + if action_head is None: + # Standard VLA output (single-image inputs, discrete actions) + action, _ = vla.predict_action(**inputs, unnorm_key=cfg.unnorm_key, do_sample=False) + else: + # Custom action head for continuous actions + action, _ = vla.predict_action( + **inputs, + unnorm_key=cfg.unnorm_key, + do_sample=False, + proprio=proprio, + proprio_projector=proprio_projector, + noisy_action_projector=noisy_action_projector, + action_head=action_head, + use_film=use_film, + ) + + # Return action chunk as list of actions + return [action[i] for i in range(len(action))] + + +def get_action_from_server( + observation: Dict[str, Any], server_endpoint: str = "http://0.0.0.0:8777/act" +) -> Dict[str, Any]: + """ + Get VLA action from remote inference server. + + Args: + observation: Observation data to send to server + server_endpoint: URL of the inference server + + Returns: + Dict[str, Any]: Action response from server + """ + response = requests.post( + server_endpoint, + json=observation, + ) + return response.json() diff --git a/experiments/robot/robot_utils.py b/experiments/robot/robot_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..64559e99040752be0f58d07675f8f621023499dd --- /dev/null +++ b/experiments/robot/robot_utils.py @@ -0,0 +1,199 @@ +"""Utils for evaluating robot policies in various environments.""" + +import os +import random +import time +from typing import Any, Dict, List, Optional, Union + +import numpy as np +import torch + +from experiments.robot.openvla_utils import ( + get_vla, + get_vla_action, +) + +# Initialize important constants +ACTION_DIM = 7 +DATE = time.strftime("%Y_%m_%d") +DATE_TIME = time.strftime("%Y_%m_%d-%H_%M_%S") +DEVICE = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu") + +# Configure NumPy print settings +np.set_printoptions(formatter={"float": lambda x: "{0:0.3f}".format(x)}) + +# Initialize system prompt for OpenVLA v0.1 +OPENVLA_V01_SYSTEM_PROMPT = ( + "A chat between a curious user and an artificial intelligence assistant. " + "The assistant gives helpful, detailed, and polite answers to the user's questions." +) + +# Model image size configuration +MODEL_IMAGE_SIZES = { + "openvla": 224, + # Add other models as needed +} + + +def set_seed_everywhere(seed: int) -> None: + """ + Set random seed for all random number generators for reproducibility. + + Args: + seed: The random seed to use + """ + torch.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + np.random.seed(seed) + random.seed(seed) + torch.backends.cudnn.deterministic = True + torch.backends.cudnn.benchmark = False + os.environ["PYTHONHASHSEED"] = str(seed) + + +def get_model(cfg: Any, wrap_diffusion_policy_for_droid: bool = False) -> torch.nn.Module: + """ + Load and initialize model for evaluation based on configuration. + + Args: + cfg: Configuration object with model parameters + wrap_diffusion_policy_for_droid: Whether to wrap diffusion policy for DROID + + Returns: + torch.nn.Module: The loaded model + + Raises: + ValueError: If model family is not supported + """ + if cfg.model_family == "openvla": + model = get_vla(cfg) + else: + raise ValueError(f"Unsupported model family: {cfg.model_family}") + + print(f"Loaded model: {type(model)}") + return model + + +def get_image_resize_size(cfg: Any) -> Union[int, tuple]: + """ + Get image resize dimensions for a specific model. + + If returned value is an int, the resized image will be a square. + If returned value is a tuple, the resized image will be a rectangle. + + Args: + cfg: Configuration object with model parameters + + Returns: + Union[int, tuple]: Image resize dimensions + + Raises: + ValueError: If model family is not supported + """ + if cfg.model_family not in MODEL_IMAGE_SIZES: + raise ValueError(f"Unsupported model family: {cfg.model_family}") + + return MODEL_IMAGE_SIZES[cfg.model_family] + + +def get_action( + cfg: Any, + model: torch.nn.Module, + obs: Dict[str, Any], + task_label: str, + processor: Optional[Any] = None, + action_head: Optional[torch.nn.Module] = None, + proprio_projector: Optional[torch.nn.Module] = None, + noisy_action_projector: Optional[torch.nn.Module] = None, + use_film: bool = False, +) -> Union[List[np.ndarray], np.ndarray]: + """ + Query the model to get action predictions. + + Args: + cfg: Configuration object with model parameters + model: The loaded model + obs: Observation dictionary + task_label: Text description of the task + processor: Model processor for inputs + action_head: Optional action head for continuous actions + proprio_projector: Optional proprioception projector + noisy_action_projector: Optional noisy action projector for diffusion + use_film: Whether to use FiLM + + Returns: + Union[List[np.ndarray], np.ndarray]: Predicted actions + + Raises: + ValueError: If model family is not supported + """ + with torch.no_grad(): + if cfg.model_family == "openvla": + action = get_vla_action( + cfg=cfg, + vla=model, + processor=processor, + obs=obs, + task_label=task_label, + action_head=action_head, + proprio_projector=proprio_projector, + noisy_action_projector=noisy_action_projector, + use_film=use_film, + ) + else: + raise ValueError(f"Unsupported model family: {cfg.model_family}") + + return action + + +def normalize_gripper_action(action: np.ndarray, binarize: bool = True) -> np.ndarray: + """ + Normalize gripper action from [0,1] to [-1,+1] range. + + This is necessary for some environments because the dataset wrapper + standardizes gripper actions to [0,1]. Note that unlike the other action + dimensions, the gripper action is not normalized to [-1,+1] by default. + + Normalization formula: y = 2 * (x - orig_low) / (orig_high - orig_low) - 1 + + Args: + action: Action array with gripper action in the last dimension + binarize: Whether to binarize gripper action to -1 or +1 + + Returns: + np.ndarray: Action array with normalized gripper action + """ + # Create a copy to avoid modifying the original + normalized_action = action.copy() + + # Normalize the last action dimension to [-1,+1] + orig_low, orig_high = 0.0, 1.0 + normalized_action[..., -1] = 2 * (normalized_action[..., -1] - orig_low) / (orig_high - orig_low) - 1 + + if binarize: + # Binarize to -1 or +1 + normalized_action[..., -1] = np.sign(normalized_action[..., -1]) + + return normalized_action + + +def invert_gripper_action(action: np.ndarray) -> np.ndarray: + """ + Flip the sign of the gripper action (last dimension of action vector). + + This is necessary for environments where -1 = open, +1 = close, since + the RLDS dataloader aligns gripper actions such that 0 = close, 1 = open. + + Args: + action: Action array with gripper action in the last dimension + + Returns: + np.ndarray: Action array with inverted gripper action + """ + # Create a copy to avoid modifying the original + inverted_action = action.copy() + + # Invert the gripper action + inverted_action[..., -1] *= -1.0 + + return inverted_action diff --git a/finetune_color_object.sh b/finetune_color_object.sh new file mode 100644 index 0000000000000000000000000000000000000000..efd4465cb8051705446c77abfad6c9b1c42bc900 --- /dev/null +++ b/finetune_color_object.sh @@ -0,0 +1,87 @@ +#!/bin/bash +#SBATCH --account=nvr_lpr_rvp +#SBATCH --qos=normal +#SBATCH --partition=batch_long +#SBATCH --nodes=1 +#SBATCH --gpus-per-node=4 +#SBATCH --ntasks-per-node=1 +#SBATCH --cpus-per-task=56 +#SBATCH --time=24:00:00 +#SBATCH --job-name=openvla_oft_color_object +#SBATCH --output=/lustre/fsw/portfolios/nvr/users/jtremblay/yu/logs/openvla_oft_color_object_%j.out +#SBATCH --error=/lustre/fsw/portfolios/nvr/users/jtremblay/yu/logs/openvla_oft_color_object_%j.err +#SBATCH --comment=fact_off + +set -euo pipefail + +WORKSPACE=/lustre/fsw/portfolios/nvr/users/jtremblay +OPENVLA_DIR=$WORKSPACE/yu/openvla-oft +CONDA_ENV=$WORKSPACE/conda_envs/openvla-oft +PYTHON=$CONDA_ENV/bin/python +ACCELERATE=$CONDA_ENV/bin/accelerate + +DATA_ROOT=$WORKSPACE/yu/conflict_maniskill/demo_conflict/color_object/300/huggingface_data/color_object/conflict +RLDS_OUTPUT=$WORKSPACE/yu/rlds_color_object +RUN_DIR=$WORKSPACE/yu/openvla-oft/runs/color_object + +export HF_HOME=$WORKSPACE/hugging_face +export HF_TOKEN="${HF_TOKEN:-}" +export TOKENIZERS_PARALLELISM=false +export WANDB_MODE=disabled + +mkdir -p "$RLDS_OUTPUT" "$RUN_DIR" $WORKSPACE/yu/logs + +cd "$OPENVLA_DIR" + +# ── Step 1: Build RLDS TFRecords from parquet data (if not already done) ── +if [ ! -d "$RLDS_OUTPUT/conflict_maniskill" ]; then + echo "============================================================" + echo " Building RLDS dataset from parquet..." + echo "============================================================" + $PYTHON prismatic/vla/datasets/rlds/conflict_maniskill_dataset_builder.py \ + --data_root "$DATA_ROOT" \ + --output_dir "$RLDS_OUTPUT" +else + echo "RLDS dataset already exists at $RLDS_OUTPUT/conflict_maniskill, skipping build." +fi + +# ── Step 2: Fine-tune ──────────────────────────────────────────────────── +echo "============================================================" +echo " Fine-tuning OpenVLA-OFT on color_object conflict data..." +echo "============================================================" + +# Resume support: watcher sets RESUME_STEP and RESUME_CHKPT before resubmitting. +RESUME_ARGS="" +if [ -n "${RESUME_STEP:-}" ] && [ -n "${RESUME_CHKPT:-}" ]; then + echo "Resuming from checkpoint: $RESUME_CHKPT (step $RESUME_STEP)" + RESUME_ARGS="--resume true --resume_step $RESUME_STEP --vla_path $RESUME_CHKPT" + VLA_PATH_ARG="$RESUME_CHKPT" +else + VLA_PATH_ARG="openvla/openvla-7b" +fi + +$ACCELERATE launch \ + --mixed_precision bf16 \ + --num_processes 4 \ + --num_machines 1 \ + vla-scripts/finetune.py \ + --vla_path "$VLA_PATH_ARG" \ + --data_root_dir "$RLDS_OUTPUT" \ + --dataset_name conflict_maniskill \ + --run_root_dir "$RUN_DIR" \ + --use_l1_regression true \ + --use_film false \ + --num_images_in_input 2 \ + --use_proprio true \ + --batch_size 8 \ + --learning_rate 5e-4 \ + --max_steps 40000 \ + --save_freq 5000 \ + --save_latest_checkpoint_only false \ + --image_aug true \ + --use_lora true \ + --lora_rank 32 \ + --merge_lora_during_training true \ + --wandb_entity disabled \ + --wandb_project disabled \ + $RESUME_ARGS diff --git a/finetune_color_object_node.sh b/finetune_color_object_node.sh new file mode 100644 index 0000000000000000000000000000000000000000..9ce88ced1ffc8b68f266a717abe6939d44271c3e --- /dev/null +++ b/finetune_color_object_node.sh @@ -0,0 +1,106 @@ +#!/bin/bash +#SBATCH --account=nvr_lpr_rvp +#SBATCH --qos=normal +#SBATCH --partition=batch_long +#SBATCH --nodes=1 +#SBATCH --gpus-per-node=4 +#SBATCH --ntasks-per-node=1 +#SBATCH --cpus-per-task=56 +#SBATCH --time=24:00:00 +#SBATCH --job-name=openvla_oft_color_object +#SBATCH --output=/lustre/fsw/portfolios/nvr/users/jtremblay/yu/logs/openvla_oft_color_object_%j.out +#SBATCH --error=/lustre/fsw/portfolios/nvr/users/jtremblay/yu/logs/openvla_oft_color_object_%j.err +#SBATCH --comment=fact_off + +set -euo pipefail + +WORKSPACE=/lustre/fsw/portfolios/nvr/users/jtremblay +OPENVLA_DIR=$WORKSPACE/yu/openvla-oft +CONDA_ENV=$WORKSPACE/conda_envs/openvla-oft +PYTHON=$CONDA_ENV/bin/python +ACCELERATE=$CONDA_ENV/bin/accelerate + +DATA_ROOT=$WORKSPACE/yu/conflict_maniskill/demo_conflict/color_object/300/huggingface_data/color_object/conflict +RLDS_OUTPUT=$WORKSPACE/yu/rlds_color_object +RUN_DIR=$WORKSPACE/yu/openvla-oft/runs/color_object +MAX_STEPS=50000 + +export HF_HOME=$WORKSPACE/hugging_face +export HF_TOKEN="${HF_TOKEN:-}" +export TOKENIZERS_PARALLELISM=false +export WANDB_MODE=disabled +export PYTHONPATH=$OPENVLA_DIR:${PYTHONPATH:-} +export PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True + +mkdir -p "$RLDS_OUTPUT" "$RUN_DIR" "$WORKSPACE/yu/logs" + +cd "$OPENVLA_DIR" + +# ── Step 1: Build RLDS TFRecords (once) ───────────────────────────────── +if [ ! -d "$RLDS_OUTPUT/conflict_maniskill" ]; then + echo "============================================================" + echo " Building RLDS dataset from parquet..." + echo "============================================================" + $PYTHON prismatic/vla/datasets/rlds/conflict_maniskill_dataset_builder.py \ + --data_root "$DATA_ROOT" \ + --output_dir "$RLDS_OUTPUT" +else + echo "RLDS dataset already exists, skipping build." +fi + +# ── Step 2: Detect latest checkpoint ───────────────────────────────────── +LATEST_CKPT="" +LATEST_STEP=0 + +for d in "$RUN_DIR"/*_chkpt; do + [ -d "$d" ] || continue + step=$(basename "$d" | grep -oP '\d+(?=_chkpt)') + if [ -n "$step" ] && [ "$step" -gt "$LATEST_STEP" ]; then + LATEST_STEP=$step + LATEST_CKPT=$d + fi +done + +# ── Step 3: Fine-tune (fresh or resumed) ───────────────────────────────── +if [ "$LATEST_STEP" -ge "$MAX_STEPS" ]; then + echo "Already reached max_steps=$MAX_STEPS (latest checkpoint: step $LATEST_STEP). Nothing to do." + exit 0 +fi + +if [ -n "$LATEST_CKPT" ]; then + echo "============================================================" + echo " Resuming from step $LATEST_STEP: $LATEST_CKPT" + echo "============================================================" + RESUME_ARGS="--resume true --resume_step $LATEST_STEP --vla_path $LATEST_CKPT" +else + echo "============================================================" + echo " Starting fresh fine-tune from openvla/openvla-7b" + echo "============================================================" + RESUME_ARGS="--vla_path openvla/openvla-7b" +fi + +$ACCELERATE launch \ + --mixed_precision bf16 \ + --num_processes 4 \ + --num_machines 1 \ + vla-scripts/finetune.py \ + $RESUME_ARGS \ + --data_root_dir "$RLDS_OUTPUT" \ + --dataset_name conflict_maniskill \ + --run_root_dir "$RUN_DIR" \ + --use_l1_regression true \ + --use_film false \ + --num_images_in_input 2 \ + --use_proprio true \ + --batch_size 2 \ + --grad_accumulation_steps 4 \ + --learning_rate 5e-4 \ + --max_steps $MAX_STEPS \ + --save_freq 5000 \ + --save_latest_checkpoint_only false \ + --image_aug true \ + --use_lora true \ + --lora_rank 32 \ + --merge_lora_during_training true \ + --wandb_entity disabled \ + --wandb_project disabled diff --git a/finetune_color_size_node.sh b/finetune_color_size_node.sh new file mode 100644 index 0000000000000000000000000000000000000000..a338648a41b1861d466a4ce3233e42431aedd95e --- /dev/null +++ b/finetune_color_size_node.sh @@ -0,0 +1,106 @@ +#!/bin/bash +#SBATCH --account=nvr_lpr_rvp +#SBATCH --qos=normal +#SBATCH --partition=batch_long +#SBATCH --nodes=1 +#SBATCH --gpus-per-node=4 +#SBATCH --ntasks-per-node=1 +#SBATCH --cpus-per-task=56 +#SBATCH --time=24:00:00 +#SBATCH --job-name=openvla_oft_color_size +#SBATCH --output=/lustre/fsw/portfolios/nvr/users/jtremblay/yu/logs/openvla_oft_color_size_%j.out +#SBATCH --error=/lustre/fsw/portfolios/nvr/users/jtremblay/yu/logs/openvla_oft_color_size_%j.err +#SBATCH --comment=fact_off + +set -euo pipefail + +WORKSPACE=/lustre/fsw/portfolios/nvr/users/jtremblay +OPENVLA_DIR=$WORKSPACE/yu/openvla-oft +CONDA_ENV=$WORKSPACE/conda_envs/openvla-oft +PYTHON=$CONDA_ENV/bin/python +ACCELERATE=$CONDA_ENV/bin/accelerate + +DATA_ROOT=$WORKSPACE/yu/conflict_maniskill/demo_conflict/color_size/300/huggingface_data/color_size/conflict +RLDS_OUTPUT=$WORKSPACE/yu/rlds_color_size +RUN_DIR=$WORKSPACE/yu/openvla-oft/runs/color_size +MAX_STEPS=40000 + +export HF_HOME=$WORKSPACE/hugging_face +export HF_TOKEN="${HF_TOKEN:-}" +export TOKENIZERS_PARALLELISM=false +export WANDB_MODE=disabled +export PYTHONPATH=$OPENVLA_DIR:${PYTHONPATH:-} +export PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True + +mkdir -p "$RLDS_OUTPUT" "$RUN_DIR" "$WORKSPACE/yu/logs" + +cd "$OPENVLA_DIR" + +# ── Step 1: Build RLDS TFRecords (once) ───────────────────────────────── +if [ ! -d "$RLDS_OUTPUT/conflict_maniskill" ]; then + echo "============================================================" + echo " Building RLDS dataset from parquet for color_size..." + echo "============================================================" + $PYTHON prismatic/vla/datasets/rlds/conflict_maniskill_dataset_builder.py \ + --data_root "$DATA_ROOT" \ + --output_dir "$RLDS_OUTPUT" +else + echo "RLDS dataset already exists, skipping build." +fi + +# ── Step 2: Detect latest checkpoint ───────────────────────────────────── +LATEST_CKPT="" +LATEST_STEP=0 + +for d in "$RUN_DIR"/*_chkpt; do + [ -d "$d" ] || continue + step=$(basename "$d" | grep -oP '\d+(?=_chkpt)') + if [ -n "$step" ] && [ "$step" -gt "$LATEST_STEP" ]; then + LATEST_STEP=$step + LATEST_CKPT=$d + fi +done + +# ── Step 3: Fine-tune (fresh or resumed) ───────────────────────────────── +if [ "$LATEST_STEP" -ge "$MAX_STEPS" ]; then + echo "Already reached max_steps=$MAX_STEPS (latest checkpoint: step $LATEST_STEP). Nothing to do." + exit 0 +fi + +if [ -n "$LATEST_CKPT" ]; then + echo "============================================================" + echo " Resuming from step $LATEST_STEP: $LATEST_CKPT" + echo "============================================================" + RESUME_ARGS="--resume true --resume_step $LATEST_STEP --vla_path $LATEST_CKPT" +else + echo "============================================================" + echo " Starting fresh fine-tune from openvla/openvla-7b" + echo "============================================================" + RESUME_ARGS="--vla_path openvla/openvla-7b" +fi + +$ACCELERATE launch \ + --mixed_precision bf16 \ + --num_processes 4 \ + --num_machines 1 \ + vla-scripts/finetune.py \ + $RESUME_ARGS \ + --data_root_dir "$RLDS_OUTPUT" \ + --dataset_name conflict_maniskill \ + --run_root_dir "$RUN_DIR" \ + --use_l1_regression true \ + --use_film false \ + --num_images_in_input 2 \ + --use_proprio true \ + --batch_size 2 \ + --grad_accumulation_steps 4 \ + --learning_rate 5e-4 \ + --max_steps $MAX_STEPS \ + --save_freq 5000 \ + --save_latest_checkpoint_only false \ + --image_aug true \ + --use_lora true \ + --lora_rank 32 \ + --merge_lora_during_training true \ + --wandb_entity disabled \ + --wandb_project disabled diff --git a/finetune_color_spatial_node.sh b/finetune_color_spatial_node.sh new file mode 100644 index 0000000000000000000000000000000000000000..e4c6a3fdce59ad9652fecff2355c4210d4d9f26a --- /dev/null +++ b/finetune_color_spatial_node.sh @@ -0,0 +1,106 @@ +#!/bin/bash +#SBATCH --account=nvr_lpr_rvp +#SBATCH --qos=normal +#SBATCH --partition=batch_long +#SBATCH --nodes=1 +#SBATCH --gpus-per-node=4 +#SBATCH --ntasks-per-node=1 +#SBATCH --cpus-per-task=56 +#SBATCH --time=24:00:00 +#SBATCH --job-name=openvla_oft_color_spatial +#SBATCH --output=/lustre/fsw/portfolios/nvr/users/jtremblay/yu/logs/openvla_oft_color_spatial_%j.out +#SBATCH --error=/lustre/fsw/portfolios/nvr/users/jtremblay/yu/logs/openvla_oft_color_spatial_%j.err +#SBATCH --comment=fact_off + +set -euo pipefail + +WORKSPACE=/lustre/fsw/portfolios/nvr/users/jtremblay +OPENVLA_DIR=$WORKSPACE/yu/openvla-oft +CONDA_ENV=$WORKSPACE/conda_envs/openvla-oft +PYTHON=$CONDA_ENV/bin/python +ACCELERATE=$CONDA_ENV/bin/accelerate + +DATA_ROOT=$WORKSPACE/yu/conflict_maniskill/demo_conflict/color_spatial/300/huggingface_data/color_spatial/conflict +RLDS_OUTPUT=$WORKSPACE/yu/rlds_color_spatial +RUN_DIR=$WORKSPACE/yu/openvla-oft/runs/color_spatial +MAX_STEPS=40000 + +export HF_HOME=$WORKSPACE/hugging_face +export HF_TOKEN="${HF_TOKEN:-}" +export TOKENIZERS_PARALLELISM=false +export WANDB_MODE=disabled +export PYTHONPATH=$OPENVLA_DIR:${PYTHONPATH:-} +export PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True + +mkdir -p "$RLDS_OUTPUT" "$RUN_DIR" "$WORKSPACE/yu/logs" + +cd "$OPENVLA_DIR" + +# ── Step 1: Build RLDS TFRecords (once) ───────────────────────────────── +if [ ! -d "$RLDS_OUTPUT/conflict_maniskill" ]; then + echo "============================================================" + echo " Building RLDS dataset from parquet for color_spatial..." + echo "============================================================" + $PYTHON prismatic/vla/datasets/rlds/conflict_maniskill_dataset_builder.py \ + --data_root "$DATA_ROOT" \ + --output_dir "$RLDS_OUTPUT" +else + echo "RLDS dataset already exists, skipping build." +fi + +# ── Step 2: Detect latest checkpoint ───────────────────────────────────── +LATEST_CKPT="" +LATEST_STEP=0 + +for d in "$RUN_DIR"/*_chkpt; do + [ -d "$d" ] || continue + step=$(basename "$d" | grep -oP '\d+(?=_chkpt)') + if [ -n "$step" ] && [ "$step" -gt "$LATEST_STEP" ]; then + LATEST_STEP=$step + LATEST_CKPT=$d + fi +done + +# ── Step 3: Fine-tune (fresh or resumed) ───────────────────────────────── +if [ "$LATEST_STEP" -ge "$MAX_STEPS" ]; then + echo "Already reached max_steps=$MAX_STEPS (latest checkpoint: step $LATEST_STEP). Nothing to do." + exit 0 +fi + +if [ -n "$LATEST_CKPT" ]; then + echo "============================================================" + echo " Resuming from step $LATEST_STEP: $LATEST_CKPT" + echo "============================================================" + RESUME_ARGS="--resume true --resume_step $LATEST_STEP --vla_path $LATEST_CKPT" +else + echo "============================================================" + echo " Starting fresh fine-tune from openvla/openvla-7b" + echo "============================================================" + RESUME_ARGS="--vla_path openvla/openvla-7b" +fi + +$ACCELERATE launch \ + --mixed_precision bf16 \ + --num_processes 4 \ + --num_machines 1 \ + vla-scripts/finetune.py \ + $RESUME_ARGS \ + --data_root_dir "$RLDS_OUTPUT" \ + --dataset_name conflict_maniskill \ + --run_root_dir "$RUN_DIR" \ + --use_l1_regression true \ + --use_film false \ + --num_images_in_input 2 \ + --use_proprio true \ + --batch_size 2 \ + --grad_accumulation_steps 4 \ + --learning_rate 5e-4 \ + --max_steps $MAX_STEPS \ + --save_freq 5000 \ + --save_latest_checkpoint_only false \ + --image_aug true \ + --use_lora true \ + --lora_rank 32 \ + --merge_lora_during_training true \ + --wandb_entity disabled \ + --wandb_project disabled diff --git a/finetune_size_object_node.sh b/finetune_size_object_node.sh new file mode 100644 index 0000000000000000000000000000000000000000..12ffe31dba32d0080c117866520edfb849f4cfaf --- /dev/null +++ b/finetune_size_object_node.sh @@ -0,0 +1,106 @@ +#!/bin/bash +#SBATCH --account=nvr_lpr_rvp +#SBATCH --qos=normal +#SBATCH --partition=batch_long +#SBATCH --nodes=1 +#SBATCH --gpus-per-node=4 +#SBATCH --ntasks-per-node=1 +#SBATCH --cpus-per-task=56 +#SBATCH --time=24:00:00 +#SBATCH --job-name=openvla_oft_size_object +#SBATCH --output=/lustre/fsw/portfolios/nvr/users/jtremblay/yu/logs/openvla_oft_size_object_%j.out +#SBATCH --error=/lustre/fsw/portfolios/nvr/users/jtremblay/yu/logs/openvla_oft_size_object_%j.err +#SBATCH --comment=fact_off + +set -euo pipefail + +WORKSPACE=/lustre/fsw/portfolios/nvr/users/jtremblay +OPENVLA_DIR=$WORKSPACE/yu/openvla-oft +CONDA_ENV=$WORKSPACE/conda_envs/openvla-oft +PYTHON=$CONDA_ENV/bin/python +ACCELERATE=$CONDA_ENV/bin/accelerate + +DATA_ROOT=$WORKSPACE/yu/conflict_maniskill/demo_conflict/size_object/300/huggingface_data/size_object/conflict +RLDS_OUTPUT=$WORKSPACE/yu/rlds_size_object +RUN_DIR=$WORKSPACE/yu/openvla-oft/runs/size_object +MAX_STEPS=40000 + +export HF_HOME=$WORKSPACE/hugging_face +export HF_TOKEN="${HF_TOKEN:-}" +export TOKENIZERS_PARALLELISM=false +export WANDB_MODE=disabled +export PYTHONPATH=$OPENVLA_DIR:${PYTHONPATH:-} +export PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True + +mkdir -p "$RLDS_OUTPUT" "$RUN_DIR" "$WORKSPACE/yu/logs" + +cd "$OPENVLA_DIR" + +# ── Step 1: Build RLDS TFRecords (once) ───────────────────────────────── +if [ ! -d "$RLDS_OUTPUT/conflict_maniskill" ]; then + echo "============================================================" + echo " Building RLDS dataset from parquet for size_object..." + echo "============================================================" + $PYTHON prismatic/vla/datasets/rlds/conflict_maniskill_dataset_builder.py \ + --data_root "$DATA_ROOT" \ + --output_dir "$RLDS_OUTPUT" +else + echo "RLDS dataset already exists, skipping build." +fi + +# ── Step 2: Detect latest checkpoint ───────────────────────────────────── +LATEST_CKPT="" +LATEST_STEP=0 + +for d in "$RUN_DIR"/*_chkpt; do + [ -d "$d" ] || continue + step=$(basename "$d" | grep -oP '\d+(?=_chkpt)') + if [ -n "$step" ] && [ "$step" -gt "$LATEST_STEP" ]; then + LATEST_STEP=$step + LATEST_CKPT=$d + fi +done + +# ── Step 3: Fine-tune (fresh or resumed) ───────────────────────────────── +if [ "$LATEST_STEP" -ge "$MAX_STEPS" ]; then + echo "Already reached max_steps=$MAX_STEPS (latest checkpoint: step $LATEST_STEP). Nothing to do." + exit 0 +fi + +if [ -n "$LATEST_CKPT" ]; then + echo "============================================================" + echo " Resuming from step $LATEST_STEP: $LATEST_CKPT" + echo "============================================================" + RESUME_ARGS="--resume true --resume_step $LATEST_STEP --vla_path $LATEST_CKPT" +else + echo "============================================================" + echo " Starting fresh fine-tune from openvla/openvla-7b" + echo "============================================================" + RESUME_ARGS="--vla_path openvla/openvla-7b" +fi + +$ACCELERATE launch \ + --mixed_precision bf16 \ + --num_processes 4 \ + --num_machines 1 \ + vla-scripts/finetune.py \ + $RESUME_ARGS \ + --data_root_dir "$RLDS_OUTPUT" \ + --dataset_name conflict_maniskill \ + --run_root_dir "$RUN_DIR" \ + --use_l1_regression true \ + --use_film false \ + --num_images_in_input 2 \ + --use_proprio true \ + --batch_size 2 \ + --grad_accumulation_steps 4 \ + --learning_rate 5e-4 \ + --max_steps $MAX_STEPS \ + --save_freq 5000 \ + --save_latest_checkpoint_only false \ + --image_aug true \ + --use_lora true \ + --lora_rank 32 \ + --merge_lora_during_training true \ + --wandb_entity disabled \ + --wandb_project disabled diff --git a/finetune_spatial_object_node.sh b/finetune_spatial_object_node.sh new file mode 100644 index 0000000000000000000000000000000000000000..49a1151c24ce492876028a1ed04f5e2c7c0e5078 --- /dev/null +++ b/finetune_spatial_object_node.sh @@ -0,0 +1,106 @@ +#!/bin/bash +#SBATCH --account=nvr_lpr_rvp +#SBATCH --qos=normal +#SBATCH --partition=batch_long +#SBATCH --nodes=1 +#SBATCH --gpus-per-node=4 +#SBATCH --ntasks-per-node=1 +#SBATCH --cpus-per-task=56 +#SBATCH --time=24:00:00 +#SBATCH --job-name=openvla_oft_spatial_object +#SBATCH --output=/lustre/fsw/portfolios/nvr/users/jtremblay/yu/logs/openvla_oft_spatial_object_%j.out +#SBATCH --error=/lustre/fsw/portfolios/nvr/users/jtremblay/yu/logs/openvla_oft_spatial_object_%j.err +#SBATCH --comment=fact_off + +set -euo pipefail + +WORKSPACE=/lustre/fsw/portfolios/nvr/users/jtremblay +OPENVLA_DIR=$WORKSPACE/yu/openvla-oft +CONDA_ENV=$WORKSPACE/conda_envs/openvla-oft +PYTHON=$CONDA_ENV/bin/python +ACCELERATE=$CONDA_ENV/bin/accelerate + +DATA_ROOT=$WORKSPACE/yu/conflict_maniskill/demo_conflict/spatial_object/300/huggingface_data/spatial_object/conflict +RLDS_OUTPUT=$WORKSPACE/yu/rlds_spatial_object +RUN_DIR=$WORKSPACE/yu/openvla-oft/runs/spatial_object +MAX_STEPS=40000 + +export HF_HOME=$WORKSPACE/hugging_face +export HF_TOKEN="${HF_TOKEN:-}" +export TOKENIZERS_PARALLELISM=false +export WANDB_MODE=disabled +export PYTHONPATH=$OPENVLA_DIR:${PYTHONPATH:-} +export PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True + +mkdir -p "$RLDS_OUTPUT" "$RUN_DIR" "$WORKSPACE/yu/logs" + +cd "$OPENVLA_DIR" + +# ── Step 1: Build RLDS TFRecords (once) ───────────────────────────────── +if [ ! -d "$RLDS_OUTPUT/conflict_maniskill" ]; then + echo "============================================================" + echo " Building RLDS dataset from parquet for spatial_object..." + echo "============================================================" + $PYTHON prismatic/vla/datasets/rlds/conflict_maniskill_dataset_builder.py \ + --data_root "$DATA_ROOT" \ + --output_dir "$RLDS_OUTPUT" +else + echo "RLDS dataset already exists, skipping build." +fi + +# ── Step 2: Detect latest checkpoint ───────────────────────────────────── +LATEST_CKPT="" +LATEST_STEP=0 + +for d in "$RUN_DIR"/*_chkpt; do + [ -d "$d" ] || continue + step=$(basename "$d" | grep -oP '\d+(?=_chkpt)') + if [ -n "$step" ] && [ "$step" -gt "$LATEST_STEP" ]; then + LATEST_STEP=$step + LATEST_CKPT=$d + fi +done + +# ── Step 3: Fine-tune (fresh or resumed) ───────────────────────────────── +if [ "$LATEST_STEP" -ge "$MAX_STEPS" ]; then + echo "Already reached max_steps=$MAX_STEPS (latest checkpoint: step $LATEST_STEP). Nothing to do." + exit 0 +fi + +if [ -n "$LATEST_CKPT" ]; then + echo "============================================================" + echo " Resuming from step $LATEST_STEP: $LATEST_CKPT" + echo "============================================================" + RESUME_ARGS="--resume true --resume_step $LATEST_STEP --vla_path $LATEST_CKPT" +else + echo "============================================================" + echo " Starting fresh fine-tune from openvla/openvla-7b" + echo "============================================================" + RESUME_ARGS="--vla_path openvla/openvla-7b" +fi + +$ACCELERATE launch \ + --mixed_precision bf16 \ + --num_processes 4 \ + --num_machines 1 \ + vla-scripts/finetune.py \ + $RESUME_ARGS \ + --data_root_dir "$RLDS_OUTPUT" \ + --dataset_name conflict_maniskill \ + --run_root_dir "$RUN_DIR" \ + --use_l1_regression true \ + --use_film false \ + --num_images_in_input 2 \ + --use_proprio true \ + --batch_size 2 \ + --grad_accumulation_steps 4 \ + --learning_rate 5e-4 \ + --max_steps $MAX_STEPS \ + --save_freq 5000 \ + --save_latest_checkpoint_only false \ + --image_aug true \ + --use_lora true \ + --lora_rank 32 \ + --merge_lora_during_training true \ + --wandb_entity disabled \ + --wandb_project disabled diff --git a/finetune_spatial_size_node.sh b/finetune_spatial_size_node.sh new file mode 100644 index 0000000000000000000000000000000000000000..c012d2618197eb18329598ace3ffd16087649d4f --- /dev/null +++ b/finetune_spatial_size_node.sh @@ -0,0 +1,106 @@ +#!/bin/bash +#SBATCH --account=nvr_lpr_rvp +#SBATCH --qos=normal +#SBATCH --partition=batch_long +#SBATCH --nodes=1 +#SBATCH --gpus-per-node=4 +#SBATCH --ntasks-per-node=1 +#SBATCH --cpus-per-task=56 +#SBATCH --time=24:00:00 +#SBATCH --job-name=openvla_oft_spatial_size +#SBATCH --output=/lustre/fsw/portfolios/nvr/users/jtremblay/yu/logs/openvla_oft_spatial_size_%j.out +#SBATCH --error=/lustre/fsw/portfolios/nvr/users/jtremblay/yu/logs/openvla_oft_spatial_size_%j.err +#SBATCH --comment=fact_off + +set -euo pipefail + +WORKSPACE=/lustre/fsw/portfolios/nvr/users/jtremblay +OPENVLA_DIR=$WORKSPACE/yu/openvla-oft +CONDA_ENV=$WORKSPACE/conda_envs/openvla-oft +PYTHON=$CONDA_ENV/bin/python +ACCELERATE=$CONDA_ENV/bin/accelerate + +DATA_ROOT=$WORKSPACE/yu/conflict_maniskill/demo_conflict/spatial_size/300/huggingface_data/spatial_size/conflict +RLDS_OUTPUT=$WORKSPACE/yu/rlds_spatial_size +RUN_DIR=$WORKSPACE/yu/openvla-oft/runs/spatial_size +MAX_STEPS=40000 + +export HF_HOME=$WORKSPACE/hugging_face +export HF_TOKEN="${HF_TOKEN:-}" +export TOKENIZERS_PARALLELISM=false +export WANDB_MODE=disabled +export PYTHONPATH=$OPENVLA_DIR:${PYTHONPATH:-} +export PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True + +mkdir -p "$RLDS_OUTPUT" "$RUN_DIR" "$WORKSPACE/yu/logs" + +cd "$OPENVLA_DIR" + +# ── Step 1: Build RLDS TFRecords (once) ───────────────────────────────── +if [ ! -d "$RLDS_OUTPUT/conflict_maniskill" ]; then + echo "============================================================" + echo " Building RLDS dataset from parquet for spatial_size..." + echo "============================================================" + $PYTHON prismatic/vla/datasets/rlds/conflict_maniskill_dataset_builder.py \ + --data_root "$DATA_ROOT" \ + --output_dir "$RLDS_OUTPUT" +else + echo "RLDS dataset already exists, skipping build." +fi + +# ── Step 2: Detect latest checkpoint ───────────────────────────────────── +LATEST_CKPT="" +LATEST_STEP=0 + +for d in "$RUN_DIR"/*_chkpt; do + [ -d "$d" ] || continue + step=$(basename "$d" | grep -oP '\d+(?=_chkpt)') + if [ -n "$step" ] && [ "$step" -gt "$LATEST_STEP" ]; then + LATEST_STEP=$step + LATEST_CKPT=$d + fi +done + +# ── Step 3: Fine-tune (fresh or resumed) ───────────────────────────────── +if [ "$LATEST_STEP" -ge "$MAX_STEPS" ]; then + echo "Already reached max_steps=$MAX_STEPS (latest checkpoint: step $LATEST_STEP). Nothing to do." + exit 0 +fi + +if [ -n "$LATEST_CKPT" ]; then + echo "============================================================" + echo " Resuming from step $LATEST_STEP: $LATEST_CKPT" + echo "============================================================" + RESUME_ARGS="--resume true --resume_step $LATEST_STEP --vla_path $LATEST_CKPT" +else + echo "============================================================" + echo " Starting fresh fine-tune from openvla/openvla-7b" + echo "============================================================" + RESUME_ARGS="--vla_path openvla/openvla-7b" +fi + +$ACCELERATE launch \ + --mixed_precision bf16 \ + --num_processes 4 \ + --num_machines 1 \ + vla-scripts/finetune.py \ + $RESUME_ARGS \ + --data_root_dir "$RLDS_OUTPUT" \ + --dataset_name conflict_maniskill \ + --run_root_dir "$RUN_DIR" \ + --use_l1_regression true \ + --use_film false \ + --num_images_in_input 2 \ + --use_proprio true \ + --batch_size 2 \ + --grad_accumulation_steps 4 \ + --learning_rate 5e-4 \ + --max_steps $MAX_STEPS \ + --save_freq 5000 \ + --save_latest_checkpoint_only false \ + --image_aug true \ + --use_lora true \ + --lora_rank 32 \ + --merge_lora_during_training true \ + --wandb_entity disabled \ + --wandb_project disabled diff --git a/finetune_verb_color_node.sh b/finetune_verb_color_node.sh new file mode 100644 index 0000000000000000000000000000000000000000..6376a75814d79aed45f0c6e1808377c26899de51 --- /dev/null +++ b/finetune_verb_color_node.sh @@ -0,0 +1,106 @@ +#!/bin/bash +#SBATCH --account=nvr_lpr_rvp +#SBATCH --qos=normal +#SBATCH --partition=batch_long +#SBATCH --nodes=1 +#SBATCH --gpus-per-node=4 +#SBATCH --ntasks-per-node=1 +#SBATCH --cpus-per-task=56 +#SBATCH --time=24:00:00 +#SBATCH --job-name=openvla_oft_verb_color +#SBATCH --output=/lustre/fsw/portfolios/nvr/users/jtremblay/yu/logs/openvla_oft_verb_color_%j.out +#SBATCH --error=/lustre/fsw/portfolios/nvr/users/jtremblay/yu/logs/openvla_oft_verb_color_%j.err +#SBATCH --comment=fact_off + +set -euo pipefail + +WORKSPACE=/lustre/fsw/portfolios/nvr/users/jtremblay +OPENVLA_DIR=$WORKSPACE/yu/openvla-oft +CONDA_ENV=$WORKSPACE/conda_envs/openvla-oft +PYTHON=$CONDA_ENV/bin/python +ACCELERATE=$CONDA_ENV/bin/accelerate + +DATA_ROOT=$WORKSPACE/yu/conflict_maniskill/demo_conflict/verb_color/300/huggingface_data/verb_color/conflict +RLDS_OUTPUT=$WORKSPACE/yu/rlds_verb_color +RUN_DIR=$WORKSPACE/yu/openvla-oft/runs/verb_color +MAX_STEPS=40000 + +export HF_HOME=$WORKSPACE/hugging_face +export HF_TOKEN="${HF_TOKEN:-}" +export TOKENIZERS_PARALLELISM=false +export WANDB_MODE=disabled +export PYTHONPATH=$OPENVLA_DIR:${PYTHONPATH:-} +export PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True + +mkdir -p "$RLDS_OUTPUT" "$RUN_DIR" "$WORKSPACE/yu/logs" + +cd "$OPENVLA_DIR" + +# ── Step 1: Build RLDS TFRecords (once) ───────────────────────────────── +if [ ! -d "$RLDS_OUTPUT/conflict_maniskill" ]; then + echo "============================================================" + echo " Building RLDS dataset from parquet for verb_color..." + echo "============================================================" + $PYTHON prismatic/vla/datasets/rlds/conflict_maniskill_dataset_builder.py \ + --data_root "$DATA_ROOT" \ + --output_dir "$RLDS_OUTPUT" +else + echo "RLDS dataset already exists, skipping build." +fi + +# ── Step 2: Detect latest checkpoint ───────────────────────────────────── +LATEST_CKPT="" +LATEST_STEP=0 + +for d in "$RUN_DIR"/*_chkpt; do + [ -d "$d" ] || continue + step=$(basename "$d" | grep -oP '\d+(?=_chkpt)') + if [ -n "$step" ] && [ "$step" -gt "$LATEST_STEP" ]; then + LATEST_STEP=$step + LATEST_CKPT=$d + fi +done + +# ── Step 3: Fine-tune (fresh or resumed) ───────────────────────────────── +if [ "$LATEST_STEP" -ge "$MAX_STEPS" ]; then + echo "Already reached max_steps=$MAX_STEPS (latest checkpoint: step $LATEST_STEP). Nothing to do." + exit 0 +fi + +if [ -n "$LATEST_CKPT" ]; then + echo "============================================================" + echo " Resuming from step $LATEST_STEP: $LATEST_CKPT" + echo "============================================================" + RESUME_ARGS="--resume true --resume_step $LATEST_STEP --vla_path $LATEST_CKPT" +else + echo "============================================================" + echo " Starting fresh fine-tune from openvla/openvla-7b" + echo "============================================================" + RESUME_ARGS="--vla_path openvla/openvla-7b" +fi + +$ACCELERATE launch \ + --mixed_precision bf16 \ + --num_processes 4 \ + --num_machines 1 \ + vla-scripts/finetune.py \ + $RESUME_ARGS \ + --data_root_dir "$RLDS_OUTPUT" \ + --dataset_name conflict_maniskill \ + --run_root_dir "$RUN_DIR" \ + --use_l1_regression true \ + --use_film false \ + --num_images_in_input 2 \ + --use_proprio true \ + --batch_size 2 \ + --grad_accumulation_steps 4 \ + --learning_rate 5e-4 \ + --max_steps $MAX_STEPS \ + --save_freq 5000 \ + --save_latest_checkpoint_only false \ + --image_aug true \ + --use_lora true \ + --lora_rank 32 \ + --merge_lora_during_training true \ + --wandb_entity disabled \ + --wandb_project disabled diff --git a/finetune_verb_object_node.sh b/finetune_verb_object_node.sh new file mode 100644 index 0000000000000000000000000000000000000000..3201f0c37c30343c698f3f4fae7a07c8b2870968 --- /dev/null +++ b/finetune_verb_object_node.sh @@ -0,0 +1,106 @@ +#!/bin/bash +#SBATCH --account=nvr_lpr_rvp +#SBATCH --qos=normal +#SBATCH --partition=batch_long +#SBATCH --nodes=1 +#SBATCH --gpus-per-node=4 +#SBATCH --ntasks-per-node=1 +#SBATCH --cpus-per-task=56 +#SBATCH --time=24:00:00 +#SBATCH --job-name=openvla_oft_verb_object +#SBATCH --output=/lustre/fsw/portfolios/nvr/users/jtremblay/yu/logs/openvla_oft_verb_object_%j.out +#SBATCH --error=/lustre/fsw/portfolios/nvr/users/jtremblay/yu/logs/openvla_oft_verb_object_%j.err +#SBATCH --comment=fact_off + +set -euo pipefail + +WORKSPACE=/lustre/fsw/portfolios/nvr/users/jtremblay +OPENVLA_DIR=$WORKSPACE/yu/openvla-oft +CONDA_ENV=$WORKSPACE/conda_envs/openvla-oft +PYTHON=$CONDA_ENV/bin/python +ACCELERATE=$CONDA_ENV/bin/accelerate + +DATA_ROOT=$WORKSPACE/yu/conflict_maniskill/demo_conflict/verb_object/300/huggingface_data/verb_object/conflict +RLDS_OUTPUT=$WORKSPACE/yu/rlds_verb_object +RUN_DIR=$WORKSPACE/yu/openvla-oft/runs/verb_object +MAX_STEPS=40000 + +export HF_HOME=$WORKSPACE/hugging_face +export HF_TOKEN="${HF_TOKEN:-}" +export TOKENIZERS_PARALLELISM=false +export WANDB_MODE=disabled +export PYTHONPATH=$OPENVLA_DIR:${PYTHONPATH:-} +export PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True + +mkdir -p "$RLDS_OUTPUT" "$RUN_DIR" "$WORKSPACE/yu/logs" + +cd "$OPENVLA_DIR" + +# ── Step 1: Build RLDS TFRecords (once) ───────────────────────────────── +if [ ! -d "$RLDS_OUTPUT/conflict_maniskill" ]; then + echo "============================================================" + echo " Building RLDS dataset from parquet for verb_object..." + echo "============================================================" + $PYTHON prismatic/vla/datasets/rlds/conflict_maniskill_dataset_builder.py \ + --data_root "$DATA_ROOT" \ + --output_dir "$RLDS_OUTPUT" +else + echo "RLDS dataset already exists, skipping build." +fi + +# ── Step 2: Detect latest checkpoint ───────────────────────────────────── +LATEST_CKPT="" +LATEST_STEP=0 + +for d in "$RUN_DIR"/*_chkpt; do + [ -d "$d" ] || continue + step=$(basename "$d" | grep -oP '\d+(?=_chkpt)') + if [ -n "$step" ] && [ "$step" -gt "$LATEST_STEP" ]; then + LATEST_STEP=$step + LATEST_CKPT=$d + fi +done + +# ── Step 3: Fine-tune (fresh or resumed) ───────────────────────────────── +if [ "$LATEST_STEP" -ge "$MAX_STEPS" ]; then + echo "Already reached max_steps=$MAX_STEPS (latest checkpoint: step $LATEST_STEP). Nothing to do." + exit 0 +fi + +if [ -n "$LATEST_CKPT" ]; then + echo "============================================================" + echo " Resuming from step $LATEST_STEP: $LATEST_CKPT" + echo "============================================================" + RESUME_ARGS="--resume true --resume_step $LATEST_STEP --vla_path $LATEST_CKPT" +else + echo "============================================================" + echo " Starting fresh fine-tune from openvla/openvla-7b" + echo "============================================================" + RESUME_ARGS="--vla_path openvla/openvla-7b" +fi + +$ACCELERATE launch \ + --mixed_precision bf16 \ + --num_processes 4 \ + --num_machines 1 \ + vla-scripts/finetune.py \ + $RESUME_ARGS \ + --data_root_dir "$RLDS_OUTPUT" \ + --dataset_name conflict_maniskill \ + --run_root_dir "$RUN_DIR" \ + --use_l1_regression true \ + --use_film false \ + --num_images_in_input 2 \ + --use_proprio true \ + --batch_size 2 \ + --grad_accumulation_steps 4 \ + --learning_rate 5e-4 \ + --max_steps $MAX_STEPS \ + --save_freq 5000 \ + --save_latest_checkpoint_only false \ + --image_aug true \ + --use_lora true \ + --lora_rank 32 \ + --merge_lora_during_training true \ + --wandb_entity disabled \ + --wandb_project disabled diff --git a/finetune_verb_size_node.sh b/finetune_verb_size_node.sh new file mode 100644 index 0000000000000000000000000000000000000000..ed094ffa5c817491f9fcc55af6078ff8b95053e5 --- /dev/null +++ b/finetune_verb_size_node.sh @@ -0,0 +1,106 @@ +#!/bin/bash +#SBATCH --account=nvr_lpr_rvp +#SBATCH --qos=normal +#SBATCH --partition=batch_long +#SBATCH --nodes=1 +#SBATCH --gpus-per-node=4 +#SBATCH --ntasks-per-node=1 +#SBATCH --cpus-per-task=56 +#SBATCH --time=24:00:00 +#SBATCH --job-name=openvla_oft_verb_size +#SBATCH --output=/lustre/fsw/portfolios/nvr/users/jtremblay/yu/logs/openvla_oft_verb_size_%j.out +#SBATCH --error=/lustre/fsw/portfolios/nvr/users/jtremblay/yu/logs/openvla_oft_verb_size_%j.err +#SBATCH --comment=fact_off + +set -euo pipefail + +WORKSPACE=/lustre/fsw/portfolios/nvr/users/jtremblay +OPENVLA_DIR=$WORKSPACE/yu/openvla-oft +CONDA_ENV=$WORKSPACE/conda_envs/openvla-oft +PYTHON=$CONDA_ENV/bin/python +ACCELERATE=$CONDA_ENV/bin/accelerate + +DATA_ROOT=$WORKSPACE/yu/conflict_maniskill/demo_conflict/verb_size/300/huggingface_data/verb_size/conflict +RLDS_OUTPUT=$WORKSPACE/yu/rlds_verb_size +RUN_DIR=$WORKSPACE/yu/openvla-oft/runs/verb_size +MAX_STEPS=40000 + +export HF_HOME=$WORKSPACE/hugging_face +export HF_TOKEN="${HF_TOKEN:-}" +export TOKENIZERS_PARALLELISM=false +export WANDB_MODE=disabled +export PYTHONPATH=$OPENVLA_DIR:${PYTHONPATH:-} +export PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True + +mkdir -p "$RLDS_OUTPUT" "$RUN_DIR" "$WORKSPACE/yu/logs" + +cd "$OPENVLA_DIR" + +# ── Step 1: Build RLDS TFRecords (once) ───────────────────────────────── +if [ ! -d "$RLDS_OUTPUT/conflict_maniskill" ]; then + echo "============================================================" + echo " Building RLDS dataset from parquet for verb_size..." + echo "============================================================" + $PYTHON prismatic/vla/datasets/rlds/conflict_maniskill_dataset_builder.py \ + --data_root "$DATA_ROOT" \ + --output_dir "$RLDS_OUTPUT" +else + echo "RLDS dataset already exists, skipping build." +fi + +# ── Step 2: Detect latest checkpoint ───────────────────────────────────── +LATEST_CKPT="" +LATEST_STEP=0 + +for d in "$RUN_DIR"/*_chkpt; do + [ -d "$d" ] || continue + step=$(basename "$d" | grep -oP '\d+(?=_chkpt)') + if [ -n "$step" ] && [ "$step" -gt "$LATEST_STEP" ]; then + LATEST_STEP=$step + LATEST_CKPT=$d + fi +done + +# ── Step 3: Fine-tune (fresh or resumed) ───────────────────────────────── +if [ "$LATEST_STEP" -ge "$MAX_STEPS" ]; then + echo "Already reached max_steps=$MAX_STEPS (latest checkpoint: step $LATEST_STEP). Nothing to do." + exit 0 +fi + +if [ -n "$LATEST_CKPT" ]; then + echo "============================================================" + echo " Resuming from step $LATEST_STEP: $LATEST_CKPT" + echo "============================================================" + RESUME_ARGS="--resume true --resume_step $LATEST_STEP --vla_path $LATEST_CKPT" +else + echo "============================================================" + echo " Starting fresh fine-tune from openvla/openvla-7b" + echo "============================================================" + RESUME_ARGS="--vla_path openvla/openvla-7b" +fi + +$ACCELERATE launch \ + --mixed_precision bf16 \ + --num_processes 4 \ + --num_machines 1 \ + vla-scripts/finetune.py \ + $RESUME_ARGS \ + --data_root_dir "$RLDS_OUTPUT" \ + --dataset_name conflict_maniskill \ + --run_root_dir "$RUN_DIR" \ + --use_l1_regression true \ + --use_film false \ + --num_images_in_input 2 \ + --use_proprio true \ + --batch_size 2 \ + --grad_accumulation_steps 4 \ + --learning_rate 5e-4 \ + --max_steps $MAX_STEPS \ + --save_freq 5000 \ + --save_latest_checkpoint_only false \ + --image_aug true \ + --use_lora true \ + --lora_rank 32 \ + --merge_lora_during_training true \ + --wandb_entity disabled \ + --wandb_project disabled diff --git a/finetune_verb_spatial_node.sh b/finetune_verb_spatial_node.sh new file mode 100644 index 0000000000000000000000000000000000000000..47227cf435e238492b584b9dfe5a655bbfebfd7e --- /dev/null +++ b/finetune_verb_spatial_node.sh @@ -0,0 +1,106 @@ +#!/bin/bash +#SBATCH --account=nvr_lpr_rvp +#SBATCH --qos=normal +#SBATCH --partition=batch_long +#SBATCH --nodes=1 +#SBATCH --gpus-per-node=4 +#SBATCH --ntasks-per-node=1 +#SBATCH --cpus-per-task=56 +#SBATCH --time=24:00:00 +#SBATCH --job-name=openvla_oft_verb_spatial +#SBATCH --output=/lustre/fsw/portfolios/nvr/users/jtremblay/yu/logs/openvla_oft_verb_spatial_%j.out +#SBATCH --error=/lustre/fsw/portfolios/nvr/users/jtremblay/yu/logs/openvla_oft_verb_spatial_%j.err +#SBATCH --comment=fact_off + +set -euo pipefail + +WORKSPACE=/lustre/fsw/portfolios/nvr/users/jtremblay +OPENVLA_DIR=$WORKSPACE/yu/openvla-oft +CONDA_ENV=$WORKSPACE/conda_envs/openvla-oft +PYTHON=$CONDA_ENV/bin/python +ACCELERATE=$CONDA_ENV/bin/accelerate + +DATA_ROOT=$WORKSPACE/yu/conflict_maniskill/demo_conflict/verb_spatial/300/huggingface_data/verb_spatial/conflict +RLDS_OUTPUT=$WORKSPACE/yu/rlds_verb_spatial +RUN_DIR=$WORKSPACE/yu/openvla-oft/runs/verb_spatial +MAX_STEPS=40000 + +export HF_HOME=$WORKSPACE/hugging_face +export HF_TOKEN="${HF_TOKEN:-}" +export TOKENIZERS_PARALLELISM=false +export WANDB_MODE=disabled +export PYTHONPATH=$OPENVLA_DIR:${PYTHONPATH:-} +export PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True + +mkdir -p "$RLDS_OUTPUT" "$RUN_DIR" "$WORKSPACE/yu/logs" + +cd "$OPENVLA_DIR" + +# ── Step 1: Build RLDS TFRecords (once) ───────────────────────────────── +if [ ! -d "$RLDS_OUTPUT/conflict_maniskill" ]; then + echo "============================================================" + echo " Building RLDS dataset from parquet for verb_spatial..." + echo "============================================================" + $PYTHON prismatic/vla/datasets/rlds/conflict_maniskill_dataset_builder.py \ + --data_root "$DATA_ROOT" \ + --output_dir "$RLDS_OUTPUT" +else + echo "RLDS dataset already exists, skipping build." +fi + +# ── Step 2: Detect latest checkpoint ───────────────────────────────────── +LATEST_CKPT="" +LATEST_STEP=0 + +for d in "$RUN_DIR"/*_chkpt; do + [ -d "$d" ] || continue + step=$(basename "$d" | grep -oP '\d+(?=_chkpt)') + if [ -n "$step" ] && [ "$step" -gt "$LATEST_STEP" ]; then + LATEST_STEP=$step + LATEST_CKPT=$d + fi +done + +# ── Step 3: Fine-tune (fresh or resumed) ───────────────────────────────── +if [ "$LATEST_STEP" -ge "$MAX_STEPS" ]; then + echo "Already reached max_steps=$MAX_STEPS (latest checkpoint: step $LATEST_STEP). Nothing to do." + exit 0 +fi + +if [ -n "$LATEST_CKPT" ]; then + echo "============================================================" + echo " Resuming from step $LATEST_STEP: $LATEST_CKPT" + echo "============================================================" + RESUME_ARGS="--resume true --resume_step $LATEST_STEP --vla_path $LATEST_CKPT" +else + echo "============================================================" + echo " Starting fresh fine-tune from openvla/openvla-7b" + echo "============================================================" + RESUME_ARGS="--vla_path openvla/openvla-7b" +fi + +$ACCELERATE launch \ + --mixed_precision bf16 \ + --num_processes 4 \ + --num_machines 1 \ + vla-scripts/finetune.py \ + $RESUME_ARGS \ + --data_root_dir "$RLDS_OUTPUT" \ + --dataset_name conflict_maniskill \ + --run_root_dir "$RUN_DIR" \ + --use_l1_regression true \ + --use_film false \ + --num_images_in_input 2 \ + --use_proprio true \ + --batch_size 2 \ + --grad_accumulation_steps 4 \ + --learning_rate 5e-4 \ + --max_steps $MAX_STEPS \ + --save_freq 5000 \ + --save_latest_checkpoint_only false \ + --image_aug true \ + --use_lora true \ + --lora_rank 32 \ + --merge_lora_during_training true \ + --wandb_entity disabled \ + --wandb_project disabled diff --git a/prismatic/__init__.py b/prismatic/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..fad1d6a59fcb09f71bf70a2a9f3b890f8476c18f --- /dev/null +++ b/prismatic/__init__.py @@ -0,0 +1 @@ +from .models import available_model_names, available_models, get_model_description, load diff --git a/prismatic/conf/__init__.py b/prismatic/conf/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..d0af60ce04bf5b23d2cec9380f575d523e61997f --- /dev/null +++ b/prismatic/conf/__init__.py @@ -0,0 +1,3 @@ +from .datasets import DatasetConfig, DatasetRegistry +from .models import ModelConfig, ModelRegistry +from .vla import VLAConfig, VLARegistry diff --git a/prismatic/conf/datasets.py b/prismatic/conf/datasets.py new file mode 100644 index 0000000000000000000000000000000000000000..897ab3092e232321628f284a5e1926db21feb2bf --- /dev/null +++ b/prismatic/conf/datasets.py @@ -0,0 +1,133 @@ +""" +datasets.py + +Draccus Dataclass Definition for a DatasetConfig object, with various registered subclasses for each dataset variant +and processing scheme. A given dataset variant (e.g., `llava-lightning`) configures the following attributes: + - Dataset Variant (Identifier) --> e.g., "llava-v15" + - Align Stage Dataset Components (annotations, images) + - Finetune Stage Dataset Components (annotations, images) + - Dataset Root Directory (Path) +""" + +from dataclasses import dataclass +from enum import Enum, unique +from pathlib import Path +from typing import Tuple + +from draccus import ChoiceRegistry + + +@dataclass +class DatasetConfig(ChoiceRegistry): + # fmt: off + dataset_id: str # Unique ID that fully specifies a dataset variant + + # Dataset Components for each Stage in < align | finetune > + align_stage_components: Tuple[Path, Path] # Path to annotation file and images directory for `align` stage + finetune_stage_components: Tuple[Path, Path] # Path to annotation file and images directory for `finetune` stage + + dataset_root_dir: Path # Path to dataset root directory; others paths are relative to root + # fmt: on + + +# [Reproduction] LLaVa-v15 (exact dataset used in all public LLaVa-v15 models) +@dataclass +class LLaVa_V15_Config(DatasetConfig): + dataset_id: str = "llava-v15" + + align_stage_components: Tuple[Path, Path] = ( + Path("download/llava-laion-cc-sbu-558k/chat.json"), + Path("download/llava-laion-cc-sbu-558k/"), + ) + finetune_stage_components: Tuple[Path, Path] = ( + Path("download/llava-v1.5-instruct/llava_v1_5_mix665k.json"), + Path("download/llava-v1.5-instruct/"), + ) + dataset_root_dir: Path = Path("/mnt/fsx/skaramcheti/datasets/prismatic-vlms") + + +# [Multimodal-Only] LLava-v15 WITHOUT the Language-Only ShareGPT Data (No Co-Training) +@dataclass +class LLaVa_Multimodal_Only_Config(DatasetConfig): + dataset_id: str = "llava-multimodal" + + align_stage_components: Tuple[Path, Path] = ( + Path("download/llava-laion-cc-sbu-558k/chat.json"), + Path("download/llava-laion-cc-sbu-558k/"), + ) + finetune_stage_components: Tuple[Path, Path] = ( + Path("download/llava-v1.5-instruct/llava_v1_5_stripped625k.json"), + Path("download/llava-v1.5-instruct/"), + ) + dataset_root_dir: Path = Path("/mnt/fsx/skaramcheti/datasets/prismatic-vlms") + + +# LLaVa-v15 + LVIS-Instruct-4V +@dataclass +class LLaVa_LVIS4V_Config(DatasetConfig): + dataset_id: str = "llava-lvis4v" + + align_stage_components: Tuple[Path, Path] = ( + Path("download/llava-laion-cc-sbu-558k/chat.json"), + Path("download/llava-laion-cc-sbu-558k/"), + ) + finetune_stage_components: Tuple[Path, Path] = ( + Path("download/llava-v1.5-instruct/llava_v1_5_lvis4v_mix888k.json"), + Path("download/llava-v1.5-instruct/"), + ) + dataset_root_dir: Path = Path("/mnt/fsx/skaramcheti/datasets/prismatic-vlms") + + +# LLaVa-v15 + LRV-Instruct +@dataclass +class LLaVa_LRV_Config(DatasetConfig): + dataset_id: str = "llava-lrv" + + align_stage_components: Tuple[Path, Path] = ( + Path("download/llava-laion-cc-sbu-558k/chat.json"), + Path("download/llava-laion-cc-sbu-558k/"), + ) + finetune_stage_components: Tuple[Path, Path] = ( + Path("download/llava-v1.5-instruct/llava_v1_5_lrv_mix1008k.json"), + Path("download/llava-v1.5-instruct/"), + ) + dataset_root_dir: Path = Path("/mnt/fsx/skaramcheti/datasets/prismatic-vlms") + + +# LLaVa-v15 + LVIS-Instruct-4V + LRV-Instruct +@dataclass +class LLaVa_LVIS4V_LRV_Config(DatasetConfig): + dataset_id: str = "llava-lvis4v-lrv" + + align_stage_components: Tuple[Path, Path] = ( + Path("download/llava-laion-cc-sbu-558k/chat.json"), + Path("download/llava-laion-cc-sbu-558k/"), + ) + finetune_stage_components: Tuple[Path, Path] = ( + Path("download/llava-v1.5-instruct/llava_v1_5_lvis4v_lrv_mix1231k.json"), + Path("download/llava-v1.5-instruct/"), + ) + dataset_root_dir: Path = Path("/mnt/fsx/skaramcheti/datasets/prismatic-vlms") + + +# === Define a Dataset Registry Enum for Reference & Validation =>> all *new* datasets must be added here! === +@unique +class DatasetRegistry(Enum): + # === LLaVa v1.5 === + LLAVA_V15 = LLaVa_V15_Config + + LLAVA_MULTIMODAL_ONLY = LLaVa_Multimodal_Only_Config + + LLAVA_LVIS4V = LLaVa_LVIS4V_Config + LLAVA_LRV = LLaVa_LRV_Config + + LLAVA_LVIS4V_LRV = LLaVa_LVIS4V_LRV_Config + + @property + def dataset_id(self) -> str: + return self.value.dataset_id + + +# Register Datasets in Choice Registry +for dataset_variant in DatasetRegistry: + DatasetConfig.register_subclass(dataset_variant.dataset_id, dataset_variant.value) diff --git a/prismatic/conf/models.py b/prismatic/conf/models.py new file mode 100644 index 0000000000000000000000000000000000000000..6f507b0dd0d7df45f1d12de304425753a04aa732 --- /dev/null +++ b/prismatic/conf/models.py @@ -0,0 +1,584 @@ +""" +models.py + +Draccus Dataclass Definition for a ModelConfig object, with various registered subclasses for each model family and +variant thereof. A given model variant configures the following attributes: + - Pretrained Visual Representation (e.g., OpenAI CLIP ViT-L/14) + Pretrained LLM Backbone (e.g., LLaMa-2 7B) + - VLM Configuration + Parameters (e.g., MLP Projector, Image Preprocessing, etc.) + - [Optional] Stage 1 (`align`) Optimization Hyperparameters + - Stage 2 (`finetune`) Optimization Hyperparameters +""" + +from dataclasses import dataclass +from enum import Enum, unique +from typing import Optional + +from draccus import ChoiceRegistry + + +@dataclass +class ModelConfig(ChoiceRegistry): + # fmt: off + model_id: str # Unique Model ID that fully specifies a given variant + arch_specifier: str # Architecture specifier string (e.g., "gelu-mlp") + + # Pretrained Backbones + vision_backbone_id: str # Pretrained Visual Featurizer (from TIMM) to load + llm_backbone_id: str # Pretrained LLM (from HF Transformers) to load + + # Backbone Parameters + image_resize_strategy: str # Resizing strategy in < crop | letterbox | corner-pad > + llm_max_length: int # Maximum context length for LLM (can be < than max!) + + # === Multi-Stage Optimization Hyperparameters === + # By default, we assume an AdamW optimizer with FSDP (Gradient Sharding or Full Sharding depending on stage) + + # Align Stage Optimization Parameters + align_epochs: int # Epochs to Run (in case `max_steps` is not specified) + align_max_steps: Optional[int] # [Optional] Max Gradient Steps (overrides epochs) + align_global_batch_size: int # Global Batch Size (divided across processes) + align_per_device_batch_size: int # Per-Device Batch Size (per-process) + # => # of accumulation steps is auto-computed + + align_learning_rate: float # Peak Learning Rate (lr_scheduler sets warmup/decay) + align_weight_decay: float # Weight Decay for AdamW Optimizer + align_max_grad_norm: float # Max Grad Norm (for global gradient clipping) + align_lr_scheduler_type: str # LR Scheduler (default: "linear-warmup+cosine-decay") + align_warmup_ratio: float # Fraction of total steps to warmup + + align_train_strategy: str # Align Train Strategy (default: "fsdp-shard-grad-op") + + # Finetune Stage Optimization Parameters + finetune_epochs: int # Epochs to Run (in case `max_steps` is not specified) + finetune_max_steps: Optional[int] # [Optional] Max Gradient Steps (overrides epochs) + finetune_global_batch_size: int # Global Batch Size (divided across processes) + finetune_per_device_batch_size: int # Per-Device Batch Size (per-process) + # => # of accumulation steps is auto-computed + + finetune_learning_rate: float # Peak Learning Rate (lr_scheduler sets warmup/decay) + finetune_weight_decay: float # Weight Decay for AdamW Optimizer + finetune_max_grad_norm: float # Max Grad Norm (for global gradient clipping) + finetune_lr_scheduler_type: str # LR Scheduler (default: "linear-warmup+cosine-decay") + finetune_warmup_ratio: float # Fraction of total steps to warmup + + finetune_train_strategy: str # Finetune Train Strategy (default: "fsdp-full-shard") + + # Enable Gradient/Activation Checkpointing (for the LLM Backbone) + enable_gradient_checkpointing: bool = True + + # Enable Traditional Mixed Precision Training via Torch Native AMP (`autocast`) + enable_mixed_precision_training: bool = True # Whether to enable mixed precision training + reduce_in_full_precision: bool = False # Whether to run gradient reduction in FP32 + + # fmt: on + + +# === LLaVa v1.5 Reproduction - Fully Specified Configurations === +@dataclass +class LLaVa_v15_Reproduction_7B(ModelConfig): + model_id: str = "reproduction-llava-v15+7b" + arch_specifier: str = "gelu-mlp" + + vision_backbone_id: str = "clip-vit-l-336px" + llm_backbone_id: str = "vicuna-v15-7b" + + image_resize_strategy: str = "letterbox" + llm_max_length: int = 2048 + + # Align Stage Optimization Parameters + align_epochs: int = 1 + align_max_steps: Optional[int] = None + align_global_batch_size: int = 256 + align_per_device_batch_size: int = 16 + + align_learning_rate: float = 1e-3 + align_weight_decay: float = 0.0 + align_max_grad_norm: float = 1.0 + align_lr_scheduler_type: str = "linear-warmup+cosine-decay" + align_warmup_ratio: float = 0.03 + + align_train_strategy: str = "fsdp-shard-grad-op" + + # Finetune Stage Optimization Parameters + finetune_epochs: int = 1 + finetune_max_steps: Optional[int] = None + finetune_global_batch_size: int = 128 + finetune_per_device_batch_size: int = 16 + + finetune_learning_rate: float = 2e-5 + finetune_weight_decay: float = 0.1 + finetune_max_grad_norm: float = 1.0 + finetune_lr_scheduler_type: str = "linear-warmup+cosine-decay" + finetune_warmup_ratio: float = 0.03 + + finetune_train_strategy: str = "fsdp-full-shard" + + +@dataclass +class LLaVa_v15_Reproduction_13B(LLaVa_v15_Reproduction_7B): + model_id: str = "reproduction-llava-v15+13b" + llm_backbone_id: str = "vicuna-v15-13b" + + +# === Section 4.1 :: Optimization Procedure === + + +# Section 4.1A :: 🚀 --> Necessity of Multi-Stage Training +@dataclass +class Exp_7B_One_Stage(LLaVa_v15_Reproduction_7B): + model_id: str = "one-stage+7b" + arch_specifier: str = "no-align+gelu-mlp" + + +@dataclass +class Exp_13B_One_Stage(LLaVa_v15_Reproduction_13B): + model_id: str = "one-stage+13b" + arch_specifier: str = "no-align+gelu-mlp" + + +# Section 4.1B :: 🛠️ --> Full Finetuning through Visual Backbones +# =>> Note :: Run with `--stage full-finetune` +@dataclass +class Exp_7B_Full_Finetune_Multi_Stage(LLaVa_v15_Reproduction_7B): + model_id: str = "full-ft-multi-stage+7b" + + +@dataclass +class Exp_7B_Full_Finetune_One_Stage(Exp_7B_One_Stage): + model_id: str = "full-ft-one-stage+7b" + + +# === Section 4.2 :: Image Processing and Visual Representations === + + +# Section 4.2A :: 📸 --> Choosing a Pretrained Representation +@dataclass +class Exp_7B_IN1K_ViT_L_p16_224px(Exp_7B_One_Stage): + model_id: str = "in1k-224px+7b" + vision_backbone_id: str = "in1k-vit-l" + + +@dataclass +class Exp_7B_DINOv2_ViT_L_p14_224px(Exp_7B_One_Stage): + model_id: str = "dinov2-224px+7b" + vision_backbone_id: str = "dinov2-vit-l" + + +@dataclass +class Exp_7B_CLIP_ViT_L_p14_224px(Exp_7B_One_Stage): + model_id: str = "clip-224px+7b" + vision_backbone_id: str = "clip-vit-l" + + +@dataclass +class Exp_7B_SigLIP_ViT_SO_p14_224px(Exp_7B_One_Stage): + model_id: str = "siglip-224px+7b" + vision_backbone_id: str = "siglip-vit-so400m" + + +# Section 4.2B :: 📐 --> Choosing an Image Preprocessing Strategy +@dataclass +class Exp_7B_CLIP_ViT_L_p14_336px_Resize_Crop(Exp_7B_One_Stage): + model_id: str = "clip-336px-resize-crop+7b" + image_resize_strategy: str = "resize-crop" + + +@dataclass +class Exp_7B_CLIP_ViT_L_p14_336px_Resize_Naive(Exp_7B_One_Stage): + model_id: str = "clip-336px-resize-naive+7b" + image_resize_strategy: str = "resize-naive" + + +@dataclass +class Exp_7B_SigLIP_ViT_SO_p14_384px_Letterbox(Exp_7B_One_Stage): + model_id: str = "siglip-384px-letterbox+7b" + vision_backbone_id: str = "siglip-vit-so400m-384px" + image_resize_strategy: str = "letterbox" + + +@dataclass +class Exp_7B_SigLIP_ViT_SO_p14_384px_Resize_Crop(Exp_7B_One_Stage): + model_id: str = "siglip-384px-resize-crop+7b" + vision_backbone_id: str = "siglip-vit-so400m-384px" + image_resize_strategy: str = "resize-crop" + + +@dataclass +class Exp_7B_SigLIP_ViT_SO_p14_384px_Resize_Naive(Exp_7B_One_Stage): + model_id: str = "siglip-384px-resize-naive+7b" + vision_backbone_id: str = "siglip-vit-so400m-384px" + image_resize_strategy: str = "resize-naive" + + +# Section 4.2D :: 🥞 --> Stacking/Ensembling Visual Representations +@dataclass +class Exp_7B_DINOCLIP_ViT_L_p14_336px_Letterbox(Exp_7B_One_Stage): + model_id: str = "dinoclip-336px-letterbox+7b" + vision_backbone_id: str = "dinoclip-vit-l-336px" + image_resize_strategy: str = "letterbox" + arch_specifier: str = "no-align+fused-gelu-mlp" + + +@dataclass +class Exp_7B_DINOCLIP_ViT_L_p14_336px_Resize_Naive(Exp_7B_One_Stage): + model_id: str = "dinoclip-336px-resize-naive+7b" + vision_backbone_id: str = "dinoclip-vit-l-336px" + image_resize_strategy: str = "resize-naive" + arch_specifier: str = "no-align+fused-gelu-mlp" + + +@dataclass +class Exp_7B_DINOSigLIP_ViT_L_p14_384px_Letterbox(Exp_7B_One_Stage): + model_id: str = "dinosiglip-384px-letterbox+7b" + vision_backbone_id: str = "dinosiglip-vit-so-384px" + image_resize_strategy: str = "letterbox" + arch_specifier: str = "no-align+fused-gelu-mlp" + + +@dataclass +class Exp_7B_DINOSigLIP_ViT_L_p14_384px_Resize_Naive(Exp_7B_One_Stage): + model_id: str = "dinosiglip-384px-resize-naive+7b" + vision_backbone_id: str = "dinosiglip-vit-so-384px" + image_resize_strategy: str = "resize-naive" + arch_specifier: str = "no-align+fused-gelu-mlp" + + +# === Section 4.3 :: Language Models === + + +# Section 4.3A :: 📝 --> Base vs. Instruct-Tuned (Chat) LLMs +@dataclass +class Exp_7B_Llama2(Exp_7B_One_Stage): + model_id: str = "llama2+7b" + llm_backbone_id: str = "llama2-7b-pure" + + +@dataclass +class Exp_13B_Llama2(Exp_13B_One_Stage): + model_id: str = "llama2+13b" + llm_backbone_id: str = "llama2-13b-pure" + + +# ~ Additional LLM Backbones :: LLaMa-2 Chat, Mistral v0.1, Mistral v0.1 Instruct, Phi-2 ~ +@dataclass +class Ext_Exp_7B_Llama2_Chat(Exp_7B_One_Stage): + model_id: str = "llama2-chat+7b" + llm_backbone_id: str = "llama2-7b-chat" + + +@dataclass +class Ext_Exp_13B_Llama2_Chat(Exp_13B_One_Stage): + model_id: str = "llama2-chat+13b" + llm_backbone_id: str = "llama2-13b-chat" + + +@dataclass +class Ext_Exp_7B_Mistral_V1(Exp_7B_One_Stage): + model_id: str = "mistral-v0.1+7b" + llm_backbone_id: str = "mistral-v0.1-7b-pure" + + +@dataclass +class Ext_Exp_7B_Mistral_Instruct_V1(Exp_7B_One_Stage): + model_id: str = "mistral-instruct-v0.1+7b" + llm_backbone_id: str = "mistral-v0.1-7b-instruct" + + +@dataclass +class Ext_Exp_3B_Phi_2(Exp_7B_One_Stage): + model_id: str = "phi-2+3b" + llm_backbone_id: str = "phi-2-3b" + + +# Section 4.3B :: ✌️ --> Co-training on Language-only Data +# =>> Note :: Run with `--dataset.type "llava-multimodal" (multimodal data only / no co-training) +@dataclass +class Exp_7B_Vicuna_No_Cotraining(Exp_7B_One_Stage): + model_id: str = "vicuna-no-cotraining+7b" + + +@dataclass +class Exp_7B_Llama2_No_Cotraining(Exp_7B_One_Stage): + model_id: str = "llama2-no-cotraining+7b" + llm_backbone_id: str = "llama2-7b-pure" + + +# === Section 4.4 :: Scaling Properties - Train Time & Data === + + +# Section 4.4A :: ⏰ --> Scaling Train Time +@dataclass +class Exp_7B_1p25_Epochs(Exp_7B_One_Stage): + model_id: str = "train-1.25-epochs+7b" + finetune_max_steps: int = 6500 + + +@dataclass +class Exp_7B_1p5_Epochs(Exp_7B_One_Stage): + model_id: str = "train-1.5-epochs+7b" + finetune_max_steps: int = 7800 + + +@dataclass +class Exp_7B_2_Epochs(Exp_7B_One_Stage): + model_id: str = "train-2-epochs+7b" + finetune_epochs: int = 2 + + +@dataclass +class Exp_7B_3_Epochs(Exp_7B_One_Stage): + model_id: str = "train-3-epochs+7b" + finetune_epochs: int = 3 + + +# Section 4.4B :: 📚 --> Scaling Data +# =>> Note :: Run with `--dataset.type "llava-lvis4v"` +@dataclass +class Exp_7B_LLaVa_LVIS4V(Exp_7B_One_Stage): + model_id: str = "llava-lvis4v+7b" + + +# =>> Note :: Run with `--dataset.type "llava-lrv"` +@dataclass +class Exp_7B_LLaVa_LRV(Exp_7B_One_Stage): + model_id: str = "llava-lrv+7b" + + +# =>> Note :: Run with `--dataset.type "llava-lvis4v-lrv"` +@dataclass +class Exp_7B_LLaVa_LVIS4V_LRV(Exp_7B_One_Stage): + model_id: str = "llava-lvis4v-lrv+7b" + + +# === Section 5 :: Prisms === + + +# Prism-CLIP +@dataclass +class Prism_7B_CLIP_Controlled(Exp_7B_One_Stage): + model_id: str = "prism-clip-controlled+7b" + vision_backbone_id: str = "clip-vit-l-336px" + image_resize_strategy: str = "resize-naive" + llm_backbone_id: str = "llama2-7b-pure" + + +@dataclass +class Prism_13B_CLIP_Controlled(Exp_13B_One_Stage): + model_id: str = "prism-clip-controlled+13b" + vision_backbone_id: str = "clip-vit-l-336px" + image_resize_strategy: str = "resize-naive" + llm_backbone_id: str = "llama2-13b-pure" + + +# =>> Note :: Run with `--dataset.type "llava-lvis4v-lrv"` +@dataclass +class Prism_7B_CLIP(Exp_7B_One_Stage): + model_id: str = "prism-clip+7b" + vision_backbone_id: str = "clip-vit-l-336px" + image_resize_strategy: str = "resize-naive" + llm_backbone_id: str = "llama2-7b-pure" + finetune_epochs: int = 2 + + +# =>> Note :: Run with `--dataset.type "llava-lvis4v-lrv"` +@dataclass +class Prism_13B_CLIP(Exp_13B_One_Stage): + model_id: str = "prism-clip+13b" + vision_backbone_id: str = "clip-vit-l-336px" + image_resize_strategy: str = "resize-naive" + llm_backbone_id: str = "llama2-13b-pure" + finetune_epochs: int = 2 + + +# Prism-SigLIP +@dataclass +class Prism_7B_SigLIP_Controlled(Exp_7B_One_Stage): + model_id: str = "prism-siglip-controlled+7b" + vision_backbone_id: str = "siglip-vit-so400m-384px" + image_resize_strategy: str = "resize-naive" + llm_backbone_id: str = "llama2-7b-pure" + + +@dataclass +class Prism_13B_SigLIP_Controlled(Exp_13B_One_Stage): + model_id: str = "prism-siglip-controlled+13b" + vision_backbone_id: str = "siglip-vit-so400m-384px" + image_resize_strategy: str = "resize-naive" + llm_backbone_id: str = "llama2-13b-pure" + + +# =>> Note :: Run with `--dataset.type "llava-lvis4v-lrv"` +@dataclass +class Prism_7B_SigLIP(Exp_7B_One_Stage): + model_id: str = "prism-siglip+7b" + vision_backbone_id: str = "siglip-vit-so400m-384px" + image_resize_strategy: str = "resize-naive" + llm_backbone_id: str = "llama2-7b-pure" + finetune_epochs: int = 2 + + +# =>> Note :: Run with `--dataset.type "llava-lvis4v-lrv"` +@dataclass +class Prism_13B_SigLIP(Exp_13B_One_Stage): + model_id: str = "prism-siglip+13b" + vision_backbone_id: str = "clip-vit-l-336px" + image_resize_strategy: str = "resize-naive" + llm_backbone_id: str = "llama2-13b-pure" + finetune_epochs: int = 2 + + +# Prism-DINOSigLIP +@dataclass +class Prism_7B_DINOSigLIP_Controlled(Exp_7B_One_Stage): + model_id: str = "prism-dinosiglip-controlled+7b" + vision_backbone_id: str = "dinosiglip-vit-so-384px" + image_resize_strategy: str = "resize-naive" + llm_backbone_id: str = "llama2-7b-pure" + arch_specifier: str = "no-align+fused-gelu-mlp" + + +@dataclass +class Prism_13B_DINOSigLIP_Controlled(Exp_13B_One_Stage): + model_id: str = "prism-dinosiglip-controlled+13b" + vision_backbone_id: str = "dinosiglip-vit-so-384px" + image_resize_strategy: str = "resize-naive" + llm_backbone_id: str = "llama2-13b-pure" + arch_specifier: str = "no-align+fused-gelu-mlp" + + +# =>> Note :: Run with `--dataset.type "llava-lvis4v-lrv"` +@dataclass +class Prism_7B_DINOSigLIP(Exp_7B_One_Stage): + model_id: str = "prism-dinosiglip+7b" + vision_backbone_id: str = "dinosiglip-vit-so-384px" + image_resize_strategy: str = "resize-naive" + llm_backbone_id: str = "llama2-7b-pure" + arch_specifier: str = "no-align+fused-gelu-mlp" + finetune_epochs: int = 2 + + +# =>> Note :: Run with `--dataset.type "llava-lvis4v-lrv"` +@dataclass +class Prism_13B_DINOSigLIP(Exp_13B_One_Stage): + model_id: str = "prism-dinosiglip+13b" + vision_backbone_id: str = "dinosiglip-vit-so-384px" + image_resize_strategy: str = "resize-naive" + llm_backbone_id: str = "llama2-13b-pure" + arch_specifier: str = "no-align+fused-gelu-mlp" + finetune_epochs: int = 2 + + +# [Inference-Optimized] 224px Prisms +@dataclass +class Opt_7B_DINOSigLIP_ViT_SO_p14_224px_Resize_Naive(Exp_7B_One_Stage): + model_id: str = "dinosiglip-224px-resize-naive+7b" + vision_backbone_id: str = "dinosiglip-vit-so-224px" + image_resize_strategy: str = "resize-naive" + arch_specifier: str = "no-align+fused-gelu-mlp" + + +@dataclass +class Prism_7B_DINOSigLIP_224px_Controlled(Exp_7B_One_Stage): + model_id: str = "prism-dinosiglip-224px-controlled+7b" + vision_backbone_id: str = "dinosiglip-vit-so-224px" + image_resize_strategy: str = "resize-naive" + llm_backbone_id: str = "llama2-7b-pure" + arch_specifier: str = "no-align+fused-gelu-mlp" + + +# =>> Note :: Run with `--dataset.type "llava-lvis4v-lrv"` +@dataclass +class Prism_7B_DINOSigLIP_224px(Exp_7B_One_Stage): + model_id: str = "prism-dinosiglip-224px+7b" + vision_backbone_id: str = "dinosiglip-vit-so-224px" + image_resize_strategy: str = "resize-naive" + llm_backbone_id: str = "llama2-7b-pure" + arch_specifier: str = "no-align+fused-gelu-mlp" + finetune_epochs: int = 2 + + +# === Define a Model Registry Enum for Reference & Validation === +@unique +class ModelRegistry(Enum): + # === LLaVa v1.5 Base Reproductions === + REPRODUCTION_7B = LLaVa_v15_Reproduction_7B + REPRODUCTION_13B = LLaVa_v15_Reproduction_13B + + # === Section 4.1 :: Optimization Procedure === + EXP_ONE_STAGE_7B = Exp_7B_One_Stage + EXP_ONE_STAGE_13B = Exp_13B_One_Stage + + EXP_FULL_FT_MULTI_STAGE = Exp_7B_Full_Finetune_Multi_Stage + EXP_FULL_FT_ONE_STAGE = Exp_7B_Full_Finetune_One_Stage + + # === Section 4.2 :: Image Processing and Visual Representations === + EXP_IN1K_224PX = Exp_7B_IN1K_ViT_L_p16_224px + EXP_DINOV2_224PX = Exp_7B_DINOv2_ViT_L_p14_224px + EXP_CLIP_224PX = Exp_7B_CLIP_ViT_L_p14_224px + EXP_SIGLIP_224PX = Exp_7B_SigLIP_ViT_SO_p14_224px + + EXP_CLIP_336PX_RESIZE_CROP = Exp_7B_CLIP_ViT_L_p14_336px_Resize_Crop + EXP_CLIP_336PX_RESIZE_NAIVE = Exp_7B_CLIP_ViT_L_p14_336px_Resize_Naive + EXP_SIGLIP_384PX_LETTERBOX = Exp_7B_SigLIP_ViT_SO_p14_384px_Letterbox + EXP_SIGLIP_384PX_RESIZE_CROP = Exp_7B_SigLIP_ViT_SO_p14_384px_Resize_Crop + EXP_SIGLIP_384PX_RESIZE_NAIVE = Exp_7B_SigLIP_ViT_SO_p14_384px_Resize_Naive + + EXP_DINOCLIP_336PX_LETTERBOX = Exp_7B_DINOCLIP_ViT_L_p14_336px_Letterbox + EXP_DINOCLIP_336PX_RESIZE_NAIVE = Exp_7B_DINOCLIP_ViT_L_p14_336px_Resize_Naive + EXP_DINOSIGLIP_384PX_LETTERBOX = Exp_7B_DINOSigLIP_ViT_L_p14_384px_Letterbox + EXP_DINOSIGLIP_384PX_RESIZE_NAIVE = Exp_7B_DINOSigLIP_ViT_L_p14_384px_Resize_Naive + + # === Section 4.3 :: Language Models === + EXP_LLAMA2_7B = Exp_7B_Llama2 + EXP_LLAMA2_13B = Exp_13B_Llama2 + + # ~ Additional LLM Backbone Experiments :: LLaMa-2 Chat, Mistral v0.1, Mistral v0.1 Instruct ~ + EXT_EXP_LLAMA2_CHAT_7B = Ext_Exp_7B_Llama2_Chat + EXT_EXP_LLAMA2_CHAT_13B = Ext_Exp_13B_Llama2_Chat + EXT_EXP_MISTRAL_V1_7B = Ext_Exp_7B_Mistral_V1 + EXT_EXP_MISTRAL_INSTRUCT_V1_7B = Ext_Exp_7B_Mistral_Instruct_V1 + EXT_EXP_PHI_2_3B = Ext_Exp_3B_Phi_2 + + # Cotraining w/ Unimodal Data + EXP_VICUNA_NO_COTRAINING_7B = Exp_7B_Vicuna_No_Cotraining + EXP_LLAMA2_NO_COTRAINING_7B = Exp_7B_Llama2_No_Cotraining + + # === Section 4.4 :: Scaling Properties - Train Time & Data === + EXP_1P25_EPOCHS = Exp_7B_1p25_Epochs + EXP_1P5_EPOCHS = Exp_7B_1p5_Epochs + EXP_2_EPOCHS = Exp_7B_2_Epochs + EXP_3_EPOCHS = Exp_7B_3_Epochs + + EXP_LLAVA_LVIS4V = Exp_7B_LLaVa_LVIS4V + EXP_LLAVA_LRV = Exp_7B_LLaVa_LRV + EXP_LLAVA_LVIS4V_LRV = Exp_7B_LLaVa_LVIS4V_LRV + + # === Section 5 :: Prisms === + PRISM_CLIP_CONTROLLED_7B = Prism_7B_CLIP_Controlled + PRISM_CLIP_CONTROLLED_13B = Prism_13B_CLIP_Controlled + PRISM_CLIP_7B = Prism_7B_CLIP + PRISM_CLIP_13B = Prism_13B_CLIP + + PRISM_SIGLIP_CONTROLLED_7B = Prism_7B_SigLIP_Controlled + PRISM_SIGLIP_CONTROLLED_13B = Prism_13B_SigLIP_Controlled + PRISM_SIGLIP_7B = Prism_7B_SigLIP + PRISM_SIGLIP_13B = Prism_13B_SigLIP + + PRISM_DINOSIGLIP_CONTROLLED_7B = Prism_7B_DINOSigLIP_Controlled + PRISM_DINOSIGLIP_CONTROLLED_13B = Prism_13B_DINOSigLIP_Controlled + PRISM_DINOSIGLIP_7B = Prism_7B_DINOSigLIP + PRISM_DINOSIGLIP_13B = Prism_13B_DINOSigLIP + + # === Inference Optimized :: 224px Prisms === + OPT_DINOSIGLIP_224PX_RESIZE_NAIVE = Opt_7B_DINOSigLIP_ViT_SO_p14_224px_Resize_Naive + PRISM_DINOSIGLIP_224PX_CONTROLLED_7B = Prism_7B_DINOSigLIP_224px_Controlled + PRISM_DINOSIGLIP_224PX_7B = Prism_7B_DINOSigLIP_224px + + @property + def model_id(self) -> str: + return self.value.model_id + + +# Register Models in Choice Registry +for model_variant in ModelRegistry: + ModelConfig.register_subclass(model_variant.model_id, model_variant.value) diff --git a/prismatic/conf/vla.py b/prismatic/conf/vla.py new file mode 100644 index 0000000000000000000000000000000000000000..94d2a2b701629d99bd8b87ab0c36e13470b691a8 --- /dev/null +++ b/prismatic/conf/vla.py @@ -0,0 +1,235 @@ +""" +vla.py + +Draccus Dataclass Definition for a VLAConfig object, with various registered subclasses for each VLA experiment and +model configuration thereof. A given VLA model (`policy`) configures the following attributes: + - Data Mixture (e.g., Bridge, OXE_MAGIC_SOUP, etc.) + - Base VLM from Prismatic Registry (e.g., `prism-dinosiglip+7b`) + - VLA Model Architecture / Parameters (e.g., freeze vision encoder, last layer finetuning) + - Training / Optimization Hyperparameters +""" + +from dataclasses import dataclass +from enum import Enum, unique +from pathlib import Path +from typing import Optional, Union + +from draccus import ChoiceRegistry + + +@dataclass +class VLAConfig(ChoiceRegistry): + # fmt: off + vla_id: str # Unique VLA Policy ID that fully specifies a configuration variant + base_vlm: Union[str, Path] # Base VLM as ID/Path to Run Directory (e.g., `prism-dinosiglip+7b`) + freeze_vision_backbone: bool # Freeze Vision Backbone Parameters (akin to pretraining) + freeze_llm_backbone: bool # Freeze LLM Backbone parameters + unfreeze_last_llm_layer: bool # Unfreeze final layer of LLM (only takes effect if LLM is frozen) + + # Data Mixture Parameters + data_mix: str # Open-X Embodiment Dataset =>> Unique Mixture ID (e.g., `bridge`) + shuffle_buffer_size: int # Size of Shuffle Buffer (100K for Bridge, 1M for OXE) + + # Optimization Parameters + epochs: int # Epochs to Run (in case `max_steps` is not specified) + max_steps: Optional[int] # [Optional] Max Gradient Steps to Run (overrides `epochs`) + + expected_world_size: int # Expected # of GPUs =>> allows us to gate training on hardware + global_batch_size: int # Global Batch Size (divided across processes / world size) + per_device_batch_size: int # Per-Device Batch Size (per-process / individual GPU) + # =>> # of accumulation steps is auto-computed + + learning_rate: float # Peak Learning Rate (`lr_scheduler_type` sets warmup/decay) + weight_decay: float # Weight Decay for AdamW Optimizer + max_grad_norm: float # Max Grad Norm (for global gradient clipping) + lr_scheduler_type: str # LR Scheduler (usually: "constant" | "linear-warmup+cosine-decay") + warmup_ratio: float # Fraction of Steps to Warmup (for warmup LR schedulers) + + train_strategy: str # Train Strategy (default "fsdp-full-shard") + + # Enable Gradient/Activation Checkpointing (for the LLM Backbone) + enable_gradient_checkpointing: bool = True # Enable Gradient/Activation Checkpointing during Training + + # Mixed Precision Training via Torch Native AMP (`autocast`) + enable_mixed_precision_training: bool = True # Enable Traditional BF16 Mixed Precision + reduce_in_full_precision: bool = True # Accumulate/Reduce All-Gather Gradients in FP32 Full Precision + + # fmt: on + + +# === OpenVLA Training Configurations === + + +# = [8 GPU] Fast Iteration =>> SigLIP 224px + Bridge = +@dataclass +class Exp_SigLIP_224px_Bridge(VLAConfig): + vla_id: str = "siglip-224px+mx-bridge" + base_vlm: Union[str, Path] = "siglip-224px+7b" + + freeze_vision_backbone: bool = False + freeze_llm_backbone: bool = False + unfreeze_last_llm_layer: bool = False + + # Data Mixture Parameters + data_mix: str = "bridge" + shuffle_buffer_size: int = 256_000 + + # Optimization Parameters + epochs: int = 1000 + max_steps: Optional[int] = None + + expected_world_size: int = 8 + global_batch_size: int = 256 + per_device_batch_size: int = 32 + + learning_rate: float = 2e-5 + weight_decay: float = 0.0 + max_grad_norm: float = 1.0 + lr_scheduler_type: str = "constant" + warmup_ratio: float = 0.0 + + train_strategy: str = "fsdp-full-shard" + + +# = [8 GPU] SigLIP 224px Frozen Vision Backbone + Bridge = +@dataclass +class Exp_FreezeVIT_SigLIP_224px_Bridge(Exp_SigLIP_224px_Bridge): + vla_id: str = "siglip-224px-icy+mx-bridge" + base_vlm: Union[str, Path] = "siglip-224px+7b" + freeze_vision_backbone: bool = True + + +# = [8 GPU] Fast Iteration =>> DINO-SigLIP 224px + Bridge = +@dataclass +class Exp_DinoSigLIP_224px_Bridge(Exp_SigLIP_224px_Bridge): + vla_id: str = "prism-dinosiglip-224px+mx-bridge" + base_vlm: Union[str, Path] = "prism-dinosiglip-224px+7b" + + data_mix: str = "bridge" + + +# = [64 GPU] SigLIP 224px + OXE Magic Soup = +@dataclass +class Exp_SigLIP_224px_OXE_Magic_Soup(Exp_SigLIP_224px_Bridge): + vla_id: str = "siglip-224px+mx-oxe-magic-soup" + base_vlm: Union[str, Path] = "siglip-224px+7b" + + data_mix: str = "oxe_magic_soup" + + expected_world_size: int = 64 + global_batch_size: int = 2048 + per_device_batch_size: int = 32 + + +# = [64 GPU] DINO-SigLIP 224px + OXE Magic Soup++ = +@dataclass +class Exp_DinoSigLIP_224px_OXE_Magic_Soup_Plus(Exp_SigLIP_224px_Bridge): + vla_id: str = "prism-dinosiglip-224px+mx-oxe-magic-soup-plus" + base_vlm: Union[str, Path] = "prism-dinosiglip-224px+7b" + + # Note =>> We adopt two stages, training on a mixture including DROID for 70% of training, before resampling! + # data_mix: str = "oxe_magic_soup_plus" + data_mix: str = "oxe_magic_soup_plus_minus" + + expected_world_size: int = 64 + global_batch_size: int = 2048 + per_device_batch_size: int = 32 + + +# === OpenVLA Fine-tuning Configurations === + + +# = [8 GPU] SigLIP 224px + T-DROID = +@dataclass +class Exp_SigLIP_224px_TDROID_CarrotInBowl(Exp_SigLIP_224px_Bridge): + vla_id: str = "siglip-224px+mx-tdroid_carrot_in_bowl" + base_vlm: Union[str, Path] = "siglip-224px+7b" + + data_mix: str = "tdroid_carrot_in_bowl" + + +@dataclass +class Exp_SigLIP_224px_TDROID_PourCornInPot(Exp_SigLIP_224px_Bridge): + vla_id: str = "siglip-224px+mx-tdroid_pour_corn_in_pot" + base_vlm: Union[str, Path] = "siglip-224px+7b" + + data_mix: str = "tdroid_pour_corn_in_pot" + + +# = [8 GPU] SigLIP 224px + T-DROID -- Partial Finetuning = +@dataclass +class Exp_SigLIP_224px_Icy_TDROID_CarrotInBowl(Exp_SigLIP_224px_Bridge): + vla_id: str = "siglip-224px-icy+mx-tdroid_carrot_in_bowl" + base_vlm: Union[str, Path] = "siglip-224px+7b" + freeze_vision_backbone: bool = True + freeze_llm_backbone: bool = False + + data_mix: str = "tdroid_carrot_in_bowl" + + +@dataclass +class Exp_SigLIP_224px_LastLayer_TDROID_CarrotInBowl(Exp_SigLIP_224px_Bridge): + vla_id: str = "siglip-224px-last_layer+mx-tdroid_carrot_in_bowl" + base_vlm: Union[str, Path] = "siglip-224px+7b" + freeze_vision_backbone: bool = True + freeze_llm_backbone: bool = True + unfreeze_last_llm_layer: bool = True + + data_mix: str = "tdroid_carrot_in_bowl" + + +@dataclass +class Exp_SigLIP_224px_Sandwich_TDROID_CarrotInBowl(Exp_SigLIP_224px_Bridge): + vla_id: str = "siglip-224px-sandwich+mx-tdroid_carrot_in_bowl" + base_vlm: Union[str, Path] = "siglip-224px+7b" + freeze_vision_backbone: bool = False + freeze_llm_backbone: bool = True + unfreeze_last_llm_layer: bool = True + + data_mix: str = "tdroid_carrot_in_bowl" + + +# === [8 GPU] SigLIP 224px + FrankaWipe === +@dataclass +class Exp_SigLIP_224px_Droid_Wipe(Exp_SigLIP_224px_Bridge): + vla_id: str = "siglip-224px+mx-droid_wipe" + base_vlm: Union[str, Path] = "siglip-224px+7b" + + data_mix: str = "droid_wipe" + + +# === Define a VLA Registry Enum for Reference & Validation === +@unique +class VLARegistry(Enum): + # Sanity Check Configurations =>> BridgeV2 + SIGLIP_224PX_MX_BRIDGE = Exp_SigLIP_224px_Bridge + DINOSIGLIP_224PX_MX_BRIDGE = Exp_DinoSigLIP_224px_Bridge + + # SigLIP Frozen Backbone Experiment + FREEZE_SIGLIP_224PX_MX_BRIDGE = Exp_FreezeVIT_SigLIP_224px_Bridge + + # [OpenVLA v0.1 7B] SigLIP 224px + OXE Magic Soup + SIGLIP_224PX_MX_OXE_MAGIC_SOUP = Exp_SigLIP_224px_OXE_Magic_Soup + + # [OpenVLA 7B] DINO + SigLIP 224px + OXE Magic Soup++ + DINOSIGLIP_224PX_MX_OXE_MAGIC_SOUP_PLUS = Exp_DinoSigLIP_224px_OXE_Magic_Soup_Plus + + # === TDROID Fine-tuning Configs === + SIGLIP_224PX_MX_TDROID_CARROT_IN_BOWL = Exp_SigLIP_224px_TDROID_CarrotInBowl + SIGLIP_224PX_MX_TDROID_POUR_CORN_IN_POT = Exp_SigLIP_224px_TDROID_PourCornInPot + + SIGLIP_224PX_ICY_MX_TDROID_CARROT_IN_BOWL = Exp_SigLIP_224px_Icy_TDROID_CarrotInBowl + SIGLIP_224PX_LASTLAYER_MX_TDROID_CARROT_IN_BOWL = Exp_SigLIP_224px_LastLayer_TDROID_CarrotInBowl + SIGLIP_224PX_SANDWICH_MX_TDROID_CARROT_IN_BOWL = Exp_SigLIP_224px_Sandwich_TDROID_CarrotInBowl + + # === DROID Fine-tuning Configs === + SIGLIP_224PX_MX_DROID_WIPE = Exp_SigLIP_224px_Droid_Wipe + + @property + def vla_id(self) -> str: + return self.value.vla_id + + +# Register VLAs in Choice Registry +for vla_variant in VLARegistry: + VLAConfig.register_subclass(vla_variant.vla_id, vla_variant.value) diff --git a/prismatic/extern/__init__.py b/prismatic/extern/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/prismatic/extern/hf/__init__.py b/prismatic/extern/hf/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/prismatic/extern/hf/configuration_prismatic.py b/prismatic/extern/hf/configuration_prismatic.py new file mode 100644 index 0000000000000000000000000000000000000000..c2625753c4da1a6ef274a02645d4086bc7a7fb2b --- /dev/null +++ b/prismatic/extern/hf/configuration_prismatic.py @@ -0,0 +1,140 @@ +""" +configuration_prismatic.py + +HuggingFace-style configuration definition for Prismatic VLMs, inheriting from `transformers.PretrainedConfig`. +Default configuration specifies `siglip-224px+7b`. +""" + +from typing import Any, Dict, List, Optional + +from transformers import PretrainedConfig +from transformers.models.auto import CONFIG_MAPPING + +# === Utilities for Mapping Prismatic names to HF names === +# fmt: off +VISION_BACKBONE_TO_RESOLUTION: Dict[str, List[int]] = { + "clip-vit-l": [224], "siglip-vit-so400m": [224], "dinov2-vit-l": [224], "in1k-vit-l": [224], + + "clip-vit-l-336px": [336], + "siglip-vit-so400m-384px": [384], + + "dinoclip-vit-l-336px": [336, 336], + "dinosiglip-vit-so-224px": [224, 224], + "dinosiglip-vit-so-384px": [384, 384], +} +VISION_BACKBONE_TO_TIMM_ID: Dict[str, List[str]] = { + "clip-vit-l": ["vit_large_patch14_clip_224.openai"], + "clip-vit-l-336px": ["vit_large_patch14_clip_336.openai"], + + "dinov2-vit-l": ["vit_large_patch14_reg4_dinov2.lvd142m"], + "in1k-vit-l": ["vit_large_patch16_224.augreg_in21k_ft_in1k"], + + "siglip-vit-so400m": ["vit_so400m_patch14_siglip_224"], + "siglip-vit-so400m-384px": ["vit_so400m_patch14_siglip_384"], + + "dinoclip-vit-l-336px": ["vit_large_patch14_reg4_dinov2.lvd142m", "vit_large_patch14_clip_336.openai"], + "dinosiglip-vit-so-224px": ["vit_large_patch14_reg4_dinov2.lvd142m", "vit_so400m_patch14_siglip_224"], + "dinosiglip-vit-so-384px": ["vit_large_patch14_reg4_dinov2.lvd142m", "vit_so400m_patch14_siglip_384"], +} +TIMM_OVERRIDE_ACT_LAYER: Dict[str, List[Optional[str]]] = { + "clip-vit-l": ["quick_gelu"], "clip-vit-l-336px": ["quick_gelu"], + "dinov2-vit-l": [None], "in1k-vit-l": [None], + "siglip-vit-so400m": [None], "siglip-vit-so400m-384px": [None], + "dinoclip-vit-l-336px": [None, "quick_gelu"], + "dinosiglip-vit-so-224px": [None, None], "dinosiglip-vit-so-384px": [None, None] +} + +LLM_BACKBONE_TO_HF_PATH = { + "llama2-7b-pure": "meta-llama/Llama-2-7b-hf", "llama2-13b-pure": "meta-llama/Llama-2-13b-hf", + "llama2-7b-chat": "meta-llama/Llama-2-7b-chat-hf", "llama2-13b-chat": "meta-llama/Llama-2-13b-chat-hf", + + "vicuna-v15-7b": "lmsys/vicuna-7b-v1.5", "vicuna-v15-13b": "lmsys/vicuna-13b-v1.5", + + "mistral-v0.1-7b-pure": "mistralai/Mistral-7B-v0.1", + "mistral-v0.1-7b-instruct": "mistralai/Mistral-7B-Instruct-v0.1", + + "phi-2-3b": "microsoft/phi-2", +} +LLM_BACKBONE_TO_HF_METACLASS = { + "llama2-7b-pure": "llama", "llama2-13b-pure": "llama", "llama2-7b-chat": "llama", "llama2-13b-chat": "llama", + "vicuna-v15-7b": "llama", "vicuna-v15-13b": "llama", + + "mistral-v0.1-7b-pure": "mistral", "mistral-v0.1-7b-instruct": "mistral", + + "phi-2-3b": "phi", +} + +VALID_VISION_BACKBONES = set(VISION_BACKBONE_TO_RESOLUTION.keys()) +VALID_LLM_BACKBONES = set(LLM_BACKBONE_TO_HF_PATH) +# fmt: on + + +class PrismaticConfig(PretrainedConfig): + model_type: str = "prismatic" + is_composition: bool = False + + def __init__( + self, + vision_backbone_id: str = "siglip-vit-so400m", + llm_backbone_id: str = "vicuna-v15-7b", + arch_specifier: str = "no-align+gelu-mlp", + use_fused_vision_backbone: Optional[bool] = None, + image_resize_strategy: str = "letterbox", + text_config: Optional[Dict[str, Any]] = None, + llm_max_length: int = 2048, + pad_token_id: int = 32000, + pad_to_multiple_of: int = 64, + output_projector_states: bool = False, + **kwargs: str, + ) -> None: + if vision_backbone_id not in VALID_VISION_BACKBONES: + raise ValueError(f"Vision backbone `{vision_backbone_id}` not in {VALID_VISION_BACKBONES = }") + + if llm_backbone_id not in VALID_LLM_BACKBONES: + raise ValueError(f"LLM backbone `{llm_backbone_id}` not in {VALID_LLM_BACKBONES = }") + + # Set Prismatic Configuration Fields + self.vision_backbone_id = vision_backbone_id + self.llm_backbone_id = llm_backbone_id + self.arch_specifier = arch_specifier + self.output_projector_states = output_projector_states + + # [Contract] All vision backbone parameters are lists =>> supports fused backbones with different preprocessing + self.use_fused_vision_backbone = ( + use_fused_vision_backbone + if use_fused_vision_backbone is not None + else any(self.vision_backbone_id.startswith(v) for v in ["dinoclip", "dinosiglip"]) + ) + + self.timm_model_ids = VISION_BACKBONE_TO_TIMM_ID[self.vision_backbone_id] + self.timm_override_act_layers = TIMM_OVERRIDE_ACT_LAYER[self.vision_backbone_id] + self.image_sizes = VISION_BACKBONE_TO_RESOLUTION[self.vision_backbone_id] + self.image_resize_strategy = image_resize_strategy + + self.hf_llm_id = LLM_BACKBONE_TO_HF_PATH[self.llm_backbone_id] + self.llm_max_length = llm_max_length + self.pad_token_id, self.pad_to_multiple_of = pad_token_id, pad_to_multiple_of + + # [IMPORTANT] HF Utilities actually look for a `text_config` field... we need to use that specific naming! + self.text_config = ( + CONFIG_MAPPING[LLM_BACKBONE_TO_HF_METACLASS[self.llm_backbone_id]](**text_config) + if text_config is not None + else CONFIG_MAPPING[LLM_BACKBONE_TO_HF_METACLASS[self.llm_backbone_id]]() + ) + + # Dispatch **kwargs to super() =>> note that `pad_token_id` collides, so we pass it in here as well... + super().__init__(pad_token_id=pad_token_id, **kwargs) + + +class OpenVLAConfig(PrismaticConfig): + model_type: str = "openvla" + + def __init__( + self, + norm_stats: Optional[Dict[str, Dict[str, Dict[str, Dict[str, List[float]]]]]] = None, + n_action_bins: int = 256, + **kwargs: str, + ) -> None: + self.norm_stats, self.n_action_bins = norm_stats, n_action_bins + + super().__init__(**kwargs) diff --git a/prismatic/extern/hf/modeling_prismatic.py b/prismatic/extern/hf/modeling_prismatic.py new file mode 100644 index 0000000000000000000000000000000000000000..19ddca82714f7f6ec5f64d1e48b3c9ed9096c4b9 --- /dev/null +++ b/prismatic/extern/hf/modeling_prismatic.py @@ -0,0 +1,1085 @@ +""" +modeling_prismatic.py + +Core HuggingFace-style PrismaticPreTrainedModel and PrismaticForConditionalGeneration class definitions. +Inherits from the default `transformers.PretrainedModel`. Meant to be standalone and self-contained, +but exactly replicate the logic in `prismatic.models.vlms.prismatic.py`. +""" + +import logging +from dataclasses import dataclass +from functools import partial +from typing import Any, Callable, ClassVar, Dict, List, Optional, Tuple, Union + +import numpy as np +import timm +import tokenizers +import torch +import torch.nn as nn +import transformers +from timm.models.vision_transformer import LayerScale +from transformers import AutoModelForCausalLM, PretrainedConfig, PreTrainedModel +from transformers.modeling_outputs import ModelOutput + +from prismatic.training.train_utils import ( + get_current_action_mask, + get_next_actions_mask, +) +from prismatic.vla.constants import ( + ACTION_DIM, + ACTION_PROPRIO_NORMALIZATION_TYPE, + ACTION_TOKEN_BEGIN_IDX, + IGNORE_INDEX, + NUM_ACTIONS_CHUNK, + STOP_INDEX, + NormalizationType, +) + +from .configuration_prismatic import OpenVLAConfig, PrismaticConfig + +# Set up logger +logger = logging.getLogger(__name__) + + +# === Utility Functions for Monkey-Patching === +def unpack_tuple(fn: Callable[[Any], Tuple[Any]]) -> Callable[[Any], Any]: + def wrapper(*args: Any, **kwargs: Any) -> Any: + result = fn(*args, **kwargs) + return result[0] if isinstance(result, tuple) else result + + return wrapper + + +# HF Transformers overwrites parameters with names containing `gamma`; we're going to patch VisionBackbone.LayerScale. +# =>> TIMM :: https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/vision_transformer.py#L109 +# =>> Transformers :: https://github.com/huggingface/transformers/blob/main/src/transformers/modeling_utils.py#L3960 +def _ls_new_forward(self, x: torch.Tensor) -> torch.Tensor: + return x.mul_(self.scale_factor) if self.inplace else x * self.scale_factor + + +def ls_apply_patch(ls_module: LayerScale): + ls_module.scale_factor = nn.Parameter(ls_module.gamma.clone()) + ls_module.forward = _ls_new_forward.__get__(ls_module, LayerScale) + del ls_module.gamma + + +# === Prismatic Vision Backbone (nn.Module) Definitions (w/ Fused Backbone Support) === +class PrismaticVisionBackbone(nn.Module): + """ + Vision backbone for Prismatic models that handles image feature extraction. + + Supports both single backbone (e.g., SigLIP) and fused backbone (e.g., SigLIP + DINOv2) configurations. + For fused backbones, features from both models are concatenated along the feature dimension. + """ + + def __init__( + self, + use_fused_vision_backbone: bool, + image_sizes: List[int], + timm_model_ids: List[str], + timm_override_act_layers: List[Optional[str]], + ) -> None: + """ + Initialize the vision backbone. + + Args: + use_fused_vision_backbone: Whether to use two backbones and fuse their features + image_sizes: List of image sizes for each backbone + timm_model_ids: List of TIMM model IDs to use for each backbone + timm_override_act_layers: List of activation layer overrides for each backbone + """ + super().__init__() + self.use_fused_vision_backbone = use_fused_vision_backbone + self.num_images_in_input = 1 # Default value, can be overridden later + + # Validate number of (fused) vision backbones + if len(timm_model_ids) > 2: + raise ValueError("Prismatic models only support up to 2 (fused) vision backbones!") + + # Create primary featurizer + self.featurizer = self._create_featurizer( + model_id=timm_model_ids[0], img_size=image_sizes[0], act_layer=timm_override_act_layers[0] + ) + self.embed_dim = self.featurizer.embed_dim + + # Create secondary featurizer if using fused backbone + if self.use_fused_vision_backbone: + self.fused_featurizer = self._create_featurizer( + model_id=timm_model_ids[1], img_size=image_sizes[1], act_layer=timm_override_act_layers[1] + ) + self.embed_dim += self.fused_featurizer.embed_dim + + # Patch LayerScale modules for HF compatibility + self._patch_layer_scales() + + def _create_featurizer(self, model_id: str, img_size: int, act_layer: Optional[str]) -> nn.Module: + """ + Create a TIMM-based featurizer model with appropriate configurations. + + Args: + model_id: The TIMM model ID to load + img_size: Input image size for the model + act_layer: Override for the activation layer type + + Returns: + A configured featurizer model + """ + featurizer = timm.create_model( + model_id, + pretrained=False, + num_classes=0, + img_size=img_size, + act_layer=act_layer, + ) + + # Monkey-patch the forward function to extract the second-to-last layer features + num_blocks = len(featurizer.blocks) + featurizer.forward = unpack_tuple(partial(featurizer.get_intermediate_layers, n={num_blocks - 2})) + + return featurizer + + def _patch_layer_scales(self) -> None: + """ + Patch all LayerScale modules to be compatible with HF's parameter naming. + + HF Transformers overwrites parameters with names containing 'gamma', + so we need to rename and modify the forward method. + """ + # Patch primary featurizer + for module in self.featurizer.modules(): + if isinstance(module, LayerScale): + ls_apply_patch(module) + + # Patch secondary featurizer if it exists + if self.use_fused_vision_backbone: + for module in self.fused_featurizer.modules(): + if isinstance(module, LayerScale): + ls_apply_patch(module) + + def get_num_patches(self) -> int: + """ + Returns the number of vision patches output by the vision backbone. + + Returns: + Number of patches per image + """ + return self.featurizer.patch_embed.num_patches + + def get_num_images_in_input(self) -> int: + """ + Returns the number of input images for the vision backbone. + + Returns: + Number of images expected in the input + """ + return self.num_images_in_input + + def set_num_images_in_input(self, num_images_in_input: int) -> None: + """ + Sets the number of input images for the vision backbone. + + Args: + num_images_in_input: Number of images to expect in the input + """ + self.num_images_in_input = num_images_in_input + + def forward(self, pixel_values: torch.Tensor) -> torch.Tensor: + """ + Implements the forward pass for the vision backbone. + + If `self.use_fused_vision_backbone == True`, uses both SigLIP and DINOv2 transformers to extract visual features + (otherwise uses SigLIP only). Allows multi-image inputs (but only for fused vision backbone). + + Args: + pixel_values (torch.Tensor): Pixels for input image(s), (B, C, H, W). + """ + if self.num_images_in_input == 1: + if not self.use_fused_vision_backbone: + return self.featurizer(pixel_values) + + # Split `pixel_values :: [bsz, 2 * 3, resolution, resolution]` =>> featurize =>> channel stack + img, img_fused = torch.split(pixel_values, [3, 3], dim=1) + patches, patches_fused = self.featurizer(img), self.fused_featurizer(img_fused) + + return torch.cat([patches, patches_fused], dim=2) + + else: + assert self.use_fused_vision_backbone, "Multi-image inputs require using fused backbone!" + + # Split `pixel_values` into individual images (each with 6 channels: 3 for SigLIP + 3 for DINOv2) + images = torch.split(pixel_values, [6] * self.num_images_in_input, dim=1) + + # Process each image and collect patches + all_patches = [] + for img in images: + # Split each image further into two stacks of channels (each with 3 channels) + img_regular, img_fused = torch.split(img, [3, 3], dim=1) + + # Get patches from both SigLIP and DINOv2 vision transformers + patches = self.featurizer(img_regular) + patches_fused = self.fused_featurizer(img_fused) + + # Concatenate SigLIP and DINOv2 patches along the hidden dimension + combined_patches = torch.cat([patches, patches_fused], dim=2) + all_patches.append(combined_patches) + + # Concatenate all patches along the patch dimension + return torch.cat(all_patches, dim=1) + + +# === Prismatic Projector (nn.Module) Definitions === +class PrismaticProjector(nn.Module): + def __init__(self, use_fused_vision_backbone: bool, vision_dim: int, llm_dim: int) -> None: + super().__init__() + self.use_fused_vision_backbone = use_fused_vision_backbone + self.vision_dim, self.llm_dim = vision_dim, llm_dim + + # Switch on `use_fused_vision_backbone` =>> use slightly different MLPs and projection factors! + if not self.use_fused_vision_backbone: + self.fc1 = nn.Linear(self.vision_dim, self.llm_dim, bias=True) + self.fc2 = nn.Linear(self.llm_dim, self.llm_dim, bias=True) + self.act_fn1 = nn.GELU() + else: + initial_projection_dim = 4 * vision_dim + self.fc1 = nn.Linear(self.vision_dim, initial_projection_dim, bias=True) + self.fc2 = nn.Linear(initial_projection_dim, self.llm_dim, bias=True) + self.fc3 = nn.Linear(self.llm_dim, self.llm_dim, bias=True) + self.act_fn1 = nn.GELU() + self.act_fn2 = nn.GELU() + + def forward(self, img_patches: torch.Tensor) -> torch.Tensor: + if not self.use_fused_vision_backbone: + projected_features = self.fc1(img_patches) + projected_features = self.act_fn1(projected_features) + projected_features = self.fc2(projected_features) + else: + projected_features = self.fc1(img_patches) + projected_features = self.act_fn1(projected_features) + projected_features = self.fc2(projected_features) + projected_features = self.act_fn2(projected_features) + projected_features = self.fc3(projected_features) + + return projected_features + + +# === Main HF Class Definitions === +@dataclass +class PrismaticCausalLMOutputWithPast(ModelOutput): + """Base class for Prismatic casual (visually-conditioned) language model outputs; also exposes visual features.""" + + loss: Optional[torch.FloatTensor] = None + logits: torch.FloatTensor = None + past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None + hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None + attentions: Optional[Tuple[torch.FloatTensor]] = None + + # Additions for VLMs + projector_features: Optional[torch.FloatTensor] = None + + +class PrismaticPreTrainedModel(PreTrainedModel): + config_class: PretrainedConfig = PrismaticConfig + base_model_prefix: str = "model" + supports_gradient_checkpointing: bool = True + + _no_split_modules: ClassVar[List[str]] = ["PrismaticProjector"] + _skip_keys_device_placement: str = "past_key_values" + _supports_flash_attn_2: bool = True + + def _init_weights(self, module: nn.Module) -> None: + # Important :: this HF ported version is *not* meant for training from scratch; only inference and fine-tuning! + # => As such, this init_weights code is not correct; if training VLMs from scratch, use the main codebase at + # https://github.com/TRI-ML/prismatic-vlms + std = ( + self.config.initializer_range + if hasattr(self.config, "initializer_range") + else self.config.text_config.initializer_range + ) + + if hasattr(module, "class_embedding"): + module.class_embedding.data.normal_(mean=0.0, std=std) + + if isinstance(module, (nn.Linear, nn.Conv2d)): + module.weight.data.normal_(mean=0.0, std=std) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=std) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + + @property + def _supports_sdpa(self) -> bool: + """Check LLM supports SDPA Attention""" + return self.language_model._supports_sdpa + + +class PrismaticForConditionalGeneration(PrismaticPreTrainedModel): + def __init__(self, config: PrismaticConfig) -> None: + super().__init__(config) + + # [Validation] Lightweight Validate on `config` Fields + Dependency Versions + if config.use_fused_vision_backbone is None: + raise ValueError("Missing config field `use_fused_vision_backbone`") + + if timm.__version__ not in {"0.9.10", "0.9.11", "0.9.12", "0.9.16"}: + raise NotImplementedError( + "TIMM Version must be >= 0.9.10 and < 1.0.0 (breaking); please raise a GitHub Issue " + "if you urgently need support for latest TIMM versions." + ) + + if (transformers.__version__ != "4.40.1") or (tokenizers.__version__ != "0.19.1"): + logger.warning( + f"Expected `transformers==4.40.1` and `tokenizers==0.19.1` but got " + f"`transformers=={transformers.__version__}` and `tokenizers=={tokenizers.__version__}`; " + f"there might be inference-time regressions due to dependency changes. If in doubt, please" + f"use the above versions." + ) + + # Instantiate PrismaticVisionBackbone (w/ Potential Fused Backbone) + self.vision_backbone = PrismaticVisionBackbone( + config.use_fused_vision_backbone, config.image_sizes, config.timm_model_ids, config.timm_override_act_layers + ) + + # Create Multimodal Projector + self.projector = PrismaticProjector( + config.use_fused_vision_backbone, + vision_dim=self.vision_backbone.embed_dim, + llm_dim=config.text_config.hidden_size, + ) + + # Instantiate LLM Backbone + self.language_model = AutoModelForCausalLM.from_config( + config.text_config, attn_implementation=config._attn_implementation + ) + self.vocab_size = config.text_config.vocab_size + self.pad_token_id = config.pad_token_id + self.llm_dim = config.text_config.hidden_size + + # HF Boilerplate =>> initializes weights via `_init_weights()` and sets gradient checkpointing + self.post_init() + + # === `PreTrainedModel` Boilerplate === + def get_input_embeddings(self) -> nn.Module: + return self.language_model.get_input_embeddings() + + def set_input_embeddings(self, value: nn.Module) -> None: + self.language_model.set_input_embeddings(value) + + def get_output_embeddings(self) -> nn.Module: + return self.language_model.get_output_embeddings() + + def set_output_embeddings(self, new_embeddings: nn.Module) -> None: + self.language_model.set_output_embeddings(new_embeddings) + + def get_decoder(self) -> nn.Module: + return self.language_model.get_decoder() + + def set_decoder(self, decoder: nn.Module) -> None: + self.language_model.set_decoder(decoder) + + def tie_weights(self) -> None: + self.language_model.tie_weights() # Note: `Llama-2` and `Mistral` don't tie weights (no-op) + + def resize_token_embeddings( + self, new_num_tokens: Optional[int] = None, pad_to_multiple_of: Optional[int] = None + ) -> nn.Embedding: + updated_embeddings = self.language_model.resize_token_embeddings(new_num_tokens, pad_to_multiple_of) + + # Update config/instance variables + self.config.text_config.vocab_size = updated_embeddings.num_embeddings + self.vocab_size = updated_embeddings.num_embeddings + + return updated_embeddings + + def _replace_input_embeddings(self, input_embeddings, all_actions_mask, noisy_action_features): + """ + Replace embeddings in input_embeddings at positions where all_actions_mask is True + with embeddings from noisy_action_features, using vectorized operations. + + Args: + input_embeddings: Tensor of shape (B, S, D) + all_actions_mask: Boolean tensor of shape (B, S) + noisy_action_features: Tensor of shape (B, K, D) where K is the number of True values in mask per sample + + Returns: + Modified input_embeddings tensor + """ + # Clone input to avoid modifying the original tensor + new_input_embeddings = input_embeddings.clone() + + # Create a tensor with the same shape of input_embeddings to hold the noisy action features + repositioned_noisy_action_features = torch.zeros_like(input_embeddings) + + # Create batch indices for splicing + batch_indices = torch.arange(input_embeddings.shape[0], device=input_embeddings.device) + batch_indices = batch_indices.unsqueeze(1).expand(-1, noisy_action_features.shape[1]) + + # Get indices where mask is True for each sample + masked_indices = torch.stack([torch.where(mask)[0] for mask in all_actions_mask]) + + # Move the noisy action features into their correct positions + repositioned_noisy_action_features[batch_indices, masked_indices] = noisy_action_features + + # Combine original input embeddings and noisy action embeddings using the mask + new_input_embeddings = torch.where( + all_actions_mask.unsqueeze(-1), repositioned_noisy_action_features, new_input_embeddings + ) + + return new_input_embeddings + + def _process_action_masks(self, labels): + """Helper to get action masks from labels""" + current_action_mask = get_current_action_mask(labels) + next_actions_mask = get_next_actions_mask(labels) + all_actions_mask = current_action_mask | next_actions_mask # (B, seq_len) + return all_actions_mask + + def _process_vision_features(self, pixel_values, language_embeddings=None, use_film=False): + """Process vision features with optional FiLM conditioning""" + if use_film: + # FiLM: Infuse language inputs into visual features + patch_features = self.vision_backbone(pixel_values, language_embeddings) # (bsz, 256 * num_images, D) + else: + patch_features = self.vision_backbone(pixel_values) # (bsz, 256 * num_images, D) + + # Project patch embeddings into language embedding space + return self.projector(patch_features) + + def _process_proprio_features(self, projected_patch_embeddings, proprio, proprio_projector): + """Process proprioceptive features and append to vision features""" + if proprio_projector is not None and proprio is not None: + # projected_patch_embeddings: (bsz, num_patches * num_images, llm_dim) + # proprio: (bsz, proprio_dim) or (propro_dim,) + proprio = proprio.reshape(projected_patch_embeddings.shape[0], -1) # (bsz, proprio_dim) + proprio_features = proprio_projector(proprio) # (bsz, llm_dim) + proprio_features = proprio_features.unsqueeze(dim=1) # (bsz, 1, llm_dim) + # For simplicity, just append proprio token to the end of projected vision patch tokens + return torch.cat((projected_patch_embeddings, proprio_features), dim=1) + return projected_patch_embeddings + + def _build_multimodal_attention(self, input_embeddings, projected_patch_embeddings, attention_mask): + """Build multimodal embeddings and attention mask""" + # Update attention mask + projected_patch_attention_mask = None + if attention_mask is not None: + projected_patch_attention_mask = torch.full( + (projected_patch_embeddings.shape[0], projected_patch_embeddings.shape[1]), + fill_value=True, + dtype=attention_mask.dtype, + device=attention_mask.device, + ) + + # Build multimodal embeddings & attention mask; insert embeddings after token (1:) + multimodal_embeddings = torch.cat( + [input_embeddings[:, :1, :], projected_patch_embeddings, input_embeddings[:, 1:, :]], dim=1 + ) + + multimodal_attention_mask = None + if attention_mask is not None: + multimodal_attention_mask = torch.cat( + [attention_mask[:, :1], projected_patch_attention_mask, attention_mask[:, 1:]], dim=1 + ) + + return multimodal_embeddings, multimodal_attention_mask + + def _build_multimodal_labels(self, labels, projected_patch_embeddings): + """Build multimodal labels with IGNORE_INDEX for patch embeddings""" + if labels is not None: + projected_patch_labels = torch.full( + (projected_patch_embeddings.shape[0], projected_patch_embeddings.shape[1]), + fill_value=IGNORE_INDEX, + dtype=labels.dtype, + device=labels.device, + ) + return torch.cat([labels[:, :1], projected_patch_labels, labels[:, 1:]], dim=1) + return None + + # === Core Prismatic VLM `forward()` Logic === + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + pixel_values: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + output_projector_features: Optional[bool] = None, + return_dict: Optional[bool] = None, + proprio=None, + proprio_projector=None, + noisy_actions=None, + noisy_action_projector=None, + diffusion_timestep_embeddings=None, + use_film: bool = False, + ) -> Union[Tuple, PrismaticCausalLMOutputWithPast]: + """Run a forward pass through the VLM, returning a PrismaticCausalLMOutputWithPast instance.""" + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + output_projector_features = output_projector_features if output_projector_features is not None else False + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # Respect `use_cache` only if not training (even if `gradient_checkpointing` is off) + use_cache = use_cache and not self.training + + # Instantiate Placeholder for Projector Features + projected_patch_embeddings = None + + # === Handle Generation with Cache (`input_ids.shape[1] == 1`) =>> requires `past_keys_values` === + if input_ids.shape[1] == 1: + assert input_ids.shape[0] == 1, "Generation is only currently supported for batch size of 1!" + assert past_key_values is not None, "You must provide `past_key_values` during cached generation!" + assert labels is None, "Unexpected key `labels` provided during cached generation!" + + language_model_output = self.language_model( + input_ids=input_ids, + attention_mask=None, + position_ids=None, + past_key_values=past_key_values, + inputs_embeds=None, + labels=None, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + # === Handle Unimodal Forward === + elif pixel_values is None: + assert (input_ids is not None) and (inputs_embeds is None), "Missing `input_ids` in language-only forward!" + assert past_key_values is None, "Unexpected key `past_key_values` provided during language-only forward!" + + language_model_output = self.language_model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=None, + past_key_values=None, + inputs_embeds=None, + labels=labels, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + # === Handle Multimodal Forward === + elif (input_ids.shape[0] == pixel_values.shape[0]) or (inputs_embeds.shape[0] == pixel_values.shape[0]): + assert past_key_values is None, "Unexpected key `past_key_values` provided during multimodal forward!" + + # Get input embeddings (from language model embeddings) + input_embeddings = self.get_input_embeddings()(input_ids) # (B, seq_len, D) + + # Extract action masks + all_actions_mask = self._process_action_masks(labels) + + # Extract the language portion of the input embeddings (i.e. remove the action tokens portion) + language_embeddings = input_embeddings[~all_actions_mask].reshape( + input_embeddings.shape[0], -1, input_embeddings.shape[2] + ) # (B, lang_seq_len, llm_dim) + + # Get visual features + projected_patch_embeddings = self._process_vision_features(pixel_values, language_embeddings, use_film) + + # Add proprioceptive state if provided + projected_patch_embeddings = self._process_proprio_features( + projected_patch_embeddings, proprio, proprio_projector + ) + + # [Diffusion] Add diffusion timestep embedding if provided + if diffusion_timestep_embeddings is not None: + # For simplicity, just append diffusion timestep embedding to the end of projected vision patch tokens + projected_patch_embeddings = torch.cat( + (projected_patch_embeddings, diffusion_timestep_embeddings), dim=1 + ) + + # Process action embeddings + if noisy_actions is not None: + # Get mask corresponding to all action tokens + all_actions_mask = self._process_action_masks(labels) + + # Reshape noisy actions into individual action tokens + # noisy_actions: (B, chunk_len, action_dim) -> (B, chunk_len * action_dim, 1) + B = noisy_actions.shape[0] + noisy_actions = noisy_actions.reshape(B, -1).unsqueeze(-1) + + # Project noisy action tokens into language model embedding space + noisy_action_features = noisy_action_projector(noisy_actions) # (B, chunk_len * action_dim, llm_dim) + + # Replace embeddings of the action tokens with noisy action embeddings + input_embeddings = self._replace_input_embeddings( + input_embeddings, all_actions_mask, noisy_action_features + ) + else: + # Replace the embeddings of the action tokens with zeros + # (Later on, the positional embeddings will be added to them) + all_actions_mask = all_actions_mask.unsqueeze(-1) # (B, seq_len, 1) + input_embeddings = input_embeddings * ~all_actions_mask + + # Build multimodal embeddings & attention mask + multimodal_embeddings, multimodal_attention_mask = self._build_multimodal_attention( + input_embeddings, projected_patch_embeddings, attention_mask + ) + + # Build labels for multimodal sequence if needed + multimodal_labels = self._build_multimodal_labels(labels, projected_patch_embeddings) + + # Dispatch to language model + language_model_output = self.language_model( + input_ids=None, + attention_mask=multimodal_attention_mask, + position_ids=None, + past_key_values=None, + inputs_embeds=multimodal_embeddings, + labels=multimodal_labels, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + # === Otherwise =>> Assume Invalid! === + elif (input_ids.shape[0] != pixel_values.shape[0]) or (inputs_embeds.shape[0] != pixel_values.shape[0]): + raise ValueError("Non-homogenous batch of (text, image) input -- forward() does not support mixed batches!") + + else: + raise ValueError( + "Invalid PrismaticForConditionalGeneration `forward()` call with provided arguments:\n" + f"=> `input_ids` = {input_ids is not None}\n" + f"=> `attention_mask` = {attention_mask is not None}\n" + f"=> `pixel_values` = {pixel_values is not None}\n" + f"=> `labels` = {labels is not None}\n" + f"=> `input_embeds` = {inputs_embeds is not None}\n" + f"=> `past_key_values` = {past_key_values is not None}\n" + f"=> `use_cache` = {use_cache}" + ) + + # Unpack `language_model_output` and return PrismaticCausalLMOutputWithPast (or tuple if not `return_dict`) + if not return_dict: + if output_projector_features and (projected_patch_embeddings is not None): + return *language_model_output, projected_patch_embeddings + + return language_model_output + + return PrismaticCausalLMOutputWithPast( + loss=language_model_output.loss, + logits=language_model_output.logits, + past_key_values=language_model_output.past_key_values, + hidden_states=language_model_output.hidden_states, + attentions=language_model_output.attentions, + projector_features=projected_patch_embeddings, + ) + + # === GenerationMixin Methods === + def prepare_inputs_for_generation( + self, + input_ids: Optional[torch.Tensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + pixel_values: Optional[torch.FloatTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + **kwargs: str, + ) -> Dict[str, torch.Tensor]: + """Borrowed from `LlamaForCausalLM` and simplified for batch size = 1; mirrors original PrismaticVLM logic.""" + if ((input_ids is not None) and (input_ids.shape[0] > 1)) or ( + (inputs_embeds is not None) and (inputs_embeds.shape[0] > 1) + ): + raise ValueError("Generation with batch size > 1 is not currently supported!") + + # Handle `past_key_values` (cache) =>> assume `input_ids` just has unprocessed tokens + if past_key_values is not None: + input_ids = input_ids[:, -1:] + + # If `input_embeds` are passed, we only want to use them in the 1st generation step + if inputs_embeds is not None and past_key_values is None: + model_inputs = {"input_embeds": inputs_embeds} + else: + model_inputs = {"input_ids": input_ids} + + # Make sure `pixel_values` are preserved in `model_inputs` + model_inputs.update( + { + "attention_mask": attention_mask, + "pixel_values": pixel_values, + "past_key_values": past_key_values, + "use_cache": kwargs.get("use_cache"), + } + ) + + return model_inputs + + # Defer to Language Model (all handle this differently, with different return types) + def _reorder_cache(self, *args, **kwargs) -> Any: + return self.language_model._reorder_cache(*args, **kwargs) + + +class OpenVLAForActionPrediction(PrismaticForConditionalGeneration): + config_class: PretrainedConfig = OpenVLAConfig + + def __init__(self, config: OpenVLAConfig) -> None: + super().__init__(config) + self.norm_stats = config.norm_stats + + # Compute action bins + self.bins = np.linspace(-1, 1, config.n_action_bins) + self.bin_centers = (self.bins[:-1] + self.bins[1:]) / 2.0 + + # Compute vocab size for de-tokenization -- revert added "multiple of" + self.vocab_size = self.config.text_config.vocab_size - self.config.pad_to_multiple_of + + def _prepare_input_for_action_prediction(self, input_ids, attention_mask): + """Prepares input for action prediction by adding necessary tokens""" + # Add (ACTION_DIM * NUM_ACTIONS_CHUNK) placeholder tokens to input_ids to simulate action tokens + placeholder_action_token_ids = ( + torch.ones((input_ids.shape[0], ACTION_DIM * NUM_ACTIONS_CHUNK)).to(input_ids.device).to(input_ids.dtype) + ) + input_ids = torch.cat([input_ids, placeholder_action_token_ids], dim=-1) + + # Add stop token to sequence (needed in non-causal bi-directional self-attention, as it appears at train time) + stop_token_id = torch.ones((input_ids.shape[0], 1)).to(input_ids.device).to(input_ids.dtype) * STOP_INDEX + input_ids = torch.cat([input_ids, stop_token_id], dim=-1) + + # Extend the attention mask to fit the new shape of input + # Note: Only batch size == 1 supported right now + mask_extension = ( + torch.ones((attention_mask.shape[0], input_ids.shape[-1] - attention_mask.shape[-1])) + .to(attention_mask.device) + .to(attention_mask.dtype) + ) + attention_mask = torch.cat([attention_mask, mask_extension], dim=-1) + + return input_ids, attention_mask + + def _prepare_labels_for_action_prediction(self, labels, input_ids): + """Creates labels tensor for action prediction if not provided""" + # Extend labels tensor with fake action labels + ARBITRARY_ACTION_TOKEN_IDX = ACTION_TOKEN_BEGIN_IDX + 1 + labels_extension = ( + torch.ones((labels.shape[0], input_ids.shape[-1] - labels.shape[-1])).to(labels.device).to(labels.dtype) + * ARBITRARY_ACTION_TOKEN_IDX + ) + labels = torch.cat([labels, labels_extension], dim=-1) + + # Replace last label token with stop token + labels[:, -1] = STOP_INDEX + + return labels + + def _unnormalize_actions(self, normalized_actions, unnorm_key=None): + """Unnormalize actions using dataset statistics""" + action_norm_stats = self.get_action_stats(unnorm_key) + + if ACTION_PROPRIO_NORMALIZATION_TYPE == NormalizationType.BOUNDS: + mask = action_norm_stats.get("mask", np.ones_like(action_norm_stats["min"], dtype=bool)) + action_high, action_low = np.array(action_norm_stats["max"]), np.array(action_norm_stats["min"]) + elif ACTION_PROPRIO_NORMALIZATION_TYPE == NormalizationType.BOUNDS_Q99: + mask = action_norm_stats.get("mask", np.ones_like(action_norm_stats["q01"], dtype=bool)) + action_high, action_low = np.array(action_norm_stats["q99"]), np.array(action_norm_stats["q01"]) + else: + raise ValueError("Unsupported action/proprio normalization type detected!") + + actions = np.where( + mask, + 0.5 * (normalized_actions + 1) * (action_high - action_low + 1e-8) + action_low, + normalized_actions, + ) + + return actions + + def _run_diffusion_prediction( + self, + input_embeddings, + all_actions_mask, + noise, + action_head, + projected_patch_embeddings, + labels, + attention_mask, + NUM_PATCHES, + NUM_PROMPT_TOKENS, + noisy_action_projector, + ): + """Run diffusion-based action prediction""" + # Clone embedding for reuse in each timestep + orig_projected_patch_embeddings = projected_patch_embeddings.clone() + curr_noisy_actions = noise + + # Reverse diffusion: Iteratively denoise to generate action prediction + for t in action_head.noise_scheduler.timesteps: + # Get diffusion model's noise prediction (conditioned on VLA latent embedding, current noisy action + # embedding, and diffusion timestep embedding) + timesteps = torch.Tensor([t]).to(labels.device) + diffusion_timestep_embeddings = ( + action_head.time_encoder(timesteps).to(curr_noisy_actions.dtype).to(curr_noisy_actions.device) + ) # (B, llm_dim) + diffusion_timestep_embeddings = diffusion_timestep_embeddings.unsqueeze(1) # (B, 1, llm_dim) + + # [Diffusion] Replace the embeddings of the action tokens with noisy actions + # (Later on, the positional embeddings will be added to them) + + # For simplicity, append diffusion timestep embedding to the end of projected vision tokens + projected_patch_embeddings = torch.cat( + (orig_projected_patch_embeddings, diffusion_timestep_embeddings), dim=1 + ) + + # Reshape and project noisy actions into language embedding space + B = curr_noisy_actions.shape[0] + orig_curr_noisy_actions_shape = curr_noisy_actions.shape + curr_noisy_actions = curr_noisy_actions.reshape(B, -1).unsqueeze(-1) + noisy_action_features = noisy_action_projector(curr_noisy_actions) + curr_noisy_actions = curr_noisy_actions.reshape(orig_curr_noisy_actions_shape) + + # Replace action token embeddings with noisy action embeddings + input_embeddings = self._replace_input_embeddings( + input_embeddings.clone(), all_actions_mask, noisy_action_features + ) + + # Build multimodal embeddings and attention mask + multimodal_embeddings, multimodal_attention_mask = self._build_multimodal_attention( + input_embeddings, projected_patch_embeddings, attention_mask + ) + + # Forward pass through language model + language_model_output = self.language_model( + input_ids=None, + attention_mask=multimodal_attention_mask, + position_ids=None, + past_key_values=None, + inputs_embeds=multimodal_embeddings, + labels=None, + use_cache=None, + output_attentions=False, + output_hidden_states=True, + return_dict=True, + ) + + # Extract hidden states for action portion of response + last_hidden_states = language_model_output.hidden_states[-1] # (B, seq_len, D) + actions_hidden_states = last_hidden_states[ + :, + NUM_PATCHES + NUM_PROMPT_TOKENS : NUM_PATCHES + NUM_PROMPT_TOKENS + ACTION_DIM * NUM_ACTIONS_CHUNK, + :, + ] # (B, act_chunk_len, D) + + # Predict noise and update noisy actions: x_t -> x_{t-1} + noise_pred = action_head.predict_noise(actions_hidden_states) + curr_noisy_actions = action_head.noise_scheduler.step(noise_pred, t, curr_noisy_actions).prev_sample + + curr_noisy_actions = curr_noisy_actions.reshape(NUM_ACTIONS_CHUNK, ACTION_DIM) + + # Return final actions + return curr_noisy_actions.float().cpu().detach().numpy(), actions_hidden_states + + def _regression_or_discrete_prediction( + self, + input_embeddings, + all_actions_mask, + projected_patch_embeddings, + attention_mask, + labels, + NUM_PATCHES, + NUM_PROMPT_TOKENS, + action_head=None, + ): + """Run L1 regression-based continuous action prediction or discrete action tokens prediction.""" + # Zero out action token embeddings + all_actions_mask = all_actions_mask.unsqueeze(-1) # (B, seq_len, 1) + input_embeddings = input_embeddings * ~all_actions_mask + + # Build multimodal embeddings and attention mask + multimodal_embeddings, multimodal_attention_mask = self._build_multimodal_attention( + input_embeddings, projected_patch_embeddings, attention_mask + ) + + # Forward pass through language model + language_model_output = self.language_model( + input_ids=None, + attention_mask=multimodal_attention_mask, + position_ids=None, + past_key_values=None, + inputs_embeds=multimodal_embeddings, + labels=None, + use_cache=None, + output_attentions=False, + output_hidden_states=True, + return_dict=True, + ) + + # Extract hidden states for action tokens + last_hidden_states = language_model_output.hidden_states[-1] # (B, seq_len, D) + actions_hidden_states = last_hidden_states[ + :, + NUM_PATCHES + NUM_PROMPT_TOKENS : NUM_PATCHES + NUM_PROMPT_TOKENS + ACTION_DIM * NUM_ACTIONS_CHUNK, + :, + ] # (B, act_chunk_len, D) + + # Handle different prediction methods + if action_head is not None: + # L1 regression prediction + normalized_actions = action_head.predict_action(actions_hidden_states) + normalized_actions = normalized_actions.reshape(NUM_ACTIONS_CHUNK, ACTION_DIM) + normalized_actions = normalized_actions.float().cpu().detach().numpy() + else: + # Discrete token-based prediction + predicted_action_token_ids = ( + language_model_output.logits[ + :, + NUM_PATCHES + NUM_PROMPT_TOKENS : NUM_PATCHES + NUM_PROMPT_TOKENS + ACTION_DIM * NUM_ACTIONS_CHUNK, + ] + .argmax(dim=2) + .cpu() + .numpy() + ) + discretized_actions = self.vocab_size - predicted_action_token_ids + discretized_actions = np.clip(discretized_actions - 1, a_min=0, a_max=self.bin_centers.shape[0] - 1) + normalized_actions = self.bin_centers[discretized_actions] + normalized_actions = normalized_actions.reshape(NUM_ACTIONS_CHUNK, ACTION_DIM) + + return normalized_actions, actions_hidden_states + + def predict_action( + self, + input_ids: Optional[torch.LongTensor] = None, + unnorm_key: Optional[str] = None, + proprio=None, + proprio_projector=None, + action_head=None, + noisy_action_projector=None, + use_film: bool = False, + **kwargs: str, + ) -> np.ndarray: + """Predict actions from input sequence, with options for different prediction methods. + + Args: + input_ids: Input token ids + unnorm_key: Key for unnormalization statistics + proprio: Proprioceptive features + proprio_projector: Projector for proprioceptive features + action_head: Optional head for L1 regression or diffusion-based prediction + noisy_action_projector: Projector for noisy actions in diffusion-based prediction + use_film: Whether to use FiLM conditioning + **kwargs: Additional arguments including pixel_values and attention_mask + + Returns: + Tuple of (unnormalized_actions, action_hidden_states) + """ + # If the special empty token ('') does not already appear after the colon (':') token in the prompt + # (after "OUT:" or "ASSISTANT:"), insert it to match the inputs seen at training time + if not torch.all(input_ids[:, -1] == 29871): + input_ids = torch.cat( + (input_ids, torch.unsqueeze(torch.Tensor([29871]).long(), dim=0).to(input_ids.device)), dim=1 + ) + + pixel_values = kwargs["pixel_values"] + attention_mask = kwargs["attention_mask"] + + # Create fake labels tensor (needed for action mask) + labels = input_ids.clone() + labels[:] = IGNORE_INDEX + + # Get number of tokens in prompt (excluding the start token) + NUM_PROMPT_TOKENS = input_ids.shape[-1] - 1 # Subtract action tokens and stop token + + # Prepare inputs by adding necessary tokens + input_ids, attention_mask = self._prepare_input_for_action_prediction(input_ids, attention_mask) + + # Update labels tensor for action mask computation later + labels = self._prepare_labels_for_action_prediction(labels, input_ids) + + # Get input embeddings and action masks + input_embeddings = self.get_input_embeddings()(input_ids) + all_actions_mask = self._process_action_masks(labels) + + # Extract language embeddings + language_embeddings = input_embeddings[~all_actions_mask].reshape( + input_embeddings.shape[0], -1, input_embeddings.shape[2] + ) + + # Process vision features + projected_patch_embeddings = self._process_vision_features(pixel_values, language_embeddings, use_film) + + # Add proprioceptive features if provided + use_proprio = proprio_projector is not None and proprio is not None + if use_proprio: + proprio = torch.Tensor(proprio).to(projected_patch_embeddings.device, dtype=projected_patch_embeddings.dtype) + projected_patch_embeddings = self._process_proprio_features( + projected_patch_embeddings, proprio, proprio_projector + ) + + # Use diffusion if provided, otherwise use regression or discrete prediction + use_diffusion = noisy_action_projector is not None and hasattr(action_head, "noise_scheduler") + + # Calculate number of patches (including proprio token and/or diffusion timestep embedding if present) + NUM_PATCHES = self.vision_backbone.get_num_patches() * self.vision_backbone.get_num_images_in_input() + if use_proprio: + NUM_PATCHES += 1 + if use_diffusion: + NUM_PATCHES += 1 + + if use_diffusion: + # Sample random noise with shape equal to output action, used as the starting state for reverse diffusion + noise = torch.randn( + size=(1, NUM_ACTIONS_CHUNK, ACTION_DIM), device=input_embeddings.device, dtype=input_embeddings.dtype + ) + + # Run diffusion-based prediction + normalized_actions, actions_hidden_states = self._run_diffusion_prediction( + input_embeddings, + all_actions_mask, + noise, + action_head, + projected_patch_embeddings, + labels, + attention_mask, + NUM_PATCHES, + NUM_PROMPT_TOKENS, + noisy_action_projector, + ) + else: + # Run regression or discrete token-based prediction + normalized_actions, actions_hidden_states = self._regression_or_discrete_prediction( + input_embeddings, + all_actions_mask, + projected_patch_embeddings, + attention_mask, + labels, + NUM_PATCHES, + NUM_PROMPT_TOKENS, + action_head, + ) + + # Unnormalize predicted actions + actions = self._unnormalize_actions(normalized_actions, unnorm_key) + + return actions, actions_hidden_states + + @staticmethod + def _check_unnorm_key(norm_stats: Dict[str, Dict[str, Any]], unnorm_key: Optional[str]) -> str: + """Validate and resolve the unnormalization key for action statistics""" + if unnorm_key is None: + assert len(norm_stats) == 1, ( + f"Your model was trained on more than one dataset, " + f"please pass a `unnorm_key` from the following options to choose the statistics " + f"used for un-normalizing actions: {norm_stats.keys()}" + ) + unnorm_key = next(iter(norm_stats.keys())) + + assert unnorm_key in norm_stats, ( + f"The `unnorm_key` you chose is not in the set of available dataset statistics, " + f"please choose from: {norm_stats.keys()}" + ) + return unnorm_key + + def get_action_dim(self, unnorm_key: Optional[str] = None) -> int: + """Get the dimensionality of the policy's action space.""" + unnorm_key = self._check_unnorm_key(self.norm_stats, unnorm_key) + return len(self.norm_stats[unnorm_key]["action"]["min"]) + + def get_action_stats(self, unnorm_key: Optional[str] = None) -> Dict[str, Any]: + """Get all the logged statistics for the given dataset.""" + unnorm_key = self._check_unnorm_key(self.norm_stats, unnorm_key) + return self.norm_stats[unnorm_key]["action"] diff --git a/prismatic/extern/hf/processing_prismatic.py b/prismatic/extern/hf/processing_prismatic.py new file mode 100644 index 0000000000000000000000000000000000000000..c7ae121b87a8aa76ee63ea2cde9a033d264f4d06 --- /dev/null +++ b/prismatic/extern/hf/processing_prismatic.py @@ -0,0 +1,252 @@ +""" +processing_prismatic.py + +HuggingFace-style preprocessor definitions for Prismatic VLMs, inheriting from `ProcessorMixin`. Default configuration +specifies `siglip-224px+7b`. +""" + +from typing import Any, ClassVar, List, Optional, Tuple, Union + +import timm.data +import torch +import torchvision.transforms.functional as TVF +from PIL import Image +from torchvision.transforms import CenterCrop, Compose, Normalize, Resize, ToTensor +from transformers import PreTrainedTokenizerBase +from transformers.image_processing_utils import BatchFeature, ImageProcessingMixin +from transformers.processing_utils import ProcessorMixin +from transformers.tokenization_utils import PaddingStrategy, PreTokenizedInput, TextInput, TruncationStrategy +from transformers.utils import TensorType + + +# === Image Processing === +def letterbox_pad_transform(image: Image.Image, padding_fill_value: Tuple[int, int, int]) -> Image.Image: + """Given a PIL.Image, pad to square by adding a symmetric border around the height/width.""" + (w, h), max_wh = image.size, max(image.size) + horizontal_pad, vertical_pad = int((max_wh - w) / 2), int((max_wh - h) / 2) + padding = (horizontal_pad, vertical_pad, horizontal_pad, vertical_pad) + + return TVF.pad(image, padding, fill=padding_fill_value, padding_mode="constant") + + +class PrismaticImageProcessor(ImageProcessingMixin): + model_input_names: ClassVar[List[str]] = ["pixel_values"] + + def __init__( + self, + use_fused_vision_backbone: bool = False, + image_resize_strategy: str = "letterbox", + input_sizes: Optional[List[Tuple[int, int, int]]] = None, + interpolations: Optional[List[str]] = None, + means: Optional[List[Tuple[float, float, float]]] = None, + stds: Optional[List[Tuple[float, float, float]]] = None, + **kwargs: str, + ) -> None: + """ + Initialize a PrismaticImageProcessor as a wrapper around a torchvision transform; this transform will be + created by TIMM, and edited to follow our custom `image_resize_strategy` logic. + @param use_fused_vision_backbone: Boolean indicating single or fused (dual) vision backbone + @param image_resize_strategy: Prismatic image resize strategy in < resize-naive | resize-crop | letterbox > + @param input_size: [TIMM :: `data_cfg`] Input image size as tuple (channels, width, height) + @param interpolation: [TIMM :: `data_cfg`] Interpolation as string (default: "bicubic") + @param mean: [TIMM :: `data_cfg`] Normalization mean as float tuple (or two-tuple if `fused_backbone`) + @param std: [TIMM :: `data_cfg`] Normalization std as float tuple (or two-tuple if `fused_backbone`) + """ + self.use_fused_vision_backbone = use_fused_vision_backbone + self.image_resize_strategy = image_resize_strategy + + # Handle `None` default values + input_sizes = [(3, 224, 224)] if input_sizes is None else input_sizes + means = [(0.5, 0.5, 0.5)] if means is None else means + stds = [(0.5, 0.5, 0.5)] if stds is None else stds + + # TIMM `data_cfg` Parameters + self.input_sizes, self.interpolations, self.means, self.stds = input_sizes, interpolations, means, stds + + # Grab torchvision transforms via TIMM =>> need to parse for specific "functional" transform values! + self.tvf_resize_params, self.tvf_crop_params, self.tvf_normalize_params = [], [], [] + self.tvf_do_letterbox, self.tvf_letterbox_fill = False, None + + for idx in range(len(input_sizes)): + transform = timm.data.create_transform( + input_size=self.input_sizes[idx], + interpolation=self.interpolations[idx], + mean=self.means[idx], + std=self.stds[idx], + crop_pct=1.0, # Set to 1.0 to ignore cropping (initial Resize sets `input_size`) + crop_mode="center", # Default crop mode -- no-op when `crop_pct == 1.0` + is_training=False, # No image augmentations when loading the transform! + ) + + # [Validation] Ensure appropriate transform structure, expected sizes + if not ( + isinstance(transform, Compose) + and (len(transform.transforms) == 4) + and isinstance(transform.transforms[0], Resize) + and isinstance(transform.transforms[1], CenterCrop) + and isinstance(transform.transforms[2], ToTensor) + and isinstance(transform.transforms[3], Normalize) + and (transform.transforms[0].size == self.input_sizes[idx][-1]) + and (transform.transforms[1].size == self.input_sizes[idx][-2:]) + ): + raise ValueError(f"Unexpected TIMM image transformation structure/sizes: `{transform}`") + + # HF Image Processors *must* be JSON-serializable; as such, cannot have torchvision. as an attribute. + # => Instead, we're going to parse the transform and call "torchvision.transforms.functional" (`tvf`) + resize_t, crop_t, norm_t = transform.transforms[0], transform.transforms[1], transform.transforms[3] + self.tvf_resize_params.append( + { + "size": resize_t.size, + "interpolation": TVF.pil_modes_mapping[resize_t.interpolation], + "max_size": None, + "antialias": True, + } + ) + self.tvf_crop_params.append({"output_size": crop_t.size}) + self.tvf_normalize_params.append( + { + "mean": norm_t.mean.float().numpy().tolist(), + "std": norm_t.std.float().numpy().tolist(), + "inplace": False, + } + ) + self.tvf_do_letterbox, self.tvf_letterbox_fill = False, None + + # Handle Prismatic `image_resize_strategy` + if self.image_resize_strategy == "resize-naive": + self.tvf_resize_params[idx]["size"] = (resize_t.size, resize_t.size) + elif self.image_resize_strategy == "letterbox": + self.tvf_do_letterbox, self.tvf_letterbox_fill = True, tuple([int(x * 255) for x in self.means[idx]]) + elif self.image_resize_strategy == "resize-crop": + pass + else: + raise ValueError(f"Image resize strategy `{self.image_resize_strategy}` is not supported!") + + # Dispatch **kwargs to super() + super().__init__(**kwargs) + + def apply_transform(self, img: Image.Image) -> torch.Tensor: + """Apply `functional` variant of TIMM's Transform = Compose([Resize -> CenterCrop -> ToTensor -> Normalize])""" + if self.tvf_do_letterbox: + img = letterbox_pad_transform(img, self.tvf_letterbox_fill) + + # [Contract] Fused Backbones expect "channel-stacked" inputs; we'll unpack on the model side! + imgs_t = [] + for idx in range(len(self.input_sizes)): + img_idx = TVF.resize(img, **self.tvf_resize_params[idx]) + img_idx = TVF.center_crop(img_idx, **self.tvf_crop_params[idx]) + img_idx_t = TVF.to_tensor(img_idx) + img_idx_t = TVF.normalize(img_idx_t, **self.tvf_normalize_params[idx]) + imgs_t.append(img_idx_t) + + # [Contract] `imgs_t` is a list of Tensors of shape [3, input_size, input_size]; stack along dim = 0 + img_t = torch.vstack(imgs_t) + + return img_t + + def preprocess( + self, + images: Union[Image.Image, List[Image.Image]], + return_tensors: Optional[Union[str, TensorType]] = None, + **_: str, + ) -> BatchFeature: + """ + Preprocess an image (or batch of images); note that unlike the `transformers :: BaseImageProcessor` we + explicitly only handle PIL.Image.Image instances for simplicity. + @param images: A (batch of) PIL.Image.Image instance(s) to preprocess. + @param return_tensors: BatchFeature default Tensor format (e.g., "pt" for torch); if None, returns np.ndarray + @return: Instance of `transformers :: BatchFeature` with a single key "pixel_values" + """ + if not isinstance(images, list): + images = [images] + + # Apply `self.img_transform` to each image (will return list of torch.Tensors); stack into "batched" Tensor + pixel_values = torch.stack([self.apply_transform(img.convert("RGB")) for img in images]) + + # Return BatchFeature =>> note that for compatibility, constructor expects Dict[str, np.ndarray], so we convert + return BatchFeature(data={"pixel_values": pixel_values.float().numpy()}, tensor_type=return_tensors) + + def __call__(self, images: Union[Image.Image, List[Image.Image]], **kwargs) -> BatchFeature: + return self.preprocess(images, **kwargs) + + +# === PrismaticProcessor =>> Wraps both ImageProcessor and Tokenizer === +# =>> https://github.com/huggingface/transformers/blob/main/src/transformers/models/llava/processing_llava.py +class PrismaticProcessor(ProcessorMixin): + attributes: ClassVar[List[str]] = ["image_processor", "tokenizer"] + image_processor_class: str = "AutoImageProcessor" + tokenizer_class: str = "AutoTokenizer" + + def __init__( + self, + image_processor: Optional[ImageProcessingMixin] = None, + tokenizer: Optional[PreTrainedTokenizerBase] = None, + ) -> None: + super().__init__(image_processor, tokenizer) + + def __call__( + self, + text: Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]], + images: Union[Image.Image, List[Image.Image]], + padding: Union[bool, str, PaddingStrategy] = False, + truncation: Optional[Union[bool, str, TruncationStrategy]] = None, + max_length: Optional[int] = None, + return_tensors: Optional[Union[str, TensorType]] = TensorType.PYTORCH, + ) -> BatchFeature: + """ + Preprocess a given (batch) of text/images for a Prismatic VLM; forwards text to the underlying LLM's tokenizer, + forwards images to PrismaticImageProcessor. + @param text: The (batch) of text to encode; must be a string or list of strings. + @param images: A (batch of) PIL.Image.Image instance(s) to preprocess. + @param padding: Sequence padding strategy (if multiple specified) in < True = "longest" | "max_length" | False > + @param truncation: Truncation strategy for the output sequences; requires `max_length` to be specified + @param max_length: Maximum length (in tokens) to truncate + @param return_tensors: Type of return tensors (usually "pt" or TensorType.PYTORCH) + @return: BatchFeature with keys for `input_ids`, `attention_mask` and `pixel_values`. + """ + pixel_values = self.image_processor(images, return_tensors=return_tensors)["pixel_values"] + text_inputs = self.tokenizer( + text, return_tensors=return_tensors, padding=padding, truncation=truncation, max_length=max_length + ) + + # [Validate] Need same number of images and text inputs! + if pixel_values.shape[0] != text_inputs.input_ids.shape[0]: + raise ValueError("Batch is malformed; expected same number of images and text inputs!") + + return BatchFeature(data={**text_inputs, "pixel_values": pixel_values}) + + # === Tokenizer Dispatch Utilities =>> check `PreTrainedTokenizerBase` for documentation === + def batch_decode( + self, + sequences: Union[List[int], List[List[int]], torch.Tensor, Any], # `Any` = np.ndarray | tf.Tensor + skip_special_tokens: bool = False, + clean_up_tokenization_spaces: Optional[bool] = None, + **kwargs: str, + ) -> List[str]: + return self.tokenizer.batch_decode( + sequences=sequences, + skip_special_tokens=skip_special_tokens, + clean_up_tokenization_spaces=clean_up_tokenization_spaces, + **kwargs, + ) + + def decode( + self, + token_ids: Union[int, List[int], torch.Tensor, Any], # `Any` = np.ndarray | tf.Tensor + skip_special_tokens: bool = False, + clean_up_tokenization_spaces: Optional[bool] = None, + **kwargs: str, + ) -> str: + return self.tokenizer.decode( + token_ids=token_ids, + skip_special_tokens=skip_special_tokens, + clean_up_tokenization_spaces=clean_up_tokenization_spaces, + **kwargs, + ) + + @property + def model_input_names(self) -> List[str]: + tokenizer_input_names = self.tokenizer.model_input_names + image_processor_input_names = self.image_processor.model_input_names + + return list(dict.fromkeys(tokenizer_input_names + image_processor_input_names)) diff --git a/prismatic/models/__init__.py b/prismatic/models/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..073c620e0e23c09e05ac0a4877458777fc981f22 --- /dev/null +++ b/prismatic/models/__init__.py @@ -0,0 +1,2 @@ +from .load import available_model_names, available_models, get_model_description, load, load_vla +from .materialize import get_llm_backbone_and_tokenizer, get_vision_backbone_and_transform, get_vlm diff --git a/prismatic/models/action_heads.py b/prismatic/models/action_heads.py new file mode 100644 index 0000000000000000000000000000000000000000..0f5d63b8f60fddf59fe40f93395d99ee829746a0 --- /dev/null +++ b/prismatic/models/action_heads.py @@ -0,0 +1,211 @@ +"""Implementations of various action heads, which serve as alternatives to VLM sequential token prediction.""" + +import math + +import numpy as np +import torch +import torch.nn as nn +from diffusers.schedulers.scheduling_ddim import DDIMScheduler +from prismatic.vla.constants import ACTION_DIM, ACTION_TOKEN_BEGIN_IDX, IGNORE_INDEX, NUM_ACTIONS_CHUNK, PROPRIO_DIM, STOP_INDEX + + +class SinusoidalPositionalEncoding(nn.Module): + """ + Sine- and cosine-based positional encoding that produces embeddings of a batch of timesteps. + + For example, at train time, the input might be a batch of 32 randomly sampled diffusion timesteps -> shape (32,) + Then the output would be a batch of 32 timestep embeddings -> shape (32, D) + + Adapted from: https://github.com/real-stanford/diffusion_policy/blob/main/diffusion_policy/model/diffusion/positional_embedding.py + """ + + def __init__(self, dim): + super().__init__() + self.dim = dim # dimensionality of the positional encoding + + def forward(self, x): + # x: (batch_size,) + device = x.device + assert self.dim % 2 == 0, f"# dimensions must be even but got {self.dim}" + half_dim = self.dim // 2 + exponent = torch.arange(half_dim, device=device) * -math.log(10000) / (half_dim - 1) # shape: (D/2,) + emb = torch.exp(exponent) # shape: (D/2,) + emb = x[:, None] * emb[None, :] # shape: (batch_size, 1) * (1, D/2) -> (batch_size, D/2) + emb = torch.cat((emb.sin(), emb.cos()), dim=-1) # shape: (batch_size, D) + return emb + + +class MLPResNetBlock(nn.Module): + """One MLP ResNet block with a residual connection.""" + def __init__(self, dim): + super().__init__() + self.dim = dim + self.ffn = nn.Sequential( # feedforward network, similar to the ones in Transformers + nn.LayerNorm(dim), + nn.Linear(dim, dim), + nn.ReLU(), + ) + + def forward(self, x): + # x: (batch_size, hidden_dim) + # We follow the module ordering of "Pre-Layer Normalization" feedforward networks in Transformers as + # described here: https://arxiv.org/pdf/2002.04745.pdf + identity = x + x = self.ffn(x) + x = x + identity + return x + + +class MLPResNet(nn.Module): + """MLP with residual connection blocks.""" + def __init__(self, num_blocks, input_dim, hidden_dim, output_dim): + super().__init__() + self.layer_norm1 = nn.LayerNorm(input_dim) + self.fc1 = nn.Linear(input_dim, hidden_dim) + self.relu = nn.ReLU() + self.mlp_resnet_blocks = nn.ModuleList() + for _ in range(num_blocks): + self.mlp_resnet_blocks.append(MLPResNetBlock(dim=hidden_dim)) + self.layer_norm2 = nn.LayerNorm(hidden_dim) + self.fc2 = nn.Linear(hidden_dim, output_dim) + + def forward(self, x): + # x: (batch_size, input_dim) + x = self.layer_norm1(x) # shape: (batch_size, input_dim) + x = self.fc1(x) # shape: (batch_size, hidden_dim) + x = self.relu(x) # shape: (batch_size, hidden_dim) + for block in self.mlp_resnet_blocks: + x = block(x) # shape: (batch_size, hidden_dim) + x = self.layer_norm2(x) # shape: (batch_size, hidden_dim) + x = self.fc2(x) # shape: (batch_size, output_dim) + return x + + +class L1RegressionActionHead(nn.Module): + """Simple MLP-based action head that generates continuous actions via L1 regression.""" + def __init__( + self, + input_dim=4096, + hidden_dim=4096, + action_dim=7, + ): + super().__init__() + self.action_dim = action_dim + self.model = MLPResNet( + num_blocks=2, input_dim=input_dim*ACTION_DIM, hidden_dim=hidden_dim, output_dim=action_dim + ) + + def predict_action(self, actions_hidden_states): + # actions_hidden_states: last hidden states of Transformer corresponding to action tokens in sequence + # - shape: (batch_size, chunk_len * action_dim, hidden_dim) + # ground_truth_actions: ground-truth actions + # - shape: (batch_size, chunk_len, action_dim) + batch_size = actions_hidden_states.shape[0] + device = actions_hidden_states.device + rearranged_actions_hidden_states = actions_hidden_states.reshape(batch_size, NUM_ACTIONS_CHUNK, -1) + action = self.model(rearranged_actions_hidden_states) + return action + + +class NoisePredictionModel(nn.Module): + """ + Diffusion noise prediction model that takes an observation embedding (which fuses the + noisy action, diffusion timestep, and image-language observation embeddings) and + outputs a noise prediction. + """ + + def __init__( + self, + transformer_hidden_dim, # Transformer hidden embedding size + hidden_dim, # MLP hidden size + action_dim=7, # action dimensionality + ): + super().__init__() + self.mlp_resnet = MLPResNet( + num_blocks=2, + input_dim=transformer_hidden_dim, + hidden_dim=hidden_dim, + output_dim=action_dim, + ) + + def forward( + self, + obs, + ): + # obs: observation embeddings to condition the generation on + # - shape: (batch_size, chunk_len, rearranged_hidden_dim=action_dim*hidden_dim) + # + # output: predicted noise + # - shape: (batch_size, action_dim) + output = self.mlp_resnet(obs) + return output + + +class DiffusionActionHead(nn.Module): + """ + Simple MLP-based action head that generates continuous actions via conditional denoising diffusion process. + + Loosely inspired by: https://github.com/real-stanford/diffusion_policy/blob/main/diffusion_policy/model/diffusion/transformer_for_diffusion.py + """ + + def __init__( + self, + input_dim=4096, + hidden_dim=4096, + action_dim=7, + num_diffusion_steps_train=50, + ): + super().__init__() + self.action_dim = action_dim + self.noise_predictor = NoisePredictionModel( + transformer_hidden_dim=hidden_dim*ACTION_DIM, hidden_dim=hidden_dim, action_dim=action_dim + ) + self.num_diffusion_steps_train = num_diffusion_steps_train + self.noise_scheduler = DDIMScheduler(num_train_timesteps=num_diffusion_steps_train, beta_schedule="squaredcos_cap_v2") + self.time_encoder = SinusoidalPositionalEncoding(dim=hidden_dim) + + def sample_noisy_actions(self, ground_truth_actions): + """ + Samples noise and applies noise to ground-truth actions to produce noisy actions, which are + used as input in the noise prediction network. Returns noise, noisy actions, and the + corresponding diffusion timestep embeddings. + """ + # ground_truth_actions: ground-truth actions + # - shape: (batch_size, chunk_len, action_dim) + batch_size = ground_truth_actions.shape[0] + device = ground_truth_actions.device + # Sample random noise with shape equal to actions, used for closed-form forward diffusion. + noise = torch.randn(size=(batch_size, NUM_ACTIONS_CHUNK, ACTION_DIM), device=device, dtype=ground_truth_actions.dtype) # (B, chunk_len, action_dim) + # Sample random diffusion timesteps (one for each action in batch). + timesteps = torch.randint( + low=0, high=self.noise_scheduler.config.num_train_timesteps, size=(batch_size,), device=device + ) + # Add noise to clean actions according to the magnitude at each diffusion timestep via + # closed-form forward diffusion. + noisy_actions = self.noise_scheduler.add_noise(ground_truth_actions, noise, timesteps) # (B, chunk_len, action_dim) + + # Get diffusion timestep embeddings as well + diffusion_timestep_embeddings = self.time_encoder(timesteps).to(noisy_actions.dtype).to(noisy_actions.device) # (B, llm_dim) + diffusion_timestep_embeddings = diffusion_timestep_embeddings.unsqueeze(1) # (B, 1, llm_dim) + + return_dict = dict( + noise=noise, + noisy_actions=noisy_actions, + diffusion_timestep_embeddings=diffusion_timestep_embeddings, + ) + + return return_dict + + def predict_noise(self, actions_hidden_states): + """ + Given a batch of last hidden Transformer layer embeddings (which fuse the vision-language observation embeddings, + noisy action embeddings, and diffusion timestep embedding), predicts the noise applied to the actions. + """ + # actions_hidden_states: last hidden states of Transformer corresponding to action tokens in sequence + # - shape: (batch_size, chunk_len * action_dim, hidden_dim) + batch_size = actions_hidden_states.shape[0] + device = actions_hidden_states.device + rearranged_actions_hidden_states = actions_hidden_states.reshape(batch_size, NUM_ACTIONS_CHUNK, -1) # (batch_size, chunk_len, action_dim * hidden_dim) + # Get diffusion model's noise prediction. + noise_pred = self.noise_predictor(rearranged_actions_hidden_states) + return noise_pred diff --git a/prismatic/models/backbones/__init__.py b/prismatic/models/backbones/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/prismatic/models/backbones/llm/__init__.py b/prismatic/models/backbones/llm/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..1cdb19953b6529cb4015544c9d0cf8c43a22c2e9 --- /dev/null +++ b/prismatic/models/backbones/llm/__init__.py @@ -0,0 +1,4 @@ +from .base_llm import LLMBackbone +from .llama2 import LLaMa2LLMBackbone +from .mistral import MistralLLMBackbone +from .phi import PhiLLMBackbone diff --git a/prismatic/models/backbones/llm/base_llm.py b/prismatic/models/backbones/llm/base_llm.py new file mode 100644 index 0000000000000000000000000000000000000000..8d5a91913dc851b16684a4b4729f031ee48fce50 --- /dev/null +++ b/prismatic/models/backbones/llm/base_llm.py @@ -0,0 +1,223 @@ +""" +base_llm.py + +Abstract class definition of a large (autoregressive) language model backbone (LLM), with full annotations of class +methods, utility functions, and initialization logic. + +We also define the generic HFLLMBackbone class here, providing a default interface for loading any HF +AutoModelForCausalLM (e.g., LLamaForCausalLM). In general, we make the assumption that any given LLM backbone implements +the AutoModelForCausalLM API (though we may add Seq2Seq models in the future). + +We make this assumption to keep the LLM handling in this codebase relatively lightweight, and to inherit all the nice HF +utilities around different types of decoding/generation strategies. +""" + +import warnings +from abc import ABC, abstractmethod +from functools import partial +from typing import Callable, List, Optional, Sequence, Type + +import torch +import torch.nn as nn +from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy +from transformers import AutoConfig, AutoTokenizer, PreTrainedModel, PreTrainedTokenizerBase +from transformers.modeling_outputs import CausalLMOutputWithPast + +from prismatic.models.backbones.llm.prompting import PromptBuilder +from prismatic.overwatch import initialize_overwatch + +# Suppress HF Deprecation Warnings +warnings.filterwarnings("ignore", category=FutureWarning) + +# Initialize Overwatch =>> Wraps `logging.Logger` +overwatch = initialize_overwatch(__name__) + + +# === Abstract Base Class for arbitrary HF LLM Backbones === +class LLMBackbone(nn.Module, ABC): + def __init__(self, llm_backbone_id: str) -> None: + super().__init__() + self.identifier = llm_backbone_id + + # Instance attributes for an LLM Backbone + self.llm: PreTrainedModel = None + self.tokenizer: PreTrainedTokenizerBase = None + + def get_tokenizer(self) -> PreTrainedTokenizerBase: + return self.tokenizer + + @abstractmethod + def get_fsdp_wrapping_policy(self) -> Callable: ... + + @abstractmethod + def enable_gradient_checkpointing(self) -> None: ... + + @abstractmethod + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> CausalLMOutputWithPast: + """Run a forward pass through the LLM given targets (labels), returning the scalar Cross-Entropy Loss""" + raise NotImplementedError + + @abstractmethod + def embed_input_ids(self, input_ids: torch.LongTensor) -> torch.Tensor: ... + + @property + @abstractmethod + def prompt_builder_fn(self) -> Type[PromptBuilder]: ... + + @property + @abstractmethod + def transformer_layer_cls(self) -> Type[nn.Module]: ... + + @property + @abstractmethod + def half_precision_dtype(self) -> torch.dtype: ... + + @property + @abstractmethod + def last_layer_finetune_modules(self) -> Sequence[nn.Module]: ... + + @property + def embed_dim(self) -> int: + return self.llm.config.hidden_size + + @property + def pad_token_id(self) -> int: + return self.tokenizer.pad_token_id + + +# === Abstract Base Class for Arbitrary HF Causal LLMs === +class HFCausalLLMBackbone(LLMBackbone, ABC): + def __init__( + self, + llm_backbone_id: str, + llm_family: str, + llm_cls: Type[PreTrainedModel], + hf_hub_path: str, + llm_max_length: int = 2048, + hf_token: Optional[str] = None, + inference_mode: bool = False, + use_flash_attention_2: bool = False, + ) -> None: + super().__init__(llm_backbone_id) + self.llm_family = llm_family + self.llm_max_length = llm_max_length + self.inference_mode = inference_mode + + # Initialize LLM (downloading from HF Hub if necessary) --> `llm_cls` is the actual {Model}ForCausalLM class! + # => Note: We're eschewing use of the AutoModel API so that we can be more explicit about LLM-specific details + if not self.inference_mode: + overwatch.info(f"Loading [bold]{llm_family}[/] LLM from [underline]`{hf_hub_path}`[/]", ctx_level=1) + self.llm = llm_cls.from_pretrained( + hf_hub_path, + token=hf_token, + use_flash_attention_2=use_flash_attention_2 if not self.inference_mode else False, + # The following parameters are set to prevent `UserWarnings` from HF; we want greedy decoding! + do_sample=False, + temperature=1.0, + top_p=1.0, + ) + + # [Contract] `inference_mode` means we're loading from a pretrained checkpoint; no need to load base weights! + else: + overwatch.info(f"Building empty [bold]{llm_family}[/] LLM from [underline]`{hf_hub_path}`[/]", ctx_level=1) + llm_config = AutoConfig.from_pretrained(hf_hub_path, token=hf_token) + self.llm = llm_cls._from_config(llm_config) + + # Lightweight Handling (with extended explanation) for setting some LLM Parameters + # => Set `decoder.use_cache = False` --> incompatible with gradient checkpointing (+ training in general) + # + # Reference: https://discuss.huggingface.co/t/what-is-the-purpose-of-use-cache-in-decoder/958 + self.llm.config.use_cache = False if not self.inference_mode else True + + # => Turns out that when gradient checkpointing is on and the underlying LLM has no "trainable" parameters + # (requires_grad is False), backprop will fail; setting `enable_input_requires_grad()` registers a new + # forward hook that fixes this =>> also totally safe for the "full finetuning" setting! + if not self.inference_mode: + self.llm.enable_input_require_grads() + + # Load (Fast) Tokenizer + overwatch.info(f"Loading [bold]{llm_family}[/] (Fast) Tokenizer via the AutoTokenizer API", ctx_level=1) + self.tokenizer = AutoTokenizer.from_pretrained( + hf_hub_path, model_max_length=self.llm_max_length, token=hf_token, padding_side="right" + ) + + # Validation =>> Our VLM logic currently operates under the assumption that the tokenization of a new input + # starts with a token unless `add_special_tokens = False`; for these models, we empirically + # find that adding image patches *after* the BOS leads to much better performance. + # + # As a result we explicitly validate that a tokenizer conforms to the expected behavior; if you're reading this + # line, it's probably because you're adding a new LLM with a different tokenizer behavior. If so, feel free to + # override the `SPECIAL_CASES` set below, but make sure to make the appropriate changes in the `datasets.py` + # and VLM `forward()` logic! + SPECIAL_CASES = { + # Phi-2 Tokenizer doesn't add any BOS tokens by default, and sets BOS == EOS == "<|endoftext|>" + # =>> We'll prepend BOS to first input (to play nicely with image token insertion logic; verified that + # this works well with base LLM generation. + # =>> Like Llama-2 Tokenizers -- we'll add a special PAD token for training purposes. + "phi-2-3b", + } + if self.identifier in SPECIAL_CASES: + return + + # Note =>> this assert should hold for all Llama-derived tokenizers (`LlamaTokenizerFast` ==> includes Mistral! + assert (self.tokenizer("Test 123", add_special_tokens=True).input_ids[0] == self.tokenizer.bos_token_id) and ( + self.tokenizer("Test 123", add_special_tokens=False).input_ids[0] != self.tokenizer.bos_token_id + ), ( + f"Default Tokenizer of type `{type(self.tokenizer)}` does not automatically prefix inputs with BOS token!\n" + "Please read the comment in `base_llm.py` for more information!" + ) + + def get_fsdp_wrapping_policy(self) -> Callable: + """Return a `transformer_auto_wrap_policy` where we wrap each instance of `self.transformer_layer_cls`""" + transformer_block_policy = partial( + transformer_auto_wrap_policy, transformer_layer_cls={self.transformer_layer_cls} + ) + + return transformer_block_policy + + def enable_gradient_checkpointing(self) -> None: + """Dispatch to underlying LLM instance's `gradient_checkpointing_enable`; defined for all `PretrainedModel`.""" + self.llm.gradient_checkpointing_enable() + + def embed_input_ids(self, input_ids: torch.LongTensor) -> torch.Tensor: + return self.llm.get_input_embeddings()(input_ids) + + # [Contract] Should match the `forward` call of the underlying `llm` instance! + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> CausalLMOutputWithPast: + output: CausalLMOutputWithPast = self.llm( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + labels=labels, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + return output diff --git a/prismatic/models/backbones/llm/llama2.py b/prismatic/models/backbones/llm/llama2.py new file mode 100644 index 0000000000000000000000000000000000000000..60f1f5756b564a67693870399d18d574eedbfbd9 --- /dev/null +++ b/prismatic/models/backbones/llm/llama2.py @@ -0,0 +1,102 @@ +""" +llama2.py + +Class definition for all LLMs derived from LlamaForCausalLM. +""" + +from typing import Optional, Sequence, Type + +import torch +from torch import nn as nn +from transformers import LlamaForCausalLM +from transformers.models.llama.modeling_llama import LlamaDecoderLayer + +from prismatic.models.backbones.llm.base_llm import HFCausalLLMBackbone +from prismatic.models.backbones.llm.prompting import ( + LLaMa2ChatPromptBuilder, + PromptBuilder, + PurePromptBuilder, + VicunaV15ChatPromptBuilder, +) + +# Registry =>> Support LLaMa-2 Models (from HF Transformers) +# fmt: off +LLAMA2_MODELS = { + # === Pure Meta LLaMa-2 (non-instruct/chat-tuned) Models === + "llama2-7b-pure": { + "llm_family": "llama2", "llm_cls": LlamaForCausalLM, "hf_hub_path": "meta-llama/Llama-2-7b-hf" + }, + + "llama2-13b-pure": { + "llm_family": "llama2", "llm_cls": LlamaForCausalLM, "hf_hub_path": "meta-llama/Llama-2-13b-hf" + }, + + # === Meta LLaMa-2 Chat Models === + "llama2-7b-chat": { + "llm_family": "llama2", "llm_cls": LlamaForCausalLM, "hf_hub_path": "meta-llama/Llama-2-7b-chat-hf" + }, + + "llama2-13b-chat": { + "llm_family": "llama2", "llm_cls": LlamaForCausalLM, "hf_hub_path": "meta-llama/Llama-2-13b-chat-hf" + }, + + # === Vicuna v1.5 Chat Models === + "vicuna-v15-7b": { + "llm_family": "llama2", "llm_cls": LlamaForCausalLM, "hf_hub_path": "lmsys/vicuna-7b-v1.5" + }, + + "vicuna-v15-13b": { + "llm_family": "llama2", "llm_cls": LlamaForCausalLM, "hf_hub_path": "lmsys/vicuna-13b-v1.5" + }, +} +# fmt: on + + +class LLaMa2LLMBackbone(HFCausalLLMBackbone): + def __init__( + self, + llm_backbone_id: str, + llm_max_length: int = 2048, + hf_token: Optional[str] = None, + inference_mode: bool = False, + use_flash_attention_2: bool = True, + ) -> None: + super().__init__( + llm_backbone_id, + llm_max_length=llm_max_length, + hf_token=hf_token, + inference_mode=inference_mode, + use_flash_attention_2=use_flash_attention_2, + **LLAMA2_MODELS[llm_backbone_id], + ) + + # [Special Case] LLaMa-2 PAD Token Handling --> for clarity, we add an extra token (and resize) + self.tokenizer.add_special_tokens({"pad_token": ""}) + self.llm.config.pad_token_id = self.tokenizer.pad_token_id + self.llm.resize_token_embeddings(len(self.tokenizer), pad_to_multiple_of=64) + + @property + def prompt_builder_fn(self) -> Type[PromptBuilder]: + if self.identifier.startswith("llama2-") and self.identifier.endswith("-pure"): + return PurePromptBuilder + + elif self.identifier.startswith("llama2-") and self.identifier.endswith("-chat"): + return LLaMa2ChatPromptBuilder + + elif self.identifier.startswith("vicuna"): + return VicunaV15ChatPromptBuilder + + raise ValueError(f"No PromptBuilder defined for LLM Backbone `{self.identifier}`") + + @property + def transformer_layer_cls(self) -> Type[nn.Module]: + return LlamaDecoderLayer + + @property + def half_precision_dtype(self) -> torch.dtype: + """LLaMa-2 was trained in BF16; see https://huggingface.co/docs/transformers/main/model_doc/llama2.""" + return torch.bfloat16 + + @property + def last_layer_finetune_modules(self) -> Sequence[nn.Module]: + return (self.llm.model.embed_tokens, self.llm.model.layers[-1], self.llm.lm_head) diff --git a/prismatic/models/backbones/llm/mistral.py b/prismatic/models/backbones/llm/mistral.py new file mode 100644 index 0000000000000000000000000000000000000000..9bc25553db6f9d2cbd292c98615b67c2105345f4 --- /dev/null +++ b/prismatic/models/backbones/llm/mistral.py @@ -0,0 +1,72 @@ +""" +mistral.py + +Class definition for all LLMs derived from MistralForCausalLM. +""" + +from typing import Optional, Type + +import torch +from torch import nn as nn +from transformers import MistralForCausalLM +from transformers.models.mistral.modeling_mistral import MistralDecoderLayer + +from prismatic.models.backbones.llm.base_llm import HFCausalLLMBackbone +from prismatic.models.backbones.llm.prompting import MistralInstructPromptBuilder, PromptBuilder, PurePromptBuilder + +# Registry =>> Support Mistral Models (from HF Transformers) +# fmt: off +MISTRAL_MODELS = { + # === Base Mistral v0.1 === + "mistral-v0.1-7b-pure": { + "llm_family": "mistral", "llm_cls": MistralForCausalLM, "hf_hub_path": "mistralai/Mistral-7B-v0.1" + }, + + # === Mistral Instruct v0.1 === + "mistral-v0.1-7b-instruct": { + "llm_family": "mistral", "llm_cls": MistralForCausalLM, "hf_hub_path": "mistralai/Mistral-7B-Instruct-v0.1" + } +} +# fmt: on + + +class MistralLLMBackbone(HFCausalLLMBackbone): + def __init__( + self, + llm_backbone_id: str, + llm_max_length: int = 2048, + hf_token: Optional[str] = None, + inference_mode: bool = False, + use_flash_attention_2: bool = True, + ) -> None: + super().__init__( + llm_backbone_id, + llm_max_length=llm_max_length, + hf_token=hf_token, + inference_mode=inference_mode, + use_flash_attention_2=use_flash_attention_2, + **MISTRAL_MODELS[llm_backbone_id], + ) + + # [Special Case] Mistral PAD Token Handling --> for clarity, we add an extra token (and resize) + self.tokenizer.add_special_tokens({"pad_token": ""}) + self.llm.config.pad_token_id = self.tokenizer.pad_token_id + self.llm.resize_token_embeddings(len(self.tokenizer), pad_to_multiple_of=64) + + @property + def prompt_builder_fn(self) -> Type[PromptBuilder]: + if self.identifier.endswith("-pure"): + return PurePromptBuilder + + elif self.identifier.endswith("-instruct"): + return MistralInstructPromptBuilder + + raise ValueError(f"No PromptBuilder defined for LLM Backbone `{self.identifier}`") + + @property + def transformer_layer_cls(self) -> Type[nn.Module]: + return MistralDecoderLayer + + @property + def half_precision_dtype(self) -> torch.dtype: + return torch.bfloat16 diff --git a/prismatic/models/backbones/llm/phi.py b/prismatic/models/backbones/llm/phi.py new file mode 100644 index 0000000000000000000000000000000000000000..27bb8b79b0e384af02071a21589050e41aebb4ae --- /dev/null +++ b/prismatic/models/backbones/llm/phi.py @@ -0,0 +1,64 @@ +""" +phi.py + +Class definition for all LLMs derived from PhiForCausalLM. +""" + +from typing import Optional, Type + +import torch +from torch import nn as nn +from transformers import PhiForCausalLM +from transformers.models.phi.modeling_phi import PhiDecoderLayer + +from prismatic.models.backbones.llm.base_llm import HFCausalLLMBackbone +from prismatic.models.backbones.llm.prompting import PhiPromptBuilder, PromptBuilder + +# Registry ==> Support Phi Models (from HF Transformers) +# fmt: off +PHI_MODELS = { + # === Phi-2 === + "phi-2-3b": { + "llm_family": "phi", "llm_cls": PhiForCausalLM, "hf_hub_path": "microsoft/phi-2" + } +} +# fmt: on + + +class PhiLLMBackbone(HFCausalLLMBackbone): + def __init__( + self, + llm_backbone_id: str, + llm_max_length: int = 2048, + hf_token: Optional[str] = None, + inference_mode: bool = False, + use_flash_attention_2: bool = True, + ) -> None: + super().__init__( + llm_backbone_id, + llm_max_length=llm_max_length, + hf_token=hf_token, + inference_mode=inference_mode, + use_flash_attention_2=use_flash_attention_2, + **PHI_MODELS[llm_backbone_id], + ) + + # [Special Case] Phi PAD Token Handling --> for clarity, we add an extra token (and resize) + self.tokenizer.add_special_tokens({"pad_token": "<|pad|>"}) + self.llm.config.pad_token_id = self.tokenizer.pad_token_id + self.llm.resize_token_embeddings(len(self.tokenizer), pad_to_multiple_of=64) + + @property + def prompt_builder_fn(self) -> Type[PromptBuilder]: + if self.identifier.startswith("phi-2"): + return PhiPromptBuilder + + raise ValueError(f"No PromptBuilder defined for LLM Backbone `{self.identifier}`") + + @property + def transformer_layer_cls(self) -> Type[nn.Module]: + return PhiDecoderLayer + + @property + def half_precision_dtype(self) -> torch.dtype: + return torch.bfloat16 diff --git a/prismatic/models/backbones/llm/prompting/__init__.py b/prismatic/models/backbones/llm/prompting/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..8292789c29f76344bf400f844cb946e8a4081a98 --- /dev/null +++ b/prismatic/models/backbones/llm/prompting/__init__.py @@ -0,0 +1,5 @@ +from .base_prompter import PromptBuilder, PurePromptBuilder +from .llama2_chat_prompter import LLaMa2ChatPromptBuilder +from .mistral_instruct_prompter import MistralInstructPromptBuilder +from .phi_prompter import PhiPromptBuilder +from .vicuna_v15_prompter import VicunaV15ChatPromptBuilder diff --git a/prismatic/models/backbones/llm/prompting/base_prompter.py b/prismatic/models/backbones/llm/prompting/base_prompter.py new file mode 100644 index 0000000000000000000000000000000000000000..3816b1831177d128bbe872c041ed5b0ffaf5b5ed --- /dev/null +++ b/prismatic/models/backbones/llm/prompting/base_prompter.py @@ -0,0 +1,73 @@ +""" +base_prompter.py + +Abstract class definition of a multi-turn prompt builder for ensuring consistent formatting for chat-based LLMs. +""" + +from abc import ABC, abstractmethod +from typing import Optional + + +class PromptBuilder(ABC): + def __init__(self, model_family: str, system_prompt: Optional[str] = None) -> None: + self.model_family = model_family + + # Only some models define a system prompt => let subclasses handle this logic! + self.system_prompt = system_prompt + + @abstractmethod + def add_turn(self, role: str, message: str) -> str: ... + + @abstractmethod + def get_potential_prompt(self, user_msg: str) -> None: ... + + @abstractmethod + def get_prompt(self) -> str: ... + + +class PurePromptBuilder(PromptBuilder): + def __init__(self, model_family: str, system_prompt: Optional[str] = None) -> None: + super().__init__(model_family, system_prompt) + + # TODO (siddk) =>> Can't always assume LlamaTokenizer --> FIX ME! + self.bos, self.eos = "", "" + + # Get role-specific "wrap" functions + self.wrap_human = lambda msg: f"In: {msg}\nOut: " + self.wrap_gpt = lambda msg: f"{msg if msg != '' else ' '}{self.eos}" + + # === `self.prompt` gets built up over multiple turns === + self.prompt, self.turn_count = "", 0 + + def add_turn(self, role: str, message: str) -> str: + assert (role == "human") if (self.turn_count % 2 == 0) else (role == "gpt") + message = message.replace("", "").strip() + + if (self.turn_count % 2) == 0: + human_message = self.wrap_human(message) + wrapped_message = human_message + else: + gpt_message = self.wrap_gpt(message) + wrapped_message = gpt_message + + # Update Prompt + self.prompt += wrapped_message + + # Bump Turn Counter + self.turn_count += 1 + + # Return "wrapped_message" (effective string added to context) + return wrapped_message + + def get_potential_prompt(self, message: str) -> None: + # Assumes that it's always the user's (human's) turn! + prompt_copy = str(self.prompt) + + human_message = self.wrap_human(message) + prompt_copy += human_message + + return prompt_copy.removeprefix(self.bos).rstrip() + + def get_prompt(self) -> str: + # Remove prefix (if exists) because it gets auto-inserted by tokenizer! + return self.prompt.removeprefix(self.bos).rstrip() diff --git a/prismatic/models/backbones/llm/prompting/llama2_chat_prompter.py b/prismatic/models/backbones/llm/prompting/llama2_chat_prompter.py new file mode 100644 index 0000000000000000000000000000000000000000..b094675a6ee550a8d2d99ed35be2d3c4a71e861d --- /dev/null +++ b/prismatic/models/backbones/llm/prompting/llama2_chat_prompter.py @@ -0,0 +1,91 @@ +""" +llama2_prompter.py + +Defines a PromptBuilder for building LLaMa-2 Chat Prompts --> not sure if this is "optimal", but this is the pattern +that's used by HF and other online tutorials. + +Reference: https://huggingface.co/blog/llama2#how-to-prompt-llama-2 +""" + +from typing import Optional + +from prismatic.models.backbones.llm.prompting.base_prompter import PromptBuilder + +# Default System Prompt for Prismatic Models +SYS_PROMPTS = { + "prismatic": ( + "You are a helpful language and vision assistant. " + "You are able to understand the visual content that the user provides, " + "and assist the user with a variety of tasks using natural language." + ), + "openvla": ( + "You are a helpful language and vision assistant. " + "You are able to understand the visual content that the user provides, " + "and assist the user with a variety of tasks using natural language." + ), +} + + +def format_system_prompt(system_prompt: str) -> str: + return f"<\n{system_prompt.strip()}\n<>\n\n" + + +class LLaMa2ChatPromptBuilder(PromptBuilder): + def __init__(self, model_family: str, system_prompt: Optional[str] = None) -> None: + super().__init__(model_family, system_prompt) + self.system_prompt = format_system_prompt( + SYS_PROMPTS[self.model_family] if system_prompt is None else system_prompt + ) + + # LLaMa-2 Specific + self.bos, self.eos = "", "" + + # Get role-specific "wrap" functions + self.wrap_human = lambda msg: f"[INST] {msg} [/INST] " + self.wrap_gpt = lambda msg: f"{msg if msg != '' else ' '}{self.eos}" + + # === `self.prompt` gets built up over multiple turns === + self.prompt, self.turn_count = "", 0 + + def add_turn(self, role: str, message: str) -> str: + assert (role == "human") if (self.turn_count % 2 == 0) else (role == "gpt") + message = message.replace("", "").strip() + + # Special Handling for "system" prompt (turn_count == 0) + if self.turn_count == 0: + sys_message = self.wrap_human(self.system_prompt + message) + wrapped_message = sys_message + elif (self.turn_count % 2) == 0: + human_message = self.wrap_human(message) + wrapped_message = human_message + else: + gpt_message = self.wrap_gpt(message) + wrapped_message = gpt_message + + # Update Prompt + self.prompt += wrapped_message + + # Bump Turn Counter + self.turn_count += 1 + + # Return "wrapped_message" (effective string added to context) + return wrapped_message + + def get_potential_prompt(self, message: str) -> None: + # Assumes that it's always the user's (human's) turn! + prompt_copy = str(self.prompt) + + # Special Handling for "system" prompt (turn_count == 0) + if self.turn_count == 0: + sys_message = self.wrap_human(self.system_prompt + message) + prompt_copy += sys_message + + else: + human_message = self.wrap_human(message) + prompt_copy += human_message + + return prompt_copy.removeprefix(self.bos).rstrip() + + def get_prompt(self) -> str: + # Remove prefix because it gets auto-inserted by tokenizer! + return self.prompt.removeprefix(self.bos).rstrip() diff --git a/prismatic/models/backbones/llm/prompting/mistral_instruct_prompter.py b/prismatic/models/backbones/llm/prompting/mistral_instruct_prompter.py new file mode 100644 index 0000000000000000000000000000000000000000..35a5eab8b3caf31ea671727230d1aed29903dde3 --- /dev/null +++ b/prismatic/models/backbones/llm/prompting/mistral_instruct_prompter.py @@ -0,0 +1,60 @@ +""" +mistral_instruct_prompter.py + +Defines a PromptBuilder for building Mistral Instruct Chat Prompts --> recommended pattern used by HF / online tutorial.s + +Reference: https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.1#instruction-format +""" + +from typing import Optional + +from prismatic.models.backbones.llm.prompting.base_prompter import PromptBuilder + + +class MistralInstructPromptBuilder(PromptBuilder): + def __init__(self, model_family: str, system_prompt: Optional[str] = None) -> None: + super().__init__(model_family, system_prompt) + + # Note =>> Mistral Tokenizer is an instance of `LlamaTokenizer(Fast)` + # =>> Mistral Instruct *does not* use a System Prompt + self.bos, self.eos = "", "" + + # Get role-specific "wrap" functions + self.wrap_human = lambda msg: f"[INST] {msg} [/INST] " + self.wrap_gpt = lambda msg: f"{msg if msg != '' else ' '}{self.eos}" + + # === `self.prompt` gets built up over multiple turns === + self.prompt, self.turn_count = "", 0 + + def add_turn(self, role: str, message: str) -> str: + assert (role == "human") if (self.turn_count % 2 == 0) else (role == "gpt") + message = message.replace("", "").strip() + + if (self.turn_count % 2) == 0: + human_message = self.wrap_human(message) + wrapped_message = human_message + else: + gpt_message = self.wrap_gpt(message) + wrapped_message = gpt_message + + # Update Prompt + self.prompt += wrapped_message + + # Bump Turn Counter + self.turn_count += 1 + + # Return "wrapped_message" (effective string added to context) + return wrapped_message + + def get_potential_prompt(self, message: str) -> None: + # Assumes that it's always the user's (human's) turn! + prompt_copy = str(self.prompt) + + human_message = self.wrap_human(message) + prompt_copy += human_message + + return prompt_copy.removeprefix(self.bos).rstrip() + + def get_prompt(self) -> str: + # Remove prefix because it gets auto-inserted by tokenizer! + return self.prompt.removeprefix(self.bos).rstrip() diff --git a/prismatic/models/backbones/llm/prompting/phi_prompter.py b/prismatic/models/backbones/llm/prompting/phi_prompter.py new file mode 100644 index 0000000000000000000000000000000000000000..b350ea3a6b6c448cc65caa6fa72db3c253dd0b7b --- /dev/null +++ b/prismatic/models/backbones/llm/prompting/phi_prompter.py @@ -0,0 +1,65 @@ +""" +phi_prompter.py + +Defines a PromptBuilder for building Phi-2 Input/Output Prompts --> recommended pattern used by HF / Microsoft. +Also handles Phi special case BOS token additions. + +Reference: https://huggingface.co/microsoft/phi-2#qa-format +""" + +from typing import Optional + +from prismatic.models.backbones.llm.prompting.base_prompter import PromptBuilder + + +class PhiPromptBuilder(PromptBuilder): + def __init__(self, model_family: str, system_prompt: Optional[str] = None) -> None: + super().__init__(model_family, system_prompt) + + # Note =>> Phi Tokenizer is an instance of `CodeGenTokenizer(Fast)` + # =>> By default, does *not* append / tokens --> we handle that here (IMPORTANT)! + self.bos, self.eos = "<|endoftext|>", "<|endoftext|>" + + # Get role-specific "wrap" functions + # =>> Note that placement of / were based on experiments generating from Phi-2 in Input/Output mode + self.wrap_human = lambda msg: f"Input: {msg}\nOutput: " + self.wrap_gpt = lambda msg: f"{msg if msg != '' else ' '}\n{self.eos}" + + # === `self.prompt` gets built up over multiple turns === + self.prompt, self.turn_count = "", 0 + + def add_turn(self, role: str, message: str) -> str: + assert (role == "human") if (self.turn_count % 2 == 0) else (role == "gpt") + message = message.replace("", "").strip() + + # Special Handling for "first" input --> prepend a token (expected by Prismatic) + if self.turn_count == 0: + bos_human_message = f"{self.bos}{self.wrap_human(message)}" + wrapped_message = bos_human_message + elif (self.turn_count % 2) == 0: + human_message = self.wrap_human(message) + wrapped_message = human_message + else: + gpt_message = self.wrap_gpt(message) + wrapped_message = gpt_message + + # Update Prompt + self.prompt += wrapped_message + + # Bump Turn Counter + self.turn_count += 1 + + # Return "wrapped_message" (effective string added to context) + return wrapped_message + + def get_potential_prompt(self, message: str) -> None: + # Assumes that it's always the user's (human's) turn! + prompt_copy = str(self.prompt) + + human_message = self.wrap_human(message) + prompt_copy += human_message + + return prompt_copy.rstrip() + + def get_prompt(self) -> str: + return self.prompt.rstrip() diff --git a/prismatic/models/backbones/llm/prompting/vicuna_v15_prompter.py b/prismatic/models/backbones/llm/prompting/vicuna_v15_prompter.py new file mode 100644 index 0000000000000000000000000000000000000000..f5e79e443a791a149cec09a81d55b1a3788eb9de --- /dev/null +++ b/prismatic/models/backbones/llm/prompting/vicuna_v15_prompter.py @@ -0,0 +1,82 @@ +""" +vicuna_v15_prompter.py + +Defines a PromptBuilder for building Vicuna-v1.5 Chat Prompts. + +Reference: https://huggingface.co/lmsys/vicuna-13b-v1.5 +""" + +from typing import Optional + +from prismatic.models.backbones.llm.prompting.base_prompter import PromptBuilder + +# Default System Prompt for LLaVa Models +SYS_PROMPTS = { + "prismatic": ( + "A chat between a curious user and an artificial intelligence assistant. " + "The assistant gives helpful, detailed, and polite answers to the user's questions." + ), + "openvla": ( + "A chat between a curious user and an artificial intelligence assistant. " + "The assistant gives helpful, detailed, and polite answers to the user's questions." + ), +} + + +class VicunaV15ChatPromptBuilder(PromptBuilder): + def __init__(self, model_family: str, system_prompt: Optional[str] = None) -> None: + super().__init__(model_family, system_prompt) + self.system_prompt = (SYS_PROMPTS[self.model_family] if system_prompt is None else system_prompt).strip() + " " + + # LLaMa-2 Specific + self.bos, self.eos = "", "" + + # Get role-specific "wrap" functions + self.wrap_human = lambda msg: f"USER: {msg} ASSISTANT: " + self.wrap_gpt = lambda msg: f"{msg if msg != '' else ' '}{self.eos}" + + # === `self.prompt` gets built up over multiple turns === + self.prompt, self.turn_count = "", 0 + + def add_turn(self, role: str, message: str) -> str: + assert (role == "human") if (self.turn_count % 2 == 0) else (role == "gpt") + message = message.replace("", "").strip() + + # Special Handling for "system" prompt (turn_count == 0) + if self.turn_count == 0: + sys_message = self.system_prompt + self.wrap_human(message) + wrapped_message = sys_message + elif (self.turn_count % 2) == 0: + human_message = self.wrap_human(message) + wrapped_message = human_message + else: + gpt_message = self.wrap_gpt(message) + wrapped_message = gpt_message + + # Update Prompt + self.prompt += wrapped_message + + # Bump Turn Counter + self.turn_count += 1 + + # Return "wrapped_message" (effective string added to context) + return wrapped_message + + def get_potential_prompt(self, message: str) -> None: + # Assumes that it's always the user's (human's) turn! + prompt_copy = str(self.prompt) + + # Special Handling for "system" prompt (turn_count == 0) + if self.turn_count == 0: + sys_message = self.system_prompt + self.wrap_human(message) + prompt_copy += sys_message + + else: + human_message = self.wrap_human(message) + prompt_copy += human_message + + return prompt_copy.removeprefix(self.bos).rstrip() + + def get_prompt(self) -> str: + # Remove prefix (if exists) because it gets auto-inserted by tokenizer! + return self.prompt.removeprefix(self.bos).rstrip() diff --git a/prismatic/models/backbones/vision/__init__.py b/prismatic/models/backbones/vision/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..d9f33f8aa95245bdeb84f16b78ebb97b0e2bcd13 --- /dev/null +++ b/prismatic/models/backbones/vision/__init__.py @@ -0,0 +1,7 @@ +from .base_vision import ImageTransform, VisionBackbone +from .clip_vit import CLIPViTBackbone +from .dinoclip_vit import DinoCLIPViTBackbone +from .dinosiglip_vit import DinoSigLIPViTBackbone +from .dinov2_vit import DinoV2ViTBackbone +from .in1k_vit import IN1KViTBackbone +from .siglip_vit import SigLIPViTBackbone diff --git a/prismatic/models/backbones/vision/base_vision.py b/prismatic/models/backbones/vision/base_vision.py new file mode 100644 index 0000000000000000000000000000000000000000..e9ccadee35f08bf81e370824dcab5feffdfffd0c --- /dev/null +++ b/prismatic/models/backbones/vision/base_vision.py @@ -0,0 +1,207 @@ +""" +base_vision.py + +Abstract class definition of a Vision Backbone (Visual Featurizer), with full annotations of class methods, utility +functions, and initialization logic. + +We also define the generic TimmViTBackbone class here, providing a default interface for loading any TIMM Vision +Transformer model for feature extraction. +""" + +from abc import ABC, abstractmethod +from dataclasses import dataclass +from functools import partial +from typing import Any, Callable, Dict, Optional, Protocol, Tuple, Union + +import timm +import torch +import torch.nn as nn +import torchvision.transforms.functional as TVF +from PIL.Image import Image +from timm.models.vision_transformer import Block, VisionTransformer +from torch.distributed.fsdp.wrap import _module_wrap_policy, _or_policy, transformer_auto_wrap_policy +from torchvision.transforms import Compose, Resize + + +# === Utility Functions for Monkey-Patching === +def unpack_tuple(fn: Callable[[Any], Tuple[Any]]) -> Callable[[Any], Any]: + def wrapper(*args: Any, **kwargs: Any) -> Any: + result = fn(*args, **kwargs) + return result[0] if isinstance(result, tuple) else result + + return wrapper + + +# === Interface for an Image Transform === +class ImageTransform(Protocol): + def __call__(self, img: Image, **kwargs: str) -> Union[torch.Tensor, Dict[str, torch.Tensor]]: ... + + +# === Custom Torchvision Image Transforms === +@dataclass +class LetterboxPad: + padding_fill_value: Tuple[int, int, int] + + def __call__(self, image: Image) -> Image: + """Given a PIL.Image, pad to square by adding a symmetric border around the height/width.""" + (w, h), max_wh = image.size, max(image.size) + horizontal_pad, vertical_pad = int((max_wh - w) / 2), int((max_wh - h) / 2) + padding = (horizontal_pad, vertical_pad, horizontal_pad, vertical_pad) + return TVF.pad(image, padding, fill=self.padding_fill_value, padding_mode="constant") + + +# === Abstract Base Class for arbitrary Vision Backbones === +class VisionBackbone(nn.Module, ABC): + def __init__(self, vision_backbone_id: str, image_resize_strategy: str, default_image_size: int = 224) -> None: + super().__init__() + self.identifier: str = vision_backbone_id + self.image_resize_strategy: str = image_resize_strategy + self.default_image_size: int = default_image_size + + # Instance attributes for a Vision Backbone + self.featurizer: nn.Module = None + self.image_transform: ImageTransform = None + + def get_image_transform(self) -> ImageTransform: + return self.image_transform + + @abstractmethod + def get_fsdp_wrapping_policy(self) -> Callable: ... + + @abstractmethod + def forward(self, pixel_values: torch.Tensor) -> torch.Tensor: + """Run a forward pass through the featurizer given a set of processed images, returning patch/grid features.""" + raise NotImplementedError + + @property + @abstractmethod + def default_image_resolution(self) -> Tuple[int, int, int]: ... + + @property + @abstractmethod + def embed_dim(self) -> int: ... + + @property + @abstractmethod + def num_patches(self) -> int: ... + + @property + @abstractmethod + def half_precision_dtype(self) -> torch.dtype: ... + + +# === Abstract Base Class for Arbitrary TIMM Vision Transformer Backbones === +class TimmViTBackbone(VisionBackbone, ABC): + def __init__( + self, + vision_backbone_id: str, + timm_path_or_url: str, + image_resize_strategy: str, + default_image_size: int = 224, + override_act_layer: Optional[str] = None, + ) -> None: + super().__init__(vision_backbone_id, image_resize_strategy, default_image_size=default_image_size) + self.timm_path_or_url = timm_path_or_url + self.override_act_layer = override_act_layer + self.dtype = torch.bfloat16 + + # Initialize Featurizer (ViT) by downloading from HF / TIMM Hub if necessary + if self.override_act_layer is None: + self.featurizer: VisionTransformer = timm.create_model( + self.timm_path_or_url, pretrained=True, num_classes=0, img_size=self.default_image_size + ) + else: + self.featurizer: VisionTransformer = timm.create_model( + self.timm_path_or_url, + pretrained=True, + num_classes=0, + img_size=self.default_image_size, + act_layer=self.override_act_layer, + ) + self.featurizer.eval() + + # Monkey-Patch the `forward()` function of the featurizer to ensure FSDP-compatibility + # => Note: By default set `get_intermediate_layers` to return the *SECOND-TO-LAST* layer patches! + # => TODO (siddk) Remove after resolution of https://github.com/pytorch/pytorch/issues/109385 + self.featurizer.forward = unpack_tuple( + partial(self.featurizer.get_intermediate_layers, n={len(self.featurizer.blocks) - 2}) + ) + + # Validation =>> for now, this class *only* supports TIMM Vision Transformers (but can be extended!) + assert isinstance(self.featurizer, VisionTransformer), ( + "Featurizer is not a TIMM VisionTransformer; if you would like to support a new visual representation, " + "file an issue or implement the requisite logic (see `prismatic/models/backbones/vision/base_vision.py`)!" + ) + + # Get Config =>> Note :: Override default image size to ensure correct image transform + self.data_cfg = timm.data.resolve_model_data_config(self.featurizer) + self.data_cfg["input_size"] = (3, self.default_image_size, self.default_image_size) + + # Initialize Default Image Transform --> Modified by `self.image_resize_strategy` + default_image_transform = timm.data.create_transform(**self.data_cfg, is_training=False) + + # Fix =>> SigLIP & IN1K default transforms resize to *larger* than `self.default_image_size` (crops image)! + if "siglip" in self.timm_path_or_url or "in1k" in self.timm_path_or_url: + assert isinstance(default_image_transform, Compose), "Unexpected `default_image_transform`!" + assert isinstance(default_image_transform.transforms[0], Resize) + default_image_transform = Compose( + [ + Resize(self.default_image_size, interpolation=default_image_transform.transforms[0].interpolation), + *default_image_transform.transforms[1:], + ] + ) + + # Switch on `image_resize_strategy` + if self.image_resize_strategy == "resize-naive": + assert isinstance(default_image_transform, Compose), "Unexpected `default_image_transform`!" + assert isinstance(default_image_transform.transforms[0], Resize) + + target_size = (self.default_image_size, self.default_image_size) + self.image_transform = Compose( + [ + Resize(target_size, interpolation=default_image_transform.transforms[0].interpolation), + *default_image_transform.transforms[1:], + ] + ) + + elif self.image_resize_strategy == "resize-crop": + self.image_transform = default_image_transform + + elif self.image_resize_strategy == "letterbox": + assert isinstance(default_image_transform, Compose), "Unexpected `default_image_transform`!" + assert "mean" in self.data_cfg, "TIMM `data_cfg` missing image normalization mean!" + + # Compute Padding Fill Value (rescaled normalization mean if applicable) + fill = tuple([int(x * 255) for x in self.data_cfg["mean"]]) + + # Build New Transform + self.image_transform = Compose([LetterboxPad(fill), *default_image_transform.transforms]) + + else: + raise ValueError(f"Image Resize Strategy `{self.image_resize_strategy}` is not supported!") + + def get_fsdp_wrapping_policy(self) -> Callable: + """Return a simple FSDP policy that wraps each ViT block and then the _entire_ featurizer.""" + vit_wrap_policy = partial(_module_wrap_policy, module_classes={VisionTransformer}) + transformer_block_policy = partial(transformer_auto_wrap_policy, transformer_layer_cls={Block}) + return partial(_or_policy, policies=[vit_wrap_policy, transformer_block_policy]) + + def forward(self, pixel_values: Union[torch.Tensor, Dict[str, torch.Tensor]]) -> torch.Tensor: + """Runs transformed image/pixel tensor through vision backbone, returning _all_ patch features.""" + return self.featurizer(pixel_values) + + @property + def default_image_resolution(self) -> Tuple[int, int, int]: + return self.data_cfg["input_size"] + + @property + def embed_dim(self) -> int: + return self.featurizer.embed_dim + + @property + def num_patches(self) -> int: + return self.featurizer.patch_embed.num_patches + + @property + def half_precision_dtype(self) -> torch.dtype: + return self.dtype diff --git a/prismatic/models/backbones/vision/clip_vit.py b/prismatic/models/backbones/vision/clip_vit.py new file mode 100644 index 0000000000000000000000000000000000000000..c5d162ec9c7d22a475a71917ea4458c2351835c6 --- /dev/null +++ b/prismatic/models/backbones/vision/clip_vit.py @@ -0,0 +1,27 @@ +""" +clip_vit.py +""" + +from prismatic.models.backbones.vision.base_vision import TimmViTBackbone + +# Registry =>> Supported CLIP Vision Backbones (from TIMM) +CLIP_VISION_BACKBONES = { + "clip-vit-b": "vit_base_patch16_clip_224.openai", + "clip-vit-l": "vit_large_patch14_clip_224.openai", + "clip-vit-l-336px": "vit_large_patch14_clip_336.openai", +} + + +# [IMPORTANT] By Default, TIMM initialized OpenAI CLIP models with the standard GELU activation from PyTorch. +# HOWEVER =>> Original OpenAI models were trained with the quick_gelu *approximation* -- while it's +# a decent approximation, the resulting features are *worse*; this was a super tricky bug +# to identify, but luckily there's an easy fix (`override_act_layer`) +class CLIPViTBackbone(TimmViTBackbone): + def __init__(self, vision_backbone_id: str, image_resize_strategy: str, default_image_size: int = 224) -> None: + super().__init__( + vision_backbone_id, + CLIP_VISION_BACKBONES[vision_backbone_id], + image_resize_strategy, + default_image_size=default_image_size, + override_act_layer="quick_gelu" if CLIP_VISION_BACKBONES[vision_backbone_id].endswith(".openai") else None, + ) diff --git a/prismatic/models/backbones/vision/dinoclip_vit.py b/prismatic/models/backbones/vision/dinoclip_vit.py new file mode 100644 index 0000000000000000000000000000000000000000..6653d69242104fecfcad8605223508edb96a7357 --- /dev/null +++ b/prismatic/models/backbones/vision/dinoclip_vit.py @@ -0,0 +1,147 @@ +""" +dinoclip_vit.py + +Vision backbone that returns concatenated features from both DINOv2 and CLIP. +""" + +from dataclasses import dataclass +from functools import partial +from typing import Callable, Dict, Tuple + +import timm +import torch +from PIL import Image +from timm.models.vision_transformer import Block, VisionTransformer +from torch.distributed.fsdp.wrap import _module_wrap_policy, _or_policy, transformer_auto_wrap_policy +from torchvision.transforms import Compose, Resize + +from prismatic.models.backbones.vision.base_vision import ImageTransform, LetterboxPad, VisionBackbone, unpack_tuple + +# Registry =>> Supported DinoCLIP Pairs (as TIMM identifiers) +DINOCLIP_VISION_BACKBONES = { + "dinoclip-vit-l-336px": { + "dino": "vit_large_patch14_reg4_dinov2.lvd142m", + "clip": "vit_large_patch14_clip_336.openai", + }, +} + + +@dataclass +class DinoCLIPImageTransform: + dino_image_transform: ImageTransform + clip_image_transform: ImageTransform + is_prismatic: bool = True + + def __call__(self, img: Image, **kwargs: str) -> Dict[str, torch.Tensor]: + return {"dino": self.dino_image_transform(img, **kwargs), "clip": self.clip_image_transform(img, **kwargs)} + + +class DinoCLIPViTBackbone(VisionBackbone): + def __init__(self, vision_backbone_id: str, image_resize_strategy: str, default_image_size: int = 224) -> None: + super().__init__(vision_backbone_id, image_resize_strategy, default_image_size=default_image_size) + self.dino_timm_path_or_url = DINOCLIP_VISION_BACKBONES[vision_backbone_id]["dino"] + self.clip_timm_path_or_url = DINOCLIP_VISION_BACKBONES[vision_backbone_id]["clip"] + + # Initialize both Featurizers (ViTs) by downloading from HF / TIMM Hub if necessary + self.dino_featurizer: VisionTransformer = timm.create_model( + self.dino_timm_path_or_url, pretrained=True, num_classes=0, img_size=self.default_image_size + ) + self.dino_featurizer.eval() + + self.clip_featurizer: VisionTransformer = timm.create_model( + self.clip_timm_path_or_url, pretrained=True, num_classes=0, img_size=self.default_image_size + ) + self.clip_featurizer.eval() + + # Monkey-Patch the `forward()` function of the featurizers to ensure FSDP-compatibility + # => Note: By default set `get_intermediate_layers` to return the *SECOND-TO-LAST* layer patches! + # => TODO (siddk) Remove after resolution of https://github.com/pytorch/pytorch/issues/109385 + self.dino_featurizer.forward = unpack_tuple( + partial(self.dino_featurizer.get_intermediate_layers, n={len(self.dino_featurizer.blocks) - 2}) + ) + self.clip_featurizer.forward = unpack_tuple( + partial(self.clip_featurizer.get_intermediate_layers, n={len(self.clip_featurizer.blocks) - 2}) + ) + + # Get Configs for _both_ Featurizers =>> Note :: Override default image size for larger resolution models + self.dino_data_cfg = timm.data.resolve_model_data_config(self.dino_featurizer) + self.dino_data_cfg["input_size"] = (3, self.default_image_size, self.default_image_size) + + self.clip_data_cfg = timm.data.resolve_model_data_config(self.clip_featurizer) + self.clip_data_cfg["input_size"] = (3, self.default_image_size, self.default_image_size) + + # Initialize *both* Transforms + default_dino_transform = timm.data.create_transform(**self.dino_data_cfg, is_training=False) + default_clip_transform = timm.data.create_transform(**self.clip_data_cfg, is_training=False) + if self.image_resize_strategy == "resize-naive": + assert isinstance(default_dino_transform, Compose), "Unexpected `default_dino_image_transform`!" + assert isinstance(default_clip_transform, Compose), "Unexpected `default_clip_image_transform`!" + assert isinstance(default_dino_transform.transforms[0], Resize) + assert isinstance(default_clip_transform.transforms[0], Resize) + + target_size = (self.default_image_size, self.default_image_size) + dino_transform = Compose( + [ + Resize(target_size, interpolation=default_dino_transform.transforms[0].interpolation), + *default_dino_transform.transforms[1:], + ] + ) + clip_transform = Compose( + [ + Resize(target_size, interpolation=default_clip_transform.transforms[0].interpolation), + *default_clip_transform.transforms[1:], + ] + ) + + self.image_transform = DinoCLIPImageTransform(dino_transform, clip_transform) + + elif self.image_resize_strategy == "resize-crop": + self.image_transform = DinoCLIPImageTransform(default_dino_transform, default_clip_transform) + + elif self.image_resize_strategy == "letterbox": + assert isinstance(default_dino_transform, Compose), "Unexpected `default_dino_transform`!" + assert isinstance(default_clip_transform, Compose), "Unexpected `default_clip_transform`!" + assert "mean" in self.dino_data_cfg and "mean" in self.clip_data_cfg, "DinoCLIP `data_cfg` missing `mean`!" + + # Compute Padding Fill Value(s) (rescaled normalization mean if applicable) + dino_fill = tuple([int(x * 255) for x in self.dino_data_cfg["mean"]]) + clip_fill = tuple([int(x * 255) for x in self.clip_data_cfg["mean"]]) + + # Build New Transform + self.image_transform = DinoCLIPImageTransform( + Compose([LetterboxPad(dino_fill), *default_dino_transform.transforms]), + Compose([LetterboxPad(clip_fill), *default_clip_transform.transforms]), + ) + + else: + raise ValueError(f"Image Resize Strategy `{self.image_resize_strategy}` is not supported!") + + def get_fsdp_wrapping_policy(self) -> Callable: + """Return a simple FSDP policy that wraps each ViT block and then both of the _entire_ featurizers.""" + vit_wrap_policy = partial(_module_wrap_policy, module_classes={VisionTransformer}) + transformer_block_policy = partial(transformer_auto_wrap_policy, transformer_layer_cls={Block}) + return partial(_or_policy, policies=[vit_wrap_policy, transformer_block_policy]) + + def forward(self, pixel_values: Dict[str, torch.Tensor]) -> torch.Tensor: + """Runs the transformed image/pixel tensors through each vision backbone, returning concatenated patches.""" + dino_patches = self.dino_featurizer(pixel_values["dino"]) + clip_patches = self.clip_featurizer(pixel_values["clip"]) + + return torch.cat([dino_patches, clip_patches], dim=2) + + @property + def default_image_resolution(self) -> Tuple[int, int, int]: + return self.dino_data_cfg["input_size"] + + @property + def embed_dim(self) -> int: + return self.dino_featurizer.embed_dim + self.clip_featurizer.embed_dim + + @property + def num_patches(self) -> int: + assert self.dino_featurizer.patch_embed.num_patches == self.clip_featurizer.patch_embed.num_patches + return self.dino_featurizer.patch_embed.num_patches + + @property + def half_precision_dtype(self) -> torch.dtype: + return torch.bfloat16 diff --git a/prismatic/models/backbones/vision/dinosiglip_vit.py b/prismatic/models/backbones/vision/dinosiglip_vit.py new file mode 100644 index 0000000000000000000000000000000000000000..bced8d179db3bf3ee488bbec36b0af7b4a1001cb --- /dev/null +++ b/prismatic/models/backbones/vision/dinosiglip_vit.py @@ -0,0 +1,164 @@ +""" +dinosiglip_vit.py + +Vision backbone that returns concatenated features from both DINOv2 and SigLIP. +""" + +from dataclasses import dataclass +from functools import partial +from typing import Callable, Dict, Tuple + +import timm +import torch +from PIL import Image +from timm.models.vision_transformer import Block, VisionTransformer +from torch.distributed.fsdp.wrap import _module_wrap_policy, _or_policy, transformer_auto_wrap_policy +from torchvision.transforms import Compose, Resize + +from prismatic.models.backbones.vision.base_vision import ImageTransform, LetterboxPad, VisionBackbone, unpack_tuple + +# Registry =>> Supported DinoSigLIP Pairs (as TIMM identifiers) +DINOSigLIP_VISION_BACKBONES = { + "dinosiglip-vit-so-224px": { + "dino": "vit_large_patch14_reg4_dinov2.lvd142m", + "siglip": "vit_so400m_patch14_siglip_224", + }, + "dinosiglip-vit-so-384px": { + "dino": "vit_large_patch14_reg4_dinov2.lvd142m", + "siglip": "vit_so400m_patch14_siglip_384", + }, +} + + +@dataclass +class DinoSigLIPImageTransform: + dino_image_transform: ImageTransform + siglip_image_transform: ImageTransform + is_prismatic: bool = True + + def __call__(self, img: Image, **kwargs: str) -> Dict[str, torch.Tensor]: + return {"dino": self.dino_image_transform(img, **kwargs), "siglip": self.siglip_image_transform(img, **kwargs)} + + +class DinoSigLIPViTBackbone(VisionBackbone): + def __init__(self, vision_backbone_id: str, image_resize_strategy: str, default_image_size: int = 224) -> None: + super().__init__(vision_backbone_id, image_resize_strategy, default_image_size=default_image_size) + self.dino_timm_path_or_url = DINOSigLIP_VISION_BACKBONES[vision_backbone_id]["dino"] + self.siglip_timm_path_or_url = DINOSigLIP_VISION_BACKBONES[vision_backbone_id]["siglip"] + + # Initialize both Featurizers (ViTs) by downloading from HF / TIMM Hub if necessary + self.dino_featurizer: VisionTransformer = timm.create_model( + self.dino_timm_path_or_url, pretrained=True, num_classes=0, img_size=self.default_image_size + ) + self.dino_featurizer.eval() + + self.siglip_featurizer: VisionTransformer = timm.create_model( + self.siglip_timm_path_or_url, pretrained=True, num_classes=0, img_size=self.default_image_size + ) + self.siglip_featurizer.eval() + + # Monkey-Patch the `forward()` function of the featurizers to ensure FSDP-compatibility + # => Note: By default set `get_intermediate_layers` to return the *SECOND-TO-LAST* layer patches! + # => TODO (siddk) Remove after resolution of https://github.com/pytorch/pytorch/issues/109385 + self.dino_featurizer.forward = unpack_tuple( + partial(self.dino_featurizer.get_intermediate_layers, n={len(self.dino_featurizer.blocks) - 2}) + ) + self.siglip_featurizer.forward = unpack_tuple( + partial(self.siglip_featurizer.get_intermediate_layers, n={len(self.siglip_featurizer.blocks) - 2}) + ) + + # Get Configs for _both_ Featurizers =>> Note :: Override default image size for larger resolution models + self.dino_data_cfg = timm.data.resolve_model_data_config(self.dino_featurizer) + self.dino_data_cfg["input_size"] = (3, self.default_image_size, self.default_image_size) + + self.siglip_data_cfg = timm.data.resolve_model_data_config(self.siglip_featurizer) + self.siglip_data_cfg["input_size"] = (3, self.default_image_size, self.default_image_size) + + # Initialize *both* Transforms + default_dino_transform = timm.data.create_transform(**self.dino_data_cfg, is_training=False) + default_siglip_transform = timm.data.create_transform(**self.siglip_data_cfg, is_training=False) + + # Fix =>> SigLIP default transform resizes to *larger* than `self.default_image_size` (crops image)!! + assert isinstance(default_siglip_transform, Compose), "Unexpected `default_image_transform`!" + assert isinstance(default_siglip_transform.transforms[0], Resize) + default_siglip_transform = Compose( + [ + Resize(self.default_image_size, interpolation=default_siglip_transform.transforms[0].interpolation), + *default_siglip_transform.transforms[1:], + ] + ) + + if self.image_resize_strategy == "resize-naive": + assert isinstance(default_dino_transform, Compose), "Unexpected `default_dino_image_transform`!" + assert isinstance(default_siglip_transform, Compose), "Unexpected `default_siglip_image_transform`!" + assert isinstance(default_dino_transform.transforms[0], Resize) + assert isinstance(default_siglip_transform.transforms[0], Resize) + + target_size = (self.default_image_size, self.default_image_size) + dino_transform = Compose( + [ + Resize(target_size, interpolation=default_dino_transform.transforms[0].interpolation), + *default_dino_transform.transforms[1:], + ] + ) + siglip_transform = Compose( + [ + Resize(target_size, interpolation=default_siglip_transform.transforms[0].interpolation), + *default_siglip_transform.transforms[1:], + ] + ) + + self.image_transform = DinoSigLIPImageTransform(dino_transform, siglip_transform) + + elif self.image_resize_strategy == "resize-crop": + self.image_transform = DinoSigLIPImageTransform(default_dino_transform, default_siglip_transform) + + elif self.image_resize_strategy == "letterbox": + assert isinstance(default_dino_transform, Compose), "Unexpected `default_dino_transform`!" + assert isinstance(default_siglip_transform, Compose), "Unexpected `default_siglip_transform`!" + assert ( + "mean" in self.dino_data_cfg and "mean" in self.siglip_data_cfg + ), "DinoSigLIP `data_cfg` missing `mean`!" + + # Compute Padding Fill Value(s) (rescaled normalization mean if applicable) + dino_fill = tuple([int(x * 255) for x in self.dino_data_cfg["mean"]]) + siglip_fill = tuple([int(x * 255) for x in self.siglip_data_cfg["mean"]]) + + # Build New Transform + self.image_transform = DinoSigLIPImageTransform( + Compose([LetterboxPad(dino_fill), *default_dino_transform.transforms]), + Compose([LetterboxPad(siglip_fill), *default_siglip_transform.transforms]), + ) + + else: + raise ValueError(f"Image Resize Strategy `{self.image_resize_strategy}` is not supported!") + + def get_fsdp_wrapping_policy(self) -> Callable: + """Return a simple FSDP policy that wraps each ViT block and then both of the _entire_ featurizers.""" + vit_wrap_policy = partial(_module_wrap_policy, module_classes={VisionTransformer}) + transformer_block_policy = partial(transformer_auto_wrap_policy, transformer_layer_cls={Block}) + return partial(_or_policy, policies=[vit_wrap_policy, transformer_block_policy]) + + def forward(self, pixel_values: Dict[str, torch.Tensor]) -> torch.Tensor: + """Runs the transformed image/pixel tensors through each vision backbone, returning concatenated patches.""" + dino_patches = self.dino_featurizer(pixel_values["dino"]) + siglip_patches = self.siglip_featurizer(pixel_values["siglip"]) + + return torch.cat([dino_patches, siglip_patches], dim=2) + + @property + def default_image_resolution(self) -> Tuple[int, int, int]: + return self.dino_data_cfg["input_size"] + + @property + def embed_dim(self) -> int: + return self.dino_featurizer.embed_dim + self.siglip_featurizer.embed_dim + + @property + def num_patches(self) -> int: + assert self.dino_featurizer.patch_embed.num_patches == self.siglip_featurizer.patch_embed.num_patches + return self.dino_featurizer.patch_embed.num_patches + + @property + def half_precision_dtype(self) -> torch.dtype: + return torch.bfloat16 diff --git a/prismatic/models/backbones/vision/dinov2_vit.py b/prismatic/models/backbones/vision/dinov2_vit.py new file mode 100644 index 0000000000000000000000000000000000000000..e1e858827100833764394f062693b1dd38c6015a --- /dev/null +++ b/prismatic/models/backbones/vision/dinov2_vit.py @@ -0,0 +1,19 @@ +""" +dinov2_vit.py +""" + +from prismatic.models.backbones.vision.base_vision import TimmViTBackbone + +# Registry =>> Supported DINOv2 Vision Backbones (from TIMM) =>> Note:: Using DINOv2 w/ Registers! +# => Reference: https://arxiv.org/abs/2309.16588 +DINOv2_VISION_BACKBONES = {"dinov2-vit-l": "vit_large_patch14_reg4_dinov2.lvd142m"} + + +class DinoV2ViTBackbone(TimmViTBackbone): + def __init__(self, vision_backbone_id: str, image_resize_strategy: str, default_image_size: int = 224) -> None: + super().__init__( + vision_backbone_id, + DINOv2_VISION_BACKBONES[vision_backbone_id], + image_resize_strategy, + default_image_size=default_image_size, + ) diff --git a/prismatic/models/backbones/vision/in1k_vit.py b/prismatic/models/backbones/vision/in1k_vit.py new file mode 100644 index 0000000000000000000000000000000000000000..9f0f6d3e0082e7e15f963b4123cdabe1deceb55d --- /dev/null +++ b/prismatic/models/backbones/vision/in1k_vit.py @@ -0,0 +1,22 @@ +""" +in1k_vit.py + +Vision Transformers trained / finetuned on ImageNet (ImageNet-21K =>> ImageNet-1K) +""" + +from prismatic.models.backbones.vision.base_vision import TimmViTBackbone + +# Registry =>> Supported Vision Backbones (from TIMM) +IN1K_VISION_BACKBONES = { + "in1k-vit-l": "vit_large_patch16_224.augreg_in21k_ft_in1k", +} + + +class IN1KViTBackbone(TimmViTBackbone): + def __init__(self, vision_backbone_id: str, image_resize_strategy: str, default_image_size: int = 224) -> None: + super().__init__( + vision_backbone_id, + IN1K_VISION_BACKBONES[vision_backbone_id], + image_resize_strategy, + default_image_size=default_image_size, + ) diff --git a/prismatic/models/backbones/vision/siglip_vit.py b/prismatic/models/backbones/vision/siglip_vit.py new file mode 100644 index 0000000000000000000000000000000000000000..2ffdbbe8d343524d0914e40742dcfa2f11e7cd4b --- /dev/null +++ b/prismatic/models/backbones/vision/siglip_vit.py @@ -0,0 +1,24 @@ +""" +siglip_vit.py +""" + +from prismatic.models.backbones.vision.base_vision import TimmViTBackbone + +# Registry =>> Supported SigLIP Vision Backbones (from TIMM) =>> Note:: Using SigLIP w/ Patch = 14 (but SO400M Arch) +SIGLIP_VISION_BACKBONES = { + "siglip-vit-b16-224px": "vit_base_patch16_siglip_224", + "siglip-vit-b16-256px": "vit_base_patch16_siglip_256", + "siglip-vit-b16-384px": "vit_base_patch16_siglip_384", + "siglip-vit-so400m": "vit_so400m_patch14_siglip_224", + "siglip-vit-so400m-384px": "vit_so400m_patch14_siglip_384", +} + + +class SigLIPViTBackbone(TimmViTBackbone): + def __init__(self, vision_backbone_id: str, image_resize_strategy: str, default_image_size: int = 224) -> None: + super().__init__( + vision_backbone_id, + SIGLIP_VISION_BACKBONES[vision_backbone_id], + image_resize_strategy, + default_image_size=default_image_size, + ) diff --git a/prismatic/models/film_vit_wrapper.py b/prismatic/models/film_vit_wrapper.py new file mode 100644 index 0000000000000000000000000000000000000000..94618ca967783d94ff347195641335b912cc66b2 --- /dev/null +++ b/prismatic/models/film_vit_wrapper.py @@ -0,0 +1,276 @@ +"""Implementation of additional modules for the VLA's vision transformer.""" + +from functools import partial +from typing import Any, Callable, Sequence, Tuple, Union + +import torch +import torch.nn as nn +from timm.models.vision_transformer import VisionTransformer + + +class FiLMedVisionTransformerBlock(nn.Module): + """ + Wrapper for ViT blocks that adds components to implement FiLM language conditioning. + + Modulates visual feature embeddings via + x = (1 + gamma) * x + beta, + where x is visual feature and gamma and beta are learned projections of the average language embedding. + gamma and beta have D dimensions each, where D is the number of hidden dimensions in the ViT's features. + + NOTE #1 (Moo Jin): + In convolutional neural architectures, the "feature" in FiLM is an entire feature map, i.e., each channel in a + convolutional layer (so gamma and beta have C dimensions, where C is the number of channels). Therefore, FiLM's + scaling and shifting is applied across all spatial locations for conv nets -- i.e., it is spatially agnostic. + + For vision transformer architectures, you may consider individual patch embeddings as individual "features" at first + instinct, but this would make FiLM scaling and shifting spatially local. In order to make the modulation spatially + global like in convolutional architectures, we should apply the scaling and shifting to each dimension of each patch + embedding. I.e., gamma and beta should have D dimensions, where D is the number of dimensions in a visual embedding. + + NOTE #2 (Moo Jin): + x = (1 + gamma) * x + beta is used in the original FiLM paper as opposed to x = gamma * x + beta (see section 7.2 in + https://arxiv.org/pdf/1709.07871.pdf). Since gamma and beta are close to zero upon initialization, this leads to an + identity transformation at the beginning of training, which minimizes perturbation to the pretrained representation. + """ + + def __init__( + self, + block, + vision_dim: int, + llm_dim: int, + ): + """ + Initializes FiLM ViT block wrapper. + + Args: + block (timm.models.vision_transformer.Block): Vision transformer block. + vision_dim (int): Number of hidden dimensions in visual embeddings. + llm_dim (int): Number of hidden dimensions in language embeddings. + """ + super().__init__() + self.block = block + # Initialize gamma and beta projectors + self.scale = nn.Linear(llm_dim, vision_dim) + self.shift = nn.Linear(llm_dim, vision_dim) + + def forward(self, x, average_language_embedding): + """ + Overrides the vision transformer block forward pass to use FiLM. + + Args: + x (torch.Tensor): Visual input embeddings, (batch_size, vision_seq_len, vision_dim). + average_language_embedding (torch.Tensor): Average language embedding for task, (batch_size, llm_dim). + """ + # Project average language embedding to visual embedding space to get gamma and beta + gamma = self.scale(average_language_embedding) # (batch_size, vision_dim) + beta = self.shift(average_language_embedding) # (batch_size, vision_dim) + + # Pass visual inputs through attention portion of original block + x = x + self.block.drop_path1(self.block.ls1(self.block.attn(self.block.norm1(x)))) + + # Modulate intermediate visual representations via FiLM + x = x * (1 + gamma.view(gamma.shape[0], 1, gamma.shape[1])) + beta.view(beta.shape[0], 1, beta.shape[1]) + + # Pass visual inputs through feedforward portion of original block + x = x + self.block.drop_path2(self.block.ls2(self.block.mlp(self.block.norm2(x)))) + + return x + + +class NullVisionTransformerBlockWrapper(nn.Module): + """ + Null wrapper for ViT blocks that doesn't do anything; just calls the original block's forward function. + Useful if you want to use a block wrapper every X blocks instead of every block (e.g., to reduce the number of new + parameters introduced by a new wrapper). + """ + + def __init__( + self, + block, + ): + super().__init__() + self.block = block + + def forward(self, x, average_language_embedding): + return self.block(x) + + +def unpack_tuple(fn: Callable[[Any], Tuple[Any]]) -> Callable[[Any], Any]: + """Utility function for monkey-patching functions.""" + + def wrapper(*args: Any, **kwargs: Any) -> Any: + result = fn(*args, **kwargs) + return result[0] if isinstance(result, tuple) else result + + return wrapper + + +class FiLMedVisionTransformer(VisionTransformer): + """ + Wrapper for timm.models.vision_transformer.VisionTransformer that overrides functions to enable infusing language + embeddings into visual embeddings via FiLM. + """ + + def _intermediate_layers( + self, + x: torch.Tensor, + language_embeddings: torch.Tensor, + n: Union[int, Sequence] = 1, + ): + """ + Copy of timm.models.vision_transformer.VisionTransformer._intermediate_layers() with modifications + to take in language embeddings as additional input. + """ + outputs, num_blocks = [], len(self.blocks) + take_indices = set(range(num_blocks - n, num_blocks) if isinstance(n, int) else n) + + # forward pass + x = self.patch_embed(x) + x = self._pos_embed(x) + x = self.patch_drop(x) + x = self.norm_pre(x) + for i, blk in enumerate(self.blocks): + x = blk(x, language_embeddings) # Modified to receive language_embeddings + if i in take_indices: + outputs.append(x) + + return outputs + + def get_intermediate_layers( + self, + x: torch.Tensor, + language_embeddings: torch.Tensor, + n: Union[int, Sequence] = 1, + reshape: bool = False, + return_prefix_tokens: bool = False, + norm: bool = False, + ) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]]]: + """ + Copy of timm.models.vision_transformer.VisionTransformer.get_intermediate_layers() with modifications + to allow language embeddings as additional input. + """ + # take last n blocks if n is an int, if in is a sequence, select by matching indices + outputs = self._intermediate_layers(x, language_embeddings, n) + if norm: + outputs = [self.norm(out) for out in outputs] + prefix_tokens = [out[:, 0 : self.num_prefix_tokens] for out in outputs] + outputs = [out[:, self.num_prefix_tokens :] for out in outputs] + + if reshape: + grid_size = self.patch_embed.grid_size + outputs = [ + out.reshape(x.shape[0], grid_size[0], grid_size[1], -1).permute(0, 3, 1, 2).contiguous() + for out in outputs + ] + + if return_prefix_tokens: + return tuple(zip(outputs, prefix_tokens)) + return tuple(outputs) + + +class FiLMedPrismaticVisionBackbone(nn.Module): + """ + Wrapper for OpenVLA's vision backbone that implements feature-wise linear modulation (FiLM). + + Wraps the Vision Transformers in the vision backbone to enable language conditioning through FiLM. + Supports processing 1-3 images using dual vision backbones (SigLIP + DINOv2). + """ + + def __init__( + self, + vision_backbone, + llm_dim: int = 4096, # 4096 for Llama-2 7B + ) -> None: + """ + Initializes FiLM wrapper. + + Args: + vision_backbone (PrismaticVisionBackbone): Base vision backbone. + llm_dim (int): Dimension of language model embeddings. + """ + super().__init__() + self.vision_backbone = vision_backbone + self.llm_dim = llm_dim + + # Wrap vision transformers + self._wrap_vit(self.vision_backbone.featurizer) # SigLIP + if self.vision_backbone.use_fused_vision_backbone: + self._wrap_vit(self.vision_backbone.fused_featurizer) # DINOv2 + + def _wrap_vit(self, vit) -> None: + """ + Creates wrapper around an individual vision transformer to allow for infusion of language inputs. + + Args: + vit (VisionTransformer): Original vision transformer. + """ + # Wrap vision transformer blocks + block_wrappers = [] + for block in vit.blocks: + block_wrappers.append( + FiLMedVisionTransformerBlock(block=block, vision_dim=vit.num_features, llm_dim=self.llm_dim) + ) + vit.blocks = nn.Sequential(*block_wrappers) + + # Wrap vision transformer with new class that overrides functions used for forward pass + vit.__class__ = FiLMedVisionTransformer + vit.forward = unpack_tuple(partial(vit.get_intermediate_layers, n={len(vit.blocks) - 2})) + + def get_num_patches(self) -> int: + """Returns the number of vision patches output by the vision backbone.""" + return self.vision_backbone.get_num_patches() + + def get_num_images_in_input(self) -> int: + """Returns the number of input images for the vision backbone.""" + return self.vision_backbone.get_num_images_in_input() + + def set_num_images_in_input(self, num_images_in_input: int) -> None: + """Sets the number of input images for the vision backbone.""" + self.vision_backbone.set_num_images_in_input(num_images_in_input) + + def forward(self, pixel_values: torch.Tensor, language_embeddings: torch.Tensor) -> torch.Tensor: + """ + Implements the forward pass for the vision backbone with FiLM to infuse language inputs into visual features. + + Identical to PrismaticVisionBackbone.forward() except that language embeddings are also used as input. + + Args: + pixel_values (torch.Tensor): Pixels for input image(s), (B, C, H, W). + language_embeddings (torch.Tensor): Language embeddings for the task description, (B, seq_len, llm_dim). + """ + # For FiLM: Average the language embeddings of the task description + average_language_embedding = language_embeddings.mean(dim=1) + + if self.get_num_images_in_input() == 1: + if not self.vision_backbone.use_fused_vision_backbone: + return self.vision_backbone.featurizer(pixel_values, average_language_embedding) + + # Split `pixel_values :: [bsz, 2 * 3, resolution, resolution]` =>> featurize =>> channel stack + img, img_fused = torch.split(pixel_values, [3, 3], dim=1) + patches = self.vision_backbone.featurizer(img, average_language_embedding) + patches_fused = self.vision_backbone.fused_featurizer(img_fused, average_language_embedding) + + return torch.cat([patches, patches_fused], dim=2) + + else: + assert self.vision_backbone.use_fused_vision_backbone, "Multi-image inputs require using fused backbone!" + + # Split `pixel_values` into individual images (each with 6 channels: 3 for SigLIP + 3 for DINOv2) + images = torch.split(pixel_values, [6] * self.get_num_images_in_input(), dim=1) + + # Process each image and collect patches + all_patches = [] + for img in images: + # Split each image further into two stacks of channels (each with 3 channels) + img_regular, img_fused = torch.split(img, [3, 3], dim=1) + + # Get patches from both SigLIP and DINOv2 vision transformers + patches = self.vision_backbone.featurizer(img_regular, average_language_embedding) + patches_fused = self.vision_backbone.fused_featurizer(img_fused, average_language_embedding) + + # Concatenate SigLIP and DINOv2 patches along the hidden dimension + combined_patches = torch.cat([patches, patches_fused], dim=2) + all_patches.append(combined_patches) + + # Concatenate all patches along the patch dimension + return torch.cat(all_patches, dim=1) diff --git a/prismatic/models/load.py b/prismatic/models/load.py new file mode 100644 index 0000000000000000000000000000000000000000..dba78abea7f4847d64a014ca90a527671351724c --- /dev/null +++ b/prismatic/models/load.py @@ -0,0 +1,226 @@ +""" +load.py + +Entry point for loading pretrained VLMs for inference; exposes functions for listing available models (with canonical +IDs, mappings to paper experiments, and short descriptions), as well as for loading models (from disk or HF Hub). +""" + +import json +import os +from pathlib import Path +from typing import List, Optional, Union + +from huggingface_hub import HfFileSystem, hf_hub_download + +from prismatic.conf import ModelConfig +from prismatic.models.materialize import get_llm_backbone_and_tokenizer, get_vision_backbone_and_transform +from prismatic.models.registry import GLOBAL_REGISTRY, MODEL_REGISTRY +from prismatic.models.vlas import OpenVLA +from prismatic.models.vlms import PrismaticVLM +from prismatic.overwatch import initialize_overwatch +from prismatic.vla.action_tokenizer import ActionTokenizer + +# Initialize Overwatch =>> Wraps `logging.Logger` +overwatch = initialize_overwatch(__name__) + + +# === HF Hub Repository === +HF_HUB_REPO = "TRI-ML/prismatic-vlms" +VLA_HF_HUB_REPO = "openvla/openvla-dev" + + +# === Available Models === +def available_models() -> List[str]: + return list(MODEL_REGISTRY.keys()) + + +def available_model_names() -> List[str]: + return list(GLOBAL_REGISTRY.items()) + + +def get_model_description(model_id_or_name: str) -> str: + if model_id_or_name not in GLOBAL_REGISTRY: + raise ValueError(f"Couldn't find `{model_id_or_name = }; check `prismatic.available_model_names()`") + + # Print Description & Return + print(json.dumps(description := GLOBAL_REGISTRY[model_id_or_name]["description"], indent=2)) + + return description + + +# === Load Pretrained Model === +def load( + model_id_or_path: Union[str, Path], + hf_token: Optional[str] = None, + cache_dir: Optional[Union[str, Path]] = None, + load_for_training: bool = False, +) -> PrismaticVLM: + """Loads a pretrained PrismaticVLM from either local disk or the HuggingFace Hub.""" + if os.path.isdir(model_id_or_path): + overwatch.info(f"Loading from local path `{(run_dir := Path(model_id_or_path))}`") + + # Get paths for `config.json` and pretrained checkpoint + config_json, checkpoint_pt = run_dir / "config.json", run_dir / "checkpoints" / "latest-checkpoint.pt" + assert config_json.exists(), f"Missing `config.json` for `{run_dir = }`" + assert checkpoint_pt.exists(), f"Missing checkpoint for `{run_dir = }`" + else: + if model_id_or_path not in GLOBAL_REGISTRY: + raise ValueError(f"Couldn't find `{model_id_or_path = }; check `prismatic.available_model_names()`") + + overwatch.info(f"Downloading `{(model_id := GLOBAL_REGISTRY[model_id_or_path]['model_id'])} from HF Hub") + with overwatch.local_zero_first(): + config_json = hf_hub_download(repo_id=HF_HUB_REPO, filename=f"{model_id}/config.json", cache_dir=cache_dir) + checkpoint_pt = hf_hub_download( + repo_id=HF_HUB_REPO, filename=f"{model_id}/checkpoints/latest-checkpoint.pt", cache_dir=cache_dir + ) + + # Load Model Config from `config.json` + with open(config_json, "r") as f: + model_cfg = json.load(f)["model"] + + # = Load Individual Components necessary for Instantiating a VLM = + # =>> Print Minimal Config + overwatch.info( + f"Found Config =>> Loading & Freezing [bold blue]{model_cfg['model_id']}[/] with:\n" + f" Vision Backbone =>> [bold]{model_cfg['vision_backbone_id']}[/]\n" + f" LLM Backbone =>> [bold]{model_cfg['llm_backbone_id']}[/]\n" + f" Arch Specifier =>> [bold]{model_cfg['arch_specifier']}[/]\n" + f" Checkpoint Path =>> [underline]`{checkpoint_pt}`[/]" + ) + + # Load Vision Backbone + overwatch.info(f"Loading Vision Backbone [bold]{model_cfg['vision_backbone_id']}[/]") + vision_backbone, image_transform = get_vision_backbone_and_transform( + model_cfg["vision_backbone_id"], + model_cfg["image_resize_strategy"], + ) + + # Load LLM Backbone --> note `inference_mode = True` by default when calling `load()` + overwatch.info(f"Loading Pretrained LLM [bold]{model_cfg['llm_backbone_id']}[/] via HF Transformers") + llm_backbone, tokenizer = get_llm_backbone_and_tokenizer( + model_cfg["llm_backbone_id"], + llm_max_length=model_cfg.get("llm_max_length", 2048), + hf_token=hf_token, + inference_mode=not load_for_training, + ) + + # Load VLM using `from_pretrained` (clobbers HF syntax... eventually should reconcile) + overwatch.info(f"Loading VLM [bold blue]{model_cfg['model_id']}[/] from Checkpoint") + vlm = PrismaticVLM.from_pretrained( + checkpoint_pt, + model_cfg["model_id"], + vision_backbone, + llm_backbone, + arch_specifier=model_cfg["arch_specifier"], + freeze_weights=not load_for_training, + ) + + return vlm + + +# === Load Pretrained VLA Model === +def load_vla( + model_id_or_path: Union[str, Path], + hf_token: Optional[str] = None, + cache_dir: Optional[Union[str, Path]] = None, + load_for_training: bool = False, + step_to_load: Optional[int] = None, + model_type: str = "pretrained", +) -> OpenVLA: + """Loads a pretrained OpenVLA from either local disk or the HuggingFace Hub.""" + + # TODO (siddk, moojink) :: Unify semantics with `load()` above; right now, `load_vla()` assumes path points to + # checkpoint `.pt` file, rather than the top-level run directory! + if os.path.isfile(model_id_or_path): + overwatch.info(f"Loading from local checkpoint path `{(checkpoint_pt := Path(model_id_or_path))}`") + + # [Validate] Checkpoint Path should look like `...//checkpoints/.pt` + assert (checkpoint_pt.suffix == ".pt") and (checkpoint_pt.parent.name == "checkpoints"), "Invalid checkpoint!" + run_dir = checkpoint_pt.parents[1] + + # Get paths for `config.json`, `dataset_statistics.json` and pretrained checkpoint + config_json, dataset_statistics_json = run_dir / "config.json", run_dir / "dataset_statistics.json" + assert config_json.exists(), f"Missing `config.json` for `{run_dir = }`" + assert dataset_statistics_json.exists(), f"Missing `dataset_statistics.json` for `{run_dir = }`" + + # Otherwise =>> try looking for a match on `model_id_or_path` on the HF Hub (`VLA_HF_HUB_REPO`) + else: + # Search HF Hub Repo via fsspec API + overwatch.info(f"Checking HF for `{(hf_path := str(Path(VLA_HF_HUB_REPO) / model_type / model_id_or_path))}`") + if not (tmpfs := HfFileSystem()).exists(hf_path): + raise ValueError(f"Couldn't find valid HF Hub Path `{hf_path = }`") + + # Identify Checkpoint to Load (via `step_to_load`) + step_to_load = f"{step_to_load:06d}" if step_to_load is not None else None + valid_ckpts = tmpfs.glob(f"{hf_path}/checkpoints/step-{step_to_load if step_to_load is not None else ''}*.pt") + if (len(valid_ckpts) == 0) or (step_to_load is not None and len(valid_ckpts) != 1): + raise ValueError(f"Couldn't find a valid checkpoint to load from HF Hub Path `{hf_path}/checkpoints/") + + # Call to `glob` will sort steps in ascending order (if `step_to_load` is None); just grab last element + target_ckpt = Path(valid_ckpts[-1]).name + + overwatch.info(f"Downloading Model `{model_id_or_path}` Config & Checkpoint `{target_ckpt}`") + with overwatch.local_zero_first(): + relpath = Path(model_type) / model_id_or_path + config_json = hf_hub_download( + repo_id=VLA_HF_HUB_REPO, filename=f"{(relpath / 'config.json')!s}", cache_dir=cache_dir + ) + dataset_statistics_json = hf_hub_download( + repo_id=VLA_HF_HUB_REPO, filename=f"{(relpath / 'dataset_statistics.json')!s}", cache_dir=cache_dir + ) + checkpoint_pt = hf_hub_download( + repo_id=VLA_HF_HUB_REPO, filename=f"{(relpath / 'checkpoints' / target_ckpt)!s}", cache_dir=cache_dir + ) + + # Load VLA Config (and corresponding base VLM `ModelConfig`) from `config.json` + with open(config_json, "r") as f: + vla_cfg = json.load(f)["vla"] + model_cfg = ModelConfig.get_choice_class(vla_cfg["base_vlm"])() + + # Load Dataset Statistics for Action Denormalization + with open(dataset_statistics_json, "r") as f: + norm_stats = json.load(f) + + # = Load Individual Components necessary for Instantiating a VLA (via base VLM components) = + # =>> Print Minimal Config + overwatch.info( + f"Found Config =>> Loading & Freezing [bold blue]{model_cfg.model_id}[/] with:\n" + f" Vision Backbone =>> [bold]{model_cfg.vision_backbone_id}[/]\n" + f" LLM Backbone =>> [bold]{model_cfg.llm_backbone_id}[/]\n" + f" Arch Specifier =>> [bold]{model_cfg.arch_specifier}[/]\n" + f" Checkpoint Path =>> [underline]`{checkpoint_pt}`[/]" + ) + + # Load Vision Backbone + overwatch.info(f"Loading Vision Backbone [bold]{model_cfg.vision_backbone_id}[/]") + vision_backbone, image_transform = get_vision_backbone_and_transform( + model_cfg.vision_backbone_id, + model_cfg.image_resize_strategy, + ) + + # Load LLM Backbone --> note `inference_mode = True` by default when calling `load()` + overwatch.info(f"Loading Pretrained LLM [bold]{model_cfg.llm_backbone_id}[/] via HF Transformers") + llm_backbone, tokenizer = get_llm_backbone_and_tokenizer( + model_cfg.llm_backbone_id, + llm_max_length=model_cfg.llm_max_length, + hf_token=hf_token, + inference_mode=not load_for_training, + ) + + # Create Action Tokenizer + action_tokenizer = ActionTokenizer(llm_backbone.get_tokenizer()) + + # Load VLM using `from_pretrained` (clobbers HF syntax... eventually should reconcile) + overwatch.info(f"Loading VLA [bold blue]{model_cfg.model_id}[/] from Checkpoint") + vla = OpenVLA.from_pretrained( + checkpoint_pt, + model_cfg.model_id, + vision_backbone, + llm_backbone, + arch_specifier=model_cfg.arch_specifier, + freeze_weights=not load_for_training, + norm_stats=norm_stats, + action_tokenizer=action_tokenizer, + ) + + return vla diff --git a/prismatic/models/materialize.py b/prismatic/models/materialize.py new file mode 100644 index 0000000000000000000000000000000000000000..df1ad612eceeaf1af2c8a58dc8242bceac778af5 --- /dev/null +++ b/prismatic/models/materialize.py @@ -0,0 +1,130 @@ +""" +materialize.py + +Factory class for initializing Vision Backbones, LLM Backbones, and VLMs from a set registry; provides and exports +individual functions for clear control flow. +""" + +from typing import Optional, Tuple + +from transformers import PreTrainedTokenizerBase + +from prismatic.models.backbones.llm import LLaMa2LLMBackbone, LLMBackbone, MistralLLMBackbone, PhiLLMBackbone +from prismatic.models.backbones.vision import ( + CLIPViTBackbone, + DinoCLIPViTBackbone, + DinoSigLIPViTBackbone, + DinoV2ViTBackbone, + ImageTransform, + IN1KViTBackbone, + SigLIPViTBackbone, + VisionBackbone, +) +from prismatic.models.vlms import PrismaticVLM + +# === Registries =>> Maps ID --> {cls(), kwargs} :: Different Registries for Vision Backbones, LLM Backbones, VLMs === +# fmt: off + +# === Vision Backbone Registry === +VISION_BACKBONES = { + # === 224px Backbones === + "clip-vit-l": {"cls": CLIPViTBackbone, "kwargs": {"default_image_size": 224}}, + "siglip-vit-so400m": {"cls": SigLIPViTBackbone, "kwargs": {"default_image_size": 224}}, + "dinov2-vit-l": {"cls": DinoV2ViTBackbone, "kwargs": {"default_image_size": 224}}, + "in1k-vit-l": {"cls": IN1KViTBackbone, "kwargs": {"default_image_size": 224}}, + "dinosiglip-vit-so-224px": {"cls": DinoSigLIPViTBackbone, "kwargs": {"default_image_size": 224}}, + + # === Assorted CLIP Backbones === + "clip-vit-b": {"cls": CLIPViTBackbone, "kwargs": {"default_image_size": 224}}, + "clip-vit-l-336px": {"cls": CLIPViTBackbone, "kwargs": {"default_image_size": 336}}, + + # === Assorted SigLIP Backbones === + "siglip-vit-b16-224px": {"cls": SigLIPViTBackbone, "kwargs": {"default_image_size": 224}}, + "siglip-vit-b16-256px": {"cls": SigLIPViTBackbone, "kwargs": {"default_image_size": 256}}, + "siglip-vit-b16-384px": {"cls": SigLIPViTBackbone, "kwargs": {"default_image_size": 384}}, + "siglip-vit-so400m-384px": {"cls": SigLIPViTBackbone, "kwargs": {"default_image_size": 384}}, + + # === Fused Backbones === + "dinoclip-vit-l-336px": {"cls": DinoCLIPViTBackbone, "kwargs": {"default_image_size": 336}}, + "dinosiglip-vit-so-384px": {"cls": DinoSigLIPViTBackbone, "kwargs": {"default_image_size": 384}}, +} + + +# === Language Model Registry === +LLM_BACKBONES = { + # === LLaMa-2 Pure (Non-Chat) Backbones === + "llama2-7b-pure": {"cls": LLaMa2LLMBackbone, "kwargs": {}}, + "llama2-13b-pure": {"cls": LLaMa2LLMBackbone, "kwargs": {}}, + + # === LLaMa-2 Chat Backbones === + "llama2-7b-chat": {"cls": LLaMa2LLMBackbone, "kwargs": {}}, + "llama2-13b-chat": {"cls": LLaMa2LLMBackbone, "kwargs": {}}, + + # === Vicuna-v1.5 Backbones === + "vicuna-v15-7b": {"cls": LLaMa2LLMBackbone, "kwargs": {}}, + "vicuna-v15-13b": {"cls": LLaMa2LLMBackbone, "kwargs": {}}, + + # === Mistral v0.1 Backbones === + "mistral-v0.1-7b-pure": {"cls": MistralLLMBackbone, "kwargs": {}}, + "mistral-v0.1-7b-instruct": {"cls": MistralLLMBackbone, "kwargs": {}}, + + # === Phi-2 Backbone === + "phi-2-3b": {"cls": PhiLLMBackbone, "kwargs": {}}, +} + +# fmt: on + + +def get_vision_backbone_and_transform( + vision_backbone_id: str, image_resize_strategy: str +) -> Tuple[VisionBackbone, ImageTransform]: + """Instantiate a Vision Backbone, returning both the nn.Module wrapper class and default Image Transform.""" + if vision_backbone_id in VISION_BACKBONES: + vision_cfg = VISION_BACKBONES[vision_backbone_id] + vision_backbone: VisionBackbone = vision_cfg["cls"]( + vision_backbone_id, image_resize_strategy, **vision_cfg["kwargs"] + ) + image_transform = vision_backbone.get_image_transform() + return vision_backbone, image_transform + + else: + raise ValueError(f"Vision Backbone `{vision_backbone_id}` is not supported!") + + +def get_llm_backbone_and_tokenizer( + llm_backbone_id: str, + llm_max_length: int = 2048, + hf_token: Optional[str] = None, + inference_mode: bool = False, +) -> Tuple[LLMBackbone, PreTrainedTokenizerBase]: + if llm_backbone_id in LLM_BACKBONES: + llm_cfg = LLM_BACKBONES[llm_backbone_id] + llm_backbone: LLMBackbone = llm_cfg["cls"]( + llm_backbone_id, + llm_max_length=llm_max_length, + hf_token=hf_token, + inference_mode=inference_mode, + **llm_cfg["kwargs"], + ) + tokenizer = llm_backbone.get_tokenizer() + return llm_backbone, tokenizer + + else: + raise ValueError(f"LLM Backbone `{llm_backbone_id}` is not supported!") + + +def get_vlm( + model_id: str, + arch_specifier: str, + vision_backbone: VisionBackbone, + llm_backbone: LLMBackbone, + enable_mixed_precision_training: bool = True, +) -> PrismaticVLM: + """Lightweight wrapper around initializing a VLM, mostly for future-proofing (if one wants to add a new VLM).""" + return PrismaticVLM( + model_id, + vision_backbone, + llm_backbone, + enable_mixed_precision_training=enable_mixed_precision_training, + arch_specifier=arch_specifier, + ) diff --git a/prismatic/models/projectors.py b/prismatic/models/projectors.py new file mode 100644 index 0000000000000000000000000000000000000000..ea20dade18ea8863f683a4a00cd4365c92aac2d6 --- /dev/null +++ b/prismatic/models/projectors.py @@ -0,0 +1,49 @@ +"""Implementation of additional projectors for additional inputs to the VLA models.""" +import torch +import torch.nn as nn + + +class ProprioProjector(nn.Module): + """ + Projects proprio state inputs into the LLM's embedding space. + """ + def __init__(self, llm_dim: int, proprio_dim: int) -> None: + super().__init__() + self.llm_dim = llm_dim + self.proprio_dim = proprio_dim + + self.fc1 = nn.Linear(self.proprio_dim, self.llm_dim, bias=True) + self.fc2 = nn.Linear(self.llm_dim, self.llm_dim, bias=True) + self.act_fn1 = nn.GELU() + + def forward(self, proprio: torch.Tensor = None) -> torch.Tensor: + # proprio: (bsz, proprio_dim) + projected_features = self.fc1(proprio) + projected_features = self.act_fn1(projected_features) + projected_features = self.fc2(projected_features) + return projected_features + + +class NoisyActionProjector(nn.Module): + """ + [Diffusion] Projects noisy action inputs into the LLM's embedding space. + + Note that since each action is tokenized into 7 tokens in OpenVLA (rather + than having 1 token per action), each noisy action token will have dimension 1 + instead of 7. + """ + def __init__(self, llm_dim: int) -> None: + super().__init__() + self.llm_dim = llm_dim + self.action_token_dim = 1 + + self.fc1 = nn.Linear(self.action_token_dim, self.llm_dim, bias=True) + self.fc2 = nn.Linear(self.llm_dim, self.llm_dim, bias=True) + self.act_fn1 = nn.GELU() + + def forward(self, noisy_actions: torch.Tensor = None) -> torch.Tensor: + # noisy_actions: (bsz, num_action_tokens=chunk_len*action_dim, 1) + projected_features = self.fc1(noisy_actions) + projected_features = self.act_fn1(projected_features) + projected_features = self.fc2(projected_features) + return projected_features diff --git a/prismatic/models/registry.py b/prismatic/models/registry.py new file mode 100644 index 0000000000000000000000000000000000000000..eabe6acc5fd2d567ddfa6d4e5a3894c5e8ad19a5 --- /dev/null +++ b/prismatic/models/registry.py @@ -0,0 +1,691 @@ +""" +registry.py + +Exhaustive list of pretrained VLMs (with full descriptions / links to corresponding names and sections of paper). +""" + +# === Pretrained Model Registry === +# fmt: off +MODEL_REGISTRY = { + # === LLaVa v1.5 Reproductions === + "reproduction-llava-v15+7b": { + "model_id": "reproduction-llava-v15+7b", + "names": ["LLaVa v1.5 7B (Reproduction)"], + "description": { + "name": "LLaVa v1.5 7B (Reproduction)", + "optimization_procedure": "multi-stage", + "visual_representation": "CLIP ViT-L/14 @ 336px", + "image_processing": "Letterbox", + "language_model": "Vicuña v1.5 7B", + "datasets": ["LLaVa v1.5 Instruct"], + "train_epochs": 1, + } + }, + "reproduction-llava-v15+13b": { + "model_id": "reproduction-llava-v15+13b", + "names": ["LLaVa v1.5 13B (Reproduction)"], + "description": { + "name": "LLaVa v1.5 13B (Reproduction)", + "optimization_procedure": "multi-stage", + "visual_representation": "CLIP ViT-L/14 @ 336px", + "image_processing": "Letterbox", + "language_model": "Vicuña v1.5 13B", + "datasets": ["LLaVa v1.5 Instruct"], + "train_epochs": 1, + } + }, + + # === Section 4.1 :: Optimization Procedure === + "one-stage+7b": { + "model_id": "one-stage+7b", + "names": [ + "One-Stage 7B", + "Single-Stage 7B", + "Frozen ViT (Single-Stage)", + "CLIP ViT-L 336px (Letterbox)", + "CLIP ViT-L 336px", + "Vicuña v1.5 7B", + "1 Epoch", + "Base", + ], + "description": { + "name": "Single-Stage 7B", + "optimization_procedure": "single-stage", + "visual_representation": "CLIP ViT-L/14 @ 336px", + "image_processing": "Letterbox", + "language_model": "Vicuña v1.5 7B", + "datasets": ["LLaVa v1.5 Instruct"], + "train_epochs": 1, + } + }, + "one-stage+13b": { + "model_id": "one-stage+13b", + "names": [ + "One-Stage 13B", + "Single-Stage 13B", + "Vicuña v1.5 13B", + ], + "description": { + "name": "Single-Stage 13B", + "optimization_procedure": "single-stage", + "visual_representation": "CLIP ViT-L/14 @ 336px", + "image_processing": "Letterbox", + "language_model": "Vicuña v1.5 13B", + "datasets": ["LLaVa v1.5 Instruct"], + "train_epochs": 1, + } + }, + + "full-ft-multi-stage+7b": { + "model_id": "full-ft-multi-stage+7b", + "names": ["Finetune ViT (Multi-Stage)"], + "description": { + "name": "Finetune ViT (Multi-Stage)", + "optimization_procedure": "multi-stage-full-finetune", + "visual_representation": "CLIP ViT-L/14 @ 336px", + "image_processing": "Letterbox", + "language_model": "Vicuña v1.5 7B", + "datasets": ["LLaVa v1.5 Instruct"], + "train_epochs": 1, + } + }, + "full-ft-one-stage+7b": { + "model_id": "full-ft-one-stage+7b", + "names": ["Finetune ViT (Single-Stage)"], + "description": { + "name": "Finetune ViT (Single-Stage)", + "optimization_procedure": "single-stage-full-finetune", + "visual_representation": "CLIP ViT-L/14 @ 336px", + "image_processing": "Letterbox", + "language_model": "Vicuña v1.5 7B", + "datasets": ["LLaVa v1.5 Instruct"], + "train_epochs": 1, + } + }, + + # === Section 4.2 :: Image Processing and Visual Representations === + "in1k-224px+7b": { + "model_id": "in1k-224px+7b", + "names": ["IN1K ViT-L 224px"], + "description": { + "name": "IN1K ViT-L 224px", + "optimization_procedure": "single-stage", + "visual_representation": "ImageNet-21K+1K ViT-L/16 @ 224px", + "image_processing": "Letterbox", + "language_model": "Vicuña v1.5 7B", + "datasets": ["LLaVa v1.5 Instruct"], + "train_epochs": 1, + }, + }, + "dinov2-224px+7b": { + "model_id": "dinov2-224px+7b", + "names": ["DINOv2 ViT-L 224px"], + "description": { + "name": "DINOv2 ViT-L 224px", + "optimization_procedure": "single-stage", + "visual_representation": "DINOv2 ViT-L/14 @ 224px", + "image_processing": "Letterbox", + "language_model": "Vicuña v1.5 7B", + "datasets": ["LLaVa v1.5 Instruct"], + "train_epochs": 1, + }, + }, + "clip-224px+7b": { + "model_id": "clip-224px+7b", + "names": ["CLIP ViT-L 224px"], + "description": { + "name": "CLIP ViT-L 224px", + "optimization_procedure": "single-stage", + "visual_representation": "CLIP ViT-L/14 @ 224px", + "image_processing": "Letterbox", + "language_model": "Vicuña v1.5 7B", + "datasets": ["LLaVa v1.5 Instruct"], + "train_epochs": 1, + }, + }, + "siglip-224px+7b": { + "model_id": "siglip-224px+7b", + "names": ["SigLIP ViT-SO 224px"], + "description": { + "name": "SigLIP ViT-SO 224px", + "optimization_procedure": "single-stage", + "visual_representation": "SigLIP ViT-SO/14 @ 224px", + "image_processing": "Letterbox", + "language_model": "Vicuña v1.5 7B", + "datasets": ["LLaVa v1.5 Instruct"], + "train_epochs": 1, + }, + }, + + "clip-336px-resize-crop+7b": { + "model_id": "clip-336px-resize-crop+7b", + "names": ["CLIP ViT-L 336px (Resize Crop)"], + "description": { + "name": "CLIP ViT-L 336px (Resize Crop)", + "optimization_procedure": "single-stage", + "visual_representation": "CLIP ViT-L/14 @ 336px", + "image_processing": "Resize Crop", + "language_model": "Vicuña v1.5 7B", + "datasets": ["LLaVa v1.5 Instruct"], + "train_epochs": 1, + } + }, + "clip-336px-resize-naive+7b": { + "model_id": "clip-336px-resize-naive+7b", + "names": ["CLIP ViT-L 336px (Naive Resize)", "CLIP 336px (Naive Resize)"], + "description": { + "name": "CLIP ViT-L 336px (Naive Resize)", + "optimization_procedure": "single-stage", + "visual_representation": "CLIP ViT-L/14 @ 336px", + "image_processing": "Naive Resize", + "language_model": "Vicuña v1.5 7B", + "datasets": ["LLaVa v1.5 Instruct"], + "train_epochs": 1, + } + }, + "siglip-384px-letterbox+7b": { + "model_id": "siglip-384px-letterbox+7b", + "names": ["SigLIP ViT-SO 384px (Letterbox)", "SigLIP ViT-SO 384px"], + "description": { + "name": "SigLIP ViT-SO 384px (Letterbox)", + "optimization_procedure": "single-stage", + "visual_representation": "SigLIP ViT-SO/14 @ 384px", + "image_processing": "Letterbox", + "language_model": "Vicuña v1.5 7B", + "datasets": ["LLaVa v1.5 Instruct"], + "train_epochs": 1, + } + }, + "siglip-384px-resize-crop+7b": { + "model_id": "siglip-384px-resize-crop+7b", + "names": ["SigLIP ViT-SO 384px (Resize Crop)"], + "description": { + "name": "SigLIP ViT-SO 384px (Resize Crop)", + "optimization_procedure": "single-stage", + "visual_representation": "SigLIP ViT-SO/14 @ 384px", + "image_processing": "Resize Crop", + "language_model": "Vicuña v1.5 7B", + "datasets": ["LLaVa v1.5 Instruct"], + "train_epochs": 1, + } + }, + "siglip-384px-resize-naive+7b": { + "model_id": "siglip-384px-resize-naive+7b", + "names": ["SigLIP ViT-SO 384px (Naive Resize)", "SigLIP 384px (Naive Resize)"], + "description": { + "name": "SigLIP ViT-SO 384px (Naive Resize)", + "optimization_procedure": "single-stage", + "visual_representation": "SigLIP ViT-SO/14 @ 384px", + "image_processing": "Naive Resize", + "language_model": "Vicuña v1.5 7B", + "datasets": ["LLaVa v1.5 Instruct"], + "train_epochs": 1, + } + }, + + "dinoclip-336px-letterbox+7b": { + "model_id": "dinoclip-336px-letterbox+7b", + "names": ["DINOv2 + CLIP 336px (Letterbox)"], + "description": { + "name": "DINOv2 + CLIP 336px (Letterbox)", + "optimization_procedure": "single-stage", + "visual_representation": "DINOv2 ViT-L/14 + CLIP ViT-L/14 @ 336px", + "image_processing": "Letterbox", + "language_model": "Vicuña v1.5 7B", + "datasets": ["LLaVa v1.5 Instruct"], + "train_epochs": 1, + } + }, + "dinoclip-336px-resize-naive+7b": { + "model_id": "dinoclip-336px-resize-naive+7b", + "names": ["DINOv2 + CLIP 336px (Naive Resize)"], + "description": { + "name": "DINOv2 + CLIP 336px (Naive Resize)", + "optimization_procedure": "single-stage", + "visual_representation": "DINOv2 ViT-L/14 + CLIP ViT-L/14 @ 336px", + "image_processing": "Naive Resize", + "language_model": "Vicuña v1.5 7B", + "datasets": ["LLaVa v1.5 Instruct"], + "train_epochs": 1, + } + }, + "dinosiglip-384px-letterbox+7b": { + "model_id": "dinosiglip-384px-letterbox+7b", + "names": ["DINOv2 + SigLIP 384px (Letterbox)"], + "description": { + "name": "DINOv2 + SigLIP 384px (Letterbox)", + "optimization_procedure": "single-stage", + "visual_representation": "DINOv2 ViT-L/14 + SigLIP ViT-L/14 @ 384px", + "image_processing": "Letterbox", + "language_model": "Vicuña v1.5 7B", + "datasets": ["LLaVa v1.5 Instruct"], + "train_epochs": 1, + } + }, + "dinosiglip-384px-resize-naive+7b": { + "model_id": "dinosiglip-384px-resize-naive+7b", + "names": ["DINOv2 + SigLIP 384px (Naive Resize)"], + "description": { + "name": "DINOv2 + SigLIP 384px (Naive Resize)", + "optimization_procedure": "single-stage", + "visual_representation": "DINOv2 ViT-L/14 + SigLIP ViT-L/14 @ 384px", + "image_processing": "Naive Resize", + "language_model": "Vicuña v1.5 7B", + "datasets": ["LLaVa v1.5 Instruct"], + "train_epochs": 1, + } + }, + + # === Section 4.3 :: Language Models === + "llama2+7b": { + "model_id": "llama2+7b", + "names": ["Llama-2 7B"], + "description": { + "name": "Llama-2 7B", + "optimization_procedure": "single-stage", + "visual_representation": "CLIP ViT-L/14 @ 336px", + "image_processing": "Letterbox", + "language_model": "Llama-2 7B", + "datasets": ["LLaVa v1.5 Instruct"], + "train_epochs": 1, + }, + }, + "llama2+13b": { + "model_id": "llama2+13b", + "names": ["Llama-2 13B"], + "description": { + "name": "Llama-2 13B", + "optimization_procedure": "single-stage", + "visual_representation": "CLIP ViT-L/14 @ 336px", + "image_processing": "Letterbox", + "language_model": "Llama-2 13B", + "datasets": ["LLaVa v1.5 Instruct"], + "train_epochs": 1, + }, + }, + + "vicuna-no-cotraining+7b": { + "model_id": "vicuna-no-cotraining+7b", + "names": ["Vicuña v1.5 7B (No Co-training)"], + "description": { + "name": "Vicuña v1.5 7B (No Co-training)", + "optimization_procedure": "single-stage", + "visual_representation": "CLIP ViT-L/14 @ 336px", + "image_processing": "Letterbox", + "language_model": "Vicuña v1.5 7B", + "datasets": ["LLaVa v1.5 Multimodal-Only"], + "train_epochs": 1, + }, + }, + "llama2-no-cotraining+7b": { + "model_id": "llama2-no-cotraining+7b", + "names": ["Llama-2 7B (No Co-training)"], + "description": { + "name": "Llama-2 7B (No Co-training)", + "optimization_procedure": "single-stage", + "visual_representation": "CLIP ViT-L/14 @ 336px", + "image_processing": "Letterbox", + "language_model": "Llama-2 7B", + "datasets": ["LLaVa v1.5 Multimodal-Only"], + "train_epochs": 1, + }, + }, + + # === Section 4.4 :: Scaling Properties === + "train-1.25-epochs+7b": { + "model_id": "train-1.25-epochs+7b", + "names": ["1.25 Epochs"], + "description": { + "name": "1.25 Epochs", + "optimization_procedure": "single-stage", + "visual_representation": "CLIP ViT-L/14 @ 336px", + "image_processing": "Letterbox", + "language_model": "Vicuña v1.5 7B", + "datasets": ["LLaVa v1.5 Instruct"], + "train_epochs": 1.25, + } + }, + "train-1.5-epochs+7b": { + "model_id": "train-1.5-epochs+7b", + "names": ["1.5 Epochs"], + "description": { + "name": "1.5 Epochs", + "optimization_procedure": "single-stage", + "visual_representation": "CLIP ViT-L/14 @ 336px", + "image_processing": "Letterbox", + "language_model": "Vicuña v1.5 7B", + "datasets": ["LLaVa v1.5 Instruct"], + "train_epochs": 1.5, + } + }, + "train-2-epochs+7b": { + "model_id": "train-2-epochs+7b", + "names": ["2 Epochs"], + "description": { + "name": "2 Epochs", + "optimization_procedure": "single-stage", + "visual_representation": "CLIP ViT-L/14 @ 336px", + "image_processing": "Letterbox", + "language_model": "Vicuña v1.5 7B", + "datasets": ["LLaVa v1.5 Instruct"], + "train_epochs": 2, + } + }, + "train-3-epochs+7b": { + "model_id": "train-3-epochs+7b", + "names": ["3 Epochs"], + "description": { + "name": "3 Epochs", + "optimization_procedure": "single-stage", + "visual_representation": "CLIP ViT-L/14 @ 336px", + "image_processing": "Letterbox", + "language_model": "Vicuña v1.5 7B", + "datasets": ["LLaVa v1.5 Instruct"], + "train_epochs": 3, + } + }, + + "llava-lvis4v+7b": { + "model_id": "llava-lvis4v+7b", + "names": ["Base + LVIS-4V"], + "description": { + "name": "Base + LVIS-4V", + "optimization_procedure": "single-stage", + "visual_representation": "CLIP ViT-L/14 @ 336px", + "image_processing": "Letterbox", + "language_model": "Vicuña v1.5 7B", + "datasets": ["LLaVa v1.5 Instruct", "LVIS-Instruct-4V"], + "train_epochs": 1, + } + }, + "llava-lrv+7b": { + "model_id": "llava-lrv+7b", + "names": ["Base + LRV"], + "description": { + "name": "Base + LRV", + "optimization_procedure": "single-stage", + "visual_representation": "CLIP ViT-L/14 @ 336px", + "image_processing": "Letterbox", + "language_model": "Vicuña v1.5 7B", + "datasets": ["LLaVa v1.5 Instruct", "LRV-Instruct"], + "train_epochs": 1, + } + }, + "llava-lvis4v-lrv+7b": { + "model_id": "llava-lvis4v-lrv+7b", + "names": ["Base + LVIS-4V + LRV"], + "description": { + "name": "Base + LVIS-4V + LRV", + "optimization_procedure": "single-stage", + "visual_representation": "CLIP ViT-L/14 @ 336px", + "image_processing": "Letterbox", + "language_model": "Vicuña v1.5 7B", + "datasets": ["LLaVa v1.5 Instruct", "LVIS-Instruct-4V", "LRV-Instruct"], + "train_epochs": 1, + } + }, + + # === + + # === CLIP Prism Models === + "prism-clip-controlled+7b": { + "model_id": "prism-clip-controlled+7b", + "names": ["Prism-CLIP 7B (Controlled)"], + "description": { + "name": "CLIP Prism 7B (Controlled)", + "optimization_procedure": "single-stage", + "visual_representation": "CLIP ViT-L/14 @ 336px", + "image_processing": "Naive Resize", + "language_model": "Llama-2 7B", + "datasets": ["LLaVa v1.5 Instruct"], + "train_epochs": 1, + } + }, + "prism-clip-controlled+13b": { + "model_id": "prism-clip-controlled+13b", + "names": ["Prism-CLIP 13B (Controlled)"], + "description": { + "name": "CLIP Prism 13B (Controlled)", + "optimization_procedure": "single-stage", + "visual_representation": "CLIP ViT-L/14 @ 336px", + "image_processing": "Naive Resize", + "language_model": "Llama-2 13B", + "datasets": ["LLaVa v1.5 Instruct"], + "train_epochs": 1, + } + }, + "prism-clip+7b": { + "model_id": "prism-clip+7b", + "names": ["Prism-CLIP 7B"], + "description": { + "name": "CLIP Prism 7B", + "optimization_procedure": "single-stage", + "visual_representation": "CLIP ViT-L/14 @ 336px", + "image_processing": "Naive Resize", + "language_model": "Llama-2 7B", + "datasets": ["LLaVa v1.5 Instruct", "LVIS-Instruct-4V", "LRV-Instruct"], + "train_epochs": 2, + }, + }, + "prism-clip+13b": { + "model_id": "prism-clip+13b", + "names": ["Prism-CLIP 13B"], + "description": { + "name": "CLIP Prism 13B", + "optimization_procedure": "single-stage", + "visual_representation": "CLIP ViT-L/14 @ 336px", + "image_processing": "Naive Resize", + "language_model": "Llama-2 13B", + "datasets": ["LLaVa v1.5 Instruct", "LVIS-Instruct-4V", "LRV-Instruct"], + "train_epochs": 2, + }, + }, + + # === SigLIP Prism Models == + "prism-siglip-controlled+7b": { + "model_id": "prism-siglip-controlled+7b", + "names": ["Prism-SigLIP 7B (Controlled)"], + "description": { + "name": "SigLIP Prism 7B (Controlled)", + "optimization_procedure": "single-stage", + "visual_representation": "SigLIP ViT-SO/14 @ 384px", + "image_processing": "Naive Resize", + "language_model": "Llama-2 7B", + "datasets": ["LLaVa v1.5 Instruct"], + "train_epochs": 1, + } + }, + "prism-siglip-controlled+13b": { + "model_id": "prism-siglip-controlled+7b", + "names": ["Prism-SigLIP 13B (Controlled)"], + "description": { + "name": "SigLIP Prism 13B (Controlled)", + "optimization_procedure": "single-stage", + "visual_representation": "SigLIP ViT-SO/14 @ 384px", + "image_processing": "Naive Resize", + "language_model": "Llama-2 13B", + "datasets": ["LLaVa v1.5 Instruct"], + "train_epochs": 1, + } + }, + "prism-siglip+7b": { + "model_id": "prism-siglip+7b", + "names": ["Prism-SigLIP 7B"], + "description": { + "name": "SigLIP Prism 7B", + "optimization_procedure": "single-stage", + "visual_representation": "SigLIP ViT-SO/14 @ 384px", + "image_processing": "Naive Resize", + "language_model": "Llama-2 7B", + "datasets": ["LLaVa v1.5 Instruct", "LVIS-Instruct-4V", "LRV-Instruct"], + "train_epochs": 2, + } + }, + "prism-siglip+13b": { + "model_id": "prism-siglip+13b", + "names": ["Prism-SigLIP 13B"], + "description": { + "name": "SigLIP Prism 13B", + "optimization_procedure": "single-stage", + "visual_representation": "SigLIP ViT-SO/14 @ 384px", + "image_processing": "Naive Resize", + "language_model": "Llama-2 13B", + "datasets": ["LLaVa v1.5 Instruct", "LVIS-Instruct-4V", "LRV-Instruct"], + "train_epochs": 2, + } + }, + + # === DINOSigLIP Prism Models === + "prism-dinosiglip-controlled+7b": { + "model_id": "prism-dinosiglip-controlled+7b", + "names": ["Prism-DINOSigLIP 7B (Controlled)", "Prism 7B (Controlled)"], + "description": { + "name": "DINOSigLIP Prism 7B (Controlled)", + "optimization_procedure": "single-stage", + "visual_representation": "DINOv2 ViT-L/14 + SigLIP ViT-SO/14 @ 384px", + "image_processing": "Naive Resize", + "language_model": "Llama-2 7B", + "datasets": ["LLaVa v1.5 Instruct"], + "train_epochs": 1, + } + }, + "prism-dinosiglip-controlled+13b": { + "model_id": "prism-dinosiglip-controlled+13b", + "names": ["Prism-DINOSigLIP 13B (Controlled)", "Prism 13B (Controlled)"], + "description": { + "name": "DINOSigLIP Prism 13B (Controlled)", + "optimization_procedure": "single-stage", + "visual_representation": "DINOv2 ViT-L/14 + SigLIP ViT-SO/14 @ 384px", + "image_processing": "Naive Resize", + "language_model": "Llama-2 13B", + "datasets": ["LLaVa v1.5 Instruct"], + "train_epochs": 1, + } + }, + "prism-dinosiglip+7b": { + "model_id": "prism-dinosiglip+7b", + "names": ["Prism-DINOSigLIP 7B"], + "description": { + "name": "DINOSigLIP Prism 7B", + "optimization_procedure": "single-stage", + "visual_representation": "DINOv2 ViT-L/14 + SigLIP ViT-SO/14 @ 384px", + "image_processing": "Naive Resize", + "language_model": "Llama-2 7B", + "datasets": ["LLaVa v1.5 Instruct", "LVIS-Instruct-4V", "LRV-Instruct"], + "train_epochs": 2, + }, + }, + "prism-dinosiglip+13b": { + "model_id": "prism-dinosiglip+13b", + "names": ["Prism-DINOSigLIP 13B"], + "description": { + "name": "DINOSigLIP Prism 13B", + "optimization_procedure": "single-stage", + "visual_representation": "DINOv2 ViT-L/14 + SigLIP ViT-SO/14 @ 384px", + "image_processing": "Naive Resize", + "language_model": "Llama-2 13B", + "datasets": ["LLaVa v1.5 Instruct", "LVIS-Instruct-4V", "LRV-Instruct"], + "train_epochs": 2, + }, + }, + + # === DINOSigLIP 224px Prism Models === + "prism-dinosiglip-224px-controlled+7b": { + "model_id": "prism-dinosiglip-224px-controlled+7b", + "names": ["Prism-DINOSigLIP 224px 7B (Controlled)"], + "description": { + "name": "DINOSigLIP 224px 7B (Controlled)", + "optimization_procedure": "single-stage", + "visual_representation": "DINOv2 ViT-L/14 + SigLIP ViT-SO 14 @ 224px", + "image_processing": "Naive Resize", + "language_model": "Llama-2 7B", + "datasets": ["LLaVa v1.5 Instruct"], + "train_epochs": 1, + } + }, + "prism-dinosiglip-224px+7b": { + "model_id": "prism-dinosiglip-224px+7b", + "names": ["Prism-DINOSigLIP 224px 7B"], + "description": { + "name": "DINOSigLIP 224px 7B", + "optimization_procedure": "single-stage", + "visual_representation": "DINOv2 ViT-L/14 + SigLIP ViT-SO 14 @ 224px", + "image_processing": "Naive Resize", + "language_model": "Llama-2 7B", + "datasets": ["LLaVa v1.5 Instruct", "LVIS-Instruct-4V", "LRV-Instruct"], + "train_epochs": 2, + } + }, + + # === Additional LLM Backbones === + "llama2-chat+7b": { + "model_id": "llama2-chat+7b", + "names": ["Llama-2 Chat 7B"], + "description": { + "name": "Llama-2 Chat 7B", + "optimization_procedure": "single-stage", + "visual_representation": "CLIP ViT-L/14 @ 336px", + "image_processing": "Letterbox", + "language_model": "Llama-2 Chat 7B", + "datasets": ["LLaVa v1.5 Instruct"], + "train_epochs": 1, + } + }, + "llama2-chat+13b": { + "model_id": "llama2-chat+13b", + "names": ["Llama-2 Chat 13B"], + "description": { + "name": "Llama-2 Chat 13B", + "optimization_procedure": "single-stage", + "visual_representation": "CLIP ViT-L/14 @ 336px", + "image_processing": "Letterbox", + "language_model": "Llama-2 Chat 13B", + "datasets": ["LLaVa v1.5 Instruct"], + "train_epochs": 1, + } + }, + "mistral-v0.1+7b": { + "model_id": "mistral-v0.1+7b", + "names": ["Mistral v0.1 7B"], + "description": { + "name": "Mistral v0.1 7B", + "optimization_procedure": "single-stage", + "visual_representation": "CLIP ViT-L/14 @ 336px", + "image_processing": "Letterbox", + "language_model": "Mistral v0.1 7B", + "datasets": ["LLaVa v1.5 Instruct"], + "train_epochs": 1, + } + }, + "mistral-instruct-v0.1+7b": { + "model_id": "mistral-instruct-v0.1+7b", + "names": ["Mistral Instruct v0.1 7B"], + "description": { + "name": "Mistral Instruct v0.1 7B", + "optimization_procedure": "single-stage", + "visual_representation": "CLIP ViT-L/14 @ 336px", + "image_processing": "Letterbox", + "language_model": "Mistral Instruct v0.1 7B", + "datasets": ["LLaVa v1.5 Instruct"], + "train_epochs": 1, + } + }, + "phi-2+3b": { + "model_id": "phi-2+3b", + "names": ["Phi-2 3B"], + "description": { + "name": "Phi-2 3B", + "optimization_procedure": "single-stage", + "visual_representation": "CLIP ViT-L/14 @ 336px", + "image_processing": "Letterbox", + "language_model": "Phi-2 3B", + "datasets": ["LLaVa v1.5 Instruct"], + "train_epochs": 1, + } + }, +} + +# Build Global Registry (Model ID, Name) -> Metadata +GLOBAL_REGISTRY = {name: v for k, v in MODEL_REGISTRY.items() for name in [k] + v["names"]} + +# fmt: on diff --git a/prismatic/models/vlas/__init__.py b/prismatic/models/vlas/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..fd8e82c9b39c6b9f6d785121c87743187853c47a --- /dev/null +++ b/prismatic/models/vlas/__init__.py @@ -0,0 +1 @@ +from .openvla import OpenVLA diff --git a/prismatic/models/vlas/openvla.py b/prismatic/models/vlas/openvla.py new file mode 100644 index 0000000000000000000000000000000000000000..58ee0f9f3af2a5b362c98c45075d3b2adc07704b --- /dev/null +++ b/prismatic/models/vlas/openvla.py @@ -0,0 +1,131 @@ +""" +openvla.py + +PyTorch Module defining OpenVLA as a lightweight wrapper around a PrismaticVLM; defines custom logic around +discretizing actions with the ActionTokenizer. +""" + +from typing import Dict, List, Optional + +import numpy as np +import torch +from PIL import Image +from transformers import LlamaTokenizerFast + +from prismatic.models.vlms.prismatic import PrismaticVLM +from prismatic.overwatch import initialize_overwatch +from prismatic.vla.action_tokenizer import ActionTokenizer + +# Initialize Overwatch =>> Wraps `logging.Logger` +overwatch = initialize_overwatch(__name__) + + +class OpenVLA(PrismaticVLM): + def __init__( + self, + *args, + norm_stats: Dict[str, Dict[str, Dict[str, Dict[str, List[float]]]]], + action_tokenizer: ActionTokenizer, + **kwargs, + ) -> None: + super().__init__(*args, **kwargs) + self.norm_stats = norm_stats + self.action_tokenizer = action_tokenizer + + @torch.inference_mode() + def predict_action( + self, image: Image, instruction: str, unnorm_key: Optional[str] = None, **kwargs: str + ) -> np.ndarray: + """ + Core function for VLA inference; maps input image and task instruction to continuous action (de-tokenizes). + + @param image: PIL Image as [height, width, 3] + @param instruction: Task instruction string + @param unnorm_key: Optional dataset name for retrieving un-normalizing statistics; if None, checks that model + was trained only on a single dataset, and retrieves those statistics. + + @return Unnormalized (continuous) action vector --> end-effector deltas. + """ + image_transform, tokenizer = self.vision_backbone.image_transform, self.llm_backbone.tokenizer + + # Build VLA Prompt + prompt_builder = self.get_prompt_builder() + prompt_builder.add_turn(role="human", message=f"What action should the robot take to {instruction.lower()}?") + prompt_text = prompt_builder.get_prompt() + + # Prepare Inputs + input_ids = tokenizer(prompt_text, truncation=True, return_tensors="pt").input_ids.to(self.device) + if isinstance(tokenizer, LlamaTokenizerFast): + # If the special empty token ('') does not already appear after the colon (':') token in the prompt + # (after "OUT:" or "ASSISTANT:"), insert it to match the inputs seen at training time + if not torch.all(input_ids[:, -1] == 29871): + input_ids = torch.cat( + (input_ids, torch.unsqueeze(torch.Tensor([29871]).long(), dim=0).to(input_ids.device)), dim=1 + ) + else: + raise ValueError(f"Unsupported `tokenizer` type = {type(tokenizer)}") + + # Preprocess Image + pixel_values = image_transform(image) + if isinstance(pixel_values, torch.Tensor): + pixel_values = pixel_values[None, ...].to(self.device) + elif isinstance(pixel_values, dict): + pixel_values = {k: v[None, ...].to(self.device) for k, v in pixel_values.items()} + else: + raise ValueError(f"Unsupported `pixel_values` type = {type(pixel_values)}") + + # Invoke super().generate --> taps into `GenerationMixin` which (redirects) to `forward()` + autocast_dtype = self.llm_backbone.half_precision_dtype + with torch.autocast("cuda", dtype=autocast_dtype, enabled=self.enable_mixed_precision_training): + # fmt: off + generated_ids = super(PrismaticVLM, self).generate( + input_ids=input_ids, # Shape: [1, seq] + pixel_values=pixel_values, # Shape: [1, 3, res, res] or Dict[str, ...] + max_new_tokens=self.get_action_dim(unnorm_key), + **kwargs + ) + # fmt: on + + # Extract predicted action tokens and translate into (normalized) continuous actions + predicted_action_token_ids = generated_ids[0, -self.get_action_dim(unnorm_key) :] + normalized_actions = self.action_tokenizer.decode_token_ids_to_actions(predicted_action_token_ids.cpu().numpy()) + + # Un-normalize Actions + action_norm_stats = self.get_action_stats(unnorm_key) + mask = action_norm_stats.get("mask", np.ones_like(action_norm_stats["q01"], dtype=bool)) + action_high, action_low = np.array(action_norm_stats["q99"]), np.array(action_norm_stats["q01"]) + actions = np.where( + mask, + 0.5 * (normalized_actions + 1) * (action_high - action_low) + action_low, + normalized_actions, + ) + + return actions + + @staticmethod + def _check_unnorm_key(norm_stats: Dict, unnorm_key: str) -> str: + if unnorm_key is None: + assert len(norm_stats) == 1, ( + f"Your model was trained on more than one dataset, please pass a `unnorm_key` from the following " + f"options to choose the statistics used for un-normalizing actions: {norm_stats.keys()}" + ) + unnorm_key = next(iter(norm_stats.keys())) + + # Error Handling + assert ( + unnorm_key in norm_stats + ), f"The `unnorm_key` you chose is not in the set of available statistics; choose from: {norm_stats.keys()}" + + return unnorm_key + + def get_action_dim(self, unnorm_key: Optional[str] = None) -> int: + """Dimensionality of the policy's action space.""" + unnorm_key = self._check_unnorm_key(self.norm_stats, unnorm_key) + + return len(self.norm_stats[unnorm_key]["action"]["q01"]) + + def get_action_stats(self, unnorm_key: Optional[str] = None) -> Dict: + """Dimensionality of the policy's action space.""" + unnorm_key = self._check_unnorm_key(self.norm_stats, unnorm_key) + + return self.norm_stats[unnorm_key]["action"] diff --git a/prismatic/models/vlms/__init__.py b/prismatic/models/vlms/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..5a31406e6f47bd5f3f75df046d8f64def2356a99 --- /dev/null +++ b/prismatic/models/vlms/__init__.py @@ -0,0 +1 @@ +from .prismatic import PrismaticVLM diff --git a/prismatic/models/vlms/base_vlm.py b/prismatic/models/vlms/base_vlm.py new file mode 100644 index 0000000000000000000000000000000000000000..3c1049501b1156c02acf116891df3296e1177c3f --- /dev/null +++ b/prismatic/models/vlms/base_vlm.py @@ -0,0 +1,108 @@ +""" +base_vlm.py + +Abstract class definition of a Vision-Language Model (VLM), with full annotations of class methods, utility functions, +and initialization logic. This is mostly to future-proof the codebase; while all our experiments instantiate +from PrismaticVLM, theoretically, this base class should be general enough to cover almost all models (e.g., IDEFICS, +PALI, Fuyu) in the future. + +We use Abstract base classes *sparingly* -- mostly as a way to encapsulate any redundant logic or nested inheritance +(e.g., dependence on nn.Module, HF PretrainedModel, etc.). For other abstract objects (e.g., Tokenizers/Transforms), +prefer Protocol definitions instead. +""" + +from __future__ import annotations + +from abc import ABC, abstractmethod +from pathlib import Path +from typing import Callable, List, Optional + +import torch +import torch.nn as nn +from transformers import GenerationMixin, PretrainedConfig +from transformers.modeling_outputs import CausalLMOutputWithPast + +from prismatic.models.backbones.llm import LLMBackbone +from prismatic.models.backbones.llm.prompting import PromptBuilder +from prismatic.models.backbones.vision import VisionBackbone + + +# === Abstract Base Class for arbitrary Vision-Language Models === +class VLM(nn.Module, GenerationMixin, ABC): + def __init__( + self, + model_family: str, + model_id: str, + vision_backbone: VisionBackbone, + llm_backbone: LLMBackbone, + enable_mixed_precision_training: bool = True, + ) -> None: + super().__init__() + self.model_family, self.model_id = model_family, model_id + self.vision_backbone, self.llm_backbone = vision_backbone, llm_backbone + self.enable_mixed_precision_training = enable_mixed_precision_training + + # Instance Attributes for a generic VLM + self.all_module_keys, self.trainable_module_keys = None, None + + # === GenerationMixin Expected Attributes =>> *DO NOT MODIFY* === + self.generation_config = self.llm_backbone.llm.generation_config + self.main_input_name = "input_ids" + + @property + def device(self) -> torch.device: + """Borrowed from `transformers.modeling_utils.py` -- checks parameter device; assumes model on *ONE* device!""" + return next(self.parameters()).device + + @classmethod + @abstractmethod + def from_pretrained( + cls, + pretrained_checkpoint: Path, + model_family: str, + model_id: str, + vision_backbone: VisionBackbone, + llm_backbone: LLMBackbone, + **kwargs: str, + ) -> VLM: ... + + @abstractmethod + def get_prompt_builder(self, system_prompt: Optional[str] = None) -> PromptBuilder: ... + + @abstractmethod + def freeze_backbones(self, stage: str) -> None: ... + + @abstractmethod + def load_from_checkpoint(self, stage: str, run_dir: Path, pretrained_checkpoint: Optional[Path] = None) -> None: ... + + @abstractmethod + def get_fsdp_wrapping_policy(self) -> Callable: ... + + @abstractmethod + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + pixel_values: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + multimodal_indices: Optional[torch.LongTensor] = None, + ) -> CausalLMOutputWithPast: ... + + # === GenerationMixin Expected Properties & Methods (DO NOT MODIFY) === + @staticmethod + def can_generate() -> bool: + return True + + @property + def config(self) -> PretrainedConfig: + return self.llm_backbone.llm.config + + # => Beam Search Utility + def _reorder_cache(self, past_key_values, beam_idx): + return self.llm_backbone.llm._reorder_cache(past_key_values, beam_idx) diff --git a/prismatic/models/vlms/prismatic.py b/prismatic/models/vlms/prismatic.py new file mode 100644 index 0000000000000000000000000000000000000000..f6d1a6aac97b724d9148dec818581a91790c285f --- /dev/null +++ b/prismatic/models/vlms/prismatic.py @@ -0,0 +1,621 @@ +""" +prismatic.py + +PyTorch Module defining a PrismaticVLM, our general interface for defining the various different VLMs in our work. + +Notes: + - For now, we don't subclass `transformers.PretrainedModel` (or CausalLM). Instead, we assume a very limited subset + of the {Model}ForCausalLM API that enables dispatch to the underlying LLM's `generate` utilities (feeding inputs + through our custom projection shim). +""" + +from __future__ import annotations + +from functools import partial +from pathlib import Path +from typing import Callable, Dict, List, Optional, Type, Union + +import torch +from PIL import Image +from torch.distributed.fsdp.wrap import _module_wrap_policy, _or_policy +from transformers.modeling_outputs import CausalLMOutputWithPast + +from prismatic.models.backbones.llm import LLMBackbone +from prismatic.models.backbones.llm.prompting import PromptBuilder +from prismatic.models.backbones.vision import VisionBackbone +from prismatic.models.vlms.base_vlm import VLM +from prismatic.overwatch import initialize_overwatch +from prismatic.util.nn_utils import FusedMLPProjector, LinearProjector, MLPProjector + +# Initialize Overwatch =>> Wraps `logging.Logger` +overwatch = initialize_overwatch(__name__) + + +# HuggingFace Default / LLaMa-2 IGNORE_INDEX (for labels) +IGNORE_INDEX = -100 + + +class PrismaticVLM(VLM): + def __init__( + self, + model_id: str, + vision_backbone: VisionBackbone, + llm_backbone: LLMBackbone, + enable_mixed_precision_training: bool = True, + arch_specifier: str = "gelu-mlp", + **kwargs, + ) -> None: + super().__init__( + "prismatic", + model_id, + vision_backbone, + llm_backbone, + enable_mixed_precision_training=enable_mixed_precision_training, + ) + + # Set Weight Initialization Seed for Projector Consistency + torch.manual_seed(vision_backbone.embed_dim) + + # Initialize Projection (Adapter) based on `arch_specifier` + self.arch_specifier = arch_specifier + if arch_specifier == "linear": + self.projector = LinearProjector(vision_backbone.embed_dim, llm_backbone.embed_dim) + elif arch_specifier.endswith("fused-gelu-mlp"): + self.projector = FusedMLPProjector(vision_backbone.embed_dim, llm_backbone.embed_dim) + elif arch_specifier.endswith("gelu-mlp"): + self.projector = MLPProjector(vision_backbone.embed_dim, llm_backbone.embed_dim) + else: + raise ValueError(f"PrismaticVLM with `{arch_specifier = }` is not supported!") + + # Trackers + self.vision_backbone_requires_grad = False + + # Set Module Keys =>> used in Checkpoint Saving / Model Loading + self.all_module_keys = ["vision_backbone", "llm_backbone", "projector"] + self.trainable_module_keys = [] + + # === Generation Utilities === + # => For computing likelihoods --> get tokens corresponding to "True", "False" and "Yes", "No" + self.string2idx = {} + for trigger_string in ["True", "False", "Yes", "No"] + [chr(ord("A") + i) for i in range(26)]: + token_idx_list = self.llm_backbone.tokenizer.encode(trigger_string, add_special_tokens=False) + assert len(token_idx_list) == 1, f'String "{trigger_string}" is tokenized as more than one token!' + self.string2idx[trigger_string] = token_idx_list[0] + + @classmethod + def from_pretrained( + cls, + pretrained_checkpoint: Path, + model_id: str, + vision_backbone: VisionBackbone, + llm_backbone: LLMBackbone, + enable_mixed_precision_training: bool = True, + arch_specifier: str = "gelu-mlp", + freeze_weights: bool = True, + **kwargs, + ) -> PrismaticVLM: + """Initialize a PrismaticVLM from a pretrained checkpoint, freezing all weights, tailored for inference.""" + vlm = cls( + model_id, + vision_backbone, + llm_backbone, + enable_mixed_precision_training=enable_mixed_precision_training, + arch_specifier=arch_specifier, + **kwargs, + ) + + # Load from Checkpoint (Custom --> should load both *projector* and *llm* weights) + model_state_dict = torch.load(pretrained_checkpoint, map_location="cpu")["model"] + assert ( + "projector" in model_state_dict and "llm_backbone" in model_state_dict + ), "PrismaticVLM `from_pretrained` expects checkpoint with keys for `projector` AND `llm_backbone`!" + + vlm.projector.load_state_dict(model_state_dict["projector"]) + vlm.llm_backbone.load_state_dict(model_state_dict["llm_backbone"]) + if "vision_backbone" in model_state_dict.keys(): + vlm.vision_backbone.load_state_dict(model_state_dict["vision_backbone"]) + + # Freeze Weights + if freeze_weights: + vlm.requires_grad_(False) + vlm.eval() + + return vlm + + def get_prompt_builder(self, system_prompt: Optional[str] = None) -> PromptBuilder: + prompt_initializer: Type[PromptBuilder] = self.llm_backbone.prompt_builder_fn + return prompt_initializer(self.model_family, system_prompt=system_prompt) + + def freeze_backbones(self, stage: str) -> None: + """ + This function sets `requires_grad_` on each of the component modules explicitly, depending on stage. + + We support two separate stages --> "align" and "finetune". + => "align" --> vision_backbone*, llm_backbone* are frozen; only the `projector` is trained. + => "finetune" --> vision_backbone* is frozen; both `projector` and `llm_backbone` are trained. + + :param stage: Pretraining stage in < "align" | "finetune" | "full-finetune" | "vla-train" | "vla-full-train" > + """ + if stage == "align": + self.vision_backbone.requires_grad_(False) + self.llm_backbone.requires_grad_(False) + self.projector.requires_grad_(True) + + # Add to `self.trainable_module_keys` + self.trainable_module_keys = ["projector"] + + # Update Trackers + self.vision_backbone_requires_grad = False + + # Explicitly Log Frozen / Trainable Components + overwatch.info(f"[Frozen] 🥶 =>> Vision Backbone `{self.vision_backbone.identifier}`", ctx_level=1) + overwatch.info(f"[Frozen] 🥶 =>> LLM Backbone `{self.llm_backbone.identifier}`", ctx_level=1) + overwatch.info(f"[TRAINABLE] 🔥 =>> Projector `{self.arch_specifier}`", ctx_level=1) + + elif stage in {"finetune", "vla-train"}: + self.vision_backbone.requires_grad_(False) + self.llm_backbone.requires_grad_(True) + self.projector.requires_grad_(True) + + # Add to `self.trainable_module_keys` + self.trainable_module_keys = ["projector", "llm_backbone"] + + # Update Trackers + self.vision_backbone_requires_grad = False + + # Explicitly Log Frozen / Unfrozen Components + overwatch.info(f"[Frozen] 🥶 =>> Vision Backbone `{self.vision_backbone.identifier}`", ctx_level=1) + overwatch.info(f"[TRAINABLE] 🔥 =>> LLM Backbone `{self.llm_backbone.identifier}`", ctx_level=1) + overwatch.info(f"[TRAINABLE] 🔥 =>> Projector `{self.arch_specifier}`", ctx_level=1) + + elif stage in {"full-finetune", "vla-full-train"}: + self.vision_backbone.dtype = torch.float32 + self.vision_backbone.requires_grad_(True) + self.llm_backbone.requires_grad_(True) + self.projector.requires_grad_(True) + + # Add to `self.trainable_module_keys` + self.trainable_module_keys = ["vision_backbone", "projector", "llm_backbone"] + + # Update Trackers + self.vision_backbone_requires_grad = True + + # Explicitly Log Frozen / Unfrozen Components + overwatch.info(f"[TRAINABLE] 🔥 =>> Vision Backbone `{self.vision_backbone.identifier}`", ctx_level=1) + overwatch.info(f"[TRAINABLE] 🔥 =>> LLM Backbone `{self.llm_backbone.identifier}`", ctx_level=1) + overwatch.info(f"[TRAINABLE] 🔥 =>> Projector `{self.arch_specifier}`", ctx_level=1) + + elif stage in {"last-layer-finetune", "vla-last-layer-train"}: + self.vision_backbone.requires_grad_(False) + self.projector.requires_grad_(False) + self.llm_backbone.requires_grad_(False) + + # Unfreeze final LLM layer + for module in self.llm_backbone.last_layer_finetune_modules: + module.requires_grad_(True) + + # Add to `self.trainable_module_keys` + self.trainable_module_keys = ["llm_backbone"] + + # Update Trackers + self.vision_backbone_requires_grad = False + + # Explicitly Log Frozen / Unfrozen Components + # fmt: off + overwatch.info(f"[Frozen] 🥶 =>> Vision Backbone `{self.vision_backbone.identifier}`", ctx_level=1) # noqa: E501 + overwatch.info(f"[Frozen, except last layer] 🥶🔥 =>> LLM Backbone `{self.llm_backbone.identifier}`", ctx_level=1) # noqa: E501 + overwatch.info(f"[Frozen] 🥶 =>> Projector `{self.arch_specifier}`", ctx_level=1) + # fmt: on + + elif stage in {"vla-sandwich-train"}: + self.vision_backbone.dtype = torch.float32 + self.vision_backbone.requires_grad_(True) + self.projector.requires_grad_(True) + self.llm_backbone.requires_grad_(False) + + # Unfreeze final LLM layer + for module in self.llm_backbone.last_layer_finetune_modules: + module.requires_grad_(True) + + # Add to `self.trainable_module_keys` + self.trainable_module_keys = ["vision_backbone", "projector", "llm_backbone"] + + # Update Trackers + self.vision_backbone_requires_grad = True + + # Explicitly Log Frozen / Unfrozen Components + # fmt: off + overwatch.info(f"[TRAINABLE] 🔥 =>> Vision Backbone `{self.vision_backbone.identifier}`", ctx_level=1) # noqa: E501 + overwatch.info(f"[Frozen, except last layer] 🥶🔥 =>> LLM Backbone `{self.llm_backbone.identifier}`", ctx_level=1) # noqa: E501 + overwatch.info(f"[TRAINABLE] 🔥 =>> Projector `{self.arch_specifier}`", ctx_level=1) + # fmt: on + + else: + raise ValueError(f"Stage `{stage}` is not supported for LLaVa! Try < align | finetune >") + + overwatch.debug("##################################################") + overwatch.debug("##### Trainable Network Parameters: #####") + overwatch.debug("##################################################") + for name, param in self.named_parameters(): + if param.requires_grad: + overwatch.debug(name) + + def load_from_checkpoint(self, stage: str, run_dir: Path, pretrained_checkpoint: Optional[Path] = None) -> None: + """Load weights from checkpoint (if required by the given stage).""" + assert stage in {"align", "finetune", "full-finetune"}, f"Stage {stage} is not supported!" + + # If we're running a `no-align` architecture, we're good! + if self.arch_specifier.startswith("no-align"): + overwatch.info( + f"PrismaticVLM with `{self.arch_specifier = }` does not require pretrained weights!", ctx_level=1 + ) + return + + # Otherwise, handle stage-specific logic! + if stage == "align": + overwatch.info("Stage `align` does not require pretrained weights =>> Starting Training", ctx_level=1) + return + + # Otherwise, load from `pretrained_checkpoint` or match on `run_dir` (s/+stage-finetune/+stage-align/g) + overwatch.info("Stage `finetune` requires `align` pretrained weights", ctx_level=1) + + # Config specifies path to a checkpoint to load + if pretrained_checkpoint is not None: + overwatch.info(f"Loading from Provided Checkpoint `{pretrained_checkpoint}`", ctx_level=1) + model_state_dict = torch.load(pretrained_checkpoint)["model"] + self.projector.load_state_dict(model_state_dict["projector"]) + + return + + # [Contract] If no `pretrained_checkpoint`, assume `align` lives in the run directory; string substitution! + model, scale, _, seed = run_dir.name.split("+") + align_dirs = [ + d + for d in run_dir.parent.iterdir() + if (d.name.startswith(f"{model}+{scale}") and d.name.endswith(f"+stage-align+{seed}")) + ] + assert len(align_dirs) == 1, "Multiple or No Valid Pretrained Directories Exist -- Double Check `runs`!" + if (pretrained_checkpoint := (align_dirs[0] / "checkpoints" / "latest-checkpoint.pt")).exists(): + overwatch.info(f"Loading from Discovered Checkpoint `{pretrained_checkpoint}`", ctx_level=1) + model_state_dict = torch.load(pretrained_checkpoint)["model"] + self.projector.load_state_dict(model_state_dict["projector"]) + else: + raise ValueError(f"Could not find valid `align` checkpoint at {pretrained_checkpoint}!") + + def get_fsdp_wrapping_policy(self) -> Callable: + """Return an FSDP _or_policy over the policies returned by each individual backbone (and our VLM policy).""" + vision_fsdp_wrapping_policy = self.vision_backbone.get_fsdp_wrapping_policy() + llm_fsdp_wrapping_policy = self.llm_backbone.get_fsdp_wrapping_policy() + + # Get Prismatic Wrapping Policy =>> just a module wrapping policy around `self.projector` + prismatic_fsdp_wrapping_policy = partial( + _module_wrap_policy, + module_classes={LinearProjector, MLPProjector, FusedMLPProjector}, + ) + + # Return union (_or_) over constituent policies + # => Note: there is *not* a fall-through policy; any module that isn't covered by the above constituents will + # automatically be folded into the root VLM FSDP instance. + return partial( + _or_policy, + policies=[ + vision_fsdp_wrapping_policy, + llm_fsdp_wrapping_policy, + prismatic_fsdp_wrapping_policy, + ], + ) + + # Note =>> We're not explicitly subclassing `PreTrainedModel` because we don't need the bloat; however, `forward()` + # *must* match the signature of a `{Model}ForCausalLM` so that we can inherit from `GenerationMixin` + + # ruff: noqa: C901 + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + pixel_values: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + multimodal_indices: Optional[torch.LongTensor] = None, + ) -> CausalLMOutputWithPast: + """Run a forward pass through the VLM, returning a CausalLMOutputWithPast instance (contains loss).""" + + # Handle Inference (leverage cache, short-circuit on just LLM forward) + if input_ids.shape[1] == 1 and past_key_values is not None: + # We're leveraging the cache, so just redirect to `self.llm_backbone` with `input_ids` and `past_key_values` + output = self.llm_backbone( + input_ids=input_ids, + attention_mask=None, + position_ids=None, + past_key_values=past_key_values, + inputs_embeds=None, + labels=None, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + return output + + elif input_ids.shape[1] == 1 or pixel_values is None: + raise RuntimeError("Invalid `forward()` call!") + + # Handle Multimodal Indices is None --> pretend like the batch is fully multimodal (always image + text)! + if multimodal_indices is None: + multimodal_indices = torch.arange(len(input_ids), dtype=torch.long, device=input_ids.device) + + # Handle Multimodal Indices is Empty (len == 0) --> simple unimodal forward + elif len(multimodal_indices) == 0: + return self.llm_backbone( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=None, + past_key_values=past_key_values, + inputs_embeds=None, + labels=labels, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + # Run Visual Feature Extraction + with torch.set_grad_enabled(self.vision_backbone_requires_grad): + if isinstance(pixel_values, dict): + patch_features = self.vision_backbone({k: pixel_values[k][multimodal_indices] for k in pixel_values}) + else: + patch_features = self.vision_backbone(pixel_values[multimodal_indices]) + + # Projection Logic :: [bsz, num_patches, llm_embed_dim] =>> num_patches = (2 *) (256 + 1) for ViT-L + CLS + projected_patch_embeddings = self.projector(patch_features) + projected_patch_attention_mask = None + if attention_mask is not None: + projected_patch_attention_mask = torch.full( + (projected_patch_embeddings.shape[0], projected_patch_embeddings.shape[1]), + True, + dtype=attention_mask.dtype, + device=attention_mask.device, + ) + + # Get Input Embeddings from LLM Backbone :: [bsz, input_seq_len, llm_embed_dim] + input_embeddings = self.llm_backbone.embed_input_ids(input_ids) + + # Build Multimodal Embeddings (and build resulting attention mask) + multimodal_embeddings = torch.cat( + [ + input_embeddings[multimodal_indices, :1, :], + projected_patch_embeddings, + input_embeddings[multimodal_indices, 1:, :], + ], + dim=1, + ) + multimodal_attention_mask = None + if attention_mask is not None: + multimodal_attention_mask = torch.cat( + [ + attention_mask[multimodal_indices, :1], + projected_patch_attention_mask, + attention_mask[multimodal_indices, 1:], + ], + dim=1, + ) + + # [Contract] We assume the first token of `labels` (associated with ) is already marked as "IGNORE" + # => We'll ignore the per-token outputs for each of the patch embeddings as well! + multimodal_labels = None + if labels is not None: + projected_patch_labels = torch.full( + (projected_patch_embeddings.shape[0], projected_patch_embeddings.shape[1]), + IGNORE_INDEX, + dtype=labels.dtype, + device=labels.device, + ) + multimodal_labels = torch.cat( + [labels[multimodal_indices, :1], projected_patch_labels, labels[multimodal_indices, 1:]], dim=1 + ) + + # === Add Unimodal Handling === + + # Create Fused Embeddings, Attention Mask, and Labels by Merging with "unimodal" Inputs (if applicable) + unimodal_indices = torch.tensor( + [idx for idx in range(len(input_ids)) if idx not in multimodal_indices], + dtype=torch.long, + device=multimodal_indices.device, + ) + + # No "unimodal" data --> Fused == Multimodal + if len(unimodal_indices) == 0: + fused_embeddings = multimodal_embeddings + fused_attention_mask = multimodal_attention_mask + fused_labels = multimodal_labels + + else: + # Otherwise --> Merge w/ unimodal data + + # This doesn't matter --> but in the "normal" case this is the embedding of the token + # => NOTE :: Verified that `zeros/randn/empty/ embedding` all return the same result! + unimodal_embeddings_pad = torch.zeros( + (len(unimodal_indices), projected_patch_embeddings.shape[1], input_embeddings.shape[2]), + dtype=input_embeddings.dtype, + device=input_embeddings.device, + ) + unimodal_attention_pad = torch.full( + (len(unimodal_indices), projected_patch_embeddings.shape[1]), + False, + dtype=attention_mask.dtype, + device=attention_mask.device, + ) + unimodal_labels_pad = torch.full( + (len(unimodal_indices), projected_patch_embeddings.shape[1]), + IGNORE_INDEX, + dtype=labels.dtype, + device=labels.device, + ) + + unimodal_embeddings = torch.cat([input_embeddings[unimodal_indices], unimodal_embeddings_pad], dim=1) + unimodal_attention_mask = torch.cat([attention_mask[unimodal_indices], unimodal_attention_pad], dim=1) + unimodal_labels = torch.cat([labels[unimodal_indices], unimodal_labels_pad], dim=1) + + # Create "Fused" Tensors by Stacking Multimodal & Unimodal + fused_embeddings = torch.vstack([multimodal_embeddings, unimodal_embeddings]) + fused_attention_mask = torch.vstack([multimodal_attention_mask, unimodal_attention_mask]) + fused_labels = torch.vstack([multimodal_labels, unimodal_labels]) + + # Run LLM Forward --> returns CausalLMOutputWithPast! + return self.llm_backbone( + input_ids=None, + attention_mask=fused_attention_mask, + position_ids=None, + past_key_values=past_key_values, + inputs_embeds=fused_embeddings, + labels=fused_labels, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + # === GenerationMixin Methods === + # => Note: The following methods override the functionality of `transformers.GenerationMixin`; these expect the + # contract in each of the function signatures, and also expect our `forward` function to roughly take + # the same arguments as the underlying LLM (see `LlamaModelForCausalLM` as an example) + + def prepare_inputs_for_generation( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + pixel_values: Optional[torch.FloatTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + use_cache: Optional[bool] = None, + **kwargs: torch.Tensor, + ) -> Dict[str, torch.Tensor]: + """Borrowed from `LlamaForCausalLM` --> in general, just handles caching logic during generation.""" + if past_key_values: + input_ids = input_ids[:, -1:] + + # if `inputs_embeds` are passed, we only want to use them in the 1st generation step + if inputs_embeds is not None and past_key_values is None: + model_inputs = {"inputs_embeds": inputs_embeds} + else: + model_inputs = {"input_ids": input_ids} + + # Make sure `pixel_values` are preserved in `model_inputs` + model_inputs.update( + { + "attention_mask": attention_mask, + "pixel_values": pixel_values, + "past_key_values": past_key_values, + "use_cache": use_cache, + } + ) + + return model_inputs + + @torch.inference_mode() + def generate_batch( + self, + pixel_values: Union[torch.Tensor, Dict[str, torch.Tensor]], + texts: List[str], + return_string_probabilities: Optional[List[str]] = None, + **kwargs: str, + ) -> Union[List[str], List[List[float]]]: + # For now, only support generation with a batch size of 1 for simplicity + tokenizer = self.llm_backbone.tokenizer + + # Prepare Inputs + batch_input_ids = [ + tokenizer(text, truncation=True, return_tensors="pt").input_ids.to(self.device) for text in texts + ] + if isinstance(pixel_values, torch.Tensor): + pixel_values = pixel_values[None, ...].to(self.device) + elif isinstance(pixel_values, dict): + pixel_values = {k: v[None, ...].to(self.device) for k, v in pixel_values.items()} + else: + raise ValueError(f"Unsupported `pixel_values` type = {type(pixel_values)}") + + # Create Output Lists + gen_texts, gen_probabilities = [], [] + + # Invoke super().generate --> taps into `GenerationMixin` which (redirects) to `forward()` + autocast_dtype = self.llm_backbone.half_precision_dtype + with torch.autocast("cuda", dtype=autocast_dtype, enabled=self.enable_mixed_precision_training): + for idx, input_ids in enumerate(batch_input_ids): + if isinstance(pixel_values, torch.Tensor): + pixel_values = pixel_values[idx] + elif isinstance(pixel_values, dict): + pixel_values = {k: pixel_values[k][idx] for k in pixel_values} + else: + raise ValueError(f"Unsupported `pixel_values` type = {type(pixel_values)}") + + # Handle `return_string_probabilities` + if return_string_probabilities is None: + full_out_ids = super().generate(input_ids=input_ids, pixel_values=pixel_values, **kwargs) + gen_ids = full_out_ids[0, input_ids.shape[1] :] + + # Decode `gen_ids` and strip any tokens + gen_texts.append(tokenizer.decode(gen_ids, skip_special_tokens=True).strip()) + + else: + full_out_dict = super().generate( + input_ids=input_ids, + pixel_values=pixel_values, + output_scores=True, + return_dict_in_generate=True, + **kwargs, + ) + + # Generation pattern should usually be [TOKEN] for True/False and Yes/No Generations + gen_ids = full_out_dict.sequences[0, input_ids.shape[1] :] + + # [Debug] Verify that the first token generated is in `self.string2idx.values()` + # assert gen_ids[0] in self.string2idx.values(), "Generated ID not in mapping!" + + # Decode `gen_ids` and strip any tokens + gen_texts.append(tokenizer.decode(gen_ids, skip_special_tokens=True).strip()) + + # Get all token probabilities --> softmax over logits + token_probs = torch.softmax(full_out_dict.scores[0][0], dim=0) + + # Get *normalized* probabilities for all values in `return_token_probabilities` + slice_idxs = torch.tensor([self.string2idx[s] for s in return_string_probabilities]) + string_probs_unnormalized = token_probs[slice_idxs] + string_probs = string_probs_unnormalized / string_probs_unnormalized.sum() + gen_probabilities.append(string_probs.cpu().numpy().tolist()) + + return gen_texts if return_string_probabilities is None else gen_probabilities + + @torch.inference_mode() + def generate(self, image: Image, prompt_text: str, **kwargs: str) -> str: + # For now, only support generation with a batch size of 1 for simplicity + image_transform, tokenizer = self.vision_backbone.image_transform, self.llm_backbone.tokenizer + + # Prepare Inputs + input_ids = tokenizer(prompt_text, truncation=True, return_tensors="pt").input_ids.to(self.device) + pixel_values = image_transform(image) + if isinstance(pixel_values, torch.Tensor): + pixel_values = pixel_values[None, ...].to(self.device) + elif isinstance(pixel_values, dict): + pixel_values = {k: v[None, ...].to(self.device) for k, v in pixel_values.items()} + else: + raise ValueError(f"Unsupported `pixel_values` type = {type(pixel_values)}") + + # Invoke super().generate --> taps into `GenerationMixin` which (redirects) to `forward()` + autocast_dtype = self.llm_backbone.half_precision_dtype + with torch.autocast("cuda", dtype=autocast_dtype, enabled=self.enable_mixed_precision_training): + # fmt: off + generated_ids = super().generate( + input_ids=input_ids, # Shape: [1, seq] + pixel_values=pixel_values, # Shape: [1, 3, res, res] or Dict[str, Shape[1, 3, res, res]] + **kwargs + ) + # fmt: on + + generated_text = tokenizer.decode(generated_ids[0, input_ids.shape[1] :], skip_special_tokens=True).strip() + + return generated_text diff --git a/prismatic/overwatch/__init__.py b/prismatic/overwatch/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..6897a047fc2741f7e434bcdaa78f6a14c473fec9 --- /dev/null +++ b/prismatic/overwatch/__init__.py @@ -0,0 +1 @@ +from .overwatch import initialize_overwatch diff --git a/prismatic/overwatch/overwatch.py b/prismatic/overwatch/overwatch.py new file mode 100644 index 0000000000000000000000000000000000000000..f7c40e65a695cc9287e1bcb6fef062904df5aace --- /dev/null +++ b/prismatic/overwatch/overwatch.py @@ -0,0 +1,147 @@ +""" +overwatch.py + +Utility class for creating a centralized/standardized logger (built on Rich) and accelerate handler. +""" + +import logging +import logging.config +import os +from contextlib import nullcontext +from logging import LoggerAdapter +from typing import Any, Callable, ClassVar, Dict, MutableMapping, Tuple, Union + +# Overwatch Default Format String +RICH_FORMATTER, DATEFMT = "| >> %(message)s", "%m/%d [%H:%M:%S]" + +# Set Logging Configuration +LOG_CONFIG = { + "version": 1, + "disable_existing_loggers": True, + "formatters": {"simple-console": {"format": RICH_FORMATTER, "datefmt": DATEFMT}}, + "handlers": { + "console": { + "class": "rich.logging.RichHandler", + "formatter": "simple-console", + "markup": True, + "rich_tracebacks": True, + "show_level": True, + "show_path": True, + "show_time": True, + } + }, + "root": {"level": "INFO", "handlers": ["console"]}, +} +logging.config.dictConfig(LOG_CONFIG) + + +# === Custom Contextual Logging Logic === +class ContextAdapter(LoggerAdapter): + CTX_PREFIXES: ClassVar[Dict[int, str]] = {**{0: "[*] "}, **{idx: "|=> ".rjust(4 + (idx * 4)) for idx in [1, 2, 3]}} + + def process(self, msg: str, kwargs: MutableMapping[str, Any]) -> Tuple[str, MutableMapping[str, Any]]: + ctx_level = kwargs.pop("ctx_level", 0) + return f"{self.CTX_PREFIXES[ctx_level]}{msg}", kwargs + + +class DistributedOverwatch: + def __init__(self, name: str) -> None: + """Initializer for an Overwatch object that wraps logging & `accelerate.PartialState`.""" + from accelerate import PartialState + + # Note that PartialState is always safe to initialize regardless of `accelerate launch` or `torchrun` + # =>> However, might be worth actually figuring out if we need the `accelerate` dependency at all! + self.logger, self.distributed_state = ContextAdapter(logging.getLogger(name), extra={}), PartialState() + + # Logger Delegation (for convenience; would be nice to just compose & dynamic dispatch eventually) + self.debug = self.logger.debug + self.info = self.logger.info + self.warning = self.logger.warning + self.error = self.logger.error + self.critical = self.logger.critical + + # Logging Defaults =>> only Log `INFO` on Main Process, `ERROR` on others! + self.logger.setLevel(logging.INFO if self.distributed_state.is_main_process else logging.ERROR) + + @property + def rank_zero_only(self) -> Callable[..., Any]: + return self.distributed_state.on_main_process + + @property + def local_zero_only(self) -> Callable[..., Any]: + return self.distributed_state.on_local_main_process + + @property + def rank_zero_first(self) -> Callable[..., Any]: + return self.distributed_state.main_process_first + + @property + def local_zero_first(self) -> Callable[..., Any]: + return self.distributed_state.local_main_process_first + + def is_rank_zero(self) -> bool: + return self.distributed_state.is_main_process + + def rank(self) -> int: + return self.distributed_state.process_index + + def local_rank(self) -> int: + return self.distributed_state.local_process_index + + def world_size(self) -> int: + return self.distributed_state.num_processes + + +class PureOverwatch: + def __init__(self, name: str) -> None: + """Initializer for an Overwatch object that just wraps logging.""" + self.logger = ContextAdapter(logging.getLogger(name), extra={}) + + # Logger Delegation (for convenience; would be nice to just compose & dynamic dispatch eventually) + self.debug = self.logger.debug + self.info = self.logger.info + self.warning = self.logger.warning + self.error = self.logger.error + self.critical = self.logger.critical + + # Logging Defaults =>> INFO + self.logger.setLevel(logging.INFO) + + @staticmethod + def get_identity_ctx() -> Callable[..., Any]: + def identity(fn: Callable[..., Any]) -> Callable[..., Any]: + return fn + + return identity + + @property + def rank_zero_only(self) -> Callable[..., Any]: + return self.get_identity_ctx() + + @property + def local_zero_only(self) -> Callable[..., Any]: + return self.get_identity_ctx() + + @property + def rank_zero_first(self) -> Callable[..., Any]: + return nullcontext + + @property + def local_zero_first(self) -> Callable[..., Any]: + return nullcontext + + @staticmethod + def is_rank_zero() -> bool: + return True + + @staticmethod + def rank() -> int: + return 0 + + @staticmethod + def world_size() -> int: + return 1 + + +def initialize_overwatch(name: str) -> Union[DistributedOverwatch, PureOverwatch]: + return DistributedOverwatch(name) if int(os.environ.get("WORLD_SIZE", -1)) != -1 else PureOverwatch(name) diff --git a/prismatic/preprocessing/__init__.py b/prismatic/preprocessing/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..b62598ef246df852419c118a3dc40a6ebddf4bd6 --- /dev/null +++ b/prismatic/preprocessing/__init__.py @@ -0,0 +1,2 @@ +from .download import convert_to_jpg, download_extract +from .materialize import get_dataset_and_collator diff --git a/prismatic/preprocessing/datasets/__init__.py b/prismatic/preprocessing/datasets/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..9a642948d2d042def8edd1848053ec7846fd0009 --- /dev/null +++ b/prismatic/preprocessing/datasets/__init__.py @@ -0,0 +1 @@ +from .datasets import AlignDataset, FinetuneDataset diff --git a/prismatic/preprocessing/datasets/datasets.py b/prismatic/preprocessing/datasets/datasets.py new file mode 100644 index 0000000000000000000000000000000000000000..35f866eda36c17e95df861063b2a41f171b68e1a --- /dev/null +++ b/prismatic/preprocessing/datasets/datasets.py @@ -0,0 +1,200 @@ +""" +datasets.py + +PyTorch Dataset Definitions for Prismatic models; supports processing for both the `align` and `finetune` stages, with +utilities for formatting conversations during the `finetune` stage subject to the given LLM backbone's expected +formatting (e.g., SYS_PROMPT + USER: ... ASSISTANT: ... for Vicuña v1.5 Chat models). + +We currently only support Map-style Datasets; assumes that all files (annotations, images) are on local disk, and that +random access image reading is relatively cheap/fast. +""" + +import copy +import json +from pathlib import Path +from typing import Dict, List, Tuple, Type + +import torch +from PIL import Image +from torch.utils.data import Dataset +from transformers import CodeGenTokenizerFast, LlamaTokenizerFast, PreTrainedTokenizerBase + +from prismatic.models.backbones.llm.prompting import PromptBuilder +from prismatic.models.backbones.vision import ImageTransform + +# HuggingFace Default / LLaMa-2 IGNORE_INDEX (for labels) +IGNORE_INDEX = -100 + + +class AlignDataset(Dataset[Dict[str, torch.Tensor]]): + def __init__( + self, + chat_json: Path, + image_dir: Path, + image_transform: ImageTransform, + tokenizer: PreTrainedTokenizerBase, + ) -> None: + super().__init__() + self.chat_json, self.image_dir = chat_json, image_dir + self.image_transform, self.tokenizer = image_transform, tokenizer + self.dataset_type = "align" + + # Create Prompt Template + self.prompt_template = "{caption}" + self.tokenizer.eos_token + + # Load Chat JSON + with open(self.chat_json, "r") as f: + self.examples = json.load(f) + + def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]: + """ + Following the *actual* code executed from the LLaVa codebase, during the "align" phase, we actually discard + the "prompt" from the human, and instead directly predict the caption from the image. + + As a concrete example given the "raw data" for the first example: + example = self.examples[0]["conversations"]` = { + [ + {"from": "human", "value": "Render a clear and concise summary of the photo.\n"}, + {"from": "gpt", "value": "select luxury furniture 3 - inch gel memory foam mattress topper"} + ] + } + + Return =>> self.tokenizer(" select luxury furniture 3 - inch gel memory foam mattress topper\n") + + :param idx: Index to retrieve from the dataset. + + :return: Dictionary of {"pixel_values": torch.Tensor, "input_ids": torch.Tensor, "labels": torch.Tensor} + """ + image_path, conversation = Path(self.examples[idx]["image"]), self.examples[idx]["conversations"] + assert (len(conversation) == 2) and ("" not in conversation[-1]["value"]), "Unexpected text!" + + # Format Caption --> {caption}{eos_token} + caption = self.prompt_template.format(caption=conversation[-1]["value"].strip()) + + # We treat image patches as "tokens = [p1 p2 p3, ...]"; we need to specify ordering of text/patch tokens. + # => Critically, we find that inserting *after* the BOS token leads to the strongest performance! + # - input_ids = " p1 p2 p3 ... \n" + # - labels = "IGNORE IGNORE ..." (copy `input_ids` replacing and p{1...K} with IGNORE) + # + # IMPORTANT => IF WE'RE USING HF LLM.forward(... labels=labels), SHIFTING HAPPENS _INSIDE_ MODEL! + input_ids = self.tokenizer(caption, truncation=True, return_tensors="pt").input_ids[0] + labels = copy.deepcopy(input_ids) + + # Set the token's label to IGNORE_INDEX (since we're inserting the image patches right after) + labels[0] = IGNORE_INDEX + + # Process Image --> get "pixel_values" (will either be a torch.Tensor OR a Dict[str,torch.Tensor]) + pixel_values = self.image_transform(Image.open(self.image_dir / image_path).convert("RGB")) + + return dict(pixel_values=pixel_values, input_ids=input_ids, labels=labels) + + def get_modality_lengths(self, n_image_patches: int) -> List[Tuple[bool, int]]: + """Get a list of modalities (unimodal / text-only vs. multimodal) and length of conversations per example.""" + modality_lengths = [] + for example in self.examples: + is_multimodal = "image" in example + n_words = sum([len(turn["value"].replace("", "").split()) for turn in example["conversations"]]) + modality_lengths.append((is_multimodal, (n_image_patches + n_words) if is_multimodal else n_words)) + return modality_lengths + + def __len__(self) -> int: + return len(self.examples) + + +class FinetuneDataset(Dataset[Dict[str, torch.Tensor]]): + def __init__( + self, + instruct_json: Path, + image_dir: Path, + image_transform: ImageTransform, + tokenizer: PreTrainedTokenizerBase, + prompt_builder_fn: Type[PromptBuilder], + ) -> None: + super().__init__() + self.instruct_json, self.image_dir = instruct_json, image_dir + self.image_transform, self.tokenizer = image_transform, tokenizer + self.prompt_builder_fn = prompt_builder_fn + self.dataset_type = "finetune" + + # Load Instruct JSON + with open(self.instruct_json, "r") as f: + self.examples = json.load(f) + + # === Unimodal + Multimodal Handling === + def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]: + """ + Unlike the *align* stage handling, for the *finetune* stage, we actually need to handle multiple "turns" of + dialog grounded in a single image. + + To do this, we leverage the `prompt_builder_fn` which instantiates a PromptBuilder object. By calling the + methods for adding turns and getting a prompt, we ensure proper formatting and consistency for each example. + + :param idx: Index to retrieve from the dataset. + + :return: Dictionary of {"pixel_values": torch.Tensor, "input_ids": torch.Tensor, "labels": torch.Tensor} + """ + conversation = self.examples[idx]["conversations"] + + # Create Prompt Builder --> add each message sequentially + prompt_builder, input_ids, labels = self.prompt_builder_fn(model_family="prismatic"), [], [] + for turn_idx, turn in enumerate(conversation): + # Get "effective" string added to prompt --> handle whitespace for tokenizer type! + msg = prompt_builder.add_turn(turn["from"], turn["value"]) + + # Llama Tokenizer (Fast) adds extra character if a string ends in whitespace --> strip if non-empty! + if isinstance(self.tokenizer, LlamaTokenizerFast): + msg = msg.rstrip() + + # Phi-2 Tokenizer == CodeGenTokenizer (Fast) -- no special handling! + elif isinstance(self.tokenizer, CodeGenTokenizerFast): + pass + + else: + raise ValueError(f"Tokenizer of type `{type(self.tokenizer)}` is not explicitly handled!") + + # Tokenize Input IDs + turn_input_ids = self.tokenizer(msg, add_special_tokens=turn_idx == 0).input_ids + + # [CRITICAL] We do not want to take the loss for the "USER: " prompts =>> just the responses! + turn_labels = ( + [IGNORE_INDEX for _ in range(len(turn_input_ids))] if (turn_idx % 2) == 0 else list(turn_input_ids) + ) + + # Add to Trackers + input_ids.extend(turn_input_ids) + labels.extend(turn_labels) + + # Tensorize =>> Set the token's label to IGNORE_INDEX (since we're inserting the image patches after) + # - IMPORTANT => IF WE'RE USING HF LLM.forward(... labels=labels), SHIFTING HAPPENS _INSIDE_ MODEL! + input_ids, labels = torch.tensor(input_ids), torch.tensor(labels) + + # Handle Truncation (if necessary) + input_ids, labels = input_ids[: self.tokenizer.model_max_length], labels[: self.tokenizer.model_max_length] + + # === Handle "unimodal" (language-only) vs. "multimodal" === + if "image" in self.examples[idx]: + image_path = Path(self.examples[idx]["image"]) + + # Set the token's label to IGNORE_INDEX (since we're inserting the image patches right after) + labels[0] = IGNORE_INDEX + + # Process Image --> get "pixel_values" (will either be a torch.Tensor OR a Dict[str,torch.Tensor]) + pixel_values = self.image_transform(Image.open(self.image_dir / image_path).convert("RGB")) + + return dict(pixel_values=pixel_values, input_ids=input_ids, labels=labels) + + else: + # No image --> return `pixel_values` = None; Collator will do the smart batch handling for us! + return dict(pixel_values=None, input_ids=input_ids, labels=labels) + + def get_modality_lengths(self) -> List[Tuple[bool, int]]: + """Get a list of modalities (unimodal / text-only vs. multimodal) and length of conversations per example.""" + modality_lengths = [] + for example in self.examples: + is_multimodal = "image" in example + n_words = sum([len(turn["value"].split()) for turn in example["conversations"]]) + modality_lengths.append((is_multimodal, n_words)) + return modality_lengths + + def __len__(self) -> int: + return len(self.examples) diff --git a/prismatic/preprocessing/download.py b/prismatic/preprocessing/download.py new file mode 100644 index 0000000000000000000000000000000000000000..cff294489e8465471be3da3a07bb4000bf4b7a63 --- /dev/null +++ b/prismatic/preprocessing/download.py @@ -0,0 +1,207 @@ +""" +download.py + +Utility functions for downloading and extracting various datasets to (local) disk. +""" + +import os +import shutil +from pathlib import Path +from typing import Dict, List, TypedDict +from zipfile import ZipFile + +import requests +from PIL import Image +from rich.progress import BarColumn, DownloadColumn, MofNCompleteColumn, Progress, TextColumn, TransferSpeedColumn +from tqdm import tqdm + +from prismatic.overwatch import initialize_overwatch + +# Initialize Overwatch =>> Wraps `logging.Logger` +overwatch = initialize_overwatch(__name__) + + +# === Dataset Registry w/ Links === +# fmt: off +DatasetComponent = TypedDict( + "DatasetComponent", + {"name": str, "extract": bool, "extract_type": str, "url": str, "do_rename": bool}, + total=False +) + +DATASET_REGISTRY: Dict[str, List[DatasetComponent]] = { + # === LLaVa v1.5 Dataset(s) === + + # Note =>> This is the full suite of datasets included in the LLaVa 1.5 "finetuning" stage; all the LLaVa v1.5 + # models are finetuned on this split. We use this dataset for all experiments in our paper. + "llava-laion-cc-sbu-558k": [ + { + "name": "chat.json", # Contains the "chat" traces :: {"human" => , "gpt" => } + "extract": False, + "url": "https://huggingface.co/datasets/liuhaotian/LLaVA-Pretrain/resolve/main/blip_laion_cc_sbu_558k.json", + "do_rename": True, + }, + { + "name": "images", # Contains the LLaVa Processed Images (jpgs, 224x224 resolution) + "extract": True, + "extract_type": "directory", + "url": "https://huggingface.co/datasets/liuhaotian/LLaVA-Pretrain/resolve/main/images.zip", + "do_rename": False, + } + ], + + "llava-v1.5-instruct": [ + { + "name": "llava_v1_5_mix665k.json", + "extract": False, + "url": ( + "https://huggingface.co/datasets/liuhaotian/LLaVA-Instruct-150K/resolve/main/llava_v1_5_mix665k.json" + ), + "do_rename": True, + }, + { + "name": "coco/train2017", # Visual Instruct Tuning images are all sourced from COCO Train 2017 + "extract": True, + "extract_type": "directory", + "url": "http://images.cocodataset.org/zips/train2017.zip", + "do_rename": True, + }, + { + "name": "gqa/images", + "extract": True, + "extract_type": "directory", + "url": "https://downloads.cs.stanford.edu/nlp/data/gqa/images.zip", + "do_rename": True, + }, + { + "name": "ocr_vqa/images", + "extract": True, + "extract_type": "directory", + "url": "https://huggingface.co/datasets/qnguyen3/ocr_vqa/resolve/main/ocr_vqa.zip", + "do_rename": True, + }, + { + "name": "textvqa/train_images", + "extract": True, + "extract_type": "directory", + "url": "https://dl.fbaipublicfiles.com/textvqa/images/train_val_images.zip", + "do_rename": True, + }, + { + "name": "vg/VG_100K", + "extract": True, + "extract_type": "directory", + "url": "https://cs.stanford.edu/people/rak248/VG_100K_2/images.zip", + "do_rename": True, + }, + { + "name": "vg/VG_100K_2", + "extract": True, + "extract_type": "directory", + "url": "https://cs.stanford.edu/people/rak248/VG_100K_2/images2.zip", + "do_rename": True, + }, + ] +} +# fmt: on + + +def convert_to_jpg(image_dir: Path) -> None: + """Handling for OCR-VQA Images specifically; iterates through directory, converts all GIFs/PNGs.""" + overwatch.info(f"Converting all Images in `{image_dir}` to JPG") + + for image_fn in tqdm(list(image_dir.iterdir())): + if image_fn.suffix in {".jpg", ".jpeg"} or (jpg_fn := image_dir / f"{image_fn.stem}.jpg").exists(): + continue + + if image_fn.suffix == ".gif": + gif = Image.open(image_fn) + gif.seek(0) + gif.convert("RGB").save(jpg_fn) + elif image_fn.suffix == ".png": + Image.open(image_fn).convert("RGB").save(jpg_fn) + else: + raise ValueError(f"Unexpected image format `{image_fn.suffix}`") + + +def download_with_progress(url: str, download_dir: Path, chunk_size_bytes: int = 1024) -> Path: + """Utility function for downloading files from the internet, with a handy Rich-based progress bar.""" + overwatch.info(f"Downloading {(dest_path := download_dir / Path(url).name)} from `{url}`", ctx_level=1) + if dest_path.exists(): + return dest_path + + # Otherwise --> fire an HTTP Request, with `stream = True` + response = requests.get(url, stream=True) + + # Download w/ Transfer-Aware Progress + # => Reference: https://github.com/Textualize/rich/blob/master/examples/downloader.py + with Progress( + TextColumn("[bold]{task.description} - {task.fields[fname]}"), + BarColumn(bar_width=None), + "[progress.percentage]{task.percentage:>3.1f}%", + "•", + DownloadColumn(), + "•", + TransferSpeedColumn(), + transient=True, + ) as dl_progress: + dl_tid = dl_progress.add_task( + "Downloading", fname=dest_path.name, total=int(response.headers.get("content-length", "None")) + ) + with open(dest_path, "wb") as f: + for data in response.iter_content(chunk_size=chunk_size_bytes): + dl_progress.advance(dl_tid, f.write(data)) + + return dest_path + + +def extract_with_progress(archive_path: Path, download_dir: Path, extract_type: str, cleanup: bool = False) -> Path: + """Utility function for extracting compressed archives, with a handy Rich-based progress bar.""" + assert archive_path.suffix == ".zip", "Only `.zip` compressed archives are supported for now!" + overwatch.info(f"Extracting {archive_path.name} to `{download_dir}`", ctx_level=1) + + # Extract w/ Progress + with Progress( + TextColumn("[bold]{task.description} - {task.fields[aname]}"), + BarColumn(bar_width=None), + "[progress.percentage]{task.percentage:>3.1f}%", + "•", + MofNCompleteColumn(), + transient=True, + ) as ext_progress: + with ZipFile(archive_path) as zf: + ext_tid = ext_progress.add_task("Extracting", aname=archive_path.name, total=len(members := zf.infolist())) + extract_path = Path(zf.extract(members[0], download_dir)) + if extract_type == "file": + assert len(members) == 1, f"Archive `{archive_path}` with extract type `{extract_type} has > 1 member!" + elif extract_type == "directory": + for member in members[1:]: + zf.extract(member, download_dir) + ext_progress.advance(ext_tid) + else: + raise ValueError(f"Extract type `{extract_type}` for archive `{archive_path}` is not defined!") + + # Cleanup (if specified) + if cleanup: + archive_path.unlink() + + return extract_path + + +def download_extract(dataset_id: str, root_dir: Path) -> None: + """Download all files for a given dataset (querying registry above), extracting archives if necessary.""" + os.makedirs(download_dir := root_dir / "download" / dataset_id, exist_ok=True) + + # Download Files => Single-Threaded, with Progress Bar + dl_tasks = [d for d in DATASET_REGISTRY[dataset_id] if not (download_dir / d["name"]).exists()] + for dl_task in dl_tasks: + dl_path = download_with_progress(dl_task["url"], download_dir) + + # Extract Files (if specified) --> Note (assumes ".zip" ONLY!) + if dl_task["extract"]: + dl_path = extract_with_progress(dl_path, download_dir, dl_task["extract_type"]) + dl_path = dl_path.parent if dl_path.is_file() else dl_path + + # Rename Path --> dl_task["name"] + if dl_task["do_rename"]: + shutil.move(dl_path, download_dir / dl_task["name"]) diff --git a/prismatic/preprocessing/materialize.py b/prismatic/preprocessing/materialize.py new file mode 100644 index 0000000000000000000000000000000000000000..b84b0d5c1cbf0650efbac20e3700a8ab3d372091 --- /dev/null +++ b/prismatic/preprocessing/materialize.py @@ -0,0 +1,69 @@ +""" +materialize.py + +Factory class for initializing pretraining datasets on a per-VLM basis; provides and exports individual functions for +clear control flow. +""" + +from typing import Tuple, Type + +from torch.utils.data import Dataset +from transformers import PreTrainedTokenizerBase + +from prismatic.conf import DatasetConfig +from prismatic.models.backbones.llm.prompting import PromptBuilder +from prismatic.models.backbones.vision import ImageTransform +from prismatic.preprocessing.datasets import AlignDataset, FinetuneDataset +from prismatic.util.data_utils import PaddedCollatorForLanguageModeling + +# Dataset Initializers =>> Maps Stage --> cls() +DATASET_INITIALIZER = {"align": AlignDataset, "finetune": FinetuneDataset, "full-finetune": FinetuneDataset} + + +def get_dataset_and_collator( + stage: str, + dataset_cfg: DatasetConfig, + image_transform: ImageTransform, + tokenizer: PreTrainedTokenizerBase, + prompt_builder_fn: Type[PromptBuilder], + default_image_resolution: Tuple[int, int, int], + padding_side: str = "right", +) -> Tuple[Dataset, PaddedCollatorForLanguageModeling]: + dataset_cls = DATASET_INITIALIZER[stage] + dataset_root_dir = dataset_cfg.dataset_root_dir + collator = PaddedCollatorForLanguageModeling( + tokenizer.model_max_length, tokenizer.pad_token_id, default_image_resolution, padding_side=padding_side + ) + + # Switch on `stage` + if stage == "align": + annotation_json, image_dir = dataset_cfg.align_stage_components + dataset = dataset_cls( + dataset_root_dir / annotation_json, dataset_root_dir / image_dir, image_transform, tokenizer + ) + return dataset, collator + + elif stage == "finetune": + annotation_json, image_dir = dataset_cfg.finetune_stage_components + dataset = dataset_cls( + dataset_root_dir / annotation_json, + dataset_root_dir / image_dir, + image_transform, + tokenizer, + prompt_builder_fn=prompt_builder_fn, + ) + return dataset, collator + + elif stage == "full-finetune": + annotation_json, image_dir = dataset_cfg.finetune_stage_components + dataset = dataset_cls( + dataset_root_dir / annotation_json, + dataset_root_dir / image_dir, + image_transform, + tokenizer, + prompt_builder_fn=prompt_builder_fn, + ) + return dataset, collator + + else: + raise ValueError(f"Stage `{stage}` is not supported!") diff --git a/prismatic/training/__init__.py b/prismatic/training/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..58c7f8c8bf8ef7e9c8507eae82d30055e04fae25 --- /dev/null +++ b/prismatic/training/__init__.py @@ -0,0 +1,2 @@ +from .materialize import get_train_strategy +from .metrics import Metrics, VLAMetrics diff --git a/prismatic/training/materialize.py b/prismatic/training/materialize.py new file mode 100644 index 0000000000000000000000000000000000000000..4e9f364dbd7d4b908fe21ba3381ae2305b053f83 --- /dev/null +++ b/prismatic/training/materialize.py @@ -0,0 +1,66 @@ +""" +materialize.py + +Factory class defining functions for instantiating various Training Strategies, supporting different VLMs, backbones, +and strategy configurations. +""" + +from typing import Callable, Optional + +import torch + +from prismatic.models.vlms import PrismaticVLM +from prismatic.training.strategies import FSDPStrategy, TrainingStrategy + +# Registry =>> Maps ID --> {cls(), kwargs} :: supports FSDP for now, but DDP handler is also implemented! +TRAIN_STRATEGIES = { + "fsdp-shard-grad-op": {"cls": FSDPStrategy, "kwargs": {"sharding_strategy": "shard-grad-op"}}, + "fsdp-full-shard": {"cls": FSDPStrategy, "kwargs": {"sharding_strategy": "full-shard"}}, +} + + +def get_train_strategy( + train_strategy: str, + vlm: PrismaticVLM, + device_id: int, + stage: str, + epochs: int, + max_steps: Optional[int], + global_batch_size: int, + per_device_batch_size: int, + learning_rate: float, + weight_decay: float, + max_grad_norm: float, + lr_scheduler_type: str, + warmup_ratio: float, + enable_gradient_checkpointing: bool = True, + enable_mixed_precision_training: bool = True, + reduce_in_full_precision: bool = False, + mixed_precision_dtype: torch.dtype = torch.bfloat16, + worker_init_fn: Optional[Callable[[int], None]] = None, +) -> TrainingStrategy: + if train_strategy in TRAIN_STRATEGIES: + strategy_cfg = TRAIN_STRATEGIES[train_strategy] + strategy = strategy_cfg["cls"]( + vlm=vlm, + device_id=device_id, + stage=stage, + epochs=epochs, + max_steps=max_steps, + global_batch_size=global_batch_size, + per_device_batch_size=per_device_batch_size, + learning_rate=learning_rate, + weight_decay=weight_decay, + max_grad_norm=max_grad_norm, + lr_scheduler_type=lr_scheduler_type, + warmup_ratio=warmup_ratio, + enable_gradient_checkpointing=enable_gradient_checkpointing, + enable_mixed_precision_training=enable_mixed_precision_training, + reduce_in_full_precision=reduce_in_full_precision, + mixed_precision_dtype=mixed_precision_dtype, + worker_init_fn=worker_init_fn, + **strategy_cfg["kwargs"], + ) + return strategy + else: + raise ValueError(f"Train Strategy `{train_strategy}` is not supported!") diff --git a/prismatic/training/metrics.py b/prismatic/training/metrics.py new file mode 100644 index 0000000000000000000000000000000000000000..9dc86ed13889a6b94dca0ebf2db89cf9823d12e6 --- /dev/null +++ b/prismatic/training/metrics.py @@ -0,0 +1,348 @@ +""" +metrics.py + +Utility classes defining a Metrics container and multiple Trackers to enable model/stage-specific logging to various +endpoints (e.g., JSONL local logs, Weights & Biases). +""" + +import time +from collections import defaultdict, deque +from pathlib import Path +from typing import Any, Dict, Optional, Protocol, Tuple, Union + +import jsonlines +import numpy as np +import torch +import wandb + +from prismatic.overwatch import initialize_overwatch + +# Initialize Overwatch =>> Wraps `logging.Logger` +overwatch = initialize_overwatch(__name__) + + +# === Define Tracker Interface === +class Tracker(Protocol): + def write_hyperparameters(self) -> None: ... + + def write(self, global_step: int, metrics: Dict[str, Union[int, float]]) -> None: ... + + def finalize(self) -> None: ... + + +# === Individual Tracker Definitions === +class JSONLinesTracker: + def __init__(self, run_id: str, run_dir: Path, hparams: Dict[str, Any]) -> None: + self.run_id, self.run_dir, self.hparams = run_id, run_dir, hparams + + @overwatch.rank_zero_only + def write_hyperparameters(self) -> None: + with jsonlines.open(self.run_dir / "run-metrics.jsonl", mode="w", sort_keys=True) as js_tracker: + js_tracker.write({"run_id": self.run_id, "hparams": self.hparams}) + + @overwatch.rank_zero_only + def write(self, _: int, metrics: Dict[str, Union[int, float]]) -> None: + with jsonlines.open(self.run_dir / f"{self.run_id}.jsonl", mode="a", sort_keys=True) as js_tracker: + js_tracker.write(metrics) + + def finalize(self) -> None: + return + + +class WeightsBiasesTracker: + def __init__( + self, + run_id: str, + run_dir: Path, + hparams: Dict[str, Any], + project: str = "prismatic", + entity: Optional[str] = None, + group: str = "align", + ) -> None: + self.run_id, self.run_dir, self.hparams = run_id, run_dir, hparams + + # Get W&B-Specific Initialization Parameters + self.project, self.entity, self.group, self.wandb_dir = project, entity, group, self.run_dir + + # Call W&B.init() + self.initialize() + + @overwatch.rank_zero_only + def initialize(self) -> None: + wandb.init( + name=self.run_id, + dir=self.wandb_dir, + config=self.hparams, + project=self.project, + entity=self.entity, + group=self.group, + ) + + @overwatch.rank_zero_only + def write_hyperparameters(self) -> None: + wandb.config = self.hparams + + @overwatch.rank_zero_only + def write(self, global_step: int, metrics: Dict[str, Union[int, float]]) -> None: + wandb.log(metrics, step=global_step) + + @staticmethod + def finalize() -> None: + if overwatch.is_rank_zero(): + wandb.finish() + + # A job gets 210 seconds to get its affairs in order + time.sleep(210) + + +# === Core Metrics Container :: Initializes Trackers => Compiles/Pushes Metrics === + + +class Metrics: + def __init__( + self, + active_trackers: Tuple[str, ...], + run_id: str, + run_dir: Path, + hparams: Dict[str, Any], + stage: str, + wandb_project: str = "prismatic", + wandb_entity: Optional[str] = None, + grad_accumulation_steps: int = 1, + window_size: int = 128, + ) -> None: + self.run_id, self.run_dir, self.hparams, self.stage = run_id, run_dir, hparams, stage + + # Initialize Trackers + self.trackers = [] + for tracker_type in active_trackers: + if tracker_type == "jsonl": + tracker = JSONLinesTracker(run_id, run_dir, hparams) + elif tracker_type == "wandb": + tracker = WeightsBiasesTracker( + run_id, run_dir, hparams, project=wandb_project, entity=wandb_entity, group=self.stage + ) + else: + raise ValueError(f"Tracker with type `{tracker_type} is not supported!") + + # Add Hyperparameters --> add to `self.trackers` + tracker.write_hyperparameters() + self.trackers.append(tracker) + + # Create Universal Metrics Buffers + self.global_step, self.start_time, self.step_start_time = 0, time.time(), time.time() + self.state = { + "loss_raw": deque(maxlen=grad_accumulation_steps), + "loss": deque(maxlen=window_size), + "step_time": deque(maxlen=window_size), + "lr": [], + } + + def log(self, global_step: int, metrics: Dict[str, Union[int, float]]) -> None: + for tracker in self.trackers: + tracker.write(global_step, metrics) + + def get_status(self, loss: Optional[torch.Tensor] = None) -> str: + lr = self.state["lr"][-1] if len(self.state["lr"]) > 0 else 0 + if loss is None: + return f"=>> [Global Step] {self.global_step:06d} =>> LR :: {lr:.6f}" + + # Otherwise, embed `loss` in status report! + return f"=>> [Global Step] {self.global_step:06d} =>> LR :: {lr:.6f} -- Loss :: {loss:.4f}" + + def commit( + self, *, global_step: Optional[int] = None, lr: Optional[float] = None, update_step_time: bool = False, **kwargs + ) -> None: + """Update all metrics in `self.state` by iterating through special positional arguments & kwargs.""" + if global_step is not None: + self.global_step = global_step + + # For all other variables --> only track on rank zero! + if not overwatch.is_rank_zero(): + return + + # Special Positional Arguments + if lr is not None: + self.state["lr"].append(lr) + + if update_step_time: + self.state["step_time"].append(time.time() - self.step_start_time) + self.step_start_time = time.time() + + # Generic Keyword Arguments + for key, value in kwargs.items(): + if key == "loss": + loss_val = value.detach() + self.state["loss_raw"].append(loss_val) + self.state["loss"].append(loss_val) + else: + self.state[key].append(value.detach()) + + @overwatch.rank_zero_only + def push(self) -> str: + # Note :: Raw Loss is an Average over Gradient Accumulation Steps --> No Smoothing! + loss_raw = torch.stack(list(self.state["loss_raw"])).mean().item() + loss = torch.stack(list(self.state["loss"])).mean().item() + step_time, lr = np.mean(list(self.state["step_time"])), self.state["lr"][-1] + status = self.get_status(loss) + + # Fire to Trackers + prefix = self.stage.capitalize() + self.log( + self.global_step, + metrics={ + f"{prefix}/Step": self.global_step, + f"{prefix}/Loss": loss, + f"{prefix}/Loss (Raw)": loss_raw, + f"{prefix}/Learning Rate": lr, + f"{prefix}/Step Time": step_time, + }, + ) + return status + + def finalize(self) -> str: + for tracker in self.trackers: + tracker.finalize() + + +class VLAMetrics: + def __init__( + self, + active_trackers: Tuple[str, ...], + run_id: str, + run_dir: Path, + hparams: Dict[str, Any], + wandb_project: str = "openvla", + wandb_entity: Optional[str] = "stanford-voltron", + grad_accumulation_steps: int = 1, + window_size: int = 1, + resume_step: Optional[int] = None, + resume_epoch: Optional[int] = None, + ) -> None: + self.run_id, self.run_dir, self.hparams = run_id, run_dir, hparams + + # Initialize Trackers + self.trackers = [] + for tracker_type in active_trackers: + if tracker_type == "jsonl": + tracker = JSONLinesTracker(run_id, run_dir, hparams) + elif tracker_type == "wandb": + tracker = WeightsBiasesTracker( + run_id, run_dir, hparams, project=wandb_project, entity=wandb_entity, group="vla-train" + ) + else: + raise ValueError(f"Tracker with type `{tracker_type} is not supported!") + + # Add Hyperparameters --> add to `self.trackers` + tracker.write_hyperparameters() + self.trackers.append(tracker) + + # Create Universal Metrics Buffers + self.global_step = 0 if resume_step is None else resume_step + self.epoch = 0 if resume_epoch is None else resume_epoch + self.start_time, self.step_start_time = time.time(), time.time() + self.state = { + "loss_raw": deque(maxlen=grad_accumulation_steps), + "loss": deque(maxlen=window_size), + "l1_loss": deque(maxlen=window_size), + "action_accuracy": deque(maxlen=window_size), + "step_time": deque(maxlen=window_size), + "lr": [], + } + + # Created metrics buffers for individual tracked datasets + self.dataset_trackers = defaultdict(lambda: VLAMetrics([], "", "", {})) + + def log(self, global_step: int, metrics: Dict[str, Union[int, float]]) -> None: + for tracker in self.trackers: + tracker.write(global_step, metrics) + + def get_status(self, loss: Optional[torch.Tensor] = None) -> str: + lr = self.state["lr"][-1] if len(self.state["lr"]) > 0 else 0 + if loss is None: + return f"=>> [Epoch {self.epoch:03d}] Global Step {self.global_step:06d} =>> LR :: {lr:.6f}" + + # Otherwise, embed `loss` in status report! + return f"=>> [Epoch {self.epoch:03d}] Global Step {self.global_step:06d} =>> LR :: {lr:.6f} - Loss :: {loss:.4f}" + + def commit( + self, + *, + global_step: Optional[int] = None, + epoch: Optional[int] = None, + lr: Optional[float] = None, + update_step_time: bool = False, + **kwargs, + ) -> None: + """Update all metrics in `self.state` by iterating through special positional arguments & kwargs.""" + if global_step is not None: + self.global_step = global_step + + if epoch is not None: + self.epoch = epoch + + # For all other variables --> only track on rank zero! + if not overwatch.is_rank_zero(): + return + + # Special Positional Arguments + if lr is not None: + self.state["lr"].append(lr) + + if update_step_time: + self.state["step_time"].append(time.time() - self.step_start_time) + self.step_start_time = time.time() + + # Generic Keyword Arguments + for key, value in kwargs.items(): + if key == "loss": + loss_val = value.detach() + self.state["loss_raw"].append(loss_val) + self.state["loss"].append(loss_val) + else: + self.state[key].append(value.detach()) + + def commit_for_dataset(self, dataset_name: str, **kwargs) -> None: + self.dataset_trackers[dataset_name].commit(**kwargs) + + @overwatch.rank_zero_only + def push(self) -> str: + # Note :: Raw Loss is an Average over Gradient Accumulation Steps --> No Smoothing! + loss_raw = torch.stack(list(self.state["loss_raw"])).mean().item() + loss = torch.stack(list(self.state["loss"])).mean().item() + l1_loss = torch.stack(list(self.state["l1_loss"])).mean().item() + action_accuracy = torch.stack(list(self.state["action_accuracy"])).mean().item() + step_time, lr = np.mean(list(self.state["step_time"])), self.state["lr"][-1] + status = self.get_status(loss) + + # Get metrics per dataset + dataset_metrics = {} + for ds, tracker in self.dataset_trackers.items(): + dataset_metrics.update( + { + f"{ds}/L1 Loss": torch.stack(list(tracker.state["l1_loss"])).mean().item(), + f"{ds}/Action Token Accuracy": torch.stack(list(tracker.state["action_accuracy"])).mean().item(), + } + ) + + # Fire to Trackers + prefix = "VLA Train" + self.log( + self.global_step, + metrics={ + f"{prefix}/Step": self.global_step, + f"{prefix}/Epoch": self.epoch, + f"{prefix}/Loss": loss, + f"{prefix}/L1 Loss": l1_loss, + f"{prefix}/Action Token Accuracy": action_accuracy, + f"{prefix}/Loss (Raw)": loss_raw, + f"{prefix}/Learning Rate": lr, + f"{prefix}/Step Time": step_time, + **dataset_metrics, + }, + ) + return status + + def finalize(self) -> str: + for tracker in self.trackers: + tracker.finalize() diff --git a/prismatic/training/strategies/__init__.py b/prismatic/training/strategies/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..d73eb1069c982ed3969ba3af56479c0359051a1b --- /dev/null +++ b/prismatic/training/strategies/__init__.py @@ -0,0 +1,3 @@ +from .base_strategy import TrainingStrategy +from .ddp import DDPStrategy +from .fsdp import FSDPStrategy diff --git a/prismatic/training/strategies/base_strategy.py b/prismatic/training/strategies/base_strategy.py new file mode 100644 index 0000000000000000000000000000000000000000..ba4fc9428417cbbe232cd35417de5c4bbfb8e6cd --- /dev/null +++ b/prismatic/training/strategies/base_strategy.py @@ -0,0 +1,417 @@ +""" +base_strategy.py + +Abstract class definition of a (distributed) training strategy, with full annotations of class methods, utility +functions, and initialization logic. + +Training Strategies (DDP, FSDP-Grad, FSDP-Full) tend to have a lot of repeated components; this class does a lot of +heavy lifting. +""" + +from abc import ABC, abstractmethod +from pathlib import Path +from typing import Callable, Optional + +import numpy as np +import torch +import torch.distributed as dist +from torch.utils.data import DataLoader, Dataset, DistributedSampler, IterableDataset +from tqdm import tqdm +from transformers.modeling_outputs import CausalLMOutputWithPast + +from prismatic.models.vlms import PrismaticVLM +from prismatic.overwatch import initialize_overwatch +from prismatic.training.metrics import Metrics, VLAMetrics +from prismatic.training.train_utils import ( + compute_actions_l1_loss, + compute_token_accuracy, + get_current_action_mask, + get_next_actions_mask, +) +from prismatic.util import check_bloat16_supported +from prismatic.util.batching_utils import SplitModalitySampler +from prismatic.util.data_utils import PaddedCollatorForActionPrediction, PaddedCollatorForLanguageModeling +from prismatic.vla.action_tokenizer import ActionTokenizer + +# HuggingFace Default / LLaMa-2 IGNORE_INDEX (for labels) +from prismatic.vla.constants import ACTION_DIM, ACTION_TOKEN_BEGIN_IDX, NUM_ACTIONS_CHUNK, IGNORE_INDEX +NEWLINE_INDEX = 13 # '\n' +STOP_INDEX = 2 # '' + +# Initialize Overwatch =>> Wraps `logging.Logger` +overwatch = initialize_overwatch(__name__) + + +# === Abstract Base Class for an arbitrary Training Strategy === +class TrainingStrategy(ABC): + def __init__( + self, + vlm: PrismaticVLM, + device_id: int, + stage: str, + epochs: int, + max_steps: Optional[int], + global_batch_size: int, + per_device_batch_size: int, + learning_rate: float, + weight_decay: float, + max_grad_norm: float, + lr_scheduler_type: str, + warmup_ratio: float, + enable_gradient_checkpointing: bool = True, + enable_mixed_precision_training: bool = True, + reduce_in_full_precision: bool = False, + mixed_precision_dtype: torch.dtype = torch.bfloat16, + worker_init_fn: Optional[Callable[[int], None]] = None, + **_: str, + ) -> None: + self.vlm, self.device_id, self.stage = vlm, device_id, stage + + # Get relevant VLM instance parameters before they get (potentially) wrapped + self.all_module_keys, self.trainable_module_keys = self.vlm.all_module_keys, self.vlm.trainable_module_keys + self.llm_transformer_layer_cls = self.vlm.llm_backbone.transformer_layer_cls + + # Optimization Parameters + self.epochs, self.max_steps = epochs, max_steps + self.global_batch_size, self.per_device_batch_size = global_batch_size, per_device_batch_size + + self.learning_rate, self.weight_decay, self.max_grad_norm = learning_rate, weight_decay, max_grad_norm + self.lr_scheduler_type, self.warmup_ratio = lr_scheduler_type, warmup_ratio + + # Generic Strategy Parameters + self.enable_gradient_checkpointing = enable_gradient_checkpointing + self.enable_mixed_precision_training = enable_mixed_precision_training + self.reduce_in_full_precision = reduce_in_full_precision + self.mixed_precision_dtype = mixed_precision_dtype + + # DataLoader Parameters + self.worker_init_fn = worker_init_fn + + # Optimizers & Scheduler (initialized in `run_setup`) + self.optimizer, self.lr_scheduler = None, None + + # Lightweight Validation + assert ( + self.global_batch_size % self.per_device_batch_size == 0 + ), "Per-device batch size must evenly divide global batch size!" + self.grad_accumulation_steps = self.global_batch_size // self.per_device_batch_size // overwatch.world_size() + if self.enable_mixed_precision_training: + assert self.mixed_precision_dtype == torch.bfloat16, "Only BF16 mixed precision training is supported!" + assert check_bloat16_supported(), "BFloat16 is not supported on this hardware; unset `mixed_precision`" + + @abstractmethod + def save_checkpoint( + self, + run_dir: Path, + global_step: int, + epoch: int, + train_loss: Optional[float] = None, + only_trainable: bool = True, + ) -> None: ... + + @abstractmethod + def run_setup(self, run_dir: Path, n_train_examples: int) -> None: ... + + @abstractmethod + def clip_grad_norm(self) -> None: ... + + def run_training( + self, + dataset: Dataset, + collator: PaddedCollatorForLanguageModeling, + metrics: Metrics, + stage: str = "finetune", + batch_construction_strategy: str = "split-modality", + seed: int = 7, + ) -> None: + """Run the training loop for the given `dataset` and `collator`; log losses, results to `metrics`""" + if "finetune" in stage and batch_construction_strategy == "split-modality": + # Instantiate the split-modality sampler; if you want to extend with other batch construction schemes, + # (e.g., grouping by length) =>> can easily add them here! + modality_lengths = dataset.get_modality_lengths() + sampler = SplitModalitySampler( + dataset, + modality_lengths, + global_batch_size=self.global_batch_size, + num_replicas=overwatch.world_size(), + rank=overwatch.rank(), + seed=seed, + drop_last=False, + ) + + else: + sampler = DistributedSampler( + dataset, + num_replicas=overwatch.world_size(), + rank=overwatch.rank(), + shuffle=True, + seed=seed, + drop_last=False, + ) + + # Create a DataLoader with the initialized sampler, per-device-bsz, and collator + dataloader = DataLoader( + dataset, + batch_size=self.per_device_batch_size, + sampler=sampler, + collate_fn=collator, + num_workers=2, + worker_init_fn=self.worker_init_fn, + ) + + # Max Steps vs. Epochs Computation + steps_per_epoch = len(dataloader) // self.grad_accumulation_steps + if self.max_steps is not None and steps_per_epoch < self.max_steps: + # Just set `epochs` to some large number --> we'll short-circuit based on steps anyway + self.epochs = 100 + + # === Train === + status = metrics.get_status() + with tqdm( + total=( + (self.epochs * (len(dataloader) // self.grad_accumulation_steps)) + if self.max_steps is None + else self.max_steps + ), + desc=status, + leave=False, + disable=not overwatch.is_rank_zero(), + ) as progress: + for epoch in range(self.epochs): + self.vlm.train() + sampler.set_epoch(epoch) + + # Zero-Gradients (just in case) + self.optimizer.zero_grad() + + # Note that we'll unpack batch (and let AMP/FSDP do its thing) in the VLM.forward() call + # => Basically, if we're using mixed precision (or not), autocast()/FSDP will move to device! + for train_idx, batch in enumerate(dataloader): + # [Contract] self.vlm.forward() must automatically compute `loss` and return! + with torch.autocast( + "cuda", + dtype=self.mixed_precision_dtype, + enabled=self.enable_mixed_precision_training, + ): + output: CausalLMOutputWithPast = self.vlm( + input_ids=batch["input_ids"], + attention_mask=batch["attention_mask"], + pixel_values=batch["pixel_values"], + labels=batch["labels"], + multimodal_indices=batch["multimodal_indices"], + ) + loss = output.loss + + # Commit Loss (Prior to Gradient Accumulation Normalization) + metrics.commit(loss=loss) + + # Normalize Loss to account for Gradient Accumulation --> Backward! + # [IMPORTANT] Technically speaking, doing gradient accumulation in this way is "incorrect"; this is + # because in general, each batch has a *different number of masked out tokens* (because + # we're instruct-tuning). Taking the mean over two unbalanced means != the right thing! + # + # HOWEVER -- at least at the 7B scale, the "naive" approach is just as performant as + # the "correct" implementation, without adding extra complexity. + # + # That being said =>> at the 13B scale, *no matter what we tried, ANY gradient accumulation is just + # really bad for downstream performance. Initial investigation shows that BF16 accumulation + # just really tanks in precision... and don't have a good/clean way to fix this. Would love for + # someone to PR and fix this (and I'd greatly appreciate it!!!) + normalized_loss = loss / self.grad_accumulation_steps + normalized_loss.backward() + + # Step =>> Only if Done w/ Gradient Accumulation + if (train_idx + 1) % self.grad_accumulation_steps == 0: + metrics.commit(update_step_time=True) + + # Clip Gradients --> this is custom, per-strategy because of DDP vs. FSDP locality-assumptions + self.clip_grad_norm() + + # Optimizer & LR Scheduler Step + self.optimizer.step() + self.lr_scheduler.step() + self.optimizer.zero_grad() + + # Push Metrics + metrics.commit(global_step=metrics.global_step + 1, lr=self.lr_scheduler.get_last_lr()[0]) + status = metrics.push() + + # Check for Termination & Save Final Checkpoint (in case `max_steps` is not None) + if self.max_steps is not None and metrics.global_step >= self.max_steps: + self.save_checkpoint(metrics.run_dir, metrics.global_step, epoch, loss.item()) + dist.barrier() + + return + + # Update Progress Bar + progress.update() + progress.set_description(status) + + # Save checkpoint at end each epoch (if `self.max_steps` is None) + if self.max_steps is None: + self.save_checkpoint(metrics.run_dir, metrics.global_step, epoch, loss.item()) + dist.barrier() + + # === VLA Training === + + def run_vla_training( + self, + vla_dataset: IterableDataset, + collator: PaddedCollatorForActionPrediction, + action_tokenizer: ActionTokenizer, + metrics: VLAMetrics, + save_interval: int = 2500, + save_full_model: bool = True, + ) -> None: + """Run the VLA training loop for the given `dataset` and `collator`; log losses, action metrics to `metrics`.""" + assert isinstance(vla_dataset, IterableDataset), "VLA training expects an IterableDataset!" + assert self.grad_accumulation_steps == 1, "VLA training does not support gradient accumulation!" + + # Create a DataLoader =>> Set `num_workers` to 0; RLDS loader handles parallelism! + dataloader = DataLoader( + vla_dataset, + batch_size=self.per_device_batch_size, + sampler=None, + collate_fn=collator, + num_workers=0, + worker_init_fn=self.worker_init_fn, + ) + + # === Train === + status = metrics.get_status() + with tqdm( + total=(self.epochs * len(dataloader)) if self.max_steps is None else self.max_steps, + desc=status, + leave=False, + disable=not overwatch.is_rank_zero(), + ) as progress: + self.vlm.train() + + # Zero Gradients (just in case) + self.optimizer.zero_grad() + + # [Contract] DataLoader wraps RLDS Loader (`.as_numpy_iterator() =>> implicit `.repeat()`) + # => This means looping over the DataLoader is basically "infinite" (so no outer loop over epochs). + # Slightly breaks default PyTorch semantics, which is why we adaptively compute `epoch` below. + for batch in dataloader: + # Note that we'll unpack batch (and let AMP/FSDP do its thing) in the VLM.forward() call + # => Basically, if we're using mixed precision (or not), autocast()/FSDP will move to device! + with torch.autocast( + "cuda", dtype=self.mixed_precision_dtype, enabled=self.enable_mixed_precision_training + ): + # [Contract] self.vlm.forward() must automatically compute `loss` and return! + output: CausalLMOutputWithPast = self.vlm( + input_ids=batch["input_ids"], + attention_mask=batch["attention_mask"], + pixel_values=batch["pixel_values"], + labels=batch["labels"], + ) + loss = output.loss + + # Commit Loss =>> Backward! + metrics.commit(loss=loss) + loss.backward() + + # Get predicted and ground-truth token IDs + predicted_token_ids = output.logits[:, self.vlm.vision_backbone.num_patches : -1].argmax(dim=2) + ground_truth_token_ids = batch["labels"][:, 1:].to(predicted_token_ids.device) + + ####################################################################### + # === Compute Current Action Token Accuracy & L1 Loss === + ####################################################################### + + # Get current action mask: Target the first ACTION_DIM non-ignore tokens + current_action_mask = get_current_action_mask(ground_truth_token_ids) + + # Compute Accuracy + action_accuracy = compute_token_accuracy(predicted_token_ids, ground_truth_token_ids, mask=current_action_mask) + + # Compute L1 Loss on Predicted (Continuous) Actions + action_l1_loss = compute_actions_l1_loss(action_tokenizer, predicted_token_ids, ground_truth_token_ids, mask=current_action_mask) + + ####################################################################### + # === Compute Next Actions Token Accuracy & L1 Loss === + ####################################################################### + + # Get next actions mask: Target all tokens after the first ACTION_DIM non-ignore tokens (excluding the last token, which is the stop token) + next_actions_mask = get_next_actions_mask(ground_truth_token_ids) + + # Compute Accuracy + next_actions_accuracy = compute_token_accuracy(predicted_token_ids, ground_truth_token_ids, mask=next_actions_mask) + + # Compute L1 Loss on Predicted (Continuous) Actions + next_actions_l1_loss = compute_actions_l1_loss(action_tokenizer, predicted_token_ids, ground_truth_token_ids, mask=next_actions_mask) + + ####################################################################### + # === Log === + ####################################################################### + + # Commit Metrics + metrics.commit( + action_accuracy=action_accuracy, + l1_loss=action_l1_loss, + next_actions_accuracy=next_actions_accuracy, + next_actions_l1_loss=next_actions_l1_loss, + update_step_time=True, + ) + + # Compute metrics per dataset --> only on rank_zero since we don't log them on other workers anyways + if overwatch.is_rank_zero(): + datasets = set(batch["dataset_names"]) + if len(datasets) > 1: + for ds in datasets: + ds_mask = torch.tensor([elem == ds for elem in batch["dataset_names"]]) + action_accuracy_ds = correct_preds[ds_mask].sum().float() / mask[ds_mask].sum().float() + pred_continuous_actions_ds = torch.tensor( + action_tokenizer.decode_token_ids_to_actions( + predicted_token_ids[ds_mask][mask[ds_mask]].cpu().numpy() + ) + ) + continuous_actions_gt_ds = torch.tensor( + action_tokenizer.decode_token_ids_to_actions( + ground_truth_token_ids[ds_mask][mask[ds_mask]].cpu().numpy() + ) + ) + action_l1_loss_ds = torch.nn.functional.l1_loss( + pred_continuous_actions_ds, continuous_actions_gt_ds + ) + metrics.commit_for_dataset( + dataset_name=ds.decode(), + action_accuracy=action_accuracy_ds, + l1_loss=action_l1_loss_ds, + next_actions_accuracy=next_actions_accuracy, + next_actions_l1_loss=next_actions_l1_loss, + ) + + # === Gradient Step === + + # Clip Gradients --> this is custom, per-strategy because of DDP vs. FSDP locality assumptions + self.clip_grad_norm() + + # Optimizer & LR Scheduler Step + self.optimizer.step() + self.lr_scheduler.step() + self.optimizer.zero_grad() + + # Compute epoch value using number of completed gradient steps + epoch = (metrics.global_step + 1) // (len(vla_dataset) // self.global_batch_size) + + # Push Metrics + metrics.commit(global_step=metrics.global_step + 1, epoch=epoch, lr=self.lr_scheduler.get_last_lr()[0]) + status = metrics.push() + + # Check for Save Interval or Max Steps & Save Checkpoint + if (terminate := (self.max_steps is not None and metrics.global_step >= self.max_steps)) or ( + (metrics.global_step % save_interval) == 0 + ): + self.save_checkpoint( + metrics.run_dir, metrics.global_step, epoch, loss.item(), only_trainable=not save_full_model + ) + dist.barrier() + + if terminate: + return + + # Update Progress Bar + progress.update() + progress.set_description(status) diff --git a/prismatic/training/strategies/ddp.py b/prismatic/training/strategies/ddp.py new file mode 100644 index 0000000000000000000000000000000000000000..be6c1dd20ef1d315eba1aaf77a94b196ea38af45 --- /dev/null +++ b/prismatic/training/strategies/ddp.py @@ -0,0 +1,128 @@ +""" +ddp.py + +Core class definition for a strategy implementing Torch native Distributed Data Parallel Training; note that on most +GPU hardware and LLM backbones >= 5-7B parameters, DDP training will OOM, which is why we opt for FSDP. +""" + +import shutil +from pathlib import Path +from typing import Optional + +import torch +from torch.nn.parallel import DistributedDataParallel as DDP +from torch.optim import AdamW +from transformers.optimization import get_constant_schedule, get_cosine_schedule_with_warmup + +from prismatic.overwatch import initialize_overwatch +from prismatic.training.strategies.base_strategy import TrainingStrategy + +# Initialize Overwatch =>> Wraps `logging.Logger` +overwatch = initialize_overwatch(__name__) + + +class DDPStrategy(TrainingStrategy): + @overwatch.rank_zero_only + def save_checkpoint( + self, + run_dir: Path, + global_step: int, + epoch: int, + train_loss: Optional[float] = None, + only_trainable: bool = True, + ) -> None: + """Save a checkpoint to the `run_dir` only containing the state_dicts for trainable parameters by default.""" + assert isinstance(self.vlm, DDP), "save_checkpoint assumes VLM is already wrapped in DDP!" + + # Splinter State Dictionary by Top-Level Submodules (or subset, if `only_trainable`) + model_state_dicts = { + mkey: getattr(self.vlm.module, mkey).state_dict() + for mkey in (self.trainable_module_keys if only_trainable else self.all_module_keys) + } + optimizer_state_dict = self.optimizer.state_dict() + + # Set Checkpoint Path =>> Embed *minimal* training statistics! + checkpoint_dir = run_dir / "checkpoints" + if train_loss is None: + checkpoint_path = checkpoint_dir / f"step-{global_step:06d}-epoch-{epoch:02d}-loss=inf.pt" + else: + checkpoint_path = checkpoint_dir / f"step-{global_step:06d}-epoch-{epoch:02d}-loss={train_loss:.4f}.pt" + + # Save Checkpoint & Copy Latest to `latest-checkpoint.pt` + torch.save({"model": model_state_dicts, "optimizer": optimizer_state_dict}, checkpoint_path) + shutil.copy(checkpoint_path, checkpoint_dir / "latest-checkpoint.pt") + + def run_setup(self, run_dir: Path, n_train_examples: int) -> None: + # Gradient Checkpointing Setup + if self.enable_gradient_checkpointing: + # For Gradient Checkpointing --> we make the assumption that the "bulk" of activation memory is taken up + # by the LLM; because we also make the explicit assumption that each LLM is derived from a HF + # pretrained model, the only thing we *need* to do (technically) is call `gradient_checkpoint_enable` + # on `self.llm_backbone`. + # + # What does it actually do? --> runs the *generic* custom_forward + torch.utils.checkpoint.checkpoint logic + # => github.com/huggingface/transformers/.../models/llama/modeling_llama.py#L692-L706 + # + # Additional Reference (to better understand gradient checkpointing in PyTorch writ large) + # => github.com/prigoyal/pytorch_memonger/blob/master/tutorial/Checkpointing_for_PyTorch_models.ipynb + overwatch.info("Enabling Gradient Checkpointing on LLM Backbone", ctx_level=1) + self.vlm.llm_backbone.gradient_checkpointing_enable() + + # Move to Device =>> Note parameters are in full precision (*mixed precision* will only autocast as appropriate) + overwatch.info("Placing Entire VLM (Vision Backbone, LLM Backbone, Projector Weights) on GPU", ctx_level=1) + self.vlm.to(self.device_id) + + # Wrap with Distributed Data Parallel + # => Note: By default, wrapping naively with DDP(self.vlm) will initialize a *separate* buffer on GPU that + # is the same size/dtype as the model parameters; this will *double* GPU memory! + # - stackoverflow.com/questions/68949954/model-takes-twice-the-memory-footprint-with-distributed-data-parallel + overwatch.info("Wrapping VLM with Distributed Data Parallel", ctx_level=1) + self.vlm = DDP(self.vlm, device_ids=[self.device_id], gradient_as_bucket_view=True) + + # Create Optimizer and LR Scheduler =>> note that most of the LR Schedulers we use require `max_steps/epochs` + # => Optimizer should only operate on parameters that are *unfrozen* / trainable! + trainable_params = [param for param in self.vlm.parameters() if param.requires_grad] + if self.max_steps is None: + num_training_steps = (n_train_examples * self.epochs) // self.global_batch_size + else: + num_training_steps = self.max_steps + + if self.lr_scheduler_type == "linear-warmup+cosine-decay": + # Set warmup steps (floor) based on `warmup_ratio` (should be 0.03 - 0.05) + num_warmup_steps = int(num_training_steps * self.warmup_ratio) + + assert self.weight_decay == 0, "DDP training does not currently support `weight_decay` > 0!" + self.optimizer = AdamW(trainable_params, lr=self.learning_rate, weight_decay=self.weight_decay) + self.lr_scheduler = get_cosine_schedule_with_warmup(self.optimizer, num_warmup_steps, num_training_steps) + for param_group in self.optimizer.param_groups: + param_group["lr"] = 0.0 + + elif self.lr_scheduler_type == "constant": + num_warmup_steps = 0 + + assert self.weight_decay == 0, "DDP training does not currently support `weight_decay` > 0!" + self.optimizer = AdamW(trainable_params, lr=self.learning_rate, weight_decay=self.weight_decay) + self.lr_scheduler = get_constant_schedule(self.optimizer) + + else: + raise ValueError(f"Learning Rate Schedule with type `{self.lr_scheduler_type}` is not supported!") + + # Finalize Setup =>> Log + overwatch.info( + "DDP Strategy =>> Finalized Training Setup:\n" + f" |-> Global (Effective) Batch Size = {self.global_batch_size}\n" + f" |-> Per-Device Batch Size = {self.per_device_batch_size}\n" + f" |-> Distributed World Size = {overwatch.world_size()}\n" + f" |-> Gradient Accumulation Steps = {self.grad_accumulation_steps}\n\n" + f" |-> LLM Backbone Gradient Checkpointing = {self.enable_gradient_checkpointing}\n" + f" |-> Use Native AMP = {self.enable_mixed_precision_training} ({self.mixed_precision_dtype})\n\n" + f" |-> Default AdamW LR = {self.learning_rate}\n" + f" |-> AdamW Weight Decay = {self.weight_decay}\n" + f" |-> LR Scheduler Type = {self.lr_scheduler_type}\n" + f" |-> LR Scheduler Warmup Steps (Ratio) = {num_warmup_steps} ({self.warmup_ratio})\n" + f" |-> Dataset Size = {n_train_examples} Examples\n" + f" |-> Max Steps = {num_training_steps}\n" + ) + + def clip_grad_norm(self) -> None: + torch.nn.utils.clip_grad_norm_(self.vlm.parameters(), max_norm=self.max_grad_norm) diff --git a/prismatic/training/strategies/fsdp.py b/prismatic/training/strategies/fsdp.py new file mode 100644 index 0000000000000000000000000000000000000000..9af28f474908af1bbb048a28968c986629ecc5a5 --- /dev/null +++ b/prismatic/training/strategies/fsdp.py @@ -0,0 +1,270 @@ +""" +fsdp.py + +Core class definition for a strategy implementing Torch native Fully Sharded Data Parallel Training (with support for +fine-grained control over wrapping policies and mixed precision per component). +""" + +import math +from collections import OrderedDict +from functools import partial +from pathlib import Path +from typing import Callable, Optional + +import torch +import torch.distributed as dist +import torch.nn as nn +from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import ( + CheckpointImpl, + apply_activation_checkpointing, + checkpoint_wrapper, +) +from torch.distributed.fsdp import ( + FullStateDictConfig, + MixedPrecision, + ShardingStrategy, + StateDictType, +) +from torch.distributed.fsdp import FullyShardedDataParallel as FSDP +from torch.optim import AdamW +from transformers.optimization import get_constant_schedule, get_cosine_schedule_with_warmup + +from prismatic.models.vlms import PrismaticVLM +from prismatic.overwatch import initialize_overwatch +from prismatic.training.strategies.base_strategy import TrainingStrategy + +# Initialize Overwatch =>> Wraps `logging.Logger` +overwatch = initialize_overwatch(__name__) + + +class FSDPStrategy(TrainingStrategy): + def __init__( + self, + vlm: PrismaticVLM, + device_id: int, + stage: str, + epochs: int, + max_steps: Optional[int], + global_batch_size: int, + per_device_batch_size: int, + learning_rate: float, + weight_decay: float, + max_grad_norm: float, + lr_scheduler_type: str, + warmup_ratio: float, + enable_gradient_checkpointing: bool = True, + enable_mixed_precision_training: bool = True, + reduce_in_full_precision: bool = False, + mixed_precision_dtype: torch.dtype = torch.bfloat16, + worker_init_fn: Optional[Callable[[int], None]] = None, + sharding_strategy: str = "shard-grad-op", + state_dict_type: StateDictType = StateDictType.FULL_STATE_DICT, + ) -> None: + super().__init__( + vlm=vlm, + device_id=device_id, + stage=stage, + epochs=epochs, + max_steps=max_steps, + global_batch_size=global_batch_size, + per_device_batch_size=per_device_batch_size, + learning_rate=learning_rate, + weight_decay=weight_decay, + max_grad_norm=max_grad_norm, + lr_scheduler_type=lr_scheduler_type, + warmup_ratio=warmup_ratio, + enable_gradient_checkpointing=enable_gradient_checkpointing, + enable_mixed_precision_training=enable_mixed_precision_training, + reduce_in_full_precision=reduce_in_full_precision, + mixed_precision_dtype=mixed_precision_dtype, + worker_init_fn=worker_init_fn, + ) + + # FSDP-Specific Parameters + if sharding_strategy == "shard-grad-op": + self.fsdp_sharding_strategy = ShardingStrategy._HYBRID_SHARD_ZERO2 + elif sharding_strategy == "full-shard": + self.fsdp_sharding_strategy = ShardingStrategy.HYBRID_SHARD + else: + raise ValueError(f"FSDP Sharding Strategy {sharding_strategy} is not supported!") + + assert state_dict_type == StateDictType.FULL_STATE_DICT, "Sharded state saving is not yet implemented!" + self.fsdp_state_dict_type = state_dict_type + self.fsdp_save_policy = FullStateDictConfig(offload_to_cpu=True, rank0_only=True) + + def save_checkpoint( + self, + run_dir: Path, + global_step: int, + epoch: int, + train_loss: Optional[float] = None, + only_trainable: bool = True, + ) -> None: + """Save a checkpoint to the `run_dir` only containing the state_dicts for trainable parameters by default.""" + assert isinstance(self.vlm, FSDP), "FSDPStrategy.save_checkpoint assumes VLM is already wrapped in FSDP!" + + # Summon Full State Dictionary =>> Reconstitute from Shards + with FSDP.state_dict_type(self.vlm, self.fsdp_state_dict_type, self.fsdp_save_policy): + full_vlm_state_dict = self.vlm.state_dict() + model_state_dicts = { + mkey: OrderedDict() for mkey in (self.trainable_module_keys if only_trainable else self.all_module_keys) + } + + # Iterate through `full_vlm_state_dict` and split `mkey.{full_dotted_path}` -> `mkey: {full_dotted_path}` + for key, param in full_vlm_state_dict.items(): + for mkey in model_state_dicts: + if key.startswith(mprefix := f"{mkey}."): + model_state_dicts[mkey][key.removeprefix(mprefix)] = param + + # Save on rank zero *only* + if overwatch.is_rank_zero(): + checkpoint_dir = run_dir / "checkpoints" + if train_loss is None: + checkpoint_path = checkpoint_dir / f"step-{global_step:06d}-epoch-{epoch:02d}-loss=inf.pt" + else: + checkpoint_path = ( + checkpoint_dir / f"step-{global_step:06d}-epoch-{epoch:02d}-loss={train_loss:.4f}.pt" + ) + + # Save Checkpoint & Copy Latest to `latest-checkpoint.pt` + torch.save({"model": model_state_dicts}, checkpoint_path) + + # TODO (siddk) :: This breaks w/ Sagemaker default permissions (root vs. )... skip? + # shutil.copy(checkpoint_path, checkpoint_dir / "latest-checkpoint.pt") + + def run_setup(self, run_dir: Path, n_train_examples: int) -> None: + # Iteratively Assemble FSDP Wrapping Policy by fetching the wrapping policies for each backbone/constituent + vlm_fsdp_wrapping_policy = self.vlm.get_fsdp_wrapping_policy() + + # Assemble the Default FSDP Mixed Precision Policy + if self.enable_mixed_precision_training and self.mixed_precision_dtype == torch.bfloat16: + # MixedPrecision `param_dtype` specifies *compute* dtype (for forward/backward only) + # => Reference: https://pytorch.org/docs/stable/fsdp.html#torch.distributed.fsdp.MixedPrecision + reduce_buffer_dtype = torch.bfloat16 if not self.reduce_in_full_precision else torch.float32 + fsdp_precision_policy = MixedPrecision( + param_dtype=torch.bfloat16, reduce_dtype=reduce_buffer_dtype, buffer_dtype=reduce_buffer_dtype + ) + + # When running FSDP with a frozen vision backbone --> move to half precision! + if self.stage not in {"full-finetune", "vla-full-train", "vla-sandwich-train"}: + overwatch.info("Casting Vision Backbone to *Half Precision* via `.to(dtype=...)`") + self.vlm.vision_backbone.to(dtype=self.vlm.vision_backbone.half_precision_dtype) + + else: + # If we're not using mixed precision, everything is in default full precision! + fsdp_precision_policy = MixedPrecision( + param_dtype=torch.float32, reduce_dtype=torch.float32, buffer_dtype=torch.float32 + ) + + # => note that FSDP will automatically take care of device placement (similar to `autocast`) + self.vlm = FSDP( + self.vlm, + auto_wrap_policy=vlm_fsdp_wrapping_policy, + mixed_precision=fsdp_precision_policy, + sharding_strategy=self.fsdp_sharding_strategy, + device_id=torch.cuda.current_device(), + limit_all_gathers=True, + use_orig_params=True, + ) + + # Gradient Checkpoint Setup + if self.enable_gradient_checkpointing: + # For Gradient Checkpointing under FSDP --> we make the same assumption as in the DDP/other strategies; the + # bulk of activation memory is taken up by the LLM activations. However, unlike other strategies, we + # cannot rely on the HF Transformers default `gradient_checkpointing_enable()` --> FSDP breaks semantics! + # + # Instead, we need to write our own *NO-REENTRANT* wrapper, and apply it to the LLM's Transformer Layer. + non_reentrant_wrapper = partial(checkpoint_wrapper, checkpoint_impl=CheckpointImpl.NO_REENTRANT) + + def check_fn(submodule: nn.Module) -> bool: + return isinstance(submodule, self.llm_transformer_layer_cls) + + # Note that the terms "activation checkpointing" and "gradient checkpointing" are synonymous! + apply_activation_checkpointing(self.vlm, checkpoint_wrapper_fn=non_reentrant_wrapper, check_fn=check_fn) + + # Barrier =>> Sharding takes a minute? + dist.barrier() + + # Create Optimizer and LR Scheduler =>> note that most of the LR Schedulers we use require `max_steps/epochs` + # => Optimizer should only operate on parameters that are *unfrozen* / trainable! + n_train_examples = math.ceil(n_train_examples / self.global_batch_size) * self.global_batch_size + if self.max_steps is None: + num_training_steps = (n_train_examples * self.epochs) // self.global_batch_size + else: + num_training_steps = self.max_steps + + if self.lr_scheduler_type == "linear-warmup+cosine-decay": + # Set warmup steps (floor) based on `warmup_ratio` (should be 0.03 - 0.05) + num_warmup_steps = int(num_training_steps * self.warmup_ratio) + + # Default AdamW w/ specified LR & Linear Warmup / Cosine Decay & Weight Decay + # => Create Parameter Groups --> bias terms, normalization layer parameters shouldn't be decayed! + decay, no_decay = [], [] + for name, param in self.vlm.named_parameters(): + if not param.requires_grad: + continue + + # Check on any parameters with fewer than 2 dimensions or with "bias" in the name + if param.ndim <= 1 or name.endswith(".bias"): + no_decay.append(param) + else: + decay.append(param) + + # Build Parameter Groups + groups = [{"params": decay, "weight_decay": self.weight_decay}, {"params": no_decay, "weight_decay": 0.0}] + + # Create Optimizer & LR Scheduler + self.optimizer = AdamW(groups, lr=self.learning_rate) + self.lr_scheduler = get_cosine_schedule_with_warmup(self.optimizer, num_warmup_steps, num_training_steps) + for param_group in self.optimizer.param_groups: + param_group["lr"] = 0.0 + + elif self.lr_scheduler_type == "constant": + num_warmup_steps = 0 + + # Default AdamW w/ specified LR & Linear Warmup / Cosine Decay & Weight Decay + # => Create Parameter Groups --> bias terms, normalization layer parameters shouldn't be decayed! + decay, no_decay = [], [] + for name, param in self.vlm.named_parameters(): + if not param.requires_grad: + continue + + # Check on any parameters with fewer than 2 dimensions or with "bias" in the name + if param.ndim <= 1 or name.endswith(".bias"): + no_decay.append(param) + else: + decay.append(param) + + # Build Parameter Groups + groups = [{"params": decay, "weight_decay": self.weight_decay}, {"params": no_decay, "weight_decay": 0.0}] + + # Create Optimizer & LR Scheduler + self.optimizer = AdamW(groups, lr=self.learning_rate) + self.lr_scheduler = get_constant_schedule(self.optimizer) + + else: + raise ValueError(f"Learning Rate Schedule with type `{self.lr_scheduler_type}` is not supported!") + + # Finalize Setup =>> Log! + overwatch.info( + "FSDP Full-Shard Strategy =>> Finalized Training Setup:\n" + f" |-> Global (Effective) Batch Size = {self.global_batch_size}\n" + f" |-> Per-Device Batch Size = {self.per_device_batch_size}\n" + f" |-> Distributed World Size = {overwatch.world_size()}\n" + f" |-> Gradient Accumulation Steps = {self.grad_accumulation_steps}\n\n" + f" |-> LLM Backbone FSDP Gradient Checkpointing = {self.enable_gradient_checkpointing}\n" + f" |-> Use FSDP Mixed Precision = {self.enable_mixed_precision_training}\n" + f" |-> Parameter Precision = {fsdp_precision_policy.param_dtype}\n" + f" |-> Reduction Precision = {fsdp_precision_policy.reduce_dtype}\n" + f" |-> Buffer Precision = {fsdp_precision_policy.buffer_dtype}\n\n" + f" |-> Default AdamW LR = {self.learning_rate}\n" + f" |-> AdamW Weight Decay = {self.weight_decay}\n" + f" |-> LR Scheduler Type = {self.lr_scheduler_type}\n" + f" |-> LR Scheduler Warmup Steps (Ratio) = {num_warmup_steps} ({self.warmup_ratio})\n" + f" |-> Dataset Size = {n_train_examples} Examples\n" + f" |-> Max Steps = {num_training_steps}\n" + ) + + def clip_grad_norm(self) -> None: + # Note =>> FSDP uses a custom `clip_grad_norm_` function; requires *uniform grad dtype* + self.vlm.clip_grad_norm_(max_norm=self.max_grad_norm) diff --git a/prismatic/training/train_utils.py b/prismatic/training/train_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..0c546885d26e834ef4d5a6e14bcd6bdbb731c8cc --- /dev/null +++ b/prismatic/training/train_utils.py @@ -0,0 +1,56 @@ +"""Utils for training/fine-tuning scripts.""" + +import torch + +from prismatic.vla.constants import ACTION_DIM, ACTION_TOKEN_BEGIN_IDX, IGNORE_INDEX + + +def get_current_action_mask(token_ids): + # Create a tensor marking positions of IGNORE_INDEX + newline_positions = token_ids != IGNORE_INDEX + + # Calculate cumulative sum to identify regions between newlines + cumsum = torch.cumsum(newline_positions, dim=1) + + # Create the mask + mask = (1 <= cumsum) & (cumsum <= ACTION_DIM) + + # Extract the action part only + action_tokens_only_mask = token_ids > ACTION_TOKEN_BEGIN_IDX + mask = action_tokens_only_mask * mask + + return mask + + +def get_next_actions_mask(token_ids): + # Create a tensor marking positions of IGNORE_INDEX + newline_positions = token_ids != IGNORE_INDEX + + # Calculate cumulative sum to identify regions between newlines + cumsum = torch.cumsum(newline_positions, dim=1) + + # Create the mask + mask = cumsum > ACTION_DIM + + # Extract the action part only + action_tokens_only_mask = token_ids > ACTION_TOKEN_BEGIN_IDX + mask = action_tokens_only_mask * mask + + return mask + + +def compute_token_accuracy(predicted_token_ids, ground_truth_token_ids, mask): + correct_preds = (predicted_token_ids == ground_truth_token_ids) & mask + accuracy = correct_preds.sum().float() / mask.sum().float() + return accuracy + + +def compute_actions_l1_loss(action_tokenizer, predicted_token_ids, ground_truth_token_ids, mask): + pred_continuous_actions = torch.tensor( + action_tokenizer.decode_token_ids_to_actions(predicted_token_ids[mask].cpu().numpy()) + ) + true_continuous_actions = torch.tensor( + action_tokenizer.decode_token_ids_to_actions(ground_truth_token_ids[mask].cpu().numpy()) + ) + l1_loss = torch.nn.functional.l1_loss(pred_continuous_actions, true_continuous_actions) + return l1_loss diff --git a/prismatic/util/__init__.py b/prismatic/util/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..3473f952d5fd1ddabcd6e0e372a74f4db1f407c3 --- /dev/null +++ b/prismatic/util/__init__.py @@ -0,0 +1 @@ +from .torch_utils import check_bloat16_supported, set_global_seed diff --git a/prismatic/util/batching_utils.py b/prismatic/util/batching_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..5610348e2f5ad5406f71023e014105c98ce5eeff --- /dev/null +++ b/prismatic/util/batching_utils.py @@ -0,0 +1,212 @@ +""" +batching_utils.py + +Core definitions of (Distributed) Samplers for VLM finetuning; provides functionality for construction and allocating +"split-modality" batches as described in the LLaVa paper; this makes sure that a given device/batch is either entirely +(vision, language) or (language-only) data, which leads to sizeable efficiency gains. +""" + +import math +from typing import Iterator, List, Optional, Tuple + +import numpy as np +import torch +import torch.distributed as dist +from torch.utils.data import Dataset, Sampler + + +# High-Fidelity Bitwise Reproduction of the LLaVa Codebase Sampler Strategy + Per-Rank Allocation Scheme (following +# the default batching behavior of HF's Trainer Class --> derived from `accelerate`). +# +# =>> Reference: https://github.com/haotian-liu/LLaVA/blob/main/llava/train/llava_trainer.py#L60 +# =>> Reference: https://github.com/huggingface/transformers/blob/main/src/transformers/trainer_pt_utils.py#L603 +class SplitModalitySampler(Sampler): + def __init__( + self, + dataset: Dataset, + modality_lengths: List[Tuple[bool, int]], + global_batch_size: int, + num_replicas: Optional[int] = None, + rank: Optional[int] = None, + seed: int = 0, + drop_last: bool = False, + ) -> None: + super().__init__() + self.num_replicas = num_replicas if num_replicas is not None else dist.get_world_size() + self.rank = rank if rank is not None else dist.get_rank() + self.seed, self.epoch = seed, 0 + + # Custom Parameters + self.dataset, self.modality_lengths, self.drop_last = dataset, modality_lengths, drop_last + self.global_batch_size = global_batch_size + + # For our purposes, `drop_last` is always False! + assert not self.drop_last, "SplitModalitySampler must set `drop_last = False`!" + self.total_size = math.ceil(len(self.dataset) / self.global_batch_size) * self.global_batch_size + self.num_samples = self.total_size // self.num_replicas + + @staticmethod + def reindex_batch(batch_idxs: List[int], idx2lengths: List[int], n_buckets: int) -> List[List[int]]: + """Re-indexes a batch in a way that is conducive to DistributedSampler + grouping by seqlen per rank.""" + assert len(batch_idxs) % n_buckets == 0, "Batch length is not divisible by `num_replicas`!" + + # Establish initial buckets, capacities, and max number of elements per bucket + n_examples_per_bucket = len(batch_idxs) // n_buckets + bucket_indices = [[] for _ in range(n_buckets)] + bucket_lengths = [0 for _ in range(n_buckets)] + + # Note that `batch_idxs` is already sorted by corresponding length (in descending order) + for idx in batch_idxs: + shortest_bucket_idx = bucket_lengths.index(min(bucket_lengths)) + bucket_indices[shortest_bucket_idx].append(idx) + + # Update `bucket_lengths` --> set length to infinity if at capacity! + bucket_lengths[shortest_bucket_idx] += idx2lengths[idx] + if len(bucket_indices[shortest_bucket_idx]) == n_examples_per_bucket: + bucket_lengths[shortest_bucket_idx] = float("inf") + + return bucket_indices + + def get_modality_and_length_grouped_indices(self, generator: torch.Generator) -> List[int]: + """ + Returns a list of indices so that each slice of `global_batch_size` consecutive indices corresponds to elements + of the same modality with each sub-sequence of `per_replica_batch_size` (the batch size each unique device sees + during distributed training) is roughly grouped by sequence length (for training efficiency). + """ + multimodal_indices, multimodal_lengths = zip( + *[(idx, length) for idx, (is_multimodal, length) in enumerate(self.modality_lengths) if is_multimodal] + ) + + # Handle Special Case --> no "unimodal" inputs + unimodal_split = [ + (idx, length) for idx, (is_multimodal, length) in enumerate(self.modality_lengths) if not is_multimodal + ] + if len(unimodal_split) == 0: + unimodal_indices, unimodal_lengths = [], [] + else: + unimodal_indices, unimodal_lengths = zip(*unimodal_split) + + # Create a permutation of indices for each of the multimodal and unimodal data + mm_shuffled_idxs = torch.randperm(len(multimodal_indices), generator=generator) + uni_shuffled_idxs = torch.randperm(len(unimodal_indices), generator=generator) + + # We're going to be running sorting/grouping relative to `self.global_batch_size` and `self.num_replicas` + g_bsz = self.global_batch_size + + # Break each of the permutations into batches of length `global_batch_size` + mm_batch_idxs = [mm_shuffled_idxs[i : i + g_bsz].tolist() for i in range(0, len(mm_shuffled_idxs), g_bsz)] + uni_batch_idxs = [uni_shuffled_idxs[i : i + g_bsz].tolist() for i in range(0, len(uni_shuffled_idxs), g_bsz)] + + # If "last" batch is not of length `g_bsz` --> PAD by stealing indices from the first batch! + if len(mm_batch_idxs[-1]) < g_bsz: + n_missing = g_bsz - len(mm_batch_idxs[-1]) + mm_batch_idxs[-1].extend(mm_batch_idxs[0][:n_missing]) + + if len(uni_batch_idxs) > 0 and len(uni_batch_idxs[-1]) < g_bsz: + n_missing = g_bsz - len(uni_batch_idxs[-1]) + uni_batch_idxs[-1].extend(uni_batch_idxs[0][:n_missing]) + + # Now we're going to sort each batch by length --> this will aid in grouping by length by rank (efficiency!) + mm_sorted_batch_idxs = [sorted(b, key=lambda i: multimodal_lengths[i], reverse=True) for b in mm_batch_idxs] + uni_sorted_batch_idxs = [sorted(b, key=lambda i: unimodal_lengths[i], reverse=True) for b in uni_batch_idxs] + + # IMPORTANT :: At this point, for each modality, we have a list of "batches" (made up of indices) where indices + # are sorted by example sequence length *within* each batch. To make this more concrete, consider the following: + # => World Size (`num_replicas`) = 2 + # => Global Batch Size (`g_bsz`) = 4 + # => `multimodal_indices` = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11] + # `multimodal_lengths` = [20, 90, 21, 22, 91, 18, 89, 19, 93, 88, 92, 17] + # + # At this point in the code, `mm_sorted_batch_idxs` might then look like the following (length in parenthesis): + # => `mm_sorted_batch_idxs`: [ + # [4 (91), 3 (21), 0 (20), 5 (18)] => Batch 1 + # [6 (89), 9 (88), 7 (19), 11 (17)] => Batch 2 + # [8 (93), 10 (92), 1 (90), 2 (21)] => Batch 3 + # ] + # + # In practice: `g_bsz` is large (= 128), and for contiguous mini-batch "slices", length variance is low. + + # PROBLEM :: We want to split these "global batches" into equal-sized pieces, so that each "replica" (GPU) + # sees a "mini-batch" of roughly the same sequence lengths; this is super useful for efficient training. + + # HOWEVER :: The default "access pattern" for splitting a large batch into mini-batches by a DistributedSampler + # is akin to a "take every k" where `k` is equal to the number of replicas (GPUs) you're training on. Or, in + # Python notation --> `rank_k_indices = flatten(mm_sorted_batch_idxs)[k::num_replicas]. + # + # Naively translating this our example means each GPU (in our world of 2 total) sees the following indices + # (grouped by "mini-batch" = `g_bsz / num_replicas` = 2 for convenience): + # => `rank_0_indices`: [ [4 (91), 0 (20)] =>> [6 (89), 7 (19)] =>> [8 (93), 1 (90)] ] + # => `rank_1_indices`: [ [3 (21), 5 (18)] =>> [9 (88), 11 (17)] =>> [10 (92), 2 (21)] ] + # + # We get lucky sometimes, but for the most part, each "mini-batch" has VASTLY DIFFERENT lengths! Bad! + + # FIX :: If we "undo" the access pattern with the following code and re-arrange the way we allocate batches + # inside the __iter__ method below, we can allocate indices appropriately. Running the following code gives us + # the following indices (grouped by "mini-batch" again for convenience): + # => `rank_0_indices`: [ [4 (91), 3 (21)] =>> [6 (89), 9 (88)] =>> [8 (93), 10 (92)] ] + # => `rank_1_indices`: [ [5 (18), 0 (20)] =>> [11 (17), 7 (19)] =>> [2 (21), 1 (90)] ] + # + # Much better! As `g_bsz` and `dataset` grow, we're more often than not getting *decent* groupings! + mm_length_bucketed_idxs = [ + self.reindex_batch(batch, multimodal_lengths, self.num_replicas) for batch in mm_sorted_batch_idxs + ] + uni_length_bucketed_idxs = [ + self.reindex_batch(batch, unimodal_lengths, self.num_replicas) for batch in uni_sorted_batch_idxs + ] + + # Note :: Because of the initial `randperm` --> we're indexing both sets from 0 (we're clobbering the range) + # => Flatten indices --> index into original `{modality}_indices` then re-batch! + mm_output_idxs = [idx for batch in mm_length_bucketed_idxs for bucket in batch for idx in bucket] + mm_reindexed = [multimodal_indices[idx] for idx in mm_output_idxs] + mm_batches = [mm_reindexed[i : i + g_bsz] for i in range(0, len(mm_reindexed), g_bsz)] + + uni_output_idxs = [idx for batch in uni_length_bucketed_idxs for bucket in batch for idx in bucket] + uni_reindexed = [unimodal_indices[idx] for idx in uni_output_idxs] + uni_batches = [uni_reindexed[i : i + g_bsz] for i in range(0, len(uni_reindexed), g_bsz)] + + # Finally, randomly permute the multimodal & unimodal batches, merging into a single stream of indices + merged_batches = mm_batches + uni_batches + merge_idxs = torch.randperm(len(merged_batches), generator=generator) + all_batches = [merged_batches[idx] for idx in merge_idxs] + + # [Quality of Life] Shift "max length" batch to index 0 --> if we OOM, it happens immediately! + all_lengths = [length + ((_n_patches := 24 * 24) if is_mm else 0) for is_mm, length in self.modality_lengths] + all_batches_max_lengths = [] + for batch in all_batches: + all_batches_max_lengths.append(max([all_lengths[idx] for idx in batch])) + + # Identify Batch with "max length" --> Swap into Index 0 + longest_batch_idx = np.argmax(all_batches_max_lengths) + all_batches[0], all_batches[longest_batch_idx] = all_batches[longest_batch_idx], all_batches[0] + + # Flatten & Return all Indices + indices = [idx for batch in all_batches for idx in batch] + return indices + + def __iter__(self) -> Iterator: + """Deterministically shuffle, then split indices by modality and length.""" + g = torch.Generator() + g.manual_seed(self.seed + self.epoch) + indices = self.get_modality_and_length_grouped_indices(g) + assert len(set(indices)) == len(self.modality_lengths) == len(self.dataset), "Oops!" + assert (len(indices) % self.global_batch_size == 0) and (len(indices) % self.num_replicas) == 0, "Oops" + + # Note :: We compute per-replica batch size as a function of `global_batch` and `num_replicas` to ensure that + # gradient accumulation doesn't affect what indices are assigned a given rank. + per_replica_batch_size = self.global_batch_size // self.num_replicas + + # Tensorize & Unravel --> rather than yielding via a `take_every` --> we want to partition a global batch + # across replicas by assigning each a contiguous sub-sequence. + indices_t = torch.as_tensor(indices) + per_replica_batch_indices_t = indices_t.reshape(-1, per_replica_batch_size) + replica_indices_t = per_replica_batch_indices_t[self.rank :: self.num_replicas] + + replica_indices = replica_indices_t.flatten().tolist() + return iter(replica_indices) + + def __len__(self) -> int: + return self.num_samples + + def set_epoch(self, epoch: int) -> None: + """To be called *between* epochs, prior to DataLoader instantiation; ensures random order across epochs.""" + self.epoch = epoch diff --git a/prismatic/util/data_utils.py b/prismatic/util/data_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..141dbc7417d9afa4256b65e4859981465065752b --- /dev/null +++ b/prismatic/util/data_utils.py @@ -0,0 +1,156 @@ +""" +data_utils.py + +General utilities and classes for facilitating data loading and collation. +""" + +from dataclasses import dataclass +from typing import Callable, Dict, Sequence, Tuple + +import numpy as np +import torch +from torch.nn.utils.rnn import pad_sequence + +# HuggingFace Default / LLaMa-2 IGNORE_INDEX (for labels) +IGNORE_INDEX = -100 + + +def tree_map(fn: Callable, tree: dict) -> dict: + """Maps a function over a nested dictionary.""" + return {k: tree_map(fn, v) if isinstance(v, dict) else fn(v) for k, v in tree.items()} + + +def tree_map_with_key(fn: Callable, tree: dict, keys: Sequence = ()) -> dict: + """Maps a function over a nested dictionary.""" + return { + k: tree_map_with_key(fn, v, (*keys, k)) if isinstance(v, dict) else fn((*keys, k), v) for k, v in tree.items() + } + + +@dataclass +class PaddedCollatorForLanguageModeling: + model_max_length: int + pad_token_id: int + default_image_resolution: Tuple[int, int, int] + padding_side: str = "right" + pixel_values_dtype: torch.dtype = torch.float32 + + def __post_init__(self) -> None: + self.dummy_pixel_values = torch.zeros(self.default_image_resolution, dtype=self.pixel_values_dtype) + + def __call__(self, instances: Sequence[Dict[str, torch.Tensor]]) -> Dict[str, torch.Tensor]: + input_ids, labels = tuple([instance[key] for instance in instances] for key in ("input_ids", "labels")) + pixel_values = [instance["pixel_values"] for instance in instances] + + # For now, we only support Tokenizers with `padding_side = "right"` during Training (but plan to extend!) + # => Handle padding via RNN Utils => `pad_sequence` + input_ids = pad_sequence(input_ids, batch_first=True, padding_value=self.pad_token_id) + labels = pad_sequence(labels, batch_first=True, padding_value=IGNORE_INDEX) + + # Truncate (if necessary) + input_ids, labels = input_ids[:, : self.model_max_length], labels[:, : self.model_max_length] + + # Get `attention_mask` by checking for `pad_token_id` + attention_mask = input_ids.ne(self.pad_token_id) + + # === Handle "unimodal" (language-only) vs. "multimodal" === + + # Some examples are "language-only" --> build a Tensor of `multimodal_indices` that we can slice into easily + multimodal_indices = torch.tensor( + [idx for idx in range(len(pixel_values)) if pixel_values[idx] is not None], dtype=torch.long + ) + + # Stack all `pixel_values` --> depending on type (torch.Tensor, or Dict[str, torch.Tensor]) & presence of None + if len(multimodal_indices) == 0: + pixel_values = torch.stack([self.dummy_pixel_values for _ in range(len(input_ids))]) + elif isinstance(pv_example := pixel_values[multimodal_indices[0]], torch.Tensor): + pixel_values = torch.stack( + [ + pixel_values[idx] if idx in multimodal_indices else self.dummy_pixel_values + for idx in range(len(input_ids)) + ] + ) + elif isinstance(pv_example, dict): + pixel_values = { + k: torch.stack( + [ + pixel_values[idx][k] if idx in multimodal_indices else self.dummy_pixel_values + for idx in range(len(input_ids)) + ] + ) + for k in pv_example + } + else: + raise ValueError(f"Unsupported `pixel_values` type = {type(pixel_values)}") + + return dict( + pixel_values=pixel_values, + input_ids=input_ids, + attention_mask=attention_mask, + labels=labels, + multimodal_indices=multimodal_indices, + ) + + +@dataclass +class PaddedCollatorForActionPrediction: + model_max_length: int + pad_token_id: int + padding_side: str = "right" + pixel_values_dtype: torch.dtype = torch.float32 + + def __call__(self, instances: Sequence[Dict[str, torch.Tensor]]) -> Dict[str, torch.Tensor]: + input_ids, labels = tuple([instance[key] for instance in instances] for key in ("input_ids", "labels")) + pixel_values = [instance["pixel_values"] for instance in instances] + if "dataset_name" in instances[0]: + dataset_names = [instance["dataset_name"] for instance in instances] + else: + dataset_names = None + + # For now, we only support Tokenizers with `padding_side = "right"` during training + # => Handle padding via RNN Utils => `pad_sequence` + assert self.padding_side == "right", f"Invalid Tokenizer `{self.padding_side = }`" + input_ids = pad_sequence(input_ids, batch_first=True, padding_value=self.pad_token_id) + labels = pad_sequence(labels, batch_first=True, padding_value=IGNORE_INDEX) + + # Truncate (if necessary) + input_ids, labels = input_ids[:, : self.model_max_length], labels[:, : self.model_max_length] + + # Get `attention_mask` by checking for `pad_token_id` + attention_mask = input_ids.ne(self.pad_token_id) + + # [Contract] For VLA Training =>> No "Unimodal" Data! + assert all([pv is not None for pv in pixel_values]), "Invalid VLA Example with `pixel_values = None`!" + + # Stack all `pixel_values` --> depending on type is torch.Tensor or Dict[str, torch.Tensor] + if isinstance(pixel_values[0], torch.Tensor): + if "pixel_values_wrist" in instances[0]: + pixel_values_wrist = [instance["pixel_values_wrist"] for instance in instances] + pixel_values = torch.cat((torch.stack(pixel_values), torch.stack(pixel_values_wrist)), dim=1) + else: + pixel_values = torch.stack(pixel_values) + else: + raise ValueError(f"Unsupported `pixel_values` type = {type(pixel_values)}") + + # Stack all actions + actions = [torch.from_numpy(np.copy(instance["actions"])) for instance in instances] + actions = torch.stack(actions) + + # Stack proprio + if "proprio" in instances[0]: + proprio = [instance["proprio"] for instance in instances] + proprio = torch.Tensor(np.squeeze(np.stack(proprio))) + else: + proprio = None + + output = dict( + pixel_values=pixel_values, + proprio=proprio, + input_ids=input_ids, + attention_mask=attention_mask, + labels=labels, + actions=actions, + ) + if dataset_names is not None: + output["dataset_names"] = dataset_names + return output diff --git a/prismatic/util/nn_utils.py b/prismatic/util/nn_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..2b3f6150f2914fde0b1cb80bfb3ad981ad9181ed --- /dev/null +++ b/prismatic/util/nn_utils.py @@ -0,0 +1,53 @@ +""" +nn_utils.py + +Utility functions and PyTorch submodule definitions. +""" + +import torch +import torch.nn as nn + + +# === Definitions for Various Projection Modules, with Signature :: [..., in_dim] --> [..., out_dim] === +class LinearProjector(nn.Module): + def __init__(self, vision_dim: int, llm_dim: int) -> None: + super().__init__() + self.projector = nn.Linear(vision_dim, llm_dim, bias=True) + + def forward(self, img_patches: torch.Tensor) -> torch.Tensor: + return self.projector(img_patches) + + +class MLPProjector(nn.Module): + def __init__(self, vision_dim: int, llm_dim: int, mlp_type: str = "gelu-mlp") -> None: + super().__init__() + if mlp_type == "gelu-mlp": + self.projector = nn.Sequential( + nn.Linear(vision_dim, llm_dim, bias=True), + nn.GELU(), + nn.Linear(llm_dim, llm_dim, bias=True), + ) + else: + raise ValueError(f"Projector with `{mlp_type = }` is not supported!") + + def forward(self, img_patches: torch.Tensor) -> torch.Tensor: + return self.projector(img_patches) + + +class FusedMLPProjector(nn.Module): + def __init__(self, fused_vision_dim: int, llm_dim: int, mlp_type: str = "fused-gelu-mlp") -> None: + super().__init__() + self.initial_projection_dim = fused_vision_dim * 4 + if mlp_type == "fused-gelu-mlp": + self.projector = nn.Sequential( + nn.Linear(fused_vision_dim, self.initial_projection_dim, bias=True), + nn.GELU(), + nn.Linear(self.initial_projection_dim, llm_dim, bias=True), + nn.GELU(), + nn.Linear(llm_dim, llm_dim, bias=True), + ) + else: + raise ValueError(f"Fused Projector with `{mlp_type = }` is not supported!") + + def forward(self, fused_img_patches: torch.Tensor) -> torch.Tensor: + return self.projector(fused_img_patches) diff --git a/prismatic/util/torch_utils.py b/prismatic/util/torch_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..8fc08f6ba74afee76ace76897c8de148ede0611f --- /dev/null +++ b/prismatic/util/torch_utils.py @@ -0,0 +1,95 @@ +""" +torch_utils.py + +General utilities for randomness, mixed precision training, and miscellaneous checks in PyTorch. + +Random `set_global_seed` functionality is taken directly from PyTorch-Lighting: + > Ref: https://github.com/PyTorchLightning/pytorch-lightning/blob/master/pytorch_lightning/utilities/seed.py + +This is pretty important to get right if we're every randomly generating our masks (or prefix dropout) inside our +Dataset __getitem__() with multiple workers... if not handled properly, we will get repeated augmentations anytime +we inject randomness from non-PyTorch sources (e.g., numpy, random)! + > Ref: https://tanelp.github.io/posts/a-bug-that-plagues-thousands-of-open-source-ml-projects/ + +Terminology + -> World Size :: Total number of processes distributed over (# nodes x # devices) -- assumed homogenous! + -> Rank :: Integer index of current process in the total world size + -> Local Rank :: Local index on given node in [0, Devices per Node] +""" + +import os +import random +from typing import Callable, Optional + +import numpy as np +import torch + +# === Randomness === + + +def set_global_seed(seed: int, get_worker_init_fn: bool = False) -> Optional[Callable[[int], None]]: + """Sets seed for all randomness libraries (mostly random, numpy, torch) and produces a `worker_init_fn`""" + assert np.iinfo(np.uint32).min < seed < np.iinfo(np.uint32).max, "Seed outside the np.uint32 bounds!" + + # Set Seed as an Environment Variable + os.environ["EXPERIMENT_GLOBAL_SEED"] = str(seed) + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + + return worker_init_function if get_worker_init_fn else None + + +def worker_init_function(worker_id: int) -> None: + """ + Borrowed directly from PyTorch-Lightning; inspired by this issue comment in the PyTorch repo: + > Ref: https://github.com/pytorch/pytorch/issues/5059#issuecomment-817392562 + + Intuition: You can think of the seed sequence spawn function as a "janky" torch.Generator() or jax.PRNGKey that + you can run iterative splitting on to get new (predictable) randomness. + + :param worker_id: Identifier for the given worker [0, num_workers) for the Dataloader in question. + """ + # Get current `rank` (if running distributed) and `process_seed` + global_rank, process_seed = int(os.environ["LOCAL_RANK"]), torch.initial_seed() + + # Back out the "base" (original) seed - the per-worker seed is set in PyTorch: + # > https://pytorch.org/docs/stable/data.html#data-loading-randomness + base_seed = process_seed - worker_id + + # "Magic" code --> basically creates a seed sequence that mixes different "sources" and seeds every library... + seed_seq = np.random.SeedSequence([base_seed, worker_id, global_rank]) + + # Use 128 bits (4 x 32-bit words) to represent seed --> generate_state(k) produces a `k` element array! + np.random.seed(seed_seq.generate_state(4)) + + # Spawn distinct child sequences for PyTorch (reseed) and stdlib random + torch_seed_seq, random_seed_seq = seed_seq.spawn(2) + + # Torch Manual seed takes 64 bits (so just specify a dtype of uint64 + torch.manual_seed(torch_seed_seq.generate_state(1, dtype=np.uint64)[0]) + + # Use 128 Bits for `random`, but express as integer instead of as an array + random_seed = (random_seed_seq.generate_state(2, dtype=np.uint64).astype(list) * [1 << 64, 1]).sum() + random.seed(random_seed) + + +# === BFloat16 Support === + + +def check_bloat16_supported() -> bool: + try: + import packaging.version + import torch.cuda.nccl as nccl + import torch.distributed as dist + + return ( + (torch.version.cuda is not None) + and torch.cuda.is_bf16_supported() + and (packaging.version.parse(torch.version.cuda).release >= (11, 0)) + and dist.is_nccl_available() + and (nccl.version() >= (2, 10)) + ) + + except Exception: + return False diff --git a/prismatic/vla/__init__.py b/prismatic/vla/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..5d2af7062f3a1c94d41b4734c89358b416862999 --- /dev/null +++ b/prismatic/vla/__init__.py @@ -0,0 +1 @@ +from .materialize import get_vla_dataset_and_collator diff --git a/prismatic/vla/action_tokenizer.py b/prismatic/vla/action_tokenizer.py new file mode 100644 index 0000000000000000000000000000000000000000..fc1841a714f40ba677a1493782da23db4f9d4f4b --- /dev/null +++ b/prismatic/vla/action_tokenizer.py @@ -0,0 +1,72 @@ +""" +action_tokenizer.py + +Extension class; wraps base LLM/VLM tokenizer with logic to discretize and tokenize continuous robot actions. +""" + +from typing import List, Union + +import numpy as np +from transformers import PreTrainedTokenizerBase + + +class ActionTokenizer: + def __init__( + self, tokenizer: PreTrainedTokenizerBase, bins: int = 256, min_action: int = -1, max_action: int = 1 + ) -> None: + """ + Discretizes continuous robot actions into N bins per dimension and maps to the least used tokens. + + NOTE =>> by default, assumes a BPE-style tokenizer akin to the LlamaTokenizer, where *the least used tokens* + appear at the end of the vocabulary! + + :param tokenizer: Base LLM/VLM tokenizer to extend. + :param bins: Number of bins for each continuous value; we'll adopt a uniform binning strategy. + :param min_action: Minimum action value (for clipping, setting lower bound on bin interval). + :param max_action: Maximum action value (for clipping, setting upper bound on bin interval). + """ + self.tokenizer, self.n_bins, self.min_action, self.max_action = tokenizer, bins, min_action, max_action + + # Create Uniform Bins + Compute Bin Centers + self.bins = np.linspace(min_action, max_action, self.n_bins) + self.bin_centers = (self.bins[:-1] + self.bins[1:]) / 2.0 + + # [Contract] Set "action_token_begin_idx" based on `self.tokenizer.vocab_size - (self.n_bins + 1)` + # =>> Assumes we're always overwriting the final `n_bins` tokens of the vocabulary! + self.action_token_begin_idx: int = int(self.tokenizer.vocab_size - (self.n_bins + 1)) + + def __call__(self, action: np.ndarray) -> Union[str, List[str]]: + """Clip & bin actions to *the last `n_bins` tokens* of the vocabulary (e.g., tokenizer.vocab[-256:]).""" + action = np.clip(action, a_min=float(self.min_action), a_max=float(self.max_action)) + discretized_action = np.digitize(action, self.bins) + + # Handle single element vs. batch + if len(discretized_action.shape) == 1: + return self.tokenizer.decode(list(self.tokenizer.vocab_size - discretized_action)) + else: + return self.tokenizer.batch_decode((self.tokenizer.vocab_size - discretized_action).tolist()) + + def decode_token_ids_to_actions(self, action_token_ids: np.ndarray) -> np.ndarray: + """ + Returns continuous actions for discrete action token IDs. + + NOTE =>> Because of the way the actions are discretized w.r.t. the bins (and not the bin centers), the + digitization returns bin indices between [1, # bins], inclusive, when there are actually only + (# bins - 1) bin intervals. + + Therefore, if the digitization returns the last possible index, we map this to the last bin interval. + + EXAMPLE =>> Let's say self._bins has 256 values. Then self._bin_centers has 255 values. Digitization returns + indices between [1, 256]. We subtract 1 from all indices so that they are between [0, 255]. There + is still one index (i==255) that would cause an out-of-bounds error if used to index into + self._bin_centers. Therefore, if i==255, we subtract 1 from it so that it just becomes the index of + the last bin center. We implement this simply via clipping between [0, 255 - 1]. + """ + discretized_actions = self.tokenizer.vocab_size - action_token_ids + discretized_actions = np.clip(discretized_actions - 1, a_min=0, a_max=self.bin_centers.shape[0] - 1) + + return self.bin_centers[discretized_actions] + + @property + def vocab_size(self) -> int: + return self.n_bins diff --git a/prismatic/vla/constants.py b/prismatic/vla/constants.py new file mode 100644 index 0000000000000000000000000000000000000000..308c1e11468d9a2bb330d2d580b296a67c564676 --- /dev/null +++ b/prismatic/vla/constants.py @@ -0,0 +1,97 @@ +""" +Important constants for VLA training and evaluation. + +Attempts to automatically identify the correct constants to set based on the Python command used to launch +training or evaluation. If it is unclear, defaults to using the LIBERO simulation benchmark constants. +""" +import sys +from enum import Enum + +# Llama 2 token constants +IGNORE_INDEX = -100 +ACTION_TOKEN_BEGIN_IDX = 31743 +STOP_INDEX = 2 # '' + + +# Defines supported normalization schemes for action and proprioceptive state. +class NormalizationType(str, Enum): + # fmt: off + NORMAL = "normal" # Normalize to Mean = 0, Stdev = 1 + BOUNDS = "bounds" # Normalize to Interval = [-1, 1] + BOUNDS_Q99 = "bounds_q99" # Normalize [quantile_01, ..., quantile_99] --> [-1, ..., 1] + # fmt: on + + +# Define constants for each robot platform +LIBERO_CONSTANTS = { + "NUM_ACTIONS_CHUNK": 8, + "ACTION_DIM": 7, + "PROPRIO_DIM": 8, + "ACTION_PROPRIO_NORMALIZATION_TYPE": NormalizationType.BOUNDS_Q99, +} + +ALOHA_CONSTANTS = { + "NUM_ACTIONS_CHUNK": 25, + "ACTION_DIM": 14, + "PROPRIO_DIM": 14, + "ACTION_PROPRIO_NORMALIZATION_TYPE": NormalizationType.BOUNDS, +} + +BRIDGE_CONSTANTS = { + "NUM_ACTIONS_CHUNK": 5, + "ACTION_DIM": 7, + "PROPRIO_DIM": 7, + "ACTION_PROPRIO_NORMALIZATION_TYPE": NormalizationType.BOUNDS_Q99, +} + +CONFLICT_MANISKILL_CONSTANTS = { + "NUM_ACTIONS_CHUNK": 8, # same chunk size as LIBERO + "ACTION_DIM": 8, # 7 Panda joints + 1 gripper open (0/1) + "PROPRIO_DIM": 8, # 7 Panda joints + 1 gripper width + "ACTION_PROPRIO_NORMALIZATION_TYPE": NormalizationType.BOUNDS_Q99, +} + + +# Function to detect robot platform from command line arguments +def detect_robot_platform(): + cmd_args = " ".join(sys.argv).lower() + + if "conflict_maniskill" in cmd_args: + return "CONFLICT_MANISKILL" + elif "libero" in cmd_args: + return "LIBERO" + elif "aloha" in cmd_args: + return "ALOHA" + elif "bridge" in cmd_args: + return "BRIDGE" + else: + # Default to LIBERO if unclear + return "LIBERO" + + +# Determine which robot platform to use +ROBOT_PLATFORM = detect_robot_platform() + +# Set the appropriate constants based on the detected platform +if ROBOT_PLATFORM == "CONFLICT_MANISKILL": + constants = CONFLICT_MANISKILL_CONSTANTS +elif ROBOT_PLATFORM == "LIBERO": + constants = LIBERO_CONSTANTS +elif ROBOT_PLATFORM == "ALOHA": + constants = ALOHA_CONSTANTS +elif ROBOT_PLATFORM == "BRIDGE": + constants = BRIDGE_CONSTANTS + +# Assign constants to global variables +NUM_ACTIONS_CHUNK = constants["NUM_ACTIONS_CHUNK"] +ACTION_DIM = constants["ACTION_DIM"] +PROPRIO_DIM = constants["PROPRIO_DIM"] +ACTION_PROPRIO_NORMALIZATION_TYPE = constants["ACTION_PROPRIO_NORMALIZATION_TYPE"] + +# Print which robot platform constants are being used (for debugging) +print(f"Using {ROBOT_PLATFORM} constants:") +print(f" NUM_ACTIONS_CHUNK = {NUM_ACTIONS_CHUNK}") +print(f" ACTION_DIM = {ACTION_DIM}") +print(f" PROPRIO_DIM = {PROPRIO_DIM}") +print(f" ACTION_PROPRIO_NORMALIZATION_TYPE = {ACTION_PROPRIO_NORMALIZATION_TYPE}") +print("If needed, manually set the correct constants in `prismatic/vla/constants.py`!") diff --git a/prismatic/vla/datasets/__init__.py b/prismatic/vla/datasets/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..dd620793f354ff7889151456dfdc4d5136b6edcd --- /dev/null +++ b/prismatic/vla/datasets/__init__.py @@ -0,0 +1 @@ +from .datasets import DummyDataset, EpisodicRLDSDataset, RLDSBatchTransform, RLDSDataset diff --git a/prismatic/vla/datasets/datasets.py b/prismatic/vla/datasets/datasets.py new file mode 100644 index 0000000000000000000000000000000000000000..06cadbf947c8320727c001e77ac614241d9a0031 --- /dev/null +++ b/prismatic/vla/datasets/datasets.py @@ -0,0 +1,261 @@ +""" +datasets.py + +Lightweight PyTorch Dataset Definition for wrapping RLDS TFDS Pipeline; just defines transform from RLDS default +format to OpenVLA, IterableDataset shim. +""" + +from dataclasses import dataclass +from pathlib import Path +from typing import Any, Dict, Tuple, Type + +import numpy as np +import torch +from PIL import Image +from torch.utils.data import Dataset, IterableDataset +from transformers import PreTrainedTokenizerBase + +from prismatic.models.backbones.llm.prompting import PromptBuilder +from prismatic.models.backbones.vision import ImageTransform +from prismatic.util.data_utils import tree_map +from prismatic.vla.action_tokenizer import ActionTokenizer +from prismatic.vla.constants import ACTION_DIM, ACTION_PROPRIO_NORMALIZATION_TYPE, ACTION_TOKEN_BEGIN_IDX, IGNORE_INDEX, NUM_ACTIONS_CHUNK, PROPRIO_DIM, STOP_INDEX +from prismatic.vla.datasets.rlds import make_interleaved_dataset, make_single_dataset +from prismatic.vla.datasets.rlds.oxe import OXE_NAMED_MIXTURES, get_oxe_dataset_kwargs_and_weights + +@dataclass +class RLDSBatchTransform: + action_tokenizer: ActionTokenizer + base_tokenizer: PreTrainedTokenizerBase + image_transform: ImageTransform + prompt_builder_fn: Type[PromptBuilder] + predict_stop_token: bool = True + use_wrist_image: bool = False + use_proprio: bool = False + + def __call__(self, rlds_batch: Dict[str, Any]) -> Dict[str, Any]: + """Converts a RLDS batch to the format expected by the OpenVLA collator/models.""" + dataset_name, current_action = rlds_batch["dataset_name"], rlds_batch["action"][0] + img = Image.fromarray(rlds_batch["observation"]["image_primary"][0]) + lang = rlds_batch["task"]["language_instruction"].decode().lower() + actions = rlds_batch["action"] + + # Construct Chat-based Prompt =>> Input is default query + language instruction, output are the action tokens + prompt_builder = self.prompt_builder_fn("openvla") + + # Get future action chunk + future_actions = rlds_batch["action"][1:] + future_actions_string = ''.join(self.action_tokenizer(future_actions)) + + # Get action chunk string + current_action_string = self.action_tokenizer(current_action) + action_chunk_string = current_action_string + future_actions_string + action_chunk_len = len(action_chunk_string) + + conversation = [ + {"from": "human", "value": f"What action should the robot take to {lang}?"}, + {"from": "gpt", "value": action_chunk_string}, + ] + for turn in conversation: + prompt_builder.add_turn(turn["from"], turn["value"]) + + # Tokenize (w/ `base_tokenizer`) + input_ids = self.base_tokenizer(prompt_builder.get_prompt(), add_special_tokens=True).input_ids + labels = list(input_ids) + + # Tensorize =>> Run Image Transform to get `pixel_values` =>> Return + # =>> IMPORTANT :: IF WE'RE USING HF LLM.forward(..., labels=labels), SHIFTING HAPPENS _INSIDE_ MODEL! + input_ids, labels = torch.tensor(input_ids), torch.tensor(labels) + pixel_values = self.image_transform(img) + + # [CRITICAL] We do not want to take the loss for anything but the predicted action tokens! + labels[: -(action_chunk_len + 1)] = IGNORE_INDEX + if not self.predict_stop_token: + labels[-1] = IGNORE_INDEX + + return_dict = dict(pixel_values=pixel_values, input_ids=input_ids, labels=labels, dataset_name=dataset_name, actions=actions) + + # Add additional inputs + if self.use_wrist_image: + all_wrist_pixels = [] + for k in rlds_batch["observation"].keys(): + if "wrist" in k: + img_wrist = Image.fromarray(rlds_batch["observation"][k][0]) + pixel_values_wrist = self.image_transform(img_wrist) + all_wrist_pixels.append(pixel_values_wrist) + return_dict["pixel_values_wrist"] = torch.cat(all_wrist_pixels, dim=0) + if self.use_proprio and "proprio" in rlds_batch["observation"]: + proprio = rlds_batch["observation"]["proprio"] + return_dict["proprio"] = proprio + + return return_dict + + +class RLDSDataset(IterableDataset): + def __init__( + self, + data_root_dir: Path, + data_mix: str, + batch_transform: RLDSBatchTransform, + resize_resolution: Tuple[int, int], + shuffle_buffer_size: int = 256_000, + train: bool = True, + image_aug: bool = False, + ) -> None: + """Lightweight wrapper around RLDS TFDS Pipeline for use with PyTorch/OpenVLA Data Loaders.""" + self.data_root_dir, self.data_mix, self.batch_transform = data_root_dir, data_mix, batch_transform + + # Configure RLDS Dataset(s) + if self.data_mix in OXE_NAMED_MIXTURES: + mixture_spec = OXE_NAMED_MIXTURES[self.data_mix] + else: + # Assume that passed "mixture" name is actually a single dataset -- create single-dataset "mix" + mixture_spec = [(self.data_mix, 1.0)] + + # fmt: off + if "aloha" in self.data_mix: + load_camera_views = ("primary", "left_wrist", "right_wrist") + else: + load_camera_views = ("primary", "wrist") + + per_dataset_kwargs, weights = get_oxe_dataset_kwargs_and_weights( + self.data_root_dir, + mixture_spec, + load_camera_views=load_camera_views, + load_depth=False, + load_proprio=True, + load_language=True, + action_proprio_normalization_type=ACTION_PROPRIO_NORMALIZATION_TYPE, + ) + rlds_config = dict( + traj_transform_kwargs=dict( + window_size=1, # If we wanted to feed / predict more than one step + future_action_window_size=NUM_ACTIONS_CHUNK-1, # For action chunking + skip_unlabeled=True, # Skip trajectories without language labels + goal_relabeling_strategy="uniform", # Goals are currently unused + ), + frame_transform_kwargs=dict( + resize_size=resize_resolution, + num_parallel_calls=16, # For CPU-intensive ops (decoding, resizing, etc.) + ), + dataset_kwargs_list=per_dataset_kwargs, + shuffle_buffer_size=shuffle_buffer_size, + sample_weights=weights, + balance_weights=True, + traj_transform_threads=len(mixture_spec), + traj_read_threads=len(mixture_spec), + train=train, + ) + + # If applicable, enable image augmentations + if image_aug: + rlds_config["frame_transform_kwargs"].update({"image_augment_kwargs" : dict( + random_resized_crop=dict(scale=[0.9, 0.9], ratio=[1.0, 1.0]), + random_brightness=[0.2], + random_contrast=[0.8, 1.2], + random_saturation=[0.8, 1.2], + random_hue=[0.05], + augment_order=[ + "random_resized_crop", + "random_brightness", + "random_contrast", + "random_saturation", + "random_hue", + ], + )}), + # fmt: on + + # Initialize RLDS Dataset + self.dataset, self.dataset_length, self.dataset_statistics = self.make_dataset(rlds_config) + + def make_dataset(self, rlds_config): + return make_interleaved_dataset(**rlds_config) + + def __iter__(self) -> Dict[str, Any]: + for rlds_batch in self.dataset.as_numpy_iterator(): + yield self.batch_transform(rlds_batch) + + def __len__(self) -> int: + return self.dataset_length + + # === Explicitly Unused === + def __getitem__(self, idx: int) -> None: + raise NotImplementedError("IterableDataset does not implement map-style __getitem__; see __iter__ instead!") + + +class EpisodicRLDSDataset(RLDSDataset): + """Returns full episodes as list of steps instead of individual transitions (useful for visualizations).""" + + def make_dataset(self, rlds_config): + per_dataset_kwargs = rlds_config["dataset_kwargs_list"] + assert len(per_dataset_kwargs) == 1, "Only support single-dataset `mixes` for episodic datasets." + + return make_single_dataset( + per_dataset_kwargs[0], + train=rlds_config["train"], + traj_transform_kwargs=rlds_config["traj_transform_kwargs"], + frame_transform_kwargs=rlds_config["frame_transform_kwargs"], + ) + + def __iter__(self) -> Dict[str, Any]: + for rlds_batch in self.dataset.as_numpy_iterator(): + out = [ + self.batch_transform(tree_map(lambda x: x[i], rlds_batch)) # noqa: B023 + for i in range(rlds_batch["action"].shape[0]) + ] + yield out + + +class DummyDataset(Dataset): + def __init__( + self, + action_tokenizer: ActionTokenizer, + base_tokenizer: PreTrainedTokenizerBase, + image_transform: ImageTransform, + prompt_builder_fn: Type[PromptBuilder], + ) -> None: + self.action_tokenizer = action_tokenizer + self.base_tokenizer = base_tokenizer + self.image_transform = image_transform + self.prompt_builder_fn = prompt_builder_fn + + # Note =>> We expect the dataset to store statistics for action de-normalization. Specifically, we store the + # per-dimension 1st and 99th action quantile. The values below correspond to "no normalization" for simplicity. + self.dataset_statistics = { + "dummy_dataset": { + "action": {"q01": np.zeros((7,), dtype=np.float32), "q99": np.ones((7,), dtype=np.float32)} + } + } + + def __len__(self): + # TODO =>> Replace with number of elements in your dataset! + return 10000 + + def __getitem__(self, idx): + # TODO =>> Load image, action and instruction from disk -- we use dummy values + image = Image.fromarray(np.asarray(np.random.rand(224, 224, 3) * 255.0, dtype=np.uint8)) + action = np.asarray(np.random.rand(7), dtype=np.float32) + instruction = "do something spectacular" + + # Add instruction to VLA prompt + prompt_builder = self.prompt_builder_fn("openvla") + conversation = [ + {"from": "human", "value": f"What action should the robot take to {instruction}?"}, + {"from": "gpt", "value": self.action_tokenizer(action)}, + ] + for turn in conversation: + prompt_builder.add_turn(turn["from"], turn["value"]) + + # Tokenize (w/ `base_tokenizer`) + input_ids = self.base_tokenizer(prompt_builder.get_prompt(), add_special_tokens=True).input_ids + labels = list(input_ids) + + # Tensorize =>> Run Image Transform to get `pixel_values` =>> Return + # =>> IMPORTANT :: IF WE'RE USING HF .forward(..., labels=labels), SHIFTING HAPPENS _INSIDE_ MODEL! + input_ids, labels = torch.tensor(input_ids), torch.tensor(labels) + pixel_values = self.image_transform(image) + + # [CRITICAL] We do not want to take the loss for anything but the predicted action tokens! + labels[: -(len(action) + 1)] = IGNORE_INDEX + + return dict(pixel_values=pixel_values, input_ids=input_ids, labels=labels) diff --git a/prismatic/vla/datasets/rlds/__init__.py b/prismatic/vla/datasets/rlds/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..d19440506f5ca53a1f6005e2b072174c743ec546 --- /dev/null +++ b/prismatic/vla/datasets/rlds/__init__.py @@ -0,0 +1 @@ +from .dataset import make_interleaved_dataset, make_single_dataset diff --git a/prismatic/vla/datasets/rlds/conflict_maniskill_dataset_builder.py b/prismatic/vla/datasets/rlds/conflict_maniskill_dataset_builder.py new file mode 100644 index 0000000000000000000000000000000000000000..6dd3a0bfda8c64aea042533bc7b566dfe54a46e8 --- /dev/null +++ b/prismatic/vla/datasets/rlds/conflict_maniskill_dataset_builder.py @@ -0,0 +1,200 @@ +""" +conflict_maniskill_dataset_builder.py + +TFDS dataset builder that converts LeRobot v2.1 parquet data (ManiSkill Panda conflict dataset) +into RLDS-compatible TFRecord shards readable by openvla-oft's RLDSDataset pipeline. + +Data format (LeRobot v2.1 parquet): + - image: dict with 'bytes' key → PNG-encoded 256×256 RGB + - wrist_image: dict with 'bytes' key → PNG-encoded 256×256 RGB + - state: float32[8] = [joint0..6 (rad), gripper_width (m)] + - actions: float32[8] = [joint0..6 (rad), gripper_open (0=closed, 1=open)] + - task_index: int64 → maps to task string via tasks.jsonl + +Usage (run once before launching finetune): + python conflict_maniskill_dataset_builder.py \ + --data_root /home/jtremblay/yu/conflict_data/color_object/300/huggingface_data/color_object/conflict \ + --output_dir /home/jtremblay/yu/rlds_data +""" + +import argparse +import io +import json +import os +from pathlib import Path +from typing import Any, Dict, Generator, Iterator, List, Tuple + +import numpy as np +import tensorflow as tf +import tensorflow_datasets as tfds +from PIL import Image + + +# --------------------------------------------------------------------------- +# Helper +# --------------------------------------------------------------------------- + +def _load_tasks(meta_dir: Path) -> Dict[int, str]: + tasks = {} + with open(meta_dir / "tasks.jsonl") as f: + for line in f: + obj = json.loads(line.strip()) + tasks[obj["task_index"]] = obj["task"] + return tasks + + +def _load_episodes(meta_dir: Path) -> List[Dict]: + episodes = [] + with open(meta_dir / "episodes.jsonl") as f: + for line in f: + episodes.append(json.loads(line.strip())) + return episodes + + +def _parquet_path(data_dir: Path, episode_index: int, chunks_size: int = 1000) -> Path: + chunk = episode_index // chunks_size + return data_dir / f"chunk-{chunk:03d}" / f"episode_{episode_index:06d}.parquet" + + +def _decode_png_bytes(png_bytes: bytes) -> np.ndarray: + img = Image.open(io.BytesIO(png_bytes)).convert("RGB") + return np.array(img, dtype=np.uint8) + + +# --------------------------------------------------------------------------- +# TFDS builder +# --------------------------------------------------------------------------- + +class ConflictManiskill(tfds.core.GeneratorBasedBuilder): + """TFDS builder for ManiSkill Panda conflict dataset (LeRobot v2.1 → RLDS).""" + + VERSION = tfds.core.Version("1.0.0") + RELEASE_NOTES = {"1.0.0": "Initial release."} + + # These are set at build-time via overridden builder_config or patched externally. + # We expose them as class-level attrs so the factory function below can set them. + _data_root: Path = None # path to dataset root (contains data/ and meta/) + + def _info(self) -> tfds.core.DatasetInfo: + return tfds.core.DatasetInfo( + builder=self, + description="ManiSkill Panda conflict dataset in RLDS format.", + features=tfds.features.FeaturesDict( + { + "steps": tfds.features.Dataset( + { + "observation": tfds.features.FeaturesDict( + { + "image": tfds.features.Image( + shape=(256, 256, 3), dtype=tf.uint8, encoding_format="png" + ), + "wrist_image": tfds.features.Image( + shape=(256, 256, 3), dtype=tf.uint8, encoding_format="png" + ), + "state": tfds.features.Tensor(shape=(8,), dtype=tf.float32), + } + ), + "action": tfds.features.Tensor(shape=(8,), dtype=tf.float32), + "language_instruction": tfds.features.Text(), + "is_first": tf.bool, + "is_last": tf.bool, + "is_terminal": tf.bool, + } + ) + } + ), + supervised_keys=None, + homepage="", + ) + + def _split_generators(self, dl_manager: tfds.download.DownloadManager): + data_root = self.__class__._data_root + assert data_root is not None, "_data_root must be set before calling build()" + return {"train": self._generate_examples(data_root)} + + def _generate_examples(self, data_root: Path) -> Generator[Tuple[str, Dict], None, None]: + meta_dir = data_root / "meta" + data_dir = data_root / "data" + + tasks = _load_tasks(meta_dir) + episodes = _load_episodes(meta_dir) + + # read chunks_size from info.json + with open(meta_dir / "info.json") as f: + info = json.load(f) + chunks_size = info.get("chunks_size", 1000) + + import pandas as pd + + for ep_meta in episodes: + ep_idx = ep_meta["episode_index"] + task_str = ep_meta["tasks"][0] if ep_meta.get("tasks") else tasks.get(0, "") + parquet_path = _parquet_path(data_dir, ep_idx, chunks_size) + + df = pd.read_parquet(str(parquet_path)) + n_steps = len(df) + + steps = [] + for i, row in df.iterrows(): + img_arr = _decode_png_bytes(row["image"]["bytes"]) + wrist_arr = _decode_png_bytes(row["wrist_image"]["bytes"]) + state = np.array(row["state"], dtype=np.float32) + action = np.array(row["actions"], dtype=np.float32) + + steps.append( + { + "observation": { + "image": img_arr, + "wrist_image": wrist_arr, + "state": state, + }, + "action": action, + "language_instruction": task_str, + "is_first": (i == df.index[0]), + "is_last": (i == df.index[-1]), + "is_terminal": (i == df.index[-1]), + } + ) + + yield f"ep_{ep_idx:06d}", {"steps": steps} + + +# --------------------------------------------------------------------------- +# Standalone conversion script +# --------------------------------------------------------------------------- + +def build_rlds_dataset(data_root: str, output_dir: str) -> None: + """Convert parquet LeRobot dataset to TFDS RLDS shards under output_dir/conflict_maniskill/.""" + data_root_path = Path(data_root) + output_path = Path(output_dir) + output_path.mkdir(parents=True, exist_ok=True) + + # Patch class-level attribute before instantiating builder + ConflictManiskill._data_root = data_root_path + + builder = ConflictManiskill(data_dir=str(output_path)) + builder.download_and_prepare( + download_config=tfds.download.DownloadConfig( + manual_dir=str(data_root_path), + ) + ) + print(f"\nRLDS dataset written to: {builder.data_dir}") + print("Pass the following args to finetune.py:") + print(f" --data_root_dir {output_path}") + print(f" --dataset_name conflict_maniskill") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Convert LeRobot parquet data to RLDS TFDS format.") + parser.add_argument( + "--data_root", + required=True, + help="Path to dataset root dir containing data/ and meta/ (e.g. .../color_object/conflict)", + ) + parser.add_argument( + "--output_dir", + required=True, + help="Output directory for RLDS shards (e.g. /home/jtremblay/yu/rlds_data)", + ) + args = parser.parse_args() + build_rlds_dataset(args.data_root, args.output_dir) diff --git a/prismatic/vla/datasets/rlds/dataset.py b/prismatic/vla/datasets/rlds/dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..f07215a2ddcd9d348b95d7caa8693985e8dc1c98 --- /dev/null +++ b/prismatic/vla/datasets/rlds/dataset.py @@ -0,0 +1,585 @@ +""" +dataset.py + +Core interface script for configuring and initializing RLDS datasets. +""" + +import copy +import inspect +import json +from functools import partial +from typing import Callable, Dict, List, Optional, Tuple, Union + +import dlimp as dl +import numpy as np +import tensorflow as tf +import tensorflow_datasets as tfds + +from prismatic.overwatch import initialize_overwatch +from prismatic.vla.constants import ACTION_DIM, ACTION_PROPRIO_NORMALIZATION_TYPE, ACTION_TOKEN_BEGIN_IDX, IGNORE_INDEX, NUM_ACTIONS_CHUNK, PROPRIO_DIM, STOP_INDEX +from prismatic.vla.datasets.rlds import obs_transforms, traj_transforms +from prismatic.vla.datasets.rlds.utils import goal_relabeling, task_augmentation +from prismatic.vla.datasets.rlds.utils.data_utils import ( + allocate_threads, + get_dataset_statistics, + normalize_action_and_proprio, + pprint_data_mixture, + tree_map, +) + +# Initialize Overwatch =>> Wraps `logging.Logger` +overwatch = initialize_overwatch(__name__) + + +# Configure Tensorflow with *no GPU devices* (to prevent clobber with PyTorch) +tf.config.set_visible_devices([], "GPU") + + +# ruff: noqa: B006 +def make_dataset_from_rlds( + name: str, + data_dir: str, + *, + train: bool, + standardize_fn: Optional[Callable[[dict], dict]] = None, + shuffle: bool = True, + image_obs_keys: Dict[str, Optional[str]] = {}, + depth_obs_keys: Dict[str, Optional[str]] = {}, + state_obs_keys: List[Optional[str]] = (), + language_key: Optional[str] = None, + action_proprio_normalization_type: ACTION_PROPRIO_NORMALIZATION_TYPE, + dataset_statistics: Optional[Union[dict, str]] = None, + absolute_action_mask: Optional[List[bool]] = None, + action_normalization_mask: Optional[List[bool]] = None, + num_parallel_reads: int = tf.data.AUTOTUNE, + num_parallel_calls: int = tf.data.AUTOTUNE, +) -> Tuple[dl.DLataset, dict]: + """ + This function is responsible for loading a specific RLDS dataset from storage and getting it into a standardized + format. Yields a dataset of trajectories. Does not include CPU-intensive operations. + + If `standardize_fn` is provided, it will be applied to each trajectory. This function should get the trajectory + into a standard format, which includes the keys "observation" and "action". Entry "observation" should be a + dictionary containing some number of additional keys, which will be extracted into an even more standardized format + according to the "*_obs_keys" arguments. + + The `image_obs_keys` and `depth_obs_keys` arguments are mappings from new names to old names, or None in place of an + old name to insert padding. For example, if after `standardize_fn`, your "observation" dict has RGB images called + "workspace" and "wrist", and `image_obs_keys={"primary": "workspace", "secondary": None, "wrist": "wrist"}`, then + the resulting dataset will have an "observation" dict containing the keys "image_primary", "image_secondary", and + "image_wrist", where "image_primary" corresponds to "workspace", "image_secondary" is a padding image, and + "image_wrist" corresponds to "wrist". + + Entry `state_obs_keys` is a list of 1-dimensional proprioceptive keys to concatenate into a single array, which will + be placed in the "proprio" key of the "observation" dict. A single padding element (zero) will be inserted for each + None entry. + + The dataset will also include a "task" dict. If `language_key` is provided, then the "task" dict will contain the + key "language_instruction", extracted from `traj[language_key]`. + + Args: + name (str): The name of the RLDS dataset (usually "name" or "name:version"). + data_dir (str): The path to the data directory. + train (bool): Whether to use the training or validation split. + shuffle (bool, optional): Whether to shuffle the file read order (does NOT fully shuffle the dataset, since one + file usually contains many trajectories)! + standardize_fn (Callable[[dict], dict], optional): A function that, if provided, will be the first + thing applied to each trajectory. + image_obs_keys (Mapping[str, str|None]): Mapping from {new: old} indicating which RGB images to extract from the + "observation" dict. `new_obs = {f"image_{new}": old_obs[old] for new, old in image_obs_keys.items()}`. + If a value of `old` is None, inserts a padding image instead (empty string). + depth_obs_keys (Mapping[str, str|None]): Same as `image_obs_keys`, but for depth images. Keys will be + prefixed with "depth_" instead of "image_". + state_obs_keys (Sequence[str|None]): List of 1-dimensional proprioception keys to be extracted from the + "observation" dict, concatenated, and mapped to "proprio". Inserts 1 element of padding for each None entry. + language_key (str, optional): If provided, the "task" dict will contain the key "language_instruction", + extracted from `traj[language_key]`. + action_proprio_normalization_type (str, optional): The type of normalization to perform on the action, + proprio, or both. Can be "normal" (mean 0, std 1) or "bounds" (normalized to [-1, 1]). + dataset_statistics: (dict|str, optional): dict (or path to JSON file) that contains dataset statistics + for normalization. If `action_proprio_normalization_type` is "normal", this should contain "mean" and + "std" keys. If `action_proprio_normalization_type` is "bounds", this should contain "min" and "max" + keys. May also provide "num_transitions" and "num_trajectories" keys for downstream usage (e.g., for + `make_interleaved_dataset`). If not provided, the statistics will be computed on the fly. + absolute_action_mask (Sequence[bool], optional): By default, all action dimensions are assumed to be + relative. This is important for when `future_action_window_size > 0`: actions that are taken + from beyond the end of the trajectory (or beyond the goal timestep when goal relabeling is used) + need to be made "neutral" to indicate that the task has been completed. For relative actions, + "neutral" means zero, but for absolute actions, "neutral" means repeating the last valid action. + This mask, if provided, indicates which action dimensions are absolute. + action_normalization_mask (Sequence[bool], optional): If provided, indicates which action dimensions + should be normalized. For example, you might not want to normalize the gripper action dimension if + it's always exactly 0 or 1. By default, all action dimensions are normalized. + num_parallel_reads (int): number of parallel read workers. Default to AUTOTUNE. + num_parallel_calls (int): number of parallel calls for traj_map operations. Default to AUTOTUNE. + Returns: + Dataset of trajectories where each step has the following fields: + - observation: + - image_{name1, name2, ...} # RGB image observations + - depth_{name1, name2, ...} # depth image observations + - proprio # 1-dimensional array of proprioceptive observations + - timestep # timestep of each frame + - task: + - language_instruction # language instruction, present if `language_key` is provided + - action # action vector + - dataset_name # name of the dataset + """ + REQUIRED_KEYS = {"observation", "action"} + if language_key is not None: + REQUIRED_KEYS.add(language_key) + + def restructure(traj): + # apply a standardization function, if provided + if standardize_fn is not None: + traj = standardize_fn(traj) + + if not all(k in traj for k in REQUIRED_KEYS): + raise ValueError( + f"Trajectory is missing keys: {REQUIRED_KEYS - set(traj.keys())}. " "Did you write a `standardize_fn`?" + ) + + # extracts images, depth images and proprio from the "observation" dict + traj_len = tf.shape(traj["action"])[0] + old_obs = traj["observation"] + new_obs = {} + for new, old in image_obs_keys.items(): + if old is None: + new_obs[f"image_{new}"] = tf.repeat("", traj_len) # padding + else: + new_obs[f"image_{new}"] = old_obs[old] + + for new, old in depth_obs_keys.items(): + if old is None: + new_obs[f"depth_{new}"] = tf.repeat("", traj_len) # padding + else: + new_obs[f"depth_{new}"] = old_obs[old] + + if state_obs_keys: + new_obs["proprio"] = tf.concat( + [ + ( + tf.zeros((traj_len, 1), dtype=tf.float32) # padding + if key is None + else tf.cast(old_obs[key], tf.float32) + ) + for key in state_obs_keys + ], + axis=1, + ) + + # add timestep info + new_obs["timestep"] = tf.range(traj_len) + + # extracts `language_key` into the "task" dict + task = {} + if language_key is not None: + if traj[language_key].dtype != tf.string: + raise ValueError( + f"Language key {language_key} has dtype {traj[language_key].dtype}, " "but it must be tf.string." + ) + task["language_instruction"] = traj.pop(language_key) + + traj = { + "observation": new_obs, + "task": task, + "action": tf.cast(traj["action"], tf.float32), + "dataset_name": tf.repeat(name, traj_len), + } + + if absolute_action_mask is not None: + if len(absolute_action_mask) != traj["action"].shape[-1]: + raise ValueError( + f"Length of absolute_action_mask ({len(absolute_action_mask)}) " + f"does not match action dimension ({traj['action'].shape[-1]})." + ) + traj["absolute_action_mask"] = tf.tile( + tf.convert_to_tensor(absolute_action_mask, dtype=tf.bool)[None], + [traj_len, 1], + ) + + return traj + + builder = tfds.builder(name, data_dir=data_dir) + + # load or compute dataset statistics + if isinstance(dataset_statistics, str): + with tf.io.gfile.GFile(dataset_statistics, "r") as f: + dataset_statistics = json.load(f) + elif dataset_statistics is None: + full_dataset = dl.DLataset.from_rlds( + builder, split="all", shuffle=False, num_parallel_reads=num_parallel_reads + ).traj_map(restructure, num_parallel_calls) + # tries to load from cache, otherwise computes on the fly + dataset_statistics = get_dataset_statistics( + full_dataset, + hash_dependencies=( + str(builder.info), + str(state_obs_keys), + inspect.getsource(standardize_fn) if standardize_fn is not None else "", + ), + save_dir=builder.data_dir, + ) + dataset_statistics = tree_map(np.array, dataset_statistics) + + # skip normalization for certain action dimensions + if action_normalization_mask is not None: + if len(action_normalization_mask) != dataset_statistics["action"]["mean"].shape[-1]: + raise ValueError( + f"Length of skip_normalization_mask ({len(action_normalization_mask)}) " + f"does not match action dimension ({dataset_statistics['action']['mean'].shape[-1]})." + ) + dataset_statistics["action"]["mask"] = np.array(action_normalization_mask) + + # construct the dataset + split = "train" if train else "val" + + dataset = dl.DLataset.from_rlds(builder, split=split, shuffle=shuffle, num_parallel_reads=num_parallel_reads) + + dataset = dataset.traj_map(restructure, num_parallel_calls) + dataset = dataset.traj_map( + partial( + normalize_action_and_proprio, + metadata=dataset_statistics, + normalization_type=action_proprio_normalization_type, + ), + num_parallel_calls, + ) + + return dataset, dataset_statistics + + +def apply_trajectory_transforms( + dataset: dl.DLataset, + *, + train: bool, + goal_relabeling_strategy: Optional[str] = None, + goal_relabeling_kwargs: dict = {}, + window_size: int = 1, + future_action_window_size: int = 0, + subsample_length: Optional[int] = None, + skip_unlabeled: bool = False, + max_action: Optional[float] = None, + max_proprio: Optional[float] = None, + task_augment_strategy: Optional[str] = None, + task_augment_kwargs: dict = {}, + num_parallel_calls: int = tf.data.AUTOTUNE, +) -> dl.DLataset: + """ + Applies common transforms that happen at a trajectory level. Such transforms are usually some sort of "relabeling" + (e.g., filtering, chunking, adding goals, dropping keys). + + Transforms in this function should have the following properties: + - They require access to an entire trajectory (i.e., they cannot be applied frame-wise). + - They are generally not CPU-intensive, mostly involving moving and copying data. + - They do not require decoded images. + + Args: + dataset (dl.DLataset): The dataset to transform. + train (bool): Whether the dataset is for training (affects subsampling). + goal_relabeling_strategy (str, optional): The goal relabeling strategy to use, or None for + no goal relabeling. See `goal_relabeling.py`. + goal_relabeling_kwargs (dict, optional): Additional keyword arguments to pass to the goal relabeling function. + window_size (int, optional): The length of the snippets that trajectories are chunked into. + future_action_window_size (int, optional): The number of future actions beyond window_size to include + in the chunked actions. + subsample_length (int, optional): If provided, trajectories longer than this will be subsampled to + this length (after goal relabeling and chunking). + skip_unlabeled (bool, optional): Whether to skip trajectories with no language labels. + max_action: (float, optional): If provided, trajectories in which *any* action dimension + of *any* transition has an absolute value larger than this will be skipped. + max_proprio: (float, optional): If provided, trajectories in which *any* proprio dimension + of *any* transition has an absolute value larger than this will be skipped. + task_augment_strategy (str, optional): The task augmentation strategy to use, or None for no task + augmentation. See `task_augmentation.py`. + task_augment_kwargs (dict, optional): Additional keyword arguments to pass to the task augmentation + function. + num_parallel_calls (int, optional): number of parallel calls for map operations. Default to AUTOTUNE. + """ + if skip_unlabeled: + if "language_instruction" not in dataset.element_spec["task"]: + raise ValueError("skip_unlabeled=True but dataset does not have language labels.") + + dataset = dataset.filter(lambda x: tf.math.reduce_any(x["task"]["language_instruction"] != "")) + + if max_action is not None: + dataset = dataset.filter(lambda x: tf.math.reduce_all(tf.math.abs(x["action"]) <= max_action)) + + if max_proprio is not None and "proprio" in dataset.element_spec["observation"]: + dataset = dataset.filter(lambda x: tf.math.reduce_all(tf.math.abs(x["observation"]["proprio"]) <= max_proprio)) + + # marks which entires of the observation and task dicts are padding + dataset = dataset.traj_map(traj_transforms.add_pad_mask_dict, num_parallel_calls) + + # updates the "task" dict + if goal_relabeling_strategy is not None: + dataset = dataset.traj_map( + partial(getattr(goal_relabeling, goal_relabeling_strategy), **goal_relabeling_kwargs), + num_parallel_calls, + ) + + # must run task augmentation before chunking, in case it changes goal timesteps + if train and task_augment_strategy is not None: + # perform task augmentation (e.g., dropping keys) + dataset = dataset.traj_map( + partial( + getattr(task_augmentation, task_augment_strategy), + **task_augment_kwargs, + ), + num_parallel_calls, + ) + + # chunks observations and actions, giving them a new axis at index 1 of size `window_size` and + # `window_size + future_action_window_size`, respectively + dataset = dataset.traj_map( + partial( + traj_transforms.chunk_act_obs, + window_size=window_size, + future_action_window_size=future_action_window_size, + ), + num_parallel_calls, + ) + + if train and subsample_length is not None: + dataset = dataset.traj_map( + partial(traj_transforms.subsample, subsample_length=subsample_length), + num_parallel_calls, + ) + + return dataset + + +def apply_per_dataset_frame_transforms( + dataset: dl.DLataset, + chunk_filter_fn: Optional[Callable] = None, +): + """ + Optionally applied *per-dataset* transforms that happen at a frame level. + + Args: + chunk_filter_fn (callable, optional): Filter function for chunks. + """ + if chunk_filter_fn: + dataset = dataset.filter(chunk_filter_fn) + return dataset + + +def apply_frame_transforms( + dataset: dl.DLataset, + *, + train: bool, + image_augment_kwargs: Union[Dict, Dict[str, Dict]] = {}, + resize_size: Union[Tuple[int, int], Dict[str, Tuple[int, int]]] = {}, + depth_resize_size: Union[Tuple[int, int], Dict[str, Tuple[int, int]]] = {}, + num_parallel_calls: int = tf.data.AUTOTUNE, +) -> dl.DLataset: + """ + Applies common transforms that happen at a frame level. These transforms are usually more CPU-intensive, (e.g., + decoding or resizing images). + + Args: + train (bool): Whether the dataset is for training (affects image augmentation). + dataset (dl.DLataset): The dataset to transform. + image_augment_kwargs (dict|Mapping[str, dict]): Keyword arguments to pass to the image augmentation + function. See `dlimp.transforms.augment_image` for documentation of these kwargs. If a dict of + dicts is provided, then key "k" will be used for "image_{k}" (names determined by `image_obs_keys` + in `make_dataset_from_rlds`). Augmentation will be skipped for missing keys (so pass an empty dict + to skip augmentation for all images). + resize_size (Tuple[int, int]|Mapping[str, Tuple[int, int]]): If provided, images will be resized to + this size. If a dict of tuples is provided, then key "k" will be used for "image_{k}" (names + determined by `image_obs_keys` in `make_dataset_from_rlds`). Resizing will be skipped for missing + keys (so pass an empty dict to skip resizing for all images). + depth_resize_size (Tuple[int, int]|Mapping[str, Tuple[int, int]]): Same as resize_size, but for depth + images. + num_parallel_calls (int): number of parallel calls for frame_map operations. Default to AUTOTUNE. + """ + + # Convenience wrapper that takes a function that operates on a non-chunked "observation" dict and applies + # it to the chunked "observation" dict as well as the non-chunked "task" dict + def apply_obs_transform(fn: Callable[[Dict], Dict], frame: Dict) -> Dict: + frame["task"] = fn(frame["task"]) + frame["observation"] = dl.vmap(fn)(frame["observation"]) + return frame + + # Decode + resize images (and depth images) + dataset = dataset.frame_map( + partial( + apply_obs_transform, + partial(obs_transforms.decode_and_resize, resize_size=resize_size, depth_resize_size=depth_resize_size), + ), + num_parallel_calls, + ) + + if train: + # Augment all images with the same seed, skipping padding images + def aug(frame: dict): + seed = tf.random.uniform([2], maxval=tf.dtypes.int32.max, dtype=tf.int32) + aug_fn = partial(obs_transforms.augment, seed=seed, augment_kwargs=image_augment_kwargs) + return apply_obs_transform(aug_fn, frame) + + dataset = dataset.frame_map(aug, num_parallel_calls) + + return dataset + + +def make_single_dataset( + dataset_kwargs: dict, + *, + train: bool, + traj_transform_kwargs: dict = {}, + frame_transform_kwargs: dict = {}, +) -> dl.DLataset: + """Creates a single dataset from kwargs. Returns a dataset of trajectories. + + Args: + dataset_kwargs: kwargs passed to `make_dataset_from_rlds` that are dataset-specific. + train: whether this is a training or validation dataset. + traj_transform_kwargs: kwargs passed to 'apply_trajectory_transforms'. + frame_transform_kwargs: kwargs passed to 'get_frame_transforms'. + """ + dataset, dataset_statistics = make_dataset_from_rlds( + **dataset_kwargs, + train=train, + ) + dataset = apply_trajectory_transforms(dataset, **traj_transform_kwargs, train=train) + dataset = apply_frame_transforms(dataset, **frame_transform_kwargs, train=train) + + # this seems to reduce memory usage without affecting speed + dataset = dataset.with_ram_budget(1) + + # save for later + return dataset, dataset_statistics["num_trajectories"], dataset_statistics + + +# === Core Initializer === +def make_interleaved_dataset( + dataset_kwargs_list: List[Dict], + sample_weights: Optional[List[float]] = None, + *, + train: bool, + shuffle_buffer_size: int, + traj_transform_kwargs: Optional[Dict] = None, + frame_transform_kwargs: Optional[Dict] = None, + batch_size: Optional[int] = None, + balance_weights: bool = False, + traj_transform_threads: Optional[int] = None, + traj_read_threads: Optional[int] = None, +) -> dl.DLataset: + """ + Creates an interleaved dataset from list of dataset configs (kwargs). Returns a dataset of batched frames. + + Args: + dataset_kwargs_list: list of kwargs, each element of which is passed to `make_dataset_from_rlds`. + "num_parallel_calls" and "num_parallel_reads" are overridden using `traj_transform_threads` and + `traj_read_threads`, respectively. + sample_weights: sampling weights for each dataset in list. If None, defaults to uniform. + train: whether this is a training or validation dataset. + shuffle_buffer_size: size of the dataset shuffle buffer (in number of frames). + traj_transform_kwargs: kwargs passed to `apply_trajectory_transforms`. "num_parallel_calls" is + overridden using `traj_transform_threads`. + frame_transform_kwargs: kwargs passed to `apply_frame_transforms`. + batch_size: batch size, if not provided output is not batched. + balance_weights: if True, the sample weights are multiplied by the number of frames in each dataset. + This makes it so that, if all the sample weights are equal, one full iteration through the interleaved + dataset will correspond to one full iteration through each individual dataset (only in expectation, + since in practice the sampling is random). + traj_transform_threads: total number of parallel calls for trajectory transforms, distributed across + datasets according to their sampling weights. If None, defaults to AUTOTUNE for every dataset. + traj_read_threads: total number of parallel read workers for trajectory transforms, distributed across + datasets according to their sampling weights. If None, defaults to AUTOTUNE for every dataset. + """ + # Default to uniform sampling (if `sample_weights` is not specified) + if not sample_weights: + sample_weights = [1.0] * len(dataset_kwargs_list) + + if len(sample_weights) != len(dataset_kwargs_list): + raise ValueError(f"sample_weights must be None or have length {len(dataset_kwargs_list)}.") + + # Check valid `traj_transform_kwargs` and `frame_transform_kwargs` + if (traj_transform_kwargs is None) or (frame_transform_kwargs is None): + raise ValueError("Missing `traj_transform_kwargs` and `frame_transform_kwargs`!") + + # Get Dataset Sizes + dataset_sizes, all_dataset_statistics = [], {} + for dataset_kwargs in dataset_kwargs_list: + data_kwargs = copy.deepcopy(dataset_kwargs) + if "dataset_frame_transform_kwargs" in data_kwargs: + data_kwargs.pop("dataset_frame_transform_kwargs") + _, dataset_statistics = make_dataset_from_rlds(**data_kwargs, train=train) + dataset_sizes.append(dataset_statistics["num_transitions"]) + all_dataset_statistics[dataset_kwargs["name"]] = dataset_statistics + + # Get the indices of the "primary" datasets (i.e., datasets with sample_weight == 1.0) + primary_dataset_indices = np.array([idx for idx in range(len(sample_weights)) if sample_weights[idx] == 1.0]) + + # Balance and Normalize Weights + if balance_weights: + sample_weights = np.array(sample_weights) * np.array(dataset_sizes) + sample_weights = np.array(sample_weights) / np.sum(sample_weights) + pprint_data_mixture(dataset_kwargs_list, sample_weights) + + # Effective Dataset Length = Number of samples until each dataset has completed at least one epoch + # =>> Note :: Only counting the "primary" datasets (i.e., datasets with sample_weight == 1.0) + dataset_len = int((np.array(dataset_sizes) / sample_weights)[primary_dataset_indices].max()) + + # Allocate Threads based on Weights + threads_per_dataset = allocate_threads(traj_transform_threads, sample_weights) + reads_per_dataset = allocate_threads(traj_read_threads, sample_weights) + + overwatch.info("Threads per Dataset: %s", threads_per_dataset) + overwatch.info("Reads per Dataset: %s", reads_per_dataset) + + # Construct Datasets + overwatch.info("Constructing datasets...") + datasets = [] + for dataset_kwargs, threads, reads in zip( + dataset_kwargs_list, + threads_per_dataset, + reads_per_dataset, + ): + dataset_frame_transform_kwargs = ( + dataset_kwargs.pop("dataset_frame_transform_kwargs") + if "dataset_frame_transform_kwargs" in dataset_kwargs + else {} + ) + dataset, _ = make_dataset_from_rlds( + **dataset_kwargs, + train=train, + num_parallel_calls=threads, + num_parallel_reads=reads, + dataset_statistics=all_dataset_statistics[dataset_kwargs["name"]], + ) + dataset = apply_trajectory_transforms( + dataset.repeat(), + **traj_transform_kwargs, + num_parallel_calls=threads, + train=train, + ).flatten(num_parallel_calls=threads) + dataset = apply_per_dataset_frame_transforms(dataset, **dataset_frame_transform_kwargs) + datasets.append(dataset) + + # Interleave at the Frame Level + dataset: dl.DLataset = dl.DLataset.sample_from_datasets(datasets, sample_weights) + + # Validation =>> fix a single shuffle buffer of data and cache it in RAM; prevents gradual memory increase! + if not train: + dataset = dataset.take(shuffle_buffer_size).cache() + + # Shuffle the Dataset + # =>> IMPORTANT :: Shuffle AFTER .cache(), or else memory will still leak! + dataset = dataset.shuffle(shuffle_buffer_size) + + # Apply Frame Transforms + overwatch.info("Applying frame transforms on dataset...") + dataset = apply_frame_transforms(dataset, **frame_transform_kwargs, train=train) + + # [Contract] When training VLA Policies, we let the Collator handle Batching! + if batch_size is not None: + dataset = dataset.batch(batch_size) + + # Note =>> Seems to reduce memory usage without affecting speed? + dataset = dataset.with_ram_budget(1) + + # Save for Later + dataset.sample_weights = sample_weights + + return dataset, dataset_len, all_dataset_statistics diff --git a/prismatic/vla/datasets/rlds/obs_transforms.py b/prismatic/vla/datasets/rlds/obs_transforms.py new file mode 100644 index 0000000000000000000000000000000000000000..d28b07d241fa8f451c7e149cab32397c7f8bb505 --- /dev/null +++ b/prismatic/vla/datasets/rlds/obs_transforms.py @@ -0,0 +1,99 @@ +""" +obs_transforms.py + +Contains observation-level transforms used in the orca data pipeline. + +These transforms operate on the "observation" dictionary, and are applied at a per-frame level. +""" + +from typing import Dict, Tuple, Union + +import dlimp as dl +import tensorflow as tf +from absl import logging + + +# ruff: noqa: B023 +def augment(obs: Dict, seed: tf.Tensor, augment_kwargs: Union[Dict, Dict[str, Dict]]) -> Dict: + """Augments images, skipping padding images.""" + image_names = {key[6:] for key in obs if key.startswith("image_")} + + # "augment_order" is required in augment_kwargs, so if it's there, we can assume that the user has passed + # in a single augmentation dict (otherwise, we assume that the user has passed in a mapping from image + # name to augmentation dict) + if "augment_order" in augment_kwargs: + augment_kwargs = {name: augment_kwargs for name in image_names} + + for i, name in enumerate(image_names): + if name not in augment_kwargs: + continue + kwargs = augment_kwargs[name] + logging.debug(f"Augmenting image_{name} with kwargs {kwargs}") + obs[f"image_{name}"] = tf.cond( + obs["pad_mask_dict"][f"image_{name}"], + lambda: dl.transforms.augment_image( + obs[f"image_{name}"], + **kwargs, + seed=seed + i, # augment each image differently + ), + lambda: obs[f"image_{name}"], # skip padding images + ) + + return obs + + +def decode_and_resize( + obs: Dict, + resize_size: Union[Tuple[int, int], Dict[str, Tuple[int, int]]], + depth_resize_size: Union[Tuple[int, int], Dict[str, Tuple[int, int]]], +) -> Dict: + """Decodes images and depth images, and then optionally resizes them.""" + image_names = {key[6:] for key in obs if key.startswith("image_")} + depth_names = {key[6:] for key in obs if key.startswith("depth_")} + + if isinstance(resize_size, tuple): + resize_size = {name: resize_size for name in image_names} + if isinstance(depth_resize_size, tuple): + depth_resize_size = {name: depth_resize_size for name in depth_names} + + for name in image_names: + if name not in resize_size: + logging.warning( + f"No resize_size was provided for image_{name}. This will result in 1x1 " + "padding images, which may cause errors if you mix padding and non-padding images." + ) + image = obs[f"image_{name}"] + if image.dtype == tf.string: + if tf.strings.length(image) == 0: + # this is a padding image + image = tf.zeros((*resize_size.get(name, (1, 1)), 3), dtype=tf.uint8) + else: + image = tf.io.decode_image(image, expand_animations=False, dtype=tf.uint8) + elif image.dtype != tf.uint8: + raise ValueError(f"Unsupported image dtype: found image_{name} with dtype {image.dtype}") + if name in resize_size: + image = dl.transforms.resize_image(image, size=resize_size[name]) + obs[f"image_{name}"] = image + + for name in depth_names: + if name not in depth_resize_size: + logging.warning( + f"No depth_resize_size was provided for depth_{name}. This will result in 1x1 " + "padding depth images, which may cause errors if you mix padding and non-padding images." + ) + depth = obs[f"depth_{name}"] + + if depth.dtype == tf.string: + if tf.strings.length(depth) == 0: + depth = tf.zeros((*depth_resize_size.get(name, (1, 1)), 1), dtype=tf.float32) + else: + depth = tf.io.decode_image(depth, expand_animations=False, dtype=tf.float32)[..., 0] + elif depth.dtype != tf.float32: + raise ValueError(f"Unsupported depth dtype: found depth_{name} with dtype {depth.dtype}") + + if name in depth_resize_size: + depth = dl.transforms.resize_depth_image(depth, size=depth_resize_size[name]) + + obs[f"depth_{name}"] = depth + + return obs diff --git a/prismatic/vla/datasets/rlds/oxe/__init__.py b/prismatic/vla/datasets/rlds/oxe/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..1502ecb73d70c57c184e0c90e568b02a0fbd11de --- /dev/null +++ b/prismatic/vla/datasets/rlds/oxe/__init__.py @@ -0,0 +1,2 @@ +from .materialize import get_oxe_dataset_kwargs_and_weights +from .mixtures import OXE_NAMED_MIXTURES diff --git a/prismatic/vla/datasets/rlds/oxe/configs.py b/prismatic/vla/datasets/rlds/oxe/configs.py new file mode 100644 index 0000000000000000000000000000000000000000..95cf6fef3419424b95e501697e029fcbd2a73e05 --- /dev/null +++ b/prismatic/vla/datasets/rlds/oxe/configs.py @@ -0,0 +1,718 @@ +""" +configs.py + +Defines per-dataset configuration (kwargs) for each dataset in Open-X Embodiment. + +Configuration adopts the following structure: + image_obs_keys: + primary: primary external RGB + secondary: secondary external RGB + wrist: wrist RGB + + depth_obs_keys: + primary: primary external depth + secondary: secondary external depth + wrist: wrist depth + + # Always 8-dim =>> changes based on `StateEncoding` + state_obs_keys: + StateEncoding.POS_EULER: EEF XYZ (3) + Roll-Pitch-Yaw (3) + (1) + Gripper Open/Close (1) + StateEncoding.POS_QUAT: EEF XYZ (3) + Quaternion (4) + Gripper Open/Close (1) + StateEncoding.JOINT: Joint Angles (7, if fewer) + Gripper Open/Close (1) + + state_encoding: Type of `StateEncoding` + action_encoding: Type of action encoding (e.g., EEF Position vs. Joint Position) +""" + +from enum import IntEnum + +from prismatic.vla.datasets.rlds.oxe.utils.droid_utils import zero_action_filter + + +# Defines Proprioceptive State Encoding Schemes +class StateEncoding(IntEnum): + # fmt: off + NONE = -1 # No Proprioceptive State + POS_EULER = 1 # EEF XYZ (3) + Roll-Pitch-Yaw (3) + (1) + Gripper Open/Close (1) + POS_QUAT = 2 # EEF XYZ (3) + Quaternion (4) + Gripper Open/Close (1) + JOINT = 3 # Joint Angles (7, if fewer) + Gripper Open/Close (1) + JOINT_BIMANUAL = 4 # Joint Angles (2 x [ Joint Angles (6) + Gripper Open/Close (1) ]) + # fmt: on + + +# Defines Action Encoding Schemes +class ActionEncoding(IntEnum): + # fmt: off + EEF_POS = 1 # EEF Delta XYZ (3) + Roll-Pitch-Yaw (3) + Gripper Open/Close (1) + JOINT_POS = 2 # Joint Delta Position (7) + Gripper Open/Close (1) + JOINT_POS_BIMANUAL = 3 # Joint Delta Position (2 x [ Joint Delta Position (6) + Gripper Open/Close (1) ]) + EEF_R6 = 4 # EEF Delta XYZ (3) + R6 (6) + Gripper Open/Close (1) + JOINT_POS_ABS = 5 # Absolute Joint Position (7) + Gripper Open/Close (1) -- 8-dim total + # fmt: on + + +# === Individual Dataset Configs === +OXE_DATASET_CONFIGS = { + "fractal20220817_data": { + "image_obs_keys": {"primary": "image", "secondary": None, "wrist": None}, + "depth_obs_keys": {"primary": None, "secondary": None, "wrist": None}, + "state_obs_keys": ["base_pose_tool_reached", "gripper_closed"], + "state_encoding": StateEncoding.POS_QUAT, + "action_encoding": ActionEncoding.EEF_POS, + }, + "kuka": { + "image_obs_keys": {"primary": "image", "secondary": None, "wrist": None}, + "depth_obs_keys": {"primary": None, "secondary": None, "wrist": None}, + "state_obs_keys": [ + "clip_function_input/base_pose_tool_reached", + "gripper_closed", + ], + "state_encoding": StateEncoding.POS_QUAT, + "action_encoding": ActionEncoding.EEF_POS, + }, + "bridge_oxe": { # Version of Bridge V2 in Open X-Embodiment mixture + "image_obs_keys": {"primary": "image", "secondary": "image_1", "wrist": None}, + "depth_obs_keys": {"primary": None, "secondary": None, "wrist": None}, + "state_obs_keys": ["EEF_state", "gripper_state"], + "state_encoding": StateEncoding.POS_EULER, + "action_encoding": ActionEncoding.EEF_POS, + }, + "bridge_orig": { # Original version of Bridge V2 from project website + "image_obs_keys": {"primary": "image_0", "secondary": "image_1", "wrist": None}, + "depth_obs_keys": {"primary": None, "secondary": None, "wrist": None}, + "state_obs_keys": ["EEF_state", "gripper_state"], + "state_encoding": StateEncoding.POS_EULER, + "action_encoding": ActionEncoding.EEF_POS, + }, + "bridge_dataset": { # Original version of Bridge V2 from project website + "image_obs_keys": {"primary": "image_0", "secondary": "image_1", "wrist": None}, + "depth_obs_keys": {"primary": None, "secondary": None, "wrist": None}, + "state_obs_keys": ["EEF_state", "gripper_state"], + "state_encoding": StateEncoding.POS_EULER, + "action_encoding": ActionEncoding.EEF_POS, + }, + "taco_play": { + "image_obs_keys": { + "primary": "rgb_static", + "secondary": None, + "wrist": "rgb_gripper", + }, + "depth_obs_keys": { + "primary": "depth_static", + "secondary": None, + "wrist": "depth_gripper", + }, + "state_obs_keys": ["state_eef", None, "state_gripper"], + "state_encoding": StateEncoding.POS_EULER, + "action_encoding": ActionEncoding.EEF_POS, + }, + "jaco_play": { + "image_obs_keys": { + "primary": "image", + "secondary": None, + "wrist": "image_wrist", + }, + "depth_obs_keys": {"primary": None, "secondary": None, "wrist": None}, + "state_obs_keys": ["state_eef", None, "state_gripper"], + "state_encoding": StateEncoding.POS_EULER, + "action_encoding": ActionEncoding.EEF_POS, + }, + "berkeley_cable_routing": { + "image_obs_keys": { + "primary": "image", + "secondary": "top_image", + "wrist": "wrist45_image", + }, + "depth_obs_keys": {"primary": None, "secondary": None, "wrist": None}, + "state_obs_keys": ["robot_state", None], + "state_encoding": StateEncoding.JOINT, + "action_encoding": ActionEncoding.EEF_POS, + }, + "roboturk": { + "image_obs_keys": {"primary": "front_rgb", "secondary": None, "wrist": None}, + "depth_obs_keys": {"primary": None, "secondary": None, "wrist": None}, + "state_obs_keys": [None, None, None, None, None, None, None, None], + "state_encoding": StateEncoding.NONE, + "action_encoding": ActionEncoding.EEF_POS, + }, + "nyu_door_opening_surprising_effectiveness": { + "image_obs_keys": {"primary": None, "secondary": None, "wrist": "image"}, + "depth_obs_keys": {"primary": None, "secondary": None, "wrist": None}, + "state_obs_keys": [None, None, None, None, None, None, None, None], + "state_encoding": StateEncoding.NONE, + "action_encoding": ActionEncoding.EEF_POS, + }, + "viola": { + "image_obs_keys": { + "primary": "agentview_rgb", + "secondary": None, + "wrist": "eye_in_hand_rgb", + }, + "depth_obs_keys": {"primary": None, "secondary": None, "wrist": None}, + "state_obs_keys": ["joint_states", "gripper_states"], + "state_encoding": StateEncoding.JOINT, + "action_encoding": ActionEncoding.EEF_POS, + }, + "berkeley_autolab_ur5": { + "image_obs_keys": { + "primary": "image", + "secondary": None, + "wrist": "hand_image", + }, + "depth_obs_keys": {"primary": "depth", "secondary": None, "wrist": None}, + "state_obs_keys": ["state"], + "state_encoding": StateEncoding.POS_QUAT, + "action_encoding": ActionEncoding.EEF_POS, + }, + "toto": { + "image_obs_keys": {"primary": "image", "secondary": None, "wrist": None}, + "depth_obs_keys": {"primary": None, "secondary": None, "wrist": None}, + "state_obs_keys": ["state", None], + "state_encoding": StateEncoding.JOINT, + "action_encoding": ActionEncoding.EEF_POS, + }, + "language_table": { + "image_obs_keys": {"primary": "rgb", "secondary": None, "wrist": None}, + "depth_obs_keys": {"primary": None, "secondary": None, "wrist": None}, + "state_obs_keys": ["effector_translation", None, None, None, None, None, None], + "state_encoding": StateEncoding.POS_EULER, + "action_encoding": ActionEncoding.EEF_POS, + }, + "columbia_cairlab_pusht_real": { + "image_obs_keys": { + "primary": "image", + "secondary": None, + "wrist": "wrist_image", + }, + "depth_obs_keys": {"primary": None, "secondary": None, "wrist": None}, + "state_obs_keys": ["robot_state", None, None, None, None, None, None], + "state_encoding": StateEncoding.POS_EULER, + "action_encoding": ActionEncoding.EEF_POS, + }, + "stanford_kuka_multimodal_dataset_converted_externally_to_rlds": { + "image_obs_keys": {"primary": "image", "secondary": None, "wrist": None}, + "depth_obs_keys": {"primary": "depth_image", "secondary": None, "wrist": None}, + "state_obs_keys": ["ee_position", "ee_orientation", None], + "state_encoding": StateEncoding.POS_QUAT, + "action_encoding": ActionEncoding.EEF_POS, + }, + "nyu_rot_dataset_converted_externally_to_rlds": { + "image_obs_keys": {"primary": "image", "secondary": None, "wrist": None}, + "depth_obs_keys": {"primary": None, "secondary": None, "wrist": None}, + "state_obs_keys": ["EEF_state", "gripper_state"], + "state_encoding": StateEncoding.POS_EULER, + "action_encoding": ActionEncoding.EEF_POS, + }, + "stanford_hydra_dataset_converted_externally_to_rlds": { + "image_obs_keys": { + "primary": "image", + "secondary": None, + "wrist": "wrist_image", + }, + "depth_obs_keys": {"primary": None, "secondary": None, "wrist": None}, + "state_obs_keys": ["EEF_state", "gripper_state"], + "state_encoding": StateEncoding.POS_EULER, + "action_encoding": ActionEncoding.EEF_POS, + }, + "austin_buds_dataset_converted_externally_to_rlds": { + "image_obs_keys": { + "primary": "image", + "secondary": None, + "wrist": "wrist_image", + }, + "depth_obs_keys": {"primary": None, "secondary": None, "wrist": None}, + "state_obs_keys": ["state"], + "state_encoding": StateEncoding.JOINT, + "action_encoding": ActionEncoding.EEF_POS, + }, + "nyu_franka_play_dataset_converted_externally_to_rlds": { + "image_obs_keys": { + "primary": "image", + "secondary": "image_additional_view", + "wrist": None, + }, + "depth_obs_keys": { + "primary": "depth", + "secondary": "depth_additional_view", + "wrist": None, + }, + "state_obs_keys": ["eef_state", None, None], + "state_encoding": StateEncoding.POS_EULER, + "action_encoding": ActionEncoding.EEF_POS, + }, + "maniskill_dataset_converted_externally_to_rlds": { + "image_obs_keys": { + "primary": "image", + "secondary": None, + "wrist": "wrist_image", + }, + "depth_obs_keys": { + "primary": "depth", + "secondary": None, + "wrist": "wrist_depth", + }, + "state_obs_keys": ["tcp_pose", "gripper_state"], + "state_encoding": StateEncoding.POS_QUAT, + "action_encoding": ActionEncoding.EEF_POS, + }, + "furniture_bench_dataset_converted_externally_to_rlds": { + "image_obs_keys": { + "primary": "image", + "secondary": None, + "wrist": "wrist_image", + }, + "depth_obs_keys": {"primary": None, "secondary": None, "wrist": None}, + "state_obs_keys": ["state"], + "state_encoding": StateEncoding.POS_QUAT, + "action_encoding": ActionEncoding.EEF_POS, + }, + "cmu_franka_exploration_dataset_converted_externally_to_rlds": { + "image_obs_keys": { + "primary": "highres_image", + "secondary": None, + "wrist": None, + }, + "depth_obs_keys": {"primary": None, "secondary": None, "wrist": None}, + "state_obs_keys": [None, None, None, None, None, None, None, None], + "state_encoding": StateEncoding.NONE, + "action_encoding": ActionEncoding.EEF_POS, + }, + "ucsd_kitchen_dataset_converted_externally_to_rlds": { + "image_obs_keys": {"primary": "image", "secondary": None, "wrist": None}, + "depth_obs_keys": {"primary": None, "secondary": None, "wrist": None}, + "state_obs_keys": ["joint_state", None], + "state_encoding": StateEncoding.JOINT, + "action_encoding": ActionEncoding.EEF_POS, + }, + "ucsd_pick_and_place_dataset_converted_externally_to_rlds": { + "image_obs_keys": {"primary": "image", "secondary": None, "wrist": None}, + "depth_obs_keys": {"primary": None, "secondary": None, "wrist": None}, + "state_obs_keys": ["EEF_state", "gripper_state"], + "state_encoding": StateEncoding.POS_EULER, + "action_encoding": ActionEncoding.EEF_POS, + }, + "austin_sailor_dataset_converted_externally_to_rlds": { + "image_obs_keys": { + "primary": "image", + "secondary": None, + "wrist": "wrist_image", + }, + "depth_obs_keys": {"primary": None, "secondary": None, "wrist": None}, + "state_obs_keys": ["state"], + "state_encoding": StateEncoding.POS_QUAT, + "action_encoding": ActionEncoding.EEF_POS, + }, + "austin_sirius_dataset_converted_externally_to_rlds": { + "image_obs_keys": { + "primary": "image", + "secondary": None, + "wrist": "wrist_image", + }, + "depth_obs_keys": {"primary": None, "secondary": None, "wrist": None}, + "state_obs_keys": ["state"], + "state_encoding": StateEncoding.POS_QUAT, + "action_encoding": ActionEncoding.EEF_POS, + }, + "bc_z": { + "image_obs_keys": {"primary": "image", "secondary": None, "wrist": None}, + "depth_obs_keys": {"primary": None, "secondary": None, "wrist": None}, + "state_obs_keys": [ + "present/xyz", + "present/axis_angle", + None, + "present/sensed_close", + ], + "state_encoding": StateEncoding.POS_EULER, + "action_encoding": ActionEncoding.EEF_POS, + }, + "utokyo_pr2_opening_fridge_converted_externally_to_rlds": { + "image_obs_keys": {"primary": "image", "secondary": None, "wrist": None}, + "depth_obs_keys": {"primary": None, "secondary": None, "wrist": None}, + "state_obs_keys": ["EEF_state", "gripper_state"], + "state_encoding": StateEncoding.POS_EULER, + "action_encoding": ActionEncoding.EEF_POS, + }, + "utokyo_pr2_tabletop_manipulation_converted_externally_to_rlds": { + "image_obs_keys": {"primary": "image", "secondary": None, "wrist": None}, + "depth_obs_keys": {"primary": None, "secondary": None, "wrist": None}, + "state_obs_keys": ["EEF_state", "gripper_state"], + "state_encoding": StateEncoding.POS_EULER, + "action_encoding": ActionEncoding.EEF_POS, + }, + "utokyo_xarm_pick_and_place_converted_externally_to_rlds": { + "image_obs_keys": { + "primary": "image", + "secondary": "image2", + "wrist": "hand_image", + }, + "depth_obs_keys": {"primary": None, "secondary": None, "wrist": None}, + "state_obs_keys": ["end_effector_pose", None, None], + "state_encoding": StateEncoding.POS_EULER, + "action_encoding": ActionEncoding.EEF_POS, + }, + "utokyo_xarm_bimanual_converted_externally_to_rlds": { + "image_obs_keys": {"primary": "image", "secondary": None, "wrist": None}, + "depth_obs_keys": {"primary": None, "secondary": None, "wrist": None}, + "state_obs_keys": ["pose_r", None, None], + "state_encoding": StateEncoding.POS_EULER, + "action_encoding": ActionEncoding.EEF_POS, + }, + "robo_net": { + "image_obs_keys": {"primary": "image", "secondary": "image1", "wrist": None}, + "depth_obs_keys": {"primary": None, "secondary": None, "wrist": None}, + "state_obs_keys": ["EEF_state", "gripper_state"], + "state_encoding": StateEncoding.POS_EULER, + "action_encoding": ActionEncoding.EEF_POS, + }, + "berkeley_mvp_converted_externally_to_rlds": { + "image_obs_keys": {"primary": None, "secondary": None, "wrist": "hand_image"}, + "depth_obs_keys": {"primary": None, "secondary": None, "wrist": None}, + "state_obs_keys": ["pose", "gripper"], + "state_encoding": StateEncoding.POS_QUAT, + "action_encoding": ActionEncoding.JOINT_POS, + }, + "berkeley_rpt_converted_externally_to_rlds": { + "image_obs_keys": {"primary": None, "secondary": None, "wrist": "hand_image"}, + "depth_obs_keys": {"primary": None, "secondary": None, "wrist": None}, + "state_obs_keys": ["joint_pos", "gripper"], + "state_encoding": StateEncoding.JOINT, + "action_encoding": ActionEncoding.JOINT_POS, + }, + "kaist_nonprehensile_converted_externally_to_rlds": { + "image_obs_keys": {"primary": "image", "secondary": None, "wrist": None}, + "depth_obs_keys": {"primary": None, "secondary": None, "wrist": None}, + "state_obs_keys": ["state", None], + "state_encoding": StateEncoding.POS_QUAT, + "action_encoding": ActionEncoding.EEF_POS, + }, + "stanford_mask_vit_converted_externally_to_rlds": { + "image_obs_keys": {"primary": "image", "secondary": None, "wrist": None}, + "depth_obs_keys": {"primary": None, "secondary": None, "wrist": None}, + "state_obs_keys": ["EEF_state", "gripper_state"], + "state_encoding": StateEncoding.POS_EULER, + "action_encoding": ActionEncoding.EEF_POS, + }, + "tokyo_u_lsmo_converted_externally_to_rlds": { + "image_obs_keys": {"primary": "image", "secondary": None, "wrist": None}, + "depth_obs_keys": {"primary": None, "secondary": None, "wrist": None}, + "state_obs_keys": ["EEF_state", "gripper_state"], + "state_encoding": StateEncoding.POS_EULER, + "action_encoding": ActionEncoding.EEF_POS, + }, + "dlr_sara_pour_converted_externally_to_rlds": { + "image_obs_keys": {"primary": "image", "secondary": None, "wrist": None}, + "depth_obs_keys": {"primary": None, "secondary": None, "wrist": None}, + "state_obs_keys": ["state", None, None], + "state_encoding": StateEncoding.POS_EULER, + "action_encoding": ActionEncoding.EEF_POS, + }, + "dlr_sara_grid_clamp_converted_externally_to_rlds": { + "image_obs_keys": {"primary": "image", "secondary": None, "wrist": None}, + "depth_obs_keys": {"primary": None, "secondary": None, "wrist": None}, + "state_obs_keys": ["state", None, None], + "state_encoding": StateEncoding.POS_EULER, + "action_encoding": ActionEncoding.EEF_POS, + }, + "dlr_edan_shared_control_converted_externally_to_rlds": { + "image_obs_keys": {"primary": "image", "secondary": None, "wrist": None}, + "depth_obs_keys": {"primary": None, "secondary": None, "wrist": None}, + "state_obs_keys": ["state", None], + "state_encoding": StateEncoding.POS_EULER, + "action_encoding": ActionEncoding.EEF_POS, + }, + "asu_table_top_converted_externally_to_rlds": { + "image_obs_keys": {"primary": "image", "secondary": None, "wrist": None}, + "depth_obs_keys": {"primary": None, "secondary": None, "wrist": None}, + "state_obs_keys": ["EEF_state", "gripper_state"], + "state_encoding": StateEncoding.POS_EULER, + "action_encoding": ActionEncoding.EEF_POS, + }, + "stanford_robocook_converted_externally_to_rlds": { + "image_obs_keys": {"primary": "image_1", "secondary": "image_2", "wrist": None}, + "depth_obs_keys": {"primary": "depth_1", "secondary": "depth_2", "wrist": None}, + "state_obs_keys": ["EEF_state", "gripper_state"], + "state_encoding": StateEncoding.POS_EULER, + "action_encoding": ActionEncoding.EEF_POS, + }, + "imperialcollege_sawyer_wrist_cam": { + "image_obs_keys": { + "primary": "image", + "secondary": None, + "wrist": "wrist_image", + }, + "depth_obs_keys": {"primary": None, "secondary": None, "wrist": None}, + "state_obs_keys": [None, None, None, None, None, None, None, "state"], + "state_encoding": StateEncoding.NONE, + "action_encoding": ActionEncoding.EEF_POS, + }, + "iamlab_cmu_pickup_insert_converted_externally_to_rlds": { + "image_obs_keys": { + "primary": "image", + "secondary": None, + "wrist": "wrist_image", + }, + "depth_obs_keys": {"primary": None, "secondary": None, "wrist": None}, + "state_obs_keys": ["joint_state", "gripper_state"], + "state_encoding": StateEncoding.JOINT, + "action_encoding": ActionEncoding.EEF_POS, + }, + "uiuc_d3field": { + "image_obs_keys": {"primary": "image_1", "secondary": "image_2", "wrist": None}, + "depth_obs_keys": {"primary": "depth_1", "secondary": "depth_2", "wrist": None}, + "state_obs_keys": [None, None, None, None, None, None, None, None], + "state_encoding": StateEncoding.NONE, + "action_encoding": ActionEncoding.EEF_POS, + }, + "utaustin_mutex": { + "image_obs_keys": { + "primary": "image", + "secondary": None, + "wrist": "wrist_image", + }, + "depth_obs_keys": {"primary": None, "secondary": None, "wrist": None}, + "state_obs_keys": ["state"], + "state_encoding": StateEncoding.JOINT, + "action_encoding": ActionEncoding.EEF_POS, + }, + "berkeley_fanuc_manipulation": { + "image_obs_keys": { + "primary": "image", + "secondary": None, + "wrist": "wrist_image", + }, + "depth_obs_keys": {"primary": None, "secondary": None, "wrist": None}, + "state_obs_keys": ["joint_state", None, "gripper_state"], + "state_encoding": StateEncoding.JOINT, + "action_encoding": ActionEncoding.EEF_POS, + }, + "cmu_playing_with_food": { + "image_obs_keys": { + "primary": "image", + "secondary": None, + "wrist": "finger_vision_1", + }, + "depth_obs_keys": {"primary": None, "secondary": None, "wrist": None}, + "state_obs_keys": ["state", None, None], + "state_encoding": StateEncoding.POS_EULER, + "action_encoding": ActionEncoding.EEF_POS, + }, + "cmu_play_fusion": { + "image_obs_keys": {"primary": "image", "secondary": None, "wrist": None}, + "depth_obs_keys": {"primary": None, "secondary": None, "wrist": None}, + "state_obs_keys": ["state"], + "state_encoding": StateEncoding.JOINT, + "action_encoding": ActionEncoding.EEF_POS, + }, + "cmu_stretch": { + "image_obs_keys": {"primary": "image", "secondary": None, "wrist": None}, + "depth_obs_keys": {"primary": None, "secondary": None, "wrist": None}, + "state_obs_keys": ["EEF_state", "gripper_state"], + "state_encoding": StateEncoding.POS_EULER, + "action_encoding": ActionEncoding.EEF_POS, + }, + "berkeley_gnm_recon": { + "image_obs_keys": {"primary": None, "secondary": None, "wrist": "image"}, + "depth_obs_keys": {"primary": None, "secondary": None, "wrist": None}, + "state_obs_keys": ["state", None, None], + "state_encoding": StateEncoding.POS_EULER, + "action_encoding": ActionEncoding.EEF_POS, + }, + "berkeley_gnm_cory_hall": { + "image_obs_keys": {"primary": None, "secondary": None, "wrist": "image"}, + "depth_obs_keys": {"primary": None, "secondary": None, "wrist": None}, + "state_obs_keys": ["state", None, None], + "state_encoding": StateEncoding.POS_EULER, + "action_encoding": ActionEncoding.EEF_POS, + }, + "berkeley_gnm_sac_son": { + "image_obs_keys": {"primary": None, "secondary": None, "wrist": "image"}, + "depth_obs_keys": {"primary": None, "secondary": None, "wrist": None}, + "state_obs_keys": ["state", None, None], + "state_encoding": StateEncoding.POS_EULER, + "action_encoding": ActionEncoding.EEF_POS, + }, + "droid": { + "image_obs_keys": { + "primary": "exterior_image_1_left", + "secondary": "exterior_image_2_left", + "wrist": "wrist_image_left", + }, + "depth_obs_keys": {"primary": None, "secondary": None, "wrist": None}, + "state_obs_keys": ["proprio"], + "state_encoding": StateEncoding.POS_QUAT, + "action_encoding": ActionEncoding.EEF_POS, + "aux_kwargs": { + "dataset_frame_transform_kwargs": { + "chunk_filter_fn": zero_action_filter, + }, + }, + }, + "fmb_dataset": { + "image_obs_keys": { + "primary": "image_side_1", + "secondary": "image_side_2", + "wrist": "image_wrist_1", + }, + "depth_obs_keys": { + "primary": "image_side_1_depth", + "secondary": "image_side_2_depth", + "wrist": "image_wrist_1_depth", + }, + "state_obs_keys": ["proprio"], + "state_encoding": StateEncoding.POS_EULER, + "action_encoding": ActionEncoding.EEF_POS, + }, + "dobbe": { + "image_obs_keys": {"primary": "wrist_image", "secondary": None, "wrist": None}, + "depth_obs_keys": {"primary": None, "secondary": None, "wrist": None}, + "state_obs_keys": ["proprio"], + "state_encoding": StateEncoding.POS_EULER, + "action_encoding": ActionEncoding.EEF_POS, + }, + "roboset": { + "image_obs_keys": { + "primary": "image_left", + "secondary": "image_right", + "wrist": "image_wrist", + }, + "depth_obs_keys": {"primary": None, "secondary": None, "wrist": None}, + "state_obs_keys": ["proprio"], + "state_encoding": StateEncoding.JOINT, + "action_encoding": ActionEncoding.JOINT_POS, + }, + "rh20t": { + "image_obs_keys": { + "primary": "image_front", + "secondary": "image_side_right", + "wrist": "image_wrist", + }, + "depth_obs_keys": {"primary": None, "secondary": None, "wrist": None}, + "state_obs_keys": ["proprio"], + "state_encoding": StateEncoding.POS_EULER, + "action_encoding": ActionEncoding.EEF_POS, + }, + ### T-DROID datasets + "tdroid_carrot_in_bowl": { # "put carrot in bowl" task, 50 demos @ 5 Hz control + "image_obs_keys": {"primary": "static_image", "secondary": None, "wrist": None}, + "depth_obs_keys": {"primary": "static_depth_image", "secondary": None, "wrist": None}, + "state_obs_keys": ["EEF_state", "gripper_state"], + "state_encoding": StateEncoding.POS_EULER, + "action_encoding": ActionEncoding.EEF_POS, + }, + "tdroid_pour_corn_in_pot": { # "pour corn from red bowl into steel pot" task, 50 demos @ 5 Hz control + "image_obs_keys": {"primary": "static_image", "secondary": None, "wrist": None}, + "depth_obs_keys": {"primary": "static_depth_image", "secondary": None, "wrist": None}, + "state_obs_keys": ["EEF_state", "gripper_state"], + "state_encoding": StateEncoding.POS_EULER, + "action_encoding": ActionEncoding.EEF_POS, + }, + "tdroid_flip_pot_upright": { # "flip pot upright" task, 10 demos @ 5 Hz control + "image_obs_keys": {"primary": "static_image", "secondary": None, "wrist": None}, + "depth_obs_keys": {"primary": "static_depth_image", "secondary": None, "wrist": None}, + "state_obs_keys": ["EEF_state", "gripper_state"], + "state_encoding": StateEncoding.POS_EULER, + "action_encoding": ActionEncoding.EEF_POS, + }, + "tdroid_move_object_onto_plate": { # "move onto plate" task, 150 demos @ 5 Hz control + "image_obs_keys": {"primary": "static_image", "secondary": None, "wrist": None}, + "depth_obs_keys": {"primary": "static_depth_image", "secondary": None, "wrist": None}, + "state_obs_keys": ["EEF_state", "gripper_state"], + "state_encoding": StateEncoding.POS_EULER, + "action_encoding": ActionEncoding.EEF_POS, + }, + "tdroid_knock_object_over": { # "knock over" task, 70 demos @ 5 Hz control + "image_obs_keys": {"primary": "static_image", "secondary": None, "wrist": None}, + "depth_obs_keys": {"primary": "static_depth_image", "secondary": None, "wrist": None}, + "state_obs_keys": ["EEF_state", "gripper_state"], + "state_encoding": StateEncoding.POS_EULER, + "action_encoding": ActionEncoding.EEF_POS, + }, + "tdroid_cover_object_with_towel": { # "cover with towel" task, 45 demos @ 5 Hz control + "image_obs_keys": {"primary": "static_image", "secondary": None, "wrist": None}, + "depth_obs_keys": {"primary": "static_depth_image", "secondary": None, "wrist": None}, + "state_obs_keys": ["EEF_state", "gripper_state"], + "state_encoding": StateEncoding.POS_EULER, + "action_encoding": ActionEncoding.EEF_POS, + }, + ### DROID Finetuning datasets + "droid_wipe": { + "image_obs_keys": {"primary": "exterior_image_2_left", "secondary": None, "wrist": "wrist_image_left"}, + "depth_obs_keys": {"primary": None, "secondary": None, "wrist": None}, + "state_obs_keys": ["proprio"], + "state_encoding": StateEncoding.POS_EULER, + "action_encoding": ActionEncoding.EEF_POS, + }, + ### LIBERO datasets (modified versions) + "libero_spatial_no_noops": { + "image_obs_keys": {"primary": "image", "secondary": None, "wrist": "wrist_image"}, + "depth_obs_keys": {"primary": None, "secondary": None, "wrist": None}, + "state_obs_keys": ["EEF_state", "gripper_state"], + "state_encoding": StateEncoding.POS_EULER, + "action_encoding": ActionEncoding.EEF_POS, + }, + "libero_object_no_noops": { + "image_obs_keys": {"primary": "image", "secondary": None, "wrist": "wrist_image"}, + "depth_obs_keys": {"primary": None, "secondary": None, "wrist": None}, + "state_obs_keys": ["EEF_state", "gripper_state"], + "state_encoding": StateEncoding.POS_EULER, + "action_encoding": ActionEncoding.EEF_POS, + }, + "libero_goal_no_noops": { + "image_obs_keys": {"primary": "image", "secondary": None, "wrist": "wrist_image"}, + "depth_obs_keys": {"primary": None, "secondary": None, "wrist": None}, + "state_obs_keys": ["EEF_state", "gripper_state"], + "state_encoding": StateEncoding.POS_EULER, + "action_encoding": ActionEncoding.EEF_POS, + }, + "libero_10_no_noops": { + "image_obs_keys": {"primary": "image", "secondary": None, "wrist": "wrist_image"}, + "depth_obs_keys": {"primary": None, "secondary": None, "wrist": None}, + "state_obs_keys": ["EEF_state", "gripper_state"], + "state_encoding": StateEncoding.POS_EULER, + "action_encoding": ActionEncoding.EEF_POS, + }, + "libero_4_task_suites_no_noops": { + "image_obs_keys": {"primary": "image", "secondary": None, "wrist": "wrist_image"}, + "depth_obs_keys": {"primary": None, "secondary": None, "wrist": None}, + "state_obs_keys": ["EEF_state", "gripper_state"], + "state_encoding": StateEncoding.POS_EULER, + "action_encoding": ActionEncoding.EEF_POS, + }, + ### ALOHA fine-tuning datasets + "aloha1_fold_shorts_20_demos": { + "image_obs_keys": {"primary": "image", "secondary": None, "left_wrist": "left_wrist_image", "right_wrist": "right_wrist_image"}, + "depth_obs_keys": {"primary": None, "secondary": None, "wrist": None}, + "state_obs_keys": ["state"], + "state_encoding": StateEncoding.JOINT_BIMANUAL, + "action_encoding": ActionEncoding.JOINT_POS_BIMANUAL, + }, + "aloha1_fold_shirt_30_demos": { + "image_obs_keys": {"primary": "image", "secondary": None, "left_wrist": "left_wrist_image", "right_wrist": "right_wrist_image"}, + "depth_obs_keys": {"primary": None, "secondary": None, "wrist": None}, + "state_obs_keys": ["state"], + "state_encoding": StateEncoding.JOINT_BIMANUAL, + "action_encoding": ActionEncoding.JOINT_POS_BIMANUAL, + }, + "aloha1_scoop_X_into_bowl_45_demos": { + "image_obs_keys": {"primary": "image", "secondary": None, "left_wrist": "left_wrist_image", "right_wrist": "right_wrist_image"}, + "depth_obs_keys": {"primary": None, "secondary": None, "wrist": None}, + "state_obs_keys": ["state"], + "state_encoding": StateEncoding.JOINT_BIMANUAL, + "action_encoding": ActionEncoding.JOINT_POS_BIMANUAL, + }, + "aloha1_put_X_into_pot_300_demos": { + "image_obs_keys": {"primary": "image", "secondary": None, "left_wrist": "left_wrist_image", "right_wrist": "right_wrist_image"}, + "depth_obs_keys": {"primary": None, "secondary": None, "wrist": None}, + "state_obs_keys": ["state"], + "state_encoding": StateEncoding.JOINT_BIMANUAL, + "action_encoding": ActionEncoding.JOINT_POS_BIMANUAL, + }, + ### ManiSkill conflict fine-tuning datasets (Panda, absolute joint + gripper) + "conflict_maniskill": { + "image_obs_keys": {"primary": "image", "secondary": None, "wrist": "wrist_image"}, + "depth_obs_keys": {"primary": None, "secondary": None, "wrist": None}, + "state_obs_keys": ["joint_state", "gripper_state"], + "state_encoding": StateEncoding.JOINT, + "action_encoding": ActionEncoding.JOINT_POS_ABS, + }, +} diff --git a/prismatic/vla/datasets/rlds/oxe/materialize.py b/prismatic/vla/datasets/rlds/oxe/materialize.py new file mode 100644 index 0000000000000000000000000000000000000000..cf263dc21f0da701c9020ae3df91648048190854 --- /dev/null +++ b/prismatic/vla/datasets/rlds/oxe/materialize.py @@ -0,0 +1,139 @@ +""" +materialize.py + +Factory class for initializing Open-X Embodiment dataset kwargs and other parameters; provides and exports functions for +clear control flow. +""" + +from copy import deepcopy +from pathlib import Path +from typing import Any, Dict, List, Tuple + +from prismatic.overwatch import initialize_overwatch +from prismatic.vla.constants import ACTION_DIM, ACTION_PROPRIO_NORMALIZATION_TYPE, ACTION_TOKEN_BEGIN_IDX, IGNORE_INDEX, NUM_ACTIONS_CHUNK, PROPRIO_DIM, STOP_INDEX +from prismatic.vla.datasets.rlds.oxe.configs import OXE_DATASET_CONFIGS, ActionEncoding +from prismatic.vla.datasets.rlds.oxe.transforms import OXE_STANDARDIZATION_TRANSFORMS + +# Initialize Overwatch =>> Wraps `logging.Logger` +overwatch = initialize_overwatch(__name__) + + +def make_oxe_dataset_kwargs( + dataset_name: str, + data_root_dir: Path, + load_camera_views: Tuple[str] = ("primary",), + load_depth: bool = False, + load_proprio: bool = True, + load_language: bool = True, + action_proprio_normalization_type = ACTION_PROPRIO_NORMALIZATION_TYPE, +) -> Dict[str, Any]: + """Generates config (kwargs) for given dataset from Open-X Embodiment.""" + dataset_kwargs = deepcopy(OXE_DATASET_CONFIGS[dataset_name]) + if dataset_kwargs["action_encoding"] not in [ActionEncoding.EEF_POS, ActionEncoding.EEF_R6, ActionEncoding.JOINT_POS_BIMANUAL, ActionEncoding.JOINT_POS_ABS]: + raise ValueError(f"Cannot load `{dataset_name}`; only EEF_POS, EEF_R6, JOINT_POS_BIMANUAL, JOINT_POS_ABS actions supported!") + + # [Contract] For EEF_POS & EEF_R6 actions, only the last action dimension (gripper) is absolute! + # Normalize all action dimensions *except* the gripper + if dataset_kwargs["action_encoding"] is ActionEncoding.EEF_POS: + dataset_kwargs["absolute_action_mask"] = [False] * 6 + [True] + dataset_kwargs["action_normalization_mask"] = [True] * 6 + [False] + elif dataset_kwargs["action_encoding"] is ActionEncoding.EEF_R6: + dataset_kwargs["absolute_action_mask"] = [False] * 9 + [True] + dataset_kwargs["action_normalization_mask"] = [True] * 9 + [False] + elif dataset_kwargs["action_encoding"] is ActionEncoding.JOINT_POS_BIMANUAL: + dataset_kwargs["absolute_action_mask"] = [True] * 14 + dataset_kwargs["action_normalization_mask"] = [True] * 14 + elif dataset_kwargs["action_encoding"] is ActionEncoding.JOINT_POS_ABS: + # All 8 dims are absolute (joint0..6 absolute positions + gripper open 0/1). + # Normalize joints (0..6); skip gripper (dim 7) since it is already in {0, 1}. + dataset_kwargs["absolute_action_mask"] = [True] * 8 + dataset_kwargs["action_normalization_mask"] = [True] * 7 + [False] + dataset_kwargs["action_proprio_normalization_type"] = action_proprio_normalization_type + + # Adjust Loaded Camera Views + if len(missing_keys := (set(load_camera_views) - set(dataset_kwargs["image_obs_keys"]))) > 0: + raise ValueError(f"Cannot load `{dataset_name}`; missing camera views `{missing_keys}`") + + # Filter + dataset_kwargs["image_obs_keys"] = { + k: v for k, v in dataset_kwargs["image_obs_keys"].items() if k in load_camera_views + } + dataset_kwargs["depth_obs_keys"] = { + k: v for k, v in dataset_kwargs["depth_obs_keys"].items() if k in load_camera_views + } + + # Eliminate Unnecessary Keys + dataset_kwargs.pop("state_encoding") + dataset_kwargs.pop("action_encoding") + if not load_depth: + dataset_kwargs.pop("depth_obs_keys") + if not load_proprio: + dataset_kwargs.pop("state_obs_keys") + + # Load Language + if load_language: + dataset_kwargs["language_key"] = "language_instruction" + + # Specify Standardization Transform + dataset_kwargs["standardize_fn"] = OXE_STANDARDIZATION_TRANSFORMS[dataset_name] + + # Add any aux arguments + if "aux_kwargs" in dataset_kwargs: + dataset_kwargs.update(dataset_kwargs.pop("aux_kwargs")) + + return {"name": dataset_name, "data_dir": str(data_root_dir), **dataset_kwargs} + + +def get_oxe_dataset_kwargs_and_weights( + data_root_dir: Path, + mixture_spec: List[Tuple[str, float]], + load_camera_views: Tuple[str] = ("primary",), + load_depth: bool = False, + load_proprio: bool = True, + load_language: bool = True, + action_proprio_normalization_type = ACTION_PROPRIO_NORMALIZATION_TYPE, +) -> Tuple[Dict[str, Any], List[float]]: + """ + Generates dataset kwargs for a given dataset mix from the Open X-Embodiment dataset. The returned kwargs + (per-dataset configs) and weights can be passed directly to `make_interleaved_dataset`. + + :param data_root_dir: Base directory containing RLDS/TFDS-formatted datasets (from Open-X) + :param mixture_spec: List of (dataset_name, sampling_weight) from `oxe.mixtures.OXE_NAMED_MIXTURES` + :param load_camera_views: Camera views to load; see `oxe.dataset_configs.py` for available views. + :param load_depth: Load depth information in addition to camera RGB. + :param load_proprio: Load proprioceptive state. + :param load_language: Load language instructions. + :param action_proprio_normalization_type: Normalization scheme to use for proprioceptive actions. + + return: Tuple of (per_dataset_kwargs, sampling_weights) + """ + included_datasets, filtered_mixture_spec = set(), [] + for d_name, d_weight in mixture_spec: + if d_name in included_datasets: + overwatch.warning(f"Skipping Duplicate Dataset: `{(d_name, d_weight)}`") + continue + + included_datasets.add(d_name) + filtered_mixture_spec.append((d_name, d_weight)) + + # Assemble Dataset Config (kwargs) and Weights + per_dataset_kwargs, sampling_weights = [], [] + for d_name, d_weight in filtered_mixture_spec: + try: + per_dataset_kwargs.append( + make_oxe_dataset_kwargs( + d_name, + data_root_dir, + load_camera_views, + load_depth, + load_proprio, + load_language, + action_proprio_normalization_type, + ) + ) + sampling_weights.append(d_weight) + + except ValueError as e: + overwatch.warning(f"Skipping `{d_name}` due to Error: {e}") + + return per_dataset_kwargs, sampling_weights diff --git a/prismatic/vla/datasets/rlds/oxe/mixtures.py b/prismatic/vla/datasets/rlds/oxe/mixtures.py new file mode 100644 index 0000000000000000000000000000000000000000..c5a2862fdf6bd2cd69232239f06b9f97caf0cad1 --- /dev/null +++ b/prismatic/vla/datasets/rlds/oxe/mixtures.py @@ -0,0 +1,230 @@ +""" +mixtures.py + +Defines a registry of dataset mixtures and weights for the Open-X Embodiment Datasets. Each dataset is associated with +a float "sampling weight" +""" + +from typing import Dict, List, Tuple + +# fmt: off +OXE_NAMED_MIXTURES: Dict[str, List[Tuple[str, float]]] = { + # === Bridge V2 Dataset === + "bridge": [ + # ("bridge_oxe", 1.0), # Version of Bridge V2 in Open-X GCP Bucket + ("bridge_orig", 1.0), # Original Version of Bridge V2 from Project Website + ], + + + # === [Moderate-Scale] Bridge++ Mixtures === + "bridge_rt_1": [ + # ("bridge_oxe", 1.0) # Version of Bridge V2 in Open-X GCP Bucket + ("bridge_orig", 1.0), # Original Version of Bridge V2 from Project Website + + ("fractal20220817_data", 1.0), # Google RT-1 Robot Data (Large-Scale) + ], + + # === RT-X Mixtures === + "rtx": [ + ("fractal20220817_data", 0.54087122203), # Google RT-1 Robot Data (Large-Scale) + ("kuka", 0.8341046294), + # ("bridge_oxe", 1.0) # Version of Bridge V2 in Open-X GCP Bucket + ("bridge_orig", 1.0), # Original Version of Bridge V2 from Project Website + ("taco_play", 2.0), + ("jaco_play", 2.0), + ("berkeley_cable_routing", 3.0), + ("roboturk", 1.0), + # ("nyu_door_opening_surprising_effectiveness", 5.0), # Note --> only contains wrist camera images (skip?) + ("viola", 2.0), + ("berkeley_autolab_ur5", 1.0), + ("toto", 1.0), + ], + + "rtx_franka": [ + ("fractal20220817_data", 0.54087122203), # Google RT-1 Robot Data (Large-Scale) + ("kuka", 0.8341046294), + # ("bridge_oxe", 1.0) # Version of Bridge V2 in Open-X GCP Bucket + ("bridge_orig", 1.0), # Original Version of Bridge V2 from Project Website + ("taco_play", 2.0), + ("jaco_play", 2.0), + ("berkeley_cable_routing", 3.0), + ("roboturk", 1.0), + # ("nyu_door_opening_surprising_effectiveness", 5.0), # Note --> only contains wrist camera images (skip?) + ("viola", 2.0), + ("berkeley_autolab_ur5", 1.0), + ("toto", 1.0), + + ("taco_play", 1.0), + ("berkeley_cable_routing", 1.0), + ("viola", 1.0), + ("toto", 1.0), + ("stanford_hydra_dataset_converted_externally_to_rlds", 1.0), + ("austin_buds_dataset_converted_externally_to_rlds", 3.0), + ("nyu_franka_play_dataset_converted_externally_to_rlds", 3.0), + ("maniskill_dataset_converted_externally_to_rlds", 0.1), + ("furniture_bench_dataset_converted_externally_to_rlds", 0.1), + ("cmu_franka_exploration_dataset_converted_externally_to_rlds", 5.0), + ("austin_sailor_dataset_converted_externally_to_rlds", 1.0), + ("austin_sirius_dataset_converted_externally_to_rlds", 1.0), + ("berkeley_rpt_converted_externally_to_rlds", 1.0), + ("kaist_nonprehensile_converted_externally_to_rlds", 3.0), + ("stanford_robocook_converted_externally_to_rlds", 1.0), + ("iamlab_cmu_pickup_insert_converted_externally_to_rlds", 1.0), + ("utaustin_mutex", 1.0), + ("cmu_play_fusion", 1.0), + ], + + # === Open-X Magic Soup === + "oxe_magic_soup": [ + ("fractal20220817_data", 0.54087122203), # Google RT-1 Robot Data (Large-Scale) + ("kuka", 0.8341046294), + # ("bridge_oxe", 1.0) # Version of Bridge V2 in Open-X GCP Bucket + ("bridge_orig", 1.0), # Original Version of Bridge V2 from Project Website + ("taco_play", 2.0), + ("jaco_play", 1.0), + ("berkeley_cable_routing", 1.0), + ("roboturk", 2.0), + # ("nyu_door_opening_surprising_effectiveness", 1.0), # Note --> only contains wrist camera images (skip?) + ("viola", 2.0), + ("berkeley_autolab_ur5", 2.0), + ("toto", 1.0), + ("language_table", 0.1), + ("stanford_hydra_dataset_converted_externally_to_rlds", 2.0), + ("austin_buds_dataset_converted_externally_to_rlds", 1.0), + ("nyu_franka_play_dataset_converted_externally_to_rlds", 3.0), + ("furniture_bench_dataset_converted_externally_to_rlds", 0.1), + ("ucsd_kitchen_dataset_converted_externally_to_rlds", 2.0), + ("austin_sailor_dataset_converted_externally_to_rlds", 1.0), + ("austin_sirius_dataset_converted_externally_to_rlds", 1.0), + # ("bc_z", 0.2), # Note --> raw data is broken! + ("dlr_edan_shared_control_converted_externally_to_rlds", 1.0), + ("iamlab_cmu_pickup_insert_converted_externally_to_rlds", 1.0), + # ("uiuc_d3field", 1.0), # Note --> raw data is broken! + ("utaustin_mutex", 1.0), + ("berkeley_fanuc_manipulation", 2.0), + ("cmu_stretch", 1.0), + ], + + # === Open-X Magic Soup++ === + "oxe_magic_soup_plus": [ + ("fractal20220817_data", 0.54087122203), # Google RT-1 Robot Data (Large-Scale) + ("kuka", 0.8341046294), + ("bridge_orig", 1.0), # Original Version of Bridge V2 from Project Website + ("taco_play", 2.0), + ("jaco_play", 1.0), + ("berkeley_cable_routing", 1.0), + ("roboturk", 2.0), + ("viola", 2.0), + ("berkeley_autolab_ur5", 2.0), + ("toto", 1.0), + ("language_table", 0.1), + ("stanford_hydra_dataset_converted_externally_to_rlds", 2.0), + ("austin_buds_dataset_converted_externally_to_rlds", 1.0), + ("nyu_franka_play_dataset_converted_externally_to_rlds", 3.0), + ("furniture_bench_dataset_converted_externally_to_rlds", 0.1), + ("ucsd_kitchen_dataset_converted_externally_to_rlds", 2.0), + ("austin_sailor_dataset_converted_externally_to_rlds", 1.0), + ("austin_sirius_dataset_converted_externally_to_rlds", 1.0), + ("dlr_edan_shared_control_converted_externally_to_rlds", 1.0), + ("iamlab_cmu_pickup_insert_converted_externally_to_rlds", 1.0), + ("utaustin_mutex", 1.0), + ("berkeley_fanuc_manipulation", 2.0), + ("cmu_stretch", 1.0), + ## New Datasets in MagicSoup++ + ("bc_z", 0.2), # Note: use v0.1.0 --> later versions broken + ("fmb_dataset", 1.0), + ("dobbe", 0.2), + ("droid", 0.06), + ], + + "oxe_magic_soup_plus_minus": [ + ("fractal20220817_data", 1.0), # Google RT-1 Robot Data (Large-Scale) + ("kuka", 0.8341046294), + ("bridge_orig", 1.0), # Original Version of Bridge V2 from Project Website + ("taco_play", 2.0), + ("jaco_play", 1.0), + ("berkeley_cable_routing", 1.0), + ("roboturk", 2.0), + ("viola", 2.0), + ("berkeley_autolab_ur5", 2.0), + ("toto", 1.0), + # ("language_table", 0.1), + ("stanford_hydra_dataset_converted_externally_to_rlds", 2.0), + ("austin_buds_dataset_converted_externally_to_rlds", 1.0), + ("nyu_franka_play_dataset_converted_externally_to_rlds", 3.0), + ("furniture_bench_dataset_converted_externally_to_rlds", 0.1), + ("ucsd_kitchen_dataset_converted_externally_to_rlds", 2.0), + ("austin_sailor_dataset_converted_externally_to_rlds", 1.0), + ("austin_sirius_dataset_converted_externally_to_rlds", 1.0), + ("dlr_edan_shared_control_converted_externally_to_rlds", 1.0), + ("iamlab_cmu_pickup_insert_converted_externally_to_rlds", 1.0), + ("utaustin_mutex", 1.0), + ("berkeley_fanuc_manipulation", 2.0), + ("cmu_stretch", 1.0), + ## New Datasets in MagicSoup++ + ("bc_z", 0.2), # Note: use v0.1.0 --> later versions broken + ("fmb_dataset", 1.0), + ("dobbe", 0.2), + # ("droid", 0.06), + ], + + # === T-DROID Dataset === + "tdroid_carrot_in_bowl": [ + ("tdroid_carrot_in_bowl", 1.0), + ], + "tdroid_pour_corn_in_pot": [ + ("tdroid_pour_corn_in_pot", 1.0), + ], + "tdroid_flip_pot_upright": [ + ("tdroid_flip_pot_upright", 1.0), + ], + "tdroid_move_object_onto_plate": [ + ("tdroid_move_object_onto_plate", 1.0), + ], + "tdroid_knock_object_over": [ + ("tdroid_knock_object_over", 1.0), + ], + "tdroid_cover_object_with_towel": [ + ("tdroid_cover_object_with_towel", 1.0), + ], + + # === DROID Finetuning Datasets === + "droid_wipe": [ + ("droid_wipe", 1.0), + ], + + # === LIBERO Datasets (Modified Versions) === + "libero_spatial_no_noops": [ + ("libero_spatial_no_noops", 1.0), + ], + "libero_object_no_noops": [ + ("libero_object_no_noops", 1.0), + ], + "libero_goal_no_noops": [ + ("libero_goal_no_noops", 1.0), + ], + "libero_10_no_noops": [ + ("libero_10_no_noops", 1.0), + ], + "libero_4_task_suites_no_noops": [ + ("libero_spatial_no_noops", 1.0), + ("libero_object_no_noops", 1.0), + ("libero_goal_no_noops", 1.0), + ("libero_10_no_noops", 1.0), + ], + + # === ALOHA Fine-Tuning Datasets === + "aloha1_fold_shorts_20_demos": [ + ("aloha1_fold_shorts_20_demos", 1.0), + ], + "aloha1_fold_shirt_30_demos": [ + ("aloha1_fold_shirt_30_demos", 1.0), + ], + "aloha1_scoop_X_into_bowl_45_demos": [ + ("aloha1_scoop_X_into_bowl_45_demos", 1.0), + ], + "aloha1_put_X_into_pot_300_demos": [ + ("aloha1_put_X_into_pot_300_demos", 1.0), + ], +# fmt: on +} diff --git a/prismatic/vla/datasets/rlds/oxe/transforms.py b/prismatic/vla/datasets/rlds/oxe/transforms.py new file mode 100644 index 0000000000000000000000000000000000000000..35efc664c0ef3dca90d28ab72975df17854c2cd0 --- /dev/null +++ b/prismatic/vla/datasets/rlds/oxe/transforms.py @@ -0,0 +1,949 @@ +""" +transforms.py + +Defines a registry of per-dataset standardization transforms for each dataset in Open-X Embodiment. + +Transforms adopt the following structure: + Input: Dictionary of *batched* features (i.e., has leading time dimension) + Output: Dictionary `step` =>> { + "observation": { + + State (in chosen state representation) + }, + "action": Action (in chosen action representation), + "language_instruction": str + } +""" + +from typing import Any, Dict + +import tensorflow as tf + +from prismatic.vla.datasets.rlds.oxe.utils.droid_utils import droid_baseact_transform, droid_finetuning_transform +from prismatic.vla.datasets.rlds.utils.data_utils import ( + binarize_gripper_actions, + invert_gripper_actions, + rel2abs_gripper_actions, + relabel_bridge_actions, +) + + +def bridge_oxe_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]: + """ + Applies to version of Bridge V2 in Open X-Embodiment mixture. + + Note =>> In original Bridge V2 dataset, the first timestep has an all-zero action, so we remove it! + """ + for key in trajectory.keys(): + if key == "traj_metadata": + continue + elif key in ["observation", "action"]: + for key2 in trajectory[key]: + trajectory[key][key2] = trajectory[key][key2][1:] + else: + trajectory[key] = trajectory[key][1:] + + trajectory["action"] = tf.concat( + ( + trajectory["action"]["world_vector"], + trajectory["action"]["rotation_delta"], + tf.cast(trajectory["action"]["open_gripper"][:, None], tf.float32), + ), + axis=-1, + ) + trajectory["language_instruction"] = trajectory["observation"]["natural_language_instruction"] + trajectory = relabel_bridge_actions(trajectory) + trajectory["observation"]["EEF_state"] = trajectory["observation"]["state"][:, :6] + trajectory["observation"]["gripper_state"] = trajectory["observation"]["state"][:, -1:] + return trajectory + + +def bridge_orig_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]: + """ + Applies to original version of Bridge V2 from the official project website. + + Note =>> In original Bridge V2 dataset, the first timestep has an all-zero action, so we remove it! + """ + for key in trajectory.keys(): + if key == "traj_metadata": + continue + elif key == "observation": + for key2 in trajectory[key]: + trajectory[key][key2] = trajectory[key][key2][1:] + else: + trajectory[key] = trajectory[key][1:] + + trajectory["action"] = tf.concat( + [ + trajectory["action"][:, :6], + binarize_gripper_actions(trajectory["action"][:, -1])[:, None], + ], + axis=1, + ) + trajectory = relabel_bridge_actions(trajectory) + trajectory["observation"]["EEF_state"] = trajectory["observation"]["state"][:, :6] + trajectory["observation"]["gripper_state"] = trajectory["observation"]["state"][:, -1:] + return trajectory + + +def ppgm_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]: + trajectory["action"] = tf.concat( + [ + trajectory["action"][:, :6], + binarize_gripper_actions(trajectory["action"][:, -1])[:, None], + ], + axis=1, + ) + trajectory["observation"]["EEF_state"] = trajectory["observation"]["cartesian_position"][:, :6] + trajectory["observation"]["gripper_state"] = trajectory["observation"]["gripper_position"][:, -1:] + return trajectory + + +def rt1_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]: + # make gripper action absolute action, +1 = open, 0 = close + gripper_action = trajectory["action"]["gripper_closedness_action"][:, 0] + gripper_action = rel2abs_gripper_actions(gripper_action) + + trajectory["action"] = tf.concat( + ( + trajectory["action"]["world_vector"], + trajectory["action"]["rotation_delta"], + gripper_action[:, None], + ), + axis=-1, + ) + trajectory["language_instruction"] = trajectory["observation"]["natural_language_instruction"] + return trajectory + + +def kuka_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]: + # make gripper action absolute action, +1 = open, 0 = close + gripper_action = trajectory["action"]["gripper_closedness_action"][:, 0] + gripper_action = rel2abs_gripper_actions(gripper_action) + + trajectory["action"] = tf.concat( + ( + trajectory["action"]["world_vector"], + trajectory["action"]["rotation_delta"], + gripper_action[:, None], + ), + axis=-1, + ) + # decode compressed state + eef_value = tf.io.decode_compressed( + trajectory["observation"]["clip_function_input/base_pose_tool_reached"], + compression_type="ZLIB", + ) + eef_value = tf.io.decode_raw(eef_value, tf.float32) + trajectory["observation"]["clip_function_input/base_pose_tool_reached"] = tf.reshape(eef_value, (-1, 7)) + gripper_value = tf.io.decode_compressed(trajectory["observation"]["gripper_closed"], compression_type="ZLIB") + gripper_value = tf.io.decode_raw(gripper_value, tf.float32) + trajectory["observation"]["gripper_closed"] = tf.reshape(gripper_value, (-1, 1)) + # trajectory["language_instruction"] = tf.fill( + # tf.shape(trajectory["observation"]["natural_language_instruction"]), "" + # ) # delete uninformative language instruction + trajectory["language_instruction"] = trajectory["observation"]["natural_language_instruction"] + return trajectory + + +def taco_play_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]: + trajectory["observation"]["state_eef"] = trajectory["observation"]["robot_obs"][:, :6] + trajectory["observation"]["state_gripper"] = trajectory["observation"]["robot_obs"][:, 7:8] + trajectory["action"] = trajectory["action"]["rel_actions_world"] + + # invert gripper action + clip, +1 = open, 0 = close + trajectory["action"] = tf.concat( + ( + trajectory["action"][:, :6], + tf.clip_by_value(trajectory["action"][:, -1:], 0, 1), + ), + axis=-1, + ) + + trajectory["language_instruction"] = trajectory["observation"]["natural_language_instruction"] + return trajectory + + +def jaco_play_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]: + trajectory["observation"]["state_eef"] = trajectory["observation"]["end_effector_cartesian_pos"][:, :6] + trajectory["observation"]["state_gripper"] = trajectory["observation"]["end_effector_cartesian_pos"][:, -1:] + + # make gripper action absolute action, +1 = open, 0 = close + gripper_action = trajectory["action"]["gripper_closedness_action"][:, 0] + gripper_action = rel2abs_gripper_actions(gripper_action) + + trajectory["action"] = tf.concat( + ( + trajectory["action"]["world_vector"], + tf.zeros_like(trajectory["action"]["world_vector"]), + gripper_action[:, None], + ), + axis=-1, + ) + trajectory["language_instruction"] = trajectory["observation"]["natural_language_instruction"] + return trajectory + + +def berkeley_cable_routing_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]: + trajectory["action"] = tf.concat( + ( + trajectory["action"]["world_vector"], + trajectory["action"]["rotation_delta"], + tf.zeros_like(trajectory["action"]["world_vector"][:, :1]), + ), + axis=-1, + ) + # trajectory["language_instruction"] = tf.fill( + # tf.shape(trajectory["observation"]["natural_language_instruction"]), "" + # ) # delete uninformative language instruction + trajectory["language_instruction"] = trajectory["observation"]["natural_language_instruction"] + return trajectory + + +def roboturk_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]: + # invert absolute gripper action, +1 = open, 0 = close + gripper_action = invert_gripper_actions(tf.clip_by_value(trajectory["action"]["gripper_closedness_action"], 0, 1)) + + trajectory["action"] = tf.concat( + ( + trajectory["action"]["world_vector"], + trajectory["action"]["rotation_delta"], + gripper_action, + ), + axis=-1, + ) + # trajectory["language_instruction"] = tf.fill( + # tf.shape(trajectory["observation"]["natural_language_instruction"]), "" + # ) # delete uninformative language instruction + trajectory["language_instruction"] = trajectory["observation"]["natural_language_instruction"] + return trajectory + + +def nyu_door_opening_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]: + # make gripper action absolute action, +1 = open, 0 = close + gripper_action = trajectory["action"]["gripper_closedness_action"][:, 0] + gripper_action = rel2abs_gripper_actions(gripper_action) + + trajectory["action"] = tf.concat( + ( + trajectory["action"]["world_vector"], + trajectory["action"]["rotation_delta"], + gripper_action[:, None], + ), + axis=-1, + ) + # trajectory["language_instruction"] = tf.fill( + # tf.shape(trajectory["observation"]["natural_language_instruction"]), "" + # ) # delete uninformative language instruction + trajectory["language_instruction"] = trajectory["observation"]["natural_language_instruction"] + return trajectory + + +def viola_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]: + # make gripper action, +1 = open, 0 = close + gripper_action = trajectory["action"]["gripper_closedness_action"][:, None] + gripper_action = tf.clip_by_value(gripper_action, 0, 1) + gripper_action = invert_gripper_actions(gripper_action) + + trajectory["action"] = tf.concat( + ( + trajectory["action"]["world_vector"], + trajectory["action"]["rotation_delta"], + gripper_action, + ), + axis=-1, + ) + # trajectory["language_instruction"] = tf.fill( + # tf.shape(trajectory["observation"]["natural_language_instruction"]), "" + # ) # delete uninformative language instruction + trajectory["language_instruction"] = trajectory["observation"]["natural_language_instruction"] + return trajectory + + +def berkeley_autolab_ur5_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]: + trajectory["observation"]["state"] = trajectory["observation"]["robot_state"][:, 6:14] + trajectory["observation"]["depth"] = trajectory["observation"].pop("image_with_depth") + + # make gripper action absolute action, +1 = open, 0 = close + gripper_action = trajectory["action"]["gripper_closedness_action"] + gripper_action = rel2abs_gripper_actions(gripper_action) + + trajectory["action"] = tf.concat( + ( + trajectory["action"]["world_vector"], + trajectory["action"]["rotation_delta"], + gripper_action[:, None], + ), + axis=-1, + ) + trajectory["language_instruction"] = trajectory["observation"]["natural_language_instruction"] + return trajectory + + +def toto_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]: + trajectory["action"] = tf.concat( + ( + trajectory["action"]["world_vector"], + trajectory["action"]["rotation_delta"], + tf.cast(trajectory["action"]["open_gripper"][:, None], tf.float32), + ), + axis=-1, + ) + # trajectory["language_instruction"] = tf.fill( + # tf.shape(trajectory["observation"]["natural_language_instruction"]), "" + # ) # delete uninformative language instruction + trajectory["language_instruction"] = trajectory["observation"]["natural_language_instruction"] + return trajectory + + +def language_table_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]: + # default to "open" gripper + trajectory["action"] = tf.concat( + ( + trajectory["action"], + tf.zeros_like(trajectory["action"]), + tf.zeros_like(trajectory["action"]), + tf.ones_like(trajectory["action"][:, :1]), + ), + axis=-1, + ) + + # decode language instruction + instruction_bytes = trajectory["observation"]["instruction"] + instruction_encoded = tf.strings.unicode_encode(instruction_bytes, output_encoding="UTF-8") + # Remove trailing padding --> convert RaggedTensor to regular Tensor. + trajectory["language_instruction"] = tf.strings.split(instruction_encoded, "\x00")[:, :1].to_tensor()[:, 0] + return trajectory + + +def pusht_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]: + trajectory["action"] = tf.concat( + ( + trajectory["action"]["world_vector"], + trajectory["action"]["rotation_delta"], + trajectory["action"]["gripper_closedness_action"][:, None], + ), + axis=-1, + ) + trajectory["language_instruction"] = trajectory["observation"]["natural_language_instruction"] + return trajectory + + +def stanford_kuka_multimodal_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]: + trajectory["observation"]["depth_image"] = trajectory["observation"]["depth_image"][..., 0] + trajectory["action"] = tf.concat( + ( + trajectory["action"][:, :3], + tf.zeros_like(trajectory["action"][:, :3]), + trajectory["action"][:, -1:], + ), + axis=-1, + ) + return trajectory + + +def nyu_rot_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]: + trajectory["observation"]["eef_state"] = trajectory["observation"]["state"][..., :6] + trajectory["observation"]["gripper_state"] = trajectory["observation"]["state"][..., -1:] + trajectory["action"] = trajectory["action"][..., :7] + return trajectory + + +def stanford_hydra_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]: + # invert gripper action, +1 = open, 0 = close + trajectory["action"] = tf.concat( + ( + trajectory["action"][:, :6], + invert_gripper_actions(trajectory["action"][:, -1:]), + ), + axis=-1, + ) + + trajectory["observation"]["eef_state"] = tf.concat( + ( + trajectory["observation"]["state"][:, :3], + trajectory["observation"]["state"][:, 7:10], + ), + axis=-1, + ) + trajectory["observation"]["gripper_state"] = trajectory["observation"]["state"][:, -3:-2] + # trajectory["language_instruction"] = tf.fill( + # tf.shape(trajectory["language_instruction"]), "" + # ) # delete uninformative language instruction + return trajectory + + +def austin_buds_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]: + # invert gripper action + clip, +1 = open, 0 = close + trajectory["action"] = tf.concat( + ( + trajectory["action"][:, :6], + invert_gripper_actions(tf.clip_by_value(trajectory["action"][:, -1:], 0, 1)), + ), + axis=-1, + ) + + trajectory["observation"]["state"] = trajectory["observation"]["state"][:, :8] + # trajectory["language_instruction"] = tf.fill( + # tf.shape(trajectory["language_instruction"]), "" + # ) # delete uninformative language instruction + return trajectory + + +def nyu_franka_play_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]: + trajectory["observation"]["depth"] = tf.cast(trajectory["observation"]["depth"][..., 0], tf.float32) + trajectory["observation"]["depth_additional_view"] = tf.cast( + trajectory["observation"]["depth_additional_view"][..., 0], tf.float32 + ) + trajectory["observation"]["eef_state"] = trajectory["observation"]["state"][:, -6:] + + # clip gripper action, +1 = open, 0 = close + trajectory["action"] = tf.concat( + ( + trajectory["action"][:, -8:-2], + tf.clip_by_value(trajectory["action"][:, -2:-1], 0, 1), + ), + axis=-1, + ) + + # trajectory["language_instruction"] = tf.fill( + # tf.shape(trajectory["language_instruction"]), "" + # ) # delete uninformative language instruction + return trajectory + + +def maniskill_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]: + trajectory["observation"]["gripper_state"] = trajectory["observation"]["state"][..., 7:8] + return trajectory + + +def furniture_bench_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]: + import tensorflow_graphics.geometry.transformation as tft + + trajectory["observation"]["state"] = tf.concat( + ( + trajectory["observation"]["state"][:, :7], + trajectory["observation"]["state"][:, -1:], + ), + axis=-1, + ) + + # invert gripper action + clip, +1 = open, 0 = close + trajectory["action"] = tf.concat( + ( + trajectory["action"][:, :3], + tft.euler.from_quaternion(trajectory["action"][:, 3:7]), + invert_gripper_actions(tf.clip_by_value(trajectory["action"][:, -1:], 0, 1)), + ), + axis=-1, + ) + return trajectory + + +def cmu_franka_exploration_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]: + trajectory["action"] = trajectory["action"][..., :-1] + return trajectory + + +def ucsd_kitchen_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]: + trajectory["observation"]["joint_state"] = trajectory["observation"]["state"][:, :7] + trajectory["action"] = trajectory["action"][..., :-1] + return trajectory + + +def ucsd_pick_place_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]: + trajectory["observation"]["eef_state"] = trajectory["observation"]["state"][:, :6] + trajectory["observation"]["gripper_state"] = trajectory["observation"]["state"][:, -1:] + trajectory["action"] = tf.concat( + ( + trajectory["action"][:, :3], + tf.zeros_like(trajectory["action"][:, :3]), + trajectory["action"][:, -1:], + ), + axis=-1, + ) + return trajectory + + +def austin_sailor_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]: + # invert gripper action + clip, +1 = open, 0 = close + trajectory["action"] = tf.concat( + ( + trajectory["action"][:, :6], + invert_gripper_actions(tf.clip_by_value(trajectory["action"][:, -1:], 0, 1)), + ), + axis=-1, + ) + + # trajectory["language_instruction"] = tf.fill( + # tf.shape(trajectory["language_instruction"]), "" + # ) # delete uninformative language instruction + return trajectory + + +def austin_sirius_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]: + # invert gripper action + clip, +1 = open, 0 = close + trajectory["action"] = tf.concat( + ( + trajectory["action"][:, :6], + invert_gripper_actions(tf.clip_by_value(trajectory["action"][:, -1:], 0, 1)), + ), + axis=-1, + ) + + # trajectory["language_instruction"] = tf.fill( + # tf.shape(trajectory["language_instruction"]), "" + # ) # delete uninformative language instruction + return trajectory + + +def bc_z_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]: + trajectory["action"] = tf.concat( + ( + trajectory["action"]["future/xyz_residual"][:, :3], + trajectory["action"]["future/axis_angle_residual"][:, :3], + invert_gripper_actions(tf.cast(trajectory["action"]["future/target_close"][:, :1], tf.float32)), + ), + axis=-1, + ) + trajectory["language_instruction"] = trajectory["observation"]["natural_language_instruction"] + return trajectory + + +def tokyo_pr2_opening_fridge_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]: + trajectory["observation"]["eef_state"] = trajectory["observation"]["state"][:, :6] + trajectory["observation"]["gripper_state"] = trajectory["observation"]["state"][:, -1:] + trajectory["action"] = trajectory["action"][..., :-1] + return trajectory + + +def tokyo_pr2_tabletop_manipulation_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]: + trajectory["observation"]["eef_state"] = trajectory["observation"]["state"][:, :6] + trajectory["observation"]["gripper_state"] = trajectory["observation"]["state"][:, -1:] + trajectory["action"] = trajectory["action"][..., :-1] + return trajectory + + +def utokyo_xarm_pick_place_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]: + return trajectory + + +def utokyo_xarm_bimanual_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]: + trajectory["action"] = trajectory["action"][..., -7:] + return trajectory + + +def robo_net_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]: + trajectory["observation"]["eef_state"] = tf.concat( + ( + trajectory["observation"]["state"][:, :4], + tf.zeros_like(trajectory["observation"]["state"][:, :2]), + ), + axis=-1, + ) + trajectory["observation"]["gripper_state"] = trajectory["observation"]["state"][:, -1:] + trajectory["action"] = tf.concat( + ( + trajectory["action"][:, :4], + tf.zeros_like(trajectory["action"][:, :2]), + trajectory["action"][:, -1:], + ), + axis=-1, + ) + return trajectory + + +def berkeley_mvp_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]: + return trajectory + + +def berkeley_rpt_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]: + return trajectory + + +def kaist_nonprehensible_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]: + trajectory["observation"]["state"] = trajectory["observation"]["state"][:, -7:] + trajectory["action"] = tf.concat( + ( + trajectory["action"][:, :6], + tf.zeros_like(trajectory["action"][:, :1]), + ), + axis=-1, + ) + return trajectory + + +def stanford_mask_vit_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]: + trajectory["observation"]["eef_state"] = tf.concat( + ( + trajectory["observation"]["end_effector_pose"][:, :4], + tf.zeros_like(trajectory["observation"]["end_effector_pose"][:, :2]), + ), + axis=-1, + ) + trajectory["observation"]["gripper_state"] = trajectory["observation"]["end_effector_pose"][:, -1:] + trajectory["action"] = tf.concat( + ( + trajectory["action"][:, :4], + tf.zeros_like(trajectory["action"][:, :2]), + trajectory["action"][:, -1:], + ), + axis=-1, + ) + return trajectory + + +def tokyo_lsmo_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]: + trajectory["observation"]["eef_state"] = trajectory["observation"]["state"][:, :6] + trajectory["observation"]["gripper_state"] = trajectory["observation"]["state"][:, -1:] + return trajectory + + +def dlr_sara_pour_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]: + return trajectory + + +def dlr_sara_grid_clamp_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]: + trajectory["observation"]["state"] = trajectory["observation"]["state"][:, :6] + return trajectory + + +def dlr_edan_shared_control_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]: + # invert gripper action, +1 = open, 0 = close + trajectory["action"] = tf.concat( + ( + trajectory["action"][:, :6], + invert_gripper_actions(trajectory["action"][:, -1:]), + ), + axis=-1, + ) + return trajectory + + +def asu_table_top_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]: + trajectory["observation"]["eef_state"] = trajectory["ground_truth_states"]["EE"] + trajectory["observation"]["gripper_state"] = trajectory["observation"]["state"][:, -1:] + return trajectory + + +def robocook_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]: + trajectory["observation"]["eef_state"] = trajectory["observation"]["state"][:, :6] + trajectory["observation"]["gripper_state"] = trajectory["observation"]["state"][:, -1:] + return trajectory + + +def imperial_wristcam_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]: + trajectory["action"] = trajectory["action"][..., :-1] + return trajectory + + +def iamlab_pick_insert_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]: + import tensorflow_graphics.geometry.transformation as tft + + trajectory["observation"]["joint_state"] = trajectory["observation"]["state"][:, :7] + trajectory["observation"]["gripper_state"] = trajectory["observation"]["state"][:, 7:8] + trajectory["action"] = tf.concat( + ( + trajectory["action"][:, :3], + tft.euler.from_quaternion(trajectory["action"][:, 3:7]), + trajectory["action"][:, 7:8], + ), + axis=-1, + ) + return trajectory + + +def uiuc_d3field_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]: + trajectory["action"] = tf.concat( + ( + trajectory["action"], + tf.zeros_like(trajectory["action"]), + tf.zeros_like(trajectory["action"][:, :1]), + ), + axis=-1, + ) + return trajectory + + +def utaustin_mutex_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]: + trajectory["observation"]["state"] = trajectory["observation"]["state"][:, :8] + + # invert gripper action + clip, +1 = open, 0 = close + trajectory["action"] = tf.concat( + ( + trajectory["action"][:, :6], + invert_gripper_actions(tf.clip_by_value(trajectory["action"][:, -1:], 0, 1)), + ), + axis=-1, + ) + + # trajectory["language_instruction"] = tf.fill( + # tf.shape(trajectory["language_instruction"]), "" + # ) # delete uninformative language instruction + return trajectory + + +def berkeley_fanuc_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]: + trajectory["observation"]["joint_state"] = trajectory["observation"]["state"][:, :6] + trajectory["observation"]["gripper_state"] = trajectory["observation"]["state"][:, 6:7] + + # dataset does not store gripper actions, so use gripper state info, invert so +1 = open, 0 = close + trajectory["action"] = tf.concat( + ( + trajectory["action"], + invert_gripper_actions(trajectory["observation"]["gripper_state"]), + ), + axis=-1, + ) + return trajectory + + +def cmu_playing_with_food_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]: + import tensorflow_graphics.geometry.transformation as tft + + trajectory["action"] = tf.concat( + ( + trajectory["action"][:, :3], + tft.euler.from_quaternion(trajectory["action"][:, 3:7]), + trajectory["action"][:, -1:], + ), + axis=-1, + ) + return trajectory + + +def playfusion_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]: + trajectory["action"] = tf.concat( + ( + trajectory["action"][:, :3], + trajectory["action"][:, -4:], + ), + axis=-1, + ) + return trajectory + + +def cmu_stretch_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]: + trajectory["observation"]["eef_state"] = tf.concat( + ( + trajectory["observation"]["state"][:, :3], + tf.zeros_like(trajectory["observation"]["state"][:, :3]), + ), + axis=-1, + ) + trajectory["observation"]["gripper_state"] = trajectory["observation"]["state"][:, -1:] + trajectory["action"] = trajectory["action"][..., :-1] + return trajectory + + +def gnm_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]: + trajectory["observation"]["state"] = tf.concat( + ( + trajectory["observation"]["position"], + tf.zeros_like(trajectory["observation"]["state"][:, :3]), + trajectory["observation"]["yaw"], + ), + axis=-1, + ) + trajectory["action"] = tf.concat( + ( + trajectory["action"], + tf.zeros_like(trajectory["action"]), + tf.zeros_like(trajectory["action"]), + tf.zeros_like(trajectory["action"][:, :1]), + ), + axis=-1, + ) + return trajectory + + +def fmb_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]: + # every input feature is batched, ie has leading batch dimension + trajectory["observation"]["proprio"] = tf.concat( + ( + trajectory["observation"]["eef_pose"], + trajectory["observation"]["state_gripper_pose"][..., None], + ), + axis=-1, + ) + return trajectory + + +def dobbe_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]: + # every input feature is batched, ie has leading batch dimension + trajectory["observation"]["proprio"] = trajectory["observation"]["state"] + return trajectory + + +def roboset_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]: + # every input feature is batched, ie has leading batch dimension + trajectory["observation"]["proprio"] = trajectory["observation"]["state"] + + # gripper action is in -1...1 --> clip to 0...1, flip + gripper_action = trajectory["action"][:, -1:] + gripper_action = invert_gripper_actions(tf.clip_by_value(gripper_action, 0, 1)) + + trajectory["action"] = tf.concat( + ( + trajectory["action"][:, :7], + gripper_action, + ), + axis=-1, + ) + return trajectory + + +def rh20t_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]: + trajectory["action"] = tf.concat( + ( + trajectory["action"]["tcp_base"], + tf.cast(trajectory["action"]["gripper"][:, None], tf.float32), + ), + axis=-1, + ) + trajectory["observation"]["proprio"] = tf.concat( + ( + trajectory["observation"]["tcp_base"], + trajectory["observation"]["gripper_width"][..., None], + ), + axis=-1, + ) + return trajectory + + +def tdroid_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]: + trajectory["action"] = tf.concat( + [ + trajectory["action"][:, :6], + binarize_gripper_actions(trajectory["action"][:, -1])[:, None], + ], + axis=1, + ) + trajectory["observation"]["EEF_state"] = trajectory["observation"]["cartesian_position"][:, :6] + trajectory["observation"]["gripper_state"] = trajectory["observation"]["gripper_position"][:, -1:] + return trajectory + + +def libero_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]: + # gripper action is in -1 (open)...1 (close) --> clip to 0...1, flip --> +1 = open, 0 = close + gripper_action = trajectory["action"][:, -1:] + gripper_action = invert_gripper_actions(tf.clip_by_value(gripper_action, 0, 1)) + + trajectory["action"] = tf.concat( + [ + trajectory["action"][:, :6], + gripper_action, + ], + axis=1, + ) + trajectory["observation"]["EEF_state"] = trajectory["observation"]["state"][:, :6] + trajectory["observation"]["gripper_state"] = trajectory["observation"]["state"][:, -2:] # 2D gripper state + return trajectory + + +def aloha_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]: + # Don't need to do anything because dataset is already in the correct format + return trajectory + + +def conflict_maniskill_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]: + """ + ManiSkill conflict dataset (Panda, LeRobot v2.1 parquet → RLDS). + - action: 8-dim absolute [joint0..6 (rad), gripper_open (0=closed, 1=open)] + - state: 8-dim [joint0..6 (rad), gripper_width (m)] + Splits state into joint_state (7-dim) and gripper_state (1-dim) for the JOINT StateEncoding path. + No gripper binarization needed -- already 0/1. + """ + trajectory["observation"]["joint_state"] = trajectory["observation"]["state"][:, :7] + trajectory["observation"]["gripper_state"] = trajectory["observation"]["state"][:, 7:] + # actions are already [joint0..6, gripper_open]; pass through as-is + return trajectory + + +# === Registry === +OXE_STANDARDIZATION_TRANSFORMS = { + "bridge_oxe": bridge_oxe_dataset_transform, + "bridge_orig": bridge_orig_dataset_transform, + "bridge_dataset": bridge_orig_dataset_transform, + "ppgm": ppgm_dataset_transform, + "ppgm_static": ppgm_dataset_transform, + "ppgm_wrist": ppgm_dataset_transform, + "fractal20220817_data": rt1_dataset_transform, + "kuka": kuka_dataset_transform, + "taco_play": taco_play_dataset_transform, + "jaco_play": jaco_play_dataset_transform, + "berkeley_cable_routing": berkeley_cable_routing_dataset_transform, + "roboturk": roboturk_dataset_transform, + "nyu_door_opening_surprising_effectiveness": nyu_door_opening_dataset_transform, + "viola": viola_dataset_transform, + "berkeley_autolab_ur5": berkeley_autolab_ur5_dataset_transform, + "toto": toto_dataset_transform, + "language_table": language_table_dataset_transform, + "columbia_cairlab_pusht_real": pusht_dataset_transform, + "stanford_kuka_multimodal_dataset_converted_externally_to_rlds": stanford_kuka_multimodal_dataset_transform, + "nyu_rot_dataset_converted_externally_to_rlds": nyu_rot_dataset_transform, + "stanford_hydra_dataset_converted_externally_to_rlds": stanford_hydra_dataset_transform, + "austin_buds_dataset_converted_externally_to_rlds": austin_buds_dataset_transform, + "nyu_franka_play_dataset_converted_externally_to_rlds": nyu_franka_play_dataset_transform, + "maniskill_dataset_converted_externally_to_rlds": maniskill_dataset_transform, + "furniture_bench_dataset_converted_externally_to_rlds": furniture_bench_dataset_transform, + "cmu_franka_exploration_dataset_converted_externally_to_rlds": cmu_franka_exploration_dataset_transform, + "ucsd_kitchen_dataset_converted_externally_to_rlds": ucsd_kitchen_dataset_transform, + "ucsd_pick_and_place_dataset_converted_externally_to_rlds": ucsd_pick_place_dataset_transform, + "austin_sailor_dataset_converted_externally_to_rlds": austin_sailor_dataset_transform, + "austin_sirius_dataset_converted_externally_to_rlds": austin_sirius_dataset_transform, + "bc_z": bc_z_dataset_transform, + "utokyo_pr2_opening_fridge_converted_externally_to_rlds": tokyo_pr2_opening_fridge_dataset_transform, + "utokyo_pr2_tabletop_manipulation_converted_externally_to_rlds": tokyo_pr2_tabletop_manipulation_dataset_transform, + "utokyo_xarm_pick_and_place_converted_externally_to_rlds": utokyo_xarm_pick_place_dataset_transform, + "utokyo_xarm_bimanual_converted_externally_to_rlds": utokyo_xarm_bimanual_dataset_transform, + "robo_net": robo_net_dataset_transform, + "berkeley_mvp_converted_externally_to_rlds": berkeley_mvp_dataset_transform, + "berkeley_rpt_converted_externally_to_rlds": berkeley_rpt_dataset_transform, + "kaist_nonprehensile_converted_externally_to_rlds": kaist_nonprehensible_dataset_transform, + "stanford_mask_vit_converted_externally_to_rlds": stanford_mask_vit_dataset_transform, + "tokyo_u_lsmo_converted_externally_to_rlds": tokyo_lsmo_dataset_transform, + "dlr_sara_pour_converted_externally_to_rlds": dlr_sara_pour_dataset_transform, + "dlr_sara_grid_clamp_converted_externally_to_rlds": dlr_sara_grid_clamp_dataset_transform, + "dlr_edan_shared_control_converted_externally_to_rlds": dlr_edan_shared_control_dataset_transform, + "asu_table_top_converted_externally_to_rlds": asu_table_top_dataset_transform, + "stanford_robocook_converted_externally_to_rlds": robocook_dataset_transform, + "imperialcollege_sawyer_wrist_cam": imperial_wristcam_dataset_transform, + "iamlab_cmu_pickup_insert_converted_externally_to_rlds": iamlab_pick_insert_dataset_transform, + "uiuc_d3field": uiuc_d3field_dataset_transform, + "utaustin_mutex": utaustin_mutex_dataset_transform, + "berkeley_fanuc_manipulation": berkeley_fanuc_dataset_transform, + "cmu_playing_with_food": cmu_playing_with_food_dataset_transform, + "cmu_play_fusion": playfusion_dataset_transform, + "cmu_stretch": cmu_stretch_dataset_transform, + "berkeley_gnm_recon": gnm_dataset_transform, + "berkeley_gnm_cory_hall": gnm_dataset_transform, + "berkeley_gnm_sac_son": gnm_dataset_transform, + "droid": droid_baseact_transform, + "fmb_dataset": fmb_dataset_transform, + "dobbe": dobbe_dataset_transform, + "roboset": roboset_dataset_transform, + "rh20t": rh20t_dataset_transform, + ### T-DROID datasets + "tdroid_carrot_in_bowl": tdroid_dataset_transform, + "tdroid_pour_corn_in_pot": tdroid_dataset_transform, + "tdroid_flip_pot_upright": tdroid_dataset_transform, + "tdroid_move_object_onto_plate": tdroid_dataset_transform, + "tdroid_knock_object_over": tdroid_dataset_transform, + "tdroid_cover_object_with_towel": tdroid_dataset_transform, + ### DROID Finetuning datasets + "droid_wipe": droid_finetuning_transform, + ### LIBERO datasets (modified versions) + "libero_spatial_no_noops": libero_dataset_transform, + "libero_object_no_noops": libero_dataset_transform, + "libero_goal_no_noops": libero_dataset_transform, + "libero_10_no_noops": libero_dataset_transform, + "libero_4_task_suites_no_noops": libero_dataset_transform, + ### ALOHA fine-tuning datasets + "aloha1_fold_shorts_20_demos": aloha_dataset_transform, + "aloha1_fold_shirt_30_demos": aloha_dataset_transform, + "aloha1_scoop_X_into_bowl_45_demos": aloha_dataset_transform, + "aloha1_put_X_into_pot_300_demos": aloha_dataset_transform, + ### ManiSkill conflict fine-tuning datasets + "conflict_maniskill": conflict_maniskill_dataset_transform, +} diff --git a/prismatic/vla/datasets/rlds/oxe/utils/droid_utils.py b/prismatic/vla/datasets/rlds/oxe/utils/droid_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..44175a21cff6c3ae45f7596024852462ea40c68e --- /dev/null +++ b/prismatic/vla/datasets/rlds/oxe/utils/droid_utils.py @@ -0,0 +1,178 @@ +"""Episode transforms for DROID dataset.""" + +from typing import Any, Dict + +import tensorflow as tf +import tensorflow_graphics.geometry.transformation as tfg + + +def rmat_to_euler(rot_mat): + return tfg.euler.from_rotation_matrix(rot_mat) + + +def euler_to_rmat(euler): + return tfg.rotation_matrix_3d.from_euler(euler) + + +def invert_rmat(rot_mat): + return tfg.rotation_matrix_3d.inverse(rot_mat) + + +def rotmat_to_rot6d(mat): + """ + Converts rotation matrix to R6 rotation representation (first two rows in rotation matrix). + Args: + mat: rotation matrix + + Returns: 6d vector (first two rows of rotation matrix) + + """ + r6 = mat[..., :2, :] + r6_0, r6_1 = r6[..., 0, :], r6[..., 1, :] + r6_flat = tf.concat([r6_0, r6_1], axis=-1) + return r6_flat + + +def velocity_act_to_wrist_frame(velocity, wrist_in_robot_frame): + """ + Translates velocity actions (translation + rotation) from base frame of the robot to wrist frame. + Args: + velocity: 6d velocity action (3 x translation, 3 x rotation) + wrist_in_robot_frame: 6d pose of the end-effector in robot base frame + + Returns: 9d velocity action in robot wrist frame (3 x translation, 6 x rotation as R6) + + """ + R_frame = euler_to_rmat(wrist_in_robot_frame[:, 3:6]) + R_frame_inv = invert_rmat(R_frame) + + # world to wrist: dT_pi = R^-1 dT_rbt + vel_t = (R_frame_inv @ velocity[:, :3][..., None])[..., 0] + + # world to wrist: dR_pi = R^-1 dR_rbt R + dR = euler_to_rmat(velocity[:, 3:6]) + dR = R_frame_inv @ (dR @ R_frame) + dR_r6 = rotmat_to_rot6d(dR) + return tf.concat([vel_t, dR_r6], axis=-1) + + +def rand_swap_exterior_images(img1, img2): + """ + Randomly swaps the two exterior images (for training with single exterior input). + """ + return tf.cond(tf.random.uniform(shape=[]) > 0.5, lambda: (img1, img2), lambda: (img2, img1)) + + +def droid_baseact_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]: + """ + DROID dataset transformation for actions expressed in *base* frame of the robot. + """ + dt = trajectory["action_dict"]["cartesian_velocity"][:, :3] + dR = trajectory["action_dict"]["cartesian_velocity"][:, 3:6] + + trajectory["action"] = tf.concat( + ( + dt, + dR, + 1 - trajectory["action_dict"]["gripper_position"], + ), + axis=-1, + ) + trajectory["observation"]["exterior_image_1_left"], trajectory["observation"]["exterior_image_2_left"] = ( + rand_swap_exterior_images( + trajectory["observation"]["exterior_image_1_left"], + trajectory["observation"]["exterior_image_2_left"], + ) + ) + trajectory["observation"]["proprio"] = tf.concat( + ( + trajectory["observation"]["cartesian_position"], + trajectory["observation"]["gripper_position"], + ), + axis=-1, + ) + return trajectory + + +def droid_wristact_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]: + """ + DROID dataset transformation for actions expressed in *wrist* frame of the robot. + """ + wrist_act = velocity_act_to_wrist_frame( + trajectory["action_dict"]["cartesian_velocity"], trajectory["observation"]["cartesian_position"] + ) + trajectory["action"] = tf.concat( + ( + wrist_act, + trajectory["action_dict"]["gripper_position"], + ), + axis=-1, + ) + trajectory["observation"]["exterior_image_1_left"], trajectory["observation"]["exterior_image_2_left"] = ( + rand_swap_exterior_images( + trajectory["observation"]["exterior_image_1_left"], + trajectory["observation"]["exterior_image_2_left"], + ) + ) + trajectory["observation"]["proprio"] = tf.concat( + ( + trajectory["observation"]["cartesian_position"], + trajectory["observation"]["gripper_position"], + ), + axis=-1, + ) + return trajectory + + +def droid_finetuning_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]: + """ + DROID dataset transformation for actions expressed in *base* frame of the robot. + """ + dt = trajectory["action_dict"]["cartesian_velocity"][:, :3] + dR = trajectory["action_dict"]["cartesian_velocity"][:, 3:6] + trajectory["action"] = tf.concat( + ( + dt, + dR, + 1 - trajectory["action_dict"]["gripper_position"], + ), + axis=-1, + ) + trajectory["observation"]["proprio"] = tf.concat( + ( + trajectory["observation"]["cartesian_position"], + trajectory["observation"]["gripper_position"], + ), + axis=-1, + ) + return trajectory + + +def zero_action_filter(traj: Dict) -> bool: + """ + Filters transitions whose actions are all-0 (only relative actions, no gripper action). + Note: this filter is applied *after* action normalization, so need to compare to "normalized 0". + """ + DROID_Q01 = tf.convert_to_tensor( + [ + -0.7776297926902771, + -0.5803514122962952, + -0.5795090794563293, + -0.6464047729969025, + -0.7041108310222626, + -0.8895104378461838, + ] + ) + DROID_Q99 = tf.convert_to_tensor( + [ + 0.7597932070493698, + 0.5726242214441299, + 0.7351000607013702, + 0.6705610305070877, + 0.6464948207139969, + 0.8897542208433151, + ] + ) + DROID_NORM_0_ACT = 2 * (tf.zeros_like(traj["action"][:, :6]) - DROID_Q01) / (DROID_Q99 - DROID_Q01 + 1e-8) - 1 + + return tf.reduce_any(tf.math.abs(traj["action"][:, :6] - DROID_NORM_0_ACT) > 1e-5) diff --git a/prismatic/vla/datasets/rlds/traj_transforms.py b/prismatic/vla/datasets/rlds/traj_transforms.py new file mode 100644 index 0000000000000000000000000000000000000000..d2ae695abaa813e1687b1080f4b3d9b4b828fa60 --- /dev/null +++ b/prismatic/vla/datasets/rlds/traj_transforms.py @@ -0,0 +1,90 @@ +""" +traj_transforms.py + +Contains trajectory transforms used in the orca data pipeline. Trajectory transforms operate on a dictionary +that represents a single trajectory, meaning each tensor has the same leading dimension (the trajectory length). +""" + +import logging +from typing import Dict + +import tensorflow as tf + + +def chunk_act_obs(traj: Dict, window_size: int, future_action_window_size: int = 0) -> Dict: + """ + Chunks actions and observations into the given window_size. + + "observation" keys are given a new axis (at index 1) of size `window_size` containing `window_size - 1` + observations from the past and the current observation. "action" is given a new axis (at index 1) of size + `window_size + future_action_window_size` containing `window_size - 1` actions from the past, the current + action, and `future_action_window_size` actions from the future. "pad_mask" is added to "observation" and + indicates whether an observation should be considered padding (i.e. if it had come from a timestep + before the start of the trajectory). + """ + traj_len = tf.shape(traj["action"])[0] + action_dim = traj["action"].shape[-1] + effective_traj_len = traj_len - future_action_window_size + chunk_indices = tf.broadcast_to(tf.range(-window_size + 1, 1), [effective_traj_len, window_size]) + tf.broadcast_to( + tf.range(effective_traj_len)[:, None], [effective_traj_len, window_size] + ) + + action_chunk_indices = tf.broadcast_to( + tf.range(-window_size + 1, 1 + future_action_window_size), + [effective_traj_len, window_size + future_action_window_size], + ) + tf.broadcast_to( + tf.range(effective_traj_len)[:, None], + [effective_traj_len, window_size + future_action_window_size], + ) + + floored_chunk_indices = tf.maximum(chunk_indices, 0) + + goal_timestep = tf.fill([effective_traj_len], traj_len - 1) + + floored_action_chunk_indices = tf.minimum(tf.maximum(action_chunk_indices, 0), goal_timestep[:, None]) + + traj["observation"] = tf.nest.map_structure(lambda x: tf.gather(x, floored_chunk_indices), traj["observation"]) + traj["action"] = tf.gather(traj["action"], floored_action_chunk_indices) + + # indicates whether an entire observation is padding + traj["observation"]["pad_mask"] = chunk_indices >= 0 + + # Truncate other elements of the trajectory dict + traj["task"] = tf.nest.map_structure(lambda x: tf.gather(x, tf.range(effective_traj_len)), traj["task"]) + traj["dataset_name"] = tf.gather(traj["dataset_name"], tf.range(effective_traj_len)) + traj["absolute_action_mask"] = tf.gather(traj["absolute_action_mask"], tf.range(effective_traj_len)) + + return traj + + +def subsample(traj: Dict, subsample_length: int) -> Dict: + """Subsamples trajectories to the given length.""" + traj_len = tf.shape(traj["action"])[0] + if traj_len > subsample_length: + indices = tf.random.shuffle(tf.range(traj_len))[:subsample_length] + traj = tf.nest.map_structure(lambda x: tf.gather(x, indices), traj) + + return traj + + +def add_pad_mask_dict(traj: Dict) -> Dict: + """ + Adds a dictionary indicating which elements of the observation/task should be treated as padding. + =>> traj["observation"|"task"]["pad_mask_dict"] = {k: traj["observation"|"task"][k] is not padding} + """ + traj_len = tf.shape(traj["action"])[0] + + for key in ["observation", "task"]: + pad_mask_dict = {} + for subkey in traj[key]: + # Handles "language_instruction", "image_*", and "depth_*" + if traj[key][subkey].dtype == tf.string: + pad_mask_dict[subkey] = tf.strings.length(traj[key][subkey]) != 0 + + # All other keys should not be treated as padding + else: + pad_mask_dict[subkey] = tf.ones([traj_len], dtype=tf.bool) + + traj[key]["pad_mask_dict"] = pad_mask_dict + + return traj diff --git a/prismatic/vla/datasets/rlds/utils/__init__.py b/prismatic/vla/datasets/rlds/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/prismatic/vla/datasets/rlds/utils/data_utils.py b/prismatic/vla/datasets/rlds/utils/data_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..41b61bd12c9a325e549d21d19f45ecd43ffb0570 --- /dev/null +++ b/prismatic/vla/datasets/rlds/utils/data_utils.py @@ -0,0 +1,321 @@ +""" +data_utils.py + +Additional RLDS-specific data utilities. +""" + +import hashlib +import json +import os +from typing import Any, Callable, Dict, List, Optional, Tuple + +import dlimp as dl +import numpy as np +import tensorflow as tf +from tqdm import tqdm + +from prismatic.overwatch import initialize_overwatch +from prismatic.vla.constants import NormalizationType + +# Initialize Overwatch =>> Wraps `logging.Logger` +overwatch = initialize_overwatch(__name__) + + +def tree_map(fn: Callable, tree: Dict) -> Dict: + return {k: tree_map(fn, v) if isinstance(v, dict) else fn(v) for k, v in tree.items()} + + +def tree_merge(*trees: Dict) -> Dict: + merged = {} + for tree in trees: + for k, v in tree.items(): + if isinstance(v, dict): + merged[k] = tree_merge(merged.get(k, {}), v) + else: + merged[k] = v + return merged + + +def to_padding(tensor: tf.Tensor) -> tf.Tensor: + if tf.debugging.is_numeric_tensor(tensor): + return tf.zeros_like(tensor) + elif tensor.dtype == tf.string: + return tf.fill(tf.shape(tensor), "") + else: + raise ValueError(f"Cannot generate padding for tensor of type {tensor.dtype}.") + + +# === State / Action Processing Primitives === + + +# ruff: noqa: B023 +def normalize_action_and_proprio(traj: Dict, metadata: Dict, normalization_type: NormalizationType): + """Normalizes the action and proprio fields of a trajectory using the given metadata.""" + keys_to_normalize = {"action": "action", "proprio": "observation/proprio"} + + if normalization_type == NormalizationType.NORMAL: + for key, traj_key in keys_to_normalize.items(): + mask = metadata[key].get("mask", tf.ones_like(metadata[key]["mean"], dtype=tf.bool)) + traj = dl.transforms.selective_tree_map( + traj, + match=lambda k, _: k == traj_key, + map_fn=lambda x: tf.where(mask, (x - metadata[key]["mean"]) / (metadata[key]["std"] + 1e-8), x), + ) + + return traj + + elif normalization_type in [NormalizationType.BOUNDS, NormalizationType.BOUNDS_Q99]: + for key, traj_key in keys_to_normalize.items(): + if normalization_type == NormalizationType.BOUNDS: + low = metadata[key]["min"] + high = metadata[key]["max"] + elif normalization_type == NormalizationType.BOUNDS_Q99: + low = metadata[key]["q01"] + high = metadata[key]["q99"] + mask = metadata[key].get("mask", tf.ones_like(metadata[key]["min"], dtype=tf.bool)) + traj = dl.transforms.selective_tree_map( + traj, + match=lambda k, _: k == traj_key, + map_fn=lambda x: tf.where( + mask, + tf.clip_by_value(2 * (x - low) / (high - low + 1e-8) - 1, -1, 1), + x, + ), + ) + + # Note (Moo Jin): Map unused action dimensions (i.e., dimensions where min == max) to all 0s. + zeros_mask = metadata[key]["min"] == metadata[key]["max"] + traj = dl.transforms.selective_tree_map( + traj, match=lambda k, _: k == traj_key, map_fn=lambda x: tf.where(zeros_mask, 0.0, x) + ) + + return traj + + raise ValueError(f"Unknown Normalization Type {normalization_type}") + + +def binarize_gripper_actions(actions: tf.Tensor) -> tf.Tensor: + """ + Converts gripper actions from continuous to binary values (0 and 1). + + We exploit that fact that most of the time, the gripper is fully open (near 1.0) or fully closed (near 0.0). As it + transitions between the two, it sometimes passes through a few intermediate values. We relabel those intermediate + values based on the state that is reached _after_ those intermediate values. + + In the edge case that the trajectory ends with an intermediate value, we give up on binarizing and relabel that + chunk of intermediate values as the last action in the trajectory. + + The `scan_fn` implements the following logic: + new_actions = np.empty_like(actions) + carry = actions[-1] + for i in reversed(range(actions.shape[0])): + if in_between_mask[i]: + carry = carry + else: + carry = float(open_mask[i]) + new_actions[i] = carry + """ + open_mask, closed_mask = actions > 0.95, actions < 0.05 + in_between_mask = tf.logical_not(tf.logical_or(open_mask, closed_mask)) + is_open_float = tf.cast(open_mask, tf.float32) + + def scan_fn(carry, i): + return tf.cond(in_between_mask[i], lambda: tf.cast(carry, tf.float32), lambda: is_open_float[i]) + + return tf.scan(scan_fn, tf.range(tf.shape(actions)[0]), actions[-1], reverse=True) + + +def invert_gripper_actions(actions: tf.Tensor) -> tf.Tensor: + return 1 - actions + + +def rel2abs_gripper_actions(actions: tf.Tensor) -> tf.Tensor: + """ + Converts relative gripper actions (+1 for closing, -1 for opening) to absolute actions (0 = closed; 1 = open). + + Assumes that the first relative gripper is not redundant (i.e. close when already closed)! + """ + # Note =>> -1 for closing, 1 for opening, 0 for no change + opening_mask, closing_mask = actions < -0.1, actions > 0.1 + thresholded_actions = tf.where(opening_mask, 1, tf.where(closing_mask, -1, 0)) + + def scan_fn(carry, i): + return tf.cond(thresholded_actions[i] == 0, lambda: carry, lambda: thresholded_actions[i]) + + # If no relative grasp, assumes open for whole trajectory + start = -1 * thresholded_actions[tf.argmax(thresholded_actions != 0, axis=0)] + start = tf.cond(start == 0, lambda: 1, lambda: start) + + # Note =>> -1 for closed, 1 for open + new_actions = tf.scan(scan_fn, tf.range(tf.shape(actions)[0]), start) + new_actions = tf.cast(new_actions, tf.float32) / 2 + 0.5 + + return new_actions + + +# === Bridge-V2 =>> Dataset-Specific Transform === +def relabel_bridge_actions(traj: Dict[str, Any]) -> Dict[str, Any]: + """Relabels actions to use reached proprioceptive state; discards last timestep (no-action).""" + movement_actions = traj["observation"]["state"][1:, :6] - traj["observation"]["state"][:-1, :6] + traj_truncated = tf.nest.map_structure(lambda x: x[:-1], traj) + traj_truncated["action"] = tf.concat([movement_actions, traj["action"][:-1, -1:]], axis=1) + + return traj_truncated + + +# === RLDS Dataset Initialization Utilities === +def pprint_data_mixture(dataset_kwargs_list: List[Dict[str, Any]], dataset_weights: List[int]) -> None: + print("\n######################################################################################") + print(f"# Loading the following {len(dataset_kwargs_list)} datasets (incl. sampling weight):{'': >24} #") + for dataset_kwargs, weight in zip(dataset_kwargs_list, dataset_weights): + pad = 80 - len(dataset_kwargs["name"]) + print(f"# {dataset_kwargs['name']}: {weight:=>{pad}f} #") + print("######################################################################################\n") + + +def get_dataset_statistics( + dataset: dl.DLataset, + hash_dependencies: Tuple[str, ...], + save_dir: Optional[str] = None, +) -> Dict: + """ + Either computes the statistics of a dataset or loads them from a cache file if this function has been called before + with the same `hash_dependencies`. + + Currently, the statistics include the min/max/mean/std of the actions and proprio as well as the number of + transitions and trajectories in the dataset. + """ + unique_hash = hashlib.sha256("".join(hash_dependencies).encode("utf-8"), usedforsecurity=False).hexdigest() + + # Fallback local path for when data_dir is not writable or not provided + local_path = os.path.expanduser(os.path.join("~", ".cache", "orca", f"dataset_statistics_{unique_hash}.json")) + if save_dir is not None: + path = tf.io.gfile.join(save_dir, f"dataset_statistics_{unique_hash}.json") + else: + path = local_path + + # check if cache file exists and load + if tf.io.gfile.exists(path): + overwatch.info(f"Loading existing dataset statistics from {path}.") + with tf.io.gfile.GFile(path, "r") as f: + metadata = json.load(f) + return metadata + + if os.path.exists(local_path): + overwatch.info(f"Loading existing dataset statistics from {local_path}.") + with open(local_path, "r") as f: + metadata = json.load(f) + return metadata + + dataset = dataset.traj_map( + lambda traj: { + "action": traj["action"], + "proprio": ( + traj["observation"]["proprio"] if "proprio" in traj["observation"] else tf.zeros_like(traj["action"]) + ), + } + ) + + cardinality = dataset.cardinality().numpy() + if cardinality == tf.data.INFINITE_CARDINALITY: + raise ValueError("Cannot compute dataset statistics for infinite datasets.") + + overwatch.info("Computing dataset statistics. This may take a bit, but should only need to happen once.") + actions, proprios, num_transitions, num_trajectories = [], [], 0, 0 + for traj in tqdm(dataset.iterator(), total=cardinality if cardinality != tf.data.UNKNOWN_CARDINALITY else None): + actions.append(traj["action"]) + proprios.append(traj["proprio"]) + num_transitions += traj["action"].shape[0] + num_trajectories += 1 + + actions, proprios = np.concatenate(actions), np.concatenate(proprios) + metadata = { + "action": { + "mean": actions.mean(0).tolist(), + "std": actions.std(0).tolist(), + "max": actions.max(0).tolist(), + "min": actions.min(0).tolist(), + "q01": np.quantile(actions, 0.01, axis=0).tolist(), + "q99": np.quantile(actions, 0.99, axis=0).tolist(), + }, + "proprio": { + "mean": proprios.mean(0).tolist(), + "std": proprios.std(0).tolist(), + "max": proprios.max(0).tolist(), + "min": proprios.min(0).tolist(), + "q01": np.quantile(proprios, 0.01, axis=0).tolist(), + "q99": np.quantile(proprios, 0.99, axis=0).tolist(), + }, + "num_transitions": num_transitions, + "num_trajectories": num_trajectories, + } + + try: + with tf.io.gfile.GFile(path, "w") as f: + json.dump(metadata, f) + except tf.errors.PermissionDeniedError: + overwatch.warning(f"Could not write dataset statistics to {path}. Writing to {local_path} instead.") + os.makedirs(os.path.dirname(local_path), exist_ok=True) + with open(local_path, "w") as f: + json.dump(metadata, f) + + return metadata + + +def save_dataset_statistics(dataset_statistics, run_dir): + """Saves a `dataset_statistics.json` file.""" + out_path = run_dir / "dataset_statistics.json" + with open(out_path, "w") as f_json: + for _, stats in dataset_statistics.items(): + for k in stats["action"].keys(): + if isinstance(stats["action"][k], np.ndarray): + stats["action"][k] = stats["action"][k].tolist() + if "proprio" in stats: + for k in stats["proprio"].keys(): + if isinstance(stats["proprio"][k], np.ndarray): + stats["proprio"][k] = stats["proprio"][k].tolist() + if "num_trajectories" in stats: + if isinstance(stats["num_trajectories"], np.ndarray): + stats["num_trajectories"] = stats["num_trajectories"].item() + if "num_transitions" in stats: + if isinstance(stats["num_transitions"], np.ndarray): + stats["num_transitions"] = stats["num_transitions"].item() + json.dump(dataset_statistics, f_json, indent=2) + overwatch.info(f"Saved dataset statistics file at path {out_path}") + + +def allocate_threads(n: Optional[int], weights: np.ndarray): + """ + Allocates an integer number of threads across datasets based on weights. + + The final array sums to `n`, but each element is no less than 1. If `n` is None, then every dataset is assigned a + value of AUTOTUNE. + """ + if n is None: + return np.array([tf.data.AUTOTUNE] * len(weights)) + + assert np.all(weights >= 0), "Weights must be non-negative" + assert len(weights) <= n, "Number of threads must be at least as large as length of weights" + weights = np.array(weights) / np.sum(weights) + + allocation = np.zeros_like(weights, dtype=int) + while True: + # Give the remaining elements that would get less than 1 a 1 + mask = (weights * n < 1) & (weights > 0) + if not mask.any(): + break + n -= mask.sum() + allocation += mask.astype(int) + + # Recompute the distribution over the remaining elements + weights[mask] = 0 + weights = weights / weights.sum() + + # Allocate the remaining elements + fractional, integral = np.modf(weights * n) + allocation += integral.astype(int) + n -= integral.sum() + for i in np.argsort(fractional)[::-1][: int(n)]: + allocation[i] += 1 + + return allocation diff --git a/prismatic/vla/datasets/rlds/utils/goal_relabeling.py b/prismatic/vla/datasets/rlds/utils/goal_relabeling.py new file mode 100644 index 0000000000000000000000000000000000000000..4864d2b772e53ca75cb03b50efb5921d2deae50c --- /dev/null +++ b/prismatic/vla/datasets/rlds/utils/goal_relabeling.py @@ -0,0 +1,32 @@ +""" +goal_relabeling.py + +Contains simple goal relabeling logic for BC use-cases where rewards and next_observations are not required. +Each function should add entries to the "task" dict. +""" + +from typing import Dict + +import tensorflow as tf + +from prismatic.vla.datasets.rlds.utils.data_utils import tree_merge + + +def uniform(traj: Dict) -> Dict: + """Relabels with a true uniform distribution over future states.""" + traj_len = tf.shape(tf.nest.flatten(traj["observation"])[0])[0] + + # Select a random future index for each transition i in the range [i + 1, traj_len) + rand = tf.random.uniform([traj_len]) + low = tf.cast(tf.range(traj_len) + 1, tf.float32) + high = tf.cast(traj_len, tf.float32) + goal_idxs = tf.cast(rand * (high - low) + low, tf.int32) + + # Sometimes there are floating-point errors that cause an out-of-bounds + goal_idxs = tf.minimum(goal_idxs, traj_len - 1) + + # Adds keys to "task" mirroring "observation" keys (`tree_merge` to combine "pad_mask_dict" properly) + goal = tf.nest.map_structure(lambda x: tf.gather(x, goal_idxs), traj["observation"]) + traj["task"] = tree_merge(traj["task"], goal) + + return traj diff --git a/prismatic/vla/datasets/rlds/utils/task_augmentation.py b/prismatic/vla/datasets/rlds/utils/task_augmentation.py new file mode 100644 index 0000000000000000000000000000000000000000..425b57303a4d06dd60ccdc05b7ef51f328e68b18 --- /dev/null +++ b/prismatic/vla/datasets/rlds/utils/task_augmentation.py @@ -0,0 +1,57 @@ +""" +task_augmentation.py + +Contains basic logic for randomly zeroing out keys in the task specification. +""" + +from typing import Dict + +import tensorflow as tf + +from prismatic.vla.datasets.rlds.utils.data_utils import to_padding + + +def delete_task_conditioning(traj: Dict, keep_image_prob: float) -> Dict: + """ + Randomly drops out either the goal images or the language instruction. Only does something if both of + these are present. + + Args: + traj: A dictionary containing trajectory data. Should have a "task" key. + keep_image_prob: The probability of keeping the goal images. The probability of keeping the language + instruction is 1 - keep_image_prob. + """ + if "language_instruction" not in traj["task"]: + return traj + + image_keys = {key for key in traj["task"].keys() if key.startswith("image_") or key.startswith("depth_")} + if not image_keys: + return traj + + traj_len = tf.shape(traj["action"])[0] + should_keep_images = tf.random.uniform([traj_len]) < keep_image_prob + should_keep_images |= ~traj["task"]["pad_mask_dict"]["language_instruction"] + + for key in image_keys | {"language_instruction"}: + should_keep = should_keep_images if key in image_keys else ~should_keep_images + # pad out the key + traj["task"][key] = tf.where( + should_keep, + traj["task"][key], + to_padding(traj["task"][key]), + ) + # zero out the pad mask dict for the key + traj["task"]["pad_mask_dict"][key] = tf.where( + should_keep, + traj["task"]["pad_mask_dict"][key], + tf.zeros_like(traj["task"]["pad_mask_dict"][key]), + ) + + # when no goal images are present, the goal timestep becomes the final timestep + traj["task"]["timestep"] = tf.where( + should_keep_images, + traj["task"]["timestep"], + traj_len - 1, + ) + + return traj diff --git a/prismatic/vla/materialize.py b/prismatic/vla/materialize.py new file mode 100644 index 0000000000000000000000000000000000000000..1685286da18f57329ba3a9ad052530df7f3b2238 --- /dev/null +++ b/prismatic/vla/materialize.py @@ -0,0 +1,56 @@ +""" +materialize.py + +Factory class for initializing Open-X RLDS-backed datasets, given specified data mixture parameters; provides and +exports individual functions for clear control flow. +""" + +from pathlib import Path +from typing import Tuple, Type + +from torch.utils.data import Dataset +from transformers import PreTrainedTokenizerBase + +from prismatic.models.backbones.llm.prompting import PromptBuilder +from prismatic.models.backbones.vision import ImageTransform +from prismatic.util.data_utils import PaddedCollatorForActionPrediction +from prismatic.vla.action_tokenizer import ActionTokenizer +from prismatic.vla.datasets import EpisodicRLDSDataset, RLDSBatchTransform, RLDSDataset + + +def get_vla_dataset_and_collator( + data_root_dir: Path, + data_mix: str, + image_transform: ImageTransform, + tokenizer: PreTrainedTokenizerBase, + prompt_builder_fn: Type[PromptBuilder], + default_image_resolution: Tuple[int, int, int], + padding_side: str = "right", + predict_stop_token: bool = True, + shuffle_buffer_size: int = 100_000, + train: bool = True, + episodic: bool = False, + image_aug: bool = False, +) -> Tuple[Dataset, ActionTokenizer, PaddedCollatorForActionPrediction]: + """Initialize RLDS Dataset (wraps TFDS), ActionTokenizer, and initialize transform/collation functions.""" + action_tokenizer = ActionTokenizer(tokenizer) + batch_transform = RLDSBatchTransform( + action_tokenizer, tokenizer, image_transform, prompt_builder_fn, predict_stop_token=predict_stop_token + ) + collator = PaddedCollatorForActionPrediction( + tokenizer.model_max_length, tokenizer.pad_token_id, padding_side=padding_side + ) + + # Build RLDS Iterable Dataset + cls = RLDSDataset if not episodic else EpisodicRLDSDataset + dataset = cls( + data_root_dir, + data_mix, + batch_transform, + resize_resolution=default_image_resolution[1:], + shuffle_buffer_size=shuffle_buffer_size, + train=train, + image_aug=image_aug, + ) + + return dataset, action_tokenizer, collator diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000000000000000000000000000000000000..6f027fe3d685aff101f8cf345156bcc177bbc33d --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,102 @@ +[build-system] +requires = ["setuptools"] +build-backend = "setuptools.build_meta" + +[project] +name = "openvla-oft" +authors = [ + {name = "Moo Jin Kim", email="moojink@stanford.edu"}, + {name = "Chelsea Finn", email="cbfinn@cs.stanford.edu"}, + {name = "Percy Liang", email="pliang@cs.stanford.edu"}, +] +description = "Fine-Tuning Vision-Language-Action Models: Optimizing Speed and Success" +version = "0.0.1" +readme = "README.md" +requires-python = ">=3.8" +keywords = ["vision-language-actions models", "fine-tuning", "robot learning"] +license = {file = "LICENSE"} +classifiers = [ + "Development Status :: 3 - Alpha", + "Intended Audience :: Developers", + "Intended Audience :: Education", + "Intended Audience :: Science/Research", + "License :: OSI Approved :: MIT License", + "Operating System :: OS Independent", + "Programming Language :: Python :: 3", + "Programming Language :: Python :: 3.8", + "Programming Language :: Python :: 3.9", + "Programming Language :: Python :: 3.10", + "Programming Language :: Python :: 3 :: Only", + "Topic :: Scientific/Engineering :: Artificial Intelligence", +] +dependencies = [ + "accelerate>=0.25.0", + "draccus==0.8.0", + "einops", + # "flash_attn==2.5.5", # Here for documentation -- install *AFTER* editable install (follow README) + "huggingface_hub", + "json-numpy", + "jsonlines", + "matplotlib", + "peft==0.11.1", + "protobuf", + "rich", + "sentencepiece==0.1.99", + "timm==0.9.10", + "tokenizers==0.19.1", + "torch==2.2.0", + "torchvision==0.17.0", + "torchaudio==2.2.0", + "transformers @ git+https://github.com/moojink/transformers-openvla-oft.git", # IMPORTANT: Use this fork for bidirectional attn (for parallel decoding) + "wandb", + "tensorflow==2.15.0", + "tensorflow_datasets==4.9.3", + "tensorflow_graphics==2021.12.3", + "dlimp @ git+https://github.com/moojink/dlimp_openvla", + "diffusers==0.30.3", + "imageio", + "uvicorn", + "fastapi", + "json-numpy", +] + +[project.optional-dependencies] +dev = [ + "black>=24.2.0", + "gpustat", + "ipython", + "pre-commit", + "ruff>=0.2.2", +] +sagemaker = [ + "boto3", + "sagemaker" +] + +[project.urls] +homepage = "https://github.com/moojink/openvla-oft" +repository = "https://github.com/moojink/openvla-oft" +documentation = "https://github.com/moojink/openvla-oft" + +[tool.setuptools.packages.find] +where = ["."] +exclude = ["cache"] + +[tool.setuptools.package-data] +"prismatic" = ["py.typed"] + +[tool.black] +line-length = 121 +target-version = ["py38", "py39", "py310"] +preview = true + +[tool.ruff] +line-length = 121 +target-version = "py38" + +[tool.ruff.lint] +select = ["A", "B", "E", "F", "I", "RUF", "W"] +ignore = ["F722"] + +[tool.ruff.lint.per-file-ignores] +"__init__.py" = ["E402", "F401"] diff --git a/scripts/extern/convert_prismatic_weights_to_hf.py b/scripts/extern/convert_prismatic_weights_to_hf.py new file mode 100644 index 0000000000000000000000000000000000000000..f7e5b2cf4aa0e13abcb6b2397a8f2f9eed8c183f --- /dev/null +++ b/scripts/extern/convert_prismatic_weights_to_hf.py @@ -0,0 +1,237 @@ +""" +convert_prismatic_weights_to_hf.py + +Utility script for converting full Prismatic VLM weights (from this repository, in the default "Prismatic" format) to +the HuggingFace "AutoClasses" (e.g., those defined in `prismatic.extern.hf_*`) for "native" use in `transformers`` +via `trust_remote_code = True`. + +Theoretically, these changes should be fully compatible with directly merging the models into `transformers` down the +line, with first-class support. +""" + +import json +import os +from dataclasses import dataclass +from pathlib import Path +from typing import Dict, List, Union + +import draccus +import timm +import torch +import torch.nn as nn +from huggingface_hub import hf_hub_download +from timm.models.vision_transformer import LayerScale +from transformers import AutoTokenizer + +from prismatic.extern.hf.configuration_prismatic import PrismaticConfig +from prismatic.extern.hf.modeling_prismatic import PrismaticForConditionalGeneration +from prismatic.extern.hf.processing_prismatic import PrismaticImageProcessor, PrismaticProcessor + + +@dataclass +class HFConvertConfig: + # fmt: off + prismatic_model_path_or_id: Union[str, Path] = ( # Path to Pretrained VLM (on disk or HF Hub) + "siglip-224px+7b" + # "prism-dinosiglip-224px+7b" + ) + output_hf_model_local_path: Path = Path( # Path to Local Path to save HF model + "hf-convert/prismatic-siglip-224px-7b" + ) + output_hf_model_hub_path: str = ( # Path to HF Hub Path for "final" HF model + "TRI-ML/prismatic-siglip-224px-7b" # => huggingface.co/TRI-ML/prismatic-{...} + ) + + # HF Hub Credentials (required for Gated Models like LLaMa-2) + hf_token: Union[str, Path] = Path(".hf_token") # Environment variable or Path to HF Token + + def __post_init__(self) -> None: + self.hf_token = self.hf_token.read_text().strip() if isinstance(self.hf_token, Path) else self.hf_token + + # fmt: on + + +# HF Transformers overwrites parameters with names containing `gamma`; we're going to patch VisionBackbone.LayerScale. +# =>> TIMM :: https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/vision_transformer.py#L109 +# =>> Transformers :: https://github.com/huggingface/transformers/blob/main/src/transformers/modeling_utils.py#L3960 +def _ls_new_forward(self, x: torch.Tensor) -> torch.Tensor: + return x.mul_(self.scale_factor) if self.inplace else x * self.scale_factor + + +def ls_apply_patch(ls_module: LayerScale): + ls_module.scale_factor = nn.Parameter(ls_module.gamma.clone()) + ls_module.forward = _ls_new_forward.__get__(ls_module, LayerScale) + del ls_module.gamma + + +# === Conversion Constants === +PROJECTOR_KEY_MAPPING = { + "projector.0.weight": "projector.fc1.weight", + "projector.0.bias": "projector.fc1.bias", + "projector.2.weight": "projector.fc2.weight", + "projector.2.bias": "projector.fc2.bias", + "projector.4.weight": "projector.fc3.weight", + "projector.4.bias": "projector.fc3.bias", +} + + +def remap_state_dicts_for_hf( + projector_state_dict: Dict[str, torch.Tensor], + llm_backbone_state_dict: Dict[str, torch.Tensor], + vision_backbone_state_dicts: List[Dict[str, torch.Tensor]], +) -> Dict[str, torch.Tensor]: + """Iterate through Prismatic component state dictionaries and unify / fix key mapping for HF conversion.""" + hf_state_dict = {} + + # Iterate through Projector =>> use `PROJECTOR_KEY_MAPPING` + for key, value in projector_state_dict.items(): + hf_state_dict[PROJECTOR_KEY_MAPPING[key]] = value + + # Iterate through LLM Backbone =>> replace `llm.` with `language_model.` + for key, value in llm_backbone_state_dict.items(): + hf_state_dict[key.replace("llm.", "language_model.")] = value + + # Iterate through Vision Backbone =>> add "vision_backbone." prefix + assert len(vision_backbone_state_dicts) <= 2, "Prismatic models only support up to 2 (fused) vision backbones!" + for idx, vision_backbone_state_dict in enumerate(vision_backbone_state_dicts): + prefix = "vision_backbone.featurizer" if idx == 0 else "vision_backbone.fused_featurizer" + for key, value in vision_backbone_state_dict.items(): + hf_state_dict[f"{prefix}.{key}"] = value + + return hf_state_dict + + +@draccus.wrap() +def convert_prismatic_weights_to_hf(cfg: HFConvertConfig) -> None: + print(f"[*] Converting Prismatic Model `{cfg.prismatic_model_path_or_id}` to HF Transformers Format") + torch.set_default_dtype(torch.bfloat16) + + # Get `config.json` and `checkpoint_pt` -- mirrors logic in `prismatic.models.load.py` + if os.path.isdir(cfg.prismatic_model_path_or_id): + print(f"[*] Loading from Local Path `{(run_dir := Path(cfg.prismatic_model_path_or_id))}`") + config_json, checkpoint_pt = run_dir / "config.json", run_dir / "checkpoints" / "latest-checkpoint.pt" + + assert config_json.exists(), f"Missing `config.json` for `{run_dir = }`" + assert checkpoint_pt.exists(), f"Missing checkpoint for `{run_dir = }`" + else: + print(f"[*] Downloading Prismatic Checkpoint from HF Hub :: `TRI-ML/{cfg.prismatic_model_path_or_id}`") + config_json = hf_hub_download("TRI-ML/prismatic-vlms", f"{cfg.prismatic_model_path_or_id}/config.json") + checkpoint_pt = hf_hub_download( + "TRI-ML/prismatic-vlms", f"{cfg.prismatic_model_path_or_id}/checkpoints/latest-checkpoint.pt" + ) + + # Load "Native" Config JSON =>> Create LLM Config & Instantiate Tokenizer + with open(config_json, "r") as f: + prismatic_config = json.load(f)["model"] + + # Create HF PrismaticConfig (`transformers.PretrainedConfig`) + hf_config = PrismaticConfig( + vision_backbone_id=prismatic_config["vision_backbone_id"], + llm_backbone_id=prismatic_config["llm_backbone_id"], + arch_specifier=prismatic_config["arch_specifier"], + image_resize_strategy=prismatic_config["image_resize_strategy"], + llm_max_length=prismatic_config["llm_max_length"], + torch_dtype=torch.bfloat16, + ) + + # Instantiate & Add Pad to Tokenizer =>> following `prismatic.models.materialize.get_llm_backbone_and_tokenizer` + # TODO (siddk) :: Implement batched generation -- in which case this should set `padding_side = "left"`! + print("[*] Instantiating and Patching Tokenizer, LLM Config") + tokenizer = AutoTokenizer.from_pretrained( + hf_config.hf_llm_id, model_max_length=hf_config.llm_max_length, token=cfg.hf_token, padding_side="right" + ) + tokenizer.add_special_tokens({"pad_token": ""}) + tokenizer.init_kwargs.pop("add_prefix_space", None) # Pop to prevent unnecessary warning on reload... + assert tokenizer.pad_token_id == hf_config.pad_token_id, "Incorrect Pad Token ID!" + assert len(tokenizer) > hf_config.text_config.vocab_size, "Tokenizer vocabulary must be larger than LLM vocabulary!" + + # Patch LLM Config in `hf_config` with vocab_size (+ `hf_config.pad_to_multiple_of`), pad_token_id + validate + hf_config.text_config.vocab_size += hf_config.pad_to_multiple_of + hf_config.text_config.pad_token_id = hf_config.pad_token_id + hf_config.text_config.torch_dtype = torch.bfloat16 + assert hf_config.text_config.use_cache, "LLM config `use_cache` should be True for inference (set default)!" + + # Create Vision Backbone & Transform =>> following `prismatic.models.materialize.get_vision_backbone_and_transform` + # =>> Deviates a bit from existing code; as such, explicitly tested in `tests/test_image_transforms.py` + print("[*] Loading TIMM Vision Backbone(s) and Image Transform(s) =>> Initializing PrismaticImageProcessor") + timm_vision_backbones, input_sizes, interpolations, means, stds = [], [], [], [], [] + for idx, timm_model_id in enumerate(hf_config.timm_model_ids): + timm_vision_backbone = timm.create_model( + timm_model_id, + pretrained=True, + num_classes=0, + img_size=hf_config.image_sizes[idx], + act_layer=hf_config.timm_override_act_layers[idx], + ) + timm_vision_backbones.append(timm_vision_backbone) + + # Get Per-Backbone Image Processing + data_cfg = timm.data.resolve_model_data_config(timm_vision_backbone) + input_sizes.append((3, hf_config.image_sizes[idx], hf_config.image_sizes[idx])) + interpolations.append(data_cfg["interpolation"]) + means.append(data_cfg["mean"]) + stds.append(data_cfg["std"]) + + # Patch `LayerScale` because of HF annoying `fix_key` overwrite... + for module in timm_vision_backbone.modules(): + if isinstance(module, LayerScale): + ls_apply_patch(module) + + # Create PrismaticImageProcessor (`transformers.ImageProcessingMixin`) + hf_image_processor = PrismaticImageProcessor( + use_fused_vision_backbone=hf_config.use_fused_vision_backbone, + image_resize_strategy=hf_config.image_resize_strategy, + input_sizes=input_sizes, + interpolations=interpolations, + means=means, + stds=stds, + ) + + # Create top-level PrismaticProcessor (`transformers.ProcessorMixin` =>> enables registry w/ AutoProcessor) + print("[*] Creating PrismaticProcessor Instance from Tokenizer and PrismaticImageProcessor") + hf_processor = PrismaticProcessor(image_processor=hf_image_processor, tokenizer=tokenizer) + + # Load Prismatic Model State Dictionary (in preparation for conversion) + print("[*] Loading Prismatic VLM State Dictionary from Checkpoint") + model_state_dict = torch.load(checkpoint_pt, map_location="cpu")["model"] + assert ("downsampler" not in model_state_dict) or (len(model_state_dict["downsampler"]) == 0), "Downsampler?" + assert ("projector" in model_state_dict) and ("llm_backbone" in model_state_dict), "Missing keys!" + + # Convert + print("[*] Running Conversion") + converted_state_dict = remap_state_dicts_for_hf( + model_state_dict["projector"], + model_state_dict["llm_backbone"], + vision_backbone_state_dicts=[vb.state_dict() for vb in timm_vision_backbones], + ) + + # Create PrismaticForConditionalGeneration =>> Note that we can't initialize on `meta` device because TIMM + print("[*] Building (Randomly Initialized) Model =>> PrismaticForConditionalGeneration") + hf_model = PrismaticForConditionalGeneration(hf_config) + hf_model.load_state_dict(converted_state_dict, strict=True, assign=True) + + # Cast Model to BF16 before Saving + hf_model.to(torch.bfloat16) + + # Save Pretrained Versions to Local Path + print("[*] Saving Model & Processor to Local Path") + hf_model.save_pretrained(cfg.output_hf_model_local_path, max_shard_size="7GB") + hf_image_processor.save_pretrained(cfg.output_hf_model_local_path) + hf_processor.save_pretrained(cfg.output_hf_model_local_path) + + # Register AutoClasses + PrismaticConfig.register_for_auto_class() + PrismaticImageProcessor.register_for_auto_class("AutoImageProcessor") + PrismaticProcessor.register_for_auto_class("AutoProcessor") + PrismaticForConditionalGeneration.register_for_auto_class("AutoModelForVision2Seq") + + # Push to Hub + print("[*] Pushing Model & Processor to HF Hub") + hf_config.push_to_hub(cfg.output_hf_model_hub_path) + hf_model.push_to_hub(cfg.output_hf_model_hub_path, max_shard_size="7GB") + hf_image_processor.push_to_hub(cfg.output_hf_model_hub_path) + hf_processor.push_to_hub(cfg.output_hf_model_hub_path) + + +if __name__ == "__main__": + convert_prismatic_weights_to_hf() diff --git a/scripts/extern/verify_prismatic.py b/scripts/extern/verify_prismatic.py new file mode 100644 index 0000000000000000000000000000000000000000..742b873d3ef7489d1acf6c11705fa1ef84a51c3f --- /dev/null +++ b/scripts/extern/verify_prismatic.py @@ -0,0 +1,134 @@ +""" +verify_prismatic.py + +Given an HF-exported Prismatic model, attempt to load via AutoClasses, and verify forward() and generate(). +""" + +import time + +import requests +import torch +from PIL import Image +from transformers import AutoModelForVision2Seq, AutoProcessor + +# === Verification Arguments === +MODEL_PATH = "TRI-ML/prismatic-siglip-224px-7b" +DEFAULT_IMAGE_URL = ( + "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/beignets-task-guide.png" +) + +if "-prism-" in MODEL_PATH: + SAMPLE_PROMPTS_FOR_GENERATION = [ + "In: What is sitting in the coffee?\nOut:", + "In: What's the name of the food on the plate?\nOut:", + "In: caption.\nOut:", + "In: how many beinets..?\nOut:", + "In: Can you give me a lyrical description of the scene\nOut:", + ] +else: + SYSTEM_PROMPT = ( + "A chat between a curious user and an artificial intelligence assistant. " + "The assistant gives helpful, detailed, and polite answers to the user's questions." + ) + SAMPLE_PROMPTS_FOR_GENERATION = [ + f"{SYSTEM_PROMPT} USER: What is sitting in the coffee? ASSISTANT:", + f"{SYSTEM_PROMPT} USER: What's the name of the food on the plate? ASSISTANT:", + f"{SYSTEM_PROMPT} USER: caption. ASSISTANT:", + f"{SYSTEM_PROMPT} USER: how many beinets..? ASSISTANT:", + f"{SYSTEM_PROMPT} USER: Can you give me a lyrical description of the scene ASSISTANT:", + ] + + +@torch.inference_mode() +def verify_prismatic() -> None: + print(f"[*] Verifying PrismaticForConditionalGeneration using Model `{MODEL_PATH}`") + device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") + + # Load Processor & VLM + print("[*] Instantiating Processor and Pretrained VLM") + processor = AutoProcessor.from_pretrained(MODEL_PATH, trust_remote_code=True) + + # === AUTOCAST MODE === + # print("[*] Loading in BF16 Autocast Mode") + # vlm = AutoModelForVision2Seq.from_pretrained(MODEL_PATH, low_cpu_mem_usage=True, trust_remote_code=True).to( + # device, dtype=torch.bfloat16 + # ) + + # === NATIVE BFLOAT16 MODE === + # print("[*] Loading in BF16") + # vlm = AutoModelForVision2Seq.from_pretrained( + # MODEL_PATH, torch_dtype=torch.bfloat16, low_cpu_mem_usage=True, trust_remote_code=True + # ).to(device) + + # === BFLOAT16 + FLASH-ATTN MODE :: [~14GB of VRAM Passive || 18GB of VRAM Active] === + print("[*] Loading in BF16 with Flash-Attention Enabled") + vlm = AutoModelForVision2Seq.from_pretrained( + MODEL_PATH, + attn_implementation="flash_attention_2", + torch_dtype=torch.bfloat16, + low_cpu_mem_usage=True, + trust_remote_code=True, + ).to(device) + + # === 8-BIT QUANTIZATION MODE (`pip install bitsandbytes`) :: [~9GB of VRAM Passive || 10GB of VRAM Active] === + # print("[*] Loading in 8-Bit Quantization Mode") + # vlm = AutoModelForVision2Seq.from_pretrained( + # MODEL_PATH, + # attn_implementation="flash_attention_2", + # torch_dtype=torch.float16, + # quantization_config=BitsAndBytesConfig(load_in_8bit=True), + # low_cpu_mem_usage=True, + # trust_remote_code=True, + # ) + + # === 4-BIT QUANTIZATION MODE (`pip install bitsandbytes`) :: [~6GB of VRAM Passive || 7GB of VRAM Active] === + # print("[*] Loading in 4-Bit Quantization Mode") + # vlm = AutoModelForVision2Seq.from_pretrained( + # MODEL_PATH, + # attn_implementation="flash_attention_2", + # torch_dtype=torch.float16, + # quantization_config=BitsAndBytesConfig(load_in_4bit=True), + # low_cpu_mem_usage=True, + # trust_remote_code=True, + # ) + + # Iterate over Sample Prompts =>> Generate + image = Image.open(requests.get(DEFAULT_IMAGE_URL, stream=True).raw).convert("RGB") + num_tokens, total_time = 0, 0.0 + + print("[*] Iterating over Sample Prompts\n===\n") + for idx, prompt in enumerate(SAMPLE_PROMPTS_FOR_GENERATION): + # === AUTOCAST MODE (Reproduces Prismatic `scripts/generate.py`) === + # inputs = processor(prompt, image).to(device) + # + # # Using "autocast" to evaluate bit-wise equivalence to `scripts/generate.py` + # # =>> Running in native BF16 is also fine (but leads to slightly different generations) + # with torch.autocast("cuda", dtype=torch.bfloat16, enabled=True): + # gen_ids = vlm.generate(**inputs, do_sample=False, min_length=1, max_length=512) + + # === BFLOAT16 MODE === + inputs = processor(prompt, image).to(device, dtype=torch.bfloat16) + + # === 8-BIT/4-BIT QUANTIZATION MODE === + # inputs = processor(prompt, image).to(device, dtype=torch.float16) + + # Run Inference + gen_ids = None + for _ in range(5): + start_time = time.time() + gen_ids = vlm.generate(**inputs, do_sample=False, min_length=1, max_length=512) + total_time += time.time() - start_time + + gen_ids = gen_ids[0, inputs.input_ids.shape[1] :] + num_tokens += len(gen_ids) + + # === + gen_text = processor.decode(gen_ids, skip_special_tokens=True).strip() + print(f"[{idx + 1}] Input Prompt => {prompt}\n Generated => {gen_text}\n") + + # Compute Tokens / Second + print(f"[*] Generated Tokens per Second = {num_tokens / total_time} w/ {num_tokens = } and {total_time = }") + + +if __name__ == "__main__": + verify_prismatic() diff --git a/vla-scripts/deploy.py b/vla-scripts/deploy.py new file mode 100644 index 0000000000000000000000000000000000000000..3c15d9536e8d7001ff4e660ca90dae1360215e2e --- /dev/null +++ b/vla-scripts/deploy.py @@ -0,0 +1,156 @@ +""" +deploy.py + +Starts VLA server which the client can query to get robot actions. +""" + +import os.path + +# ruff: noqa: E402 +import json_numpy + +json_numpy.patch() +import json +import logging +import numpy as np +import traceback +from dataclasses import dataclass +from pathlib import Path +from typing import Any, Dict, Optional, Union + +import draccus +import torch +import uvicorn +from fastapi import FastAPI +from fastapi.responses import JSONResponse +from PIL import Image +from transformers import AutoModelForVision2Seq, AutoProcessor + +from experiments.robot.openvla_utils import ( + get_vla, + get_vla_action, + get_action_head, + get_processor, + get_proprio_projector, +) +from experiments.robot.robot_utils import ( + get_image_resize_size, +) +from prismatic.vla.constants import ACTION_DIM, ACTION_TOKEN_BEGIN_IDX, IGNORE_INDEX, NUM_ACTIONS_CHUNK, PROPRIO_DIM, STOP_INDEX + + +def get_openvla_prompt(instruction: str, openvla_path: Union[str, Path]) -> str: + return f"In: What action should the robot take to {instruction.lower()}?\nOut:" + + +# === Server Interface === +class OpenVLAServer: + def __init__(self, cfg) -> Path: + """ + A simple server for OpenVLA models; exposes `/act` to predict an action for a given observation + instruction. + """ + self.cfg = cfg + + # Load model + self.vla = get_vla(cfg) + + # Load proprio projector + self.proprio_projector = None + if cfg.use_proprio: + self.proprio_projector = get_proprio_projector(cfg, self.vla.llm_dim, PROPRIO_DIM) + + # Load continuous action head + self.action_head = None + if cfg.use_l1_regression or cfg.use_diffusion: + self.action_head = get_action_head(cfg, self.vla.llm_dim) + + # Check that the model contains the action un-normalization key + assert cfg.unnorm_key in self.vla.norm_stats, f"Action un-norm key {cfg.unnorm_key} not found in VLA `norm_stats`!" + + # Get Hugging Face processor + self.processor = None + self.processor = get_processor(cfg) + + # Get expected image dimensions + self.resize_size = get_image_resize_size(cfg) + + + def get_server_action(self, payload: Dict[str, Any]) -> str: + try: + if double_encode := "encoded" in payload: + # Support cases where `json_numpy` is hard to install, and numpy arrays are "double-encoded" as strings + assert len(payload.keys()) == 1, "Only uses encoded payload!" + payload = json.loads(payload["encoded"]) + + observation = payload + instruction = observation["instruction"] + + action = get_vla_action( + self.cfg, self.vla, self.processor, observation, instruction, action_head=self.action_head, proprio_projector=self.proprio_projector, use_film=self.cfg.use_film, + ) + + if double_encode: + return JSONResponse(json_numpy.dumps(action)) + else: + return JSONResponse(action) + except: # noqa: E722 + logging.error(traceback.format_exc()) + logging.warning( + "Your request threw an error; make sure your request complies with the expected format:\n" + "{'observation': dict, 'instruction': str}\n" + ) + return "error" + + def run(self, host: str = "0.0.0.0", port: int = 8777) -> None: + self.app = FastAPI() + self.app.post("/act")(self.get_server_action) + uvicorn.run(self.app, host=host, port=port) + + +@dataclass +class DeployConfig: + # fmt: off + + # Server Configuration + host: str = "0.0.0.0" # Host IP Address + port: int = 8777 # Host Port + + ################################################################################################################# + # Model-specific parameters + ################################################################################################################# + model_family: str = "openvla" # Model family + pretrained_checkpoint: Union[str, Path] = "" # Pretrained checkpoint path + + use_l1_regression: bool = True # If True, uses continuous action head with L1 regression objective + use_diffusion: bool = False # If True, uses continuous action head with diffusion modeling objective (DDIM) + num_diffusion_steps_train: int = 50 # (When `diffusion==True`) Number of diffusion steps used for training + num_diffusion_steps_inference: int = 50 # (When `diffusion==True`) Number of diffusion steps used for inference + use_film: bool = False # If True, uses FiLM to infuse language inputs into visual features + num_images_in_input: int = 3 # Number of images in the VLA input (default: 3) + use_proprio: bool = True # Whether to include proprio state in input + + center_crop: bool = True # Center crop? (if trained w/ random crop image aug) + + lora_rank: int = 32 # Rank of LoRA weight matrix (MAKE SURE THIS MATCHES TRAINING!) + + unnorm_key: Union[str, Path] = "" # Action un-normalization key + use_relative_actions: bool = False # Whether to use relative actions (delta joint angles) + + load_in_8bit: bool = False # (For OpenVLA only) Load with 8-bit quantization + load_in_4bit: bool = False # (For OpenVLA only) Load with 4-bit quantization + + ################################################################################################################# + # Utils + ################################################################################################################# + seed: int = 7 # Random Seed (for reproducibility) + # fmt: on + + +@draccus.wrap() +def deploy(cfg: DeployConfig) -> None: + server = OpenVLAServer(cfg) + server.run(cfg.host, port=cfg.port) + + +if __name__ == "__main__": + deploy() diff --git a/vla-scripts/extern/convert_openvla_weights_to_hf.py b/vla-scripts/extern/convert_openvla_weights_to_hf.py new file mode 100644 index 0000000000000000000000000000000000000000..5774954898d2e3c40b0417e477aca582f6471c3b --- /dev/null +++ b/vla-scripts/extern/convert_openvla_weights_to_hf.py @@ -0,0 +1,272 @@ +""" +convert_openvla_weights_to_hf.py + +Utility script for converting full OpenVLA VLA weights (from this repository, in the default "Prismatic" format) to +the HuggingFace "AutoClasses" (e.g., those defined in `prismatic.extern.hf_*`) for "native" use in `transformers`` +via `trust_remote_code = True`. + +Theoretically, these changes should be fully compatible with directly merging the models into `transformers` down the +line, with first-class support. + +Usage: + python vla-scripts/extern/convert_openvla_weights_to_hf.py \ + --openvla_model_path_or_id \ + --output_hf_model_local_path +""" + +import json +import os +import shutil +from dataclasses import dataclass +from pathlib import Path +from typing import Dict, Union + +import draccus +import timm +import torch +import torch.nn as nn +from huggingface_hub import hf_hub_download +from timm.models.vision_transformer import LayerScale +from transformers import AutoTokenizer + +from prismatic.conf import ModelConfig +from prismatic.extern.hf.configuration_prismatic import OpenVLAConfig +from prismatic.extern.hf.modeling_prismatic import OpenVLAForActionPrediction +from prismatic.extern.hf.processing_prismatic import PrismaticImageProcessor, PrismaticProcessor + + +@dataclass +class HFConvertConfig: + # fmt: off + openvla_model_path_or_id: Union[str, Path] = ( # Path to Pretrained VLA (on disk or HF Hub) + "runs/prism-dinosiglip-224px+mx-oxe-magic-soup-plus+n8+b32+x7" + ) + output_hf_model_local_path: Path = Path( # Path to Local Path to save HF model + "hf-convert/openvla-7b" + ) + output_hf_model_hub_path: str = "openvla/openvla-7b" # (Optional) Path to HF Hub Path to push + # model to + + # HF Hub Credentials (required for Gated Models like LLaMa-2) + hf_token: Union[str, Path] = Path(".hf_token") # Environment variable or Path to HF Token + + def __post_init__(self) -> None: + self.hf_token = self.hf_token.read_text().strip() if isinstance(self.hf_token, Path) else self.hf_token + + # fmt: on + + +# HF Transformers overwrites parameters with names containing `gamma`; we're going to patch VisionBackbone.LayerScale. +# =>> TIMM :: https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/vision_transformer.py#L109 +# =>> Transformers :: https://github.com/huggingface/transformers/blob/main/src/transformers/modeling_utils.py#L3960 +def _ls_new_forward(self, x: torch.Tensor) -> torch.Tensor: + return x.mul_(self.scale_factor) if self.inplace else x * self.scale_factor + + +def ls_apply_patch(ls_module: LayerScale): + ls_module.scale_factor = nn.Parameter(ls_module.gamma.clone()) + ls_module.forward = _ls_new_forward.__get__(ls_module, LayerScale) + del ls_module.gamma + + +# === Conversion Constants === +PROJECTOR_KEY_MAPPING = { + "projector.0.weight": "projector.fc1.weight", + "projector.0.bias": "projector.fc1.bias", + "projector.2.weight": "projector.fc2.weight", + "projector.2.bias": "projector.fc2.bias", + "projector.4.weight": "projector.fc3.weight", + "projector.4.bias": "projector.fc3.bias", +} + + +def remap_state_dicts_for_hf( + prismatic_vision_backbone_state_dict: Dict[str, torch.Tensor], + projector_state_dict: Dict[str, torch.Tensor], + llm_backbone_state_dict: Dict[str, torch.Tensor], + use_fused_vision_backbone: bool = False, +) -> Dict[str, torch.Tensor]: + """Iterate through Prismatic component state dictionaries and unify / fix key mapping for HF conversion.""" + hf_state_dict = {} + + # Iterate through Projector =>> use `PROJECTOR_KEY_MAPPING` + for key, value in projector_state_dict.items(): + hf_state_dict[PROJECTOR_KEY_MAPPING[key]] = value + + # Iterate through LLM Backbone =>> replace `llm.` with `language_model.` + for key, value in llm_backbone_state_dict.items(): + hf_state_dict[key.replace("llm.", "language_model.")] = value + + # Iterate through Vision Backbone =>> add "vision_backbone." prefix + if not use_fused_vision_backbone: + for key, value in prismatic_vision_backbone_state_dict.items(): + hf_state_dict[key.replace("featurizer.", "vision_backbone.featurizer.")] = value + else: + # Note =>> Assumes that backbones are always DINO + SigLIP... + for key, value in prismatic_vision_backbone_state_dict.items(): + if key.startswith("dino_featurizer"): + if key.endswith(".gamma"): + # Handle `LayerScale gamma` =>> DINOv2 only! + key = key.replace(".gamma", ".scale_factor") + hf_state_dict[key.replace("dino_featurizer.", "vision_backbone.featurizer.")] = value + elif key.startswith("siglip_featurizer"): + hf_state_dict[key.replace("siglip_featurizer.", "vision_backbone.fused_featurizer.")] = value + + return hf_state_dict + + +@draccus.wrap() +def convert_openvla_weights_to_hf(cfg: HFConvertConfig) -> None: + print(f"[*] Converting OpenVLA Model `{cfg.openvla_model_path_or_id}` to HF Transformers Format") + torch.set_default_dtype(torch.bfloat16) + + # Get `config.json`, 'dataset_statistics.json' and `checkpoint_pt` -- mirrors logic in `prismatic.models.load.py` + if os.path.isdir(cfg.openvla_model_path_or_id): + print(f"[*] Loading from Local Path `{(run_dir := Path(cfg.openvla_model_path_or_id))}`") + config_json, checkpoint_pt = run_dir / "config.json", run_dir / "checkpoints" / "latest-checkpoint.pt" + dataset_statistics_json = run_dir / "dataset_statistics.json" + + assert config_json.exists(), f"Missing `config.json` for `{run_dir = }`" + assert checkpoint_pt.exists(), f"Missing checkpoint for `{run_dir = }`" + assert dataset_statistics_json.exists(), f"Missing `dataset_statistics.json` for `{run_dir = }`" + else: + print(f"[*] Downloading Prismatic Checkpoint from HF Hub :: `TRI-ML/{cfg.openvla_model_path_or_id}`") + config_json = hf_hub_download("openvla/openvla-dev", f"{cfg.openvla_model_path_or_id}/config.json") + checkpoint_pt = hf_hub_download( + "openvla/openvla-dev", f"{cfg.openvla_model_path_or_id}/checkpoints/latest-checkpoint.pt" + ) + dataset_statistics_json = hf_hub_download( + "openvla/openvla-dev", f"{cfg.openvla_model_path_or_id}/dataset_statistics.json" + ) + + # Load "Native" Config JSON =>> Create LLM Config & Instantiate Tokenizer + with open(config_json, "r") as f: + vla_cfg = json.load(f)["vla"] + prismatic_config = ModelConfig.get_choice_class(vla_cfg["base_vlm"])().__dict__ + + # Load Normalization Statistics + with open(dataset_statistics_json, "r") as f: + norm_stats = json.load(f) + + # Create HF OpenVLAConfig (`transformers.PretrainedConfig`) + hf_config = OpenVLAConfig( + vision_backbone_id=prismatic_config["vision_backbone_id"], + llm_backbone_id=prismatic_config["llm_backbone_id"], + arch_specifier=prismatic_config["arch_specifier"], + image_resize_strategy=prismatic_config["image_resize_strategy"], + llm_max_length=prismatic_config["llm_max_length"], + torch_dtype=torch.bfloat16, + norm_stats=norm_stats, + ) + + # Instantiate & Add Pad to Tokenizer =>> following `prismatic.models.materialize.get_llm_backbone_and_tokenizer` + # TODO (siddk) :: Implement batched generation -- in which case this should set `padding_side = "left"`! + print("[*] Instantiating and Patching Tokenizer, LLM Config") + tokenizer = AutoTokenizer.from_pretrained( + hf_config.hf_llm_id, model_max_length=hf_config.llm_max_length, token=cfg.hf_token, padding_side="right" + ) + tokenizer.add_special_tokens({"pad_token": ""}) + tokenizer.init_kwargs.pop("add_prefix_space", None) # Pop to prevent unnecessary warning on reload... + assert tokenizer.pad_token_id == hf_config.pad_token_id, "Incorrect Pad Token ID!" + assert len(tokenizer) > hf_config.text_config.vocab_size, "Tokenizer vocabulary must be larger than LLM vocabulary!" + + # Patch LLM Config in `hf_config` with vocab_size (+ `hf_config.pad_to_multiple_of`), pad_token_id + validate + hf_config.text_config.vocab_size += hf_config.pad_to_multiple_of + hf_config.text_config.pad_token_id = hf_config.pad_token_id + hf_config.text_config.torch_dtype = torch.bfloat16 + assert hf_config.text_config.use_cache, "LLM config `use_cache` should be True for inference (set default)!" + + # Create Vision Backbone & Transform =>> following `prismatic.models.materialize.get_vision_backbone_and_transform` + # =>> Deviates a bit from existing code; as such, explicitly tested in `tests/test_image_transforms.py` + print("[*] Loading TIMM Vision Backbone(s) and Image Transform(s) =>> Initializing PrismaticImageProcessor") + input_sizes, interpolations, means, stds = [], [], [], [] + for idx, timm_model_id in enumerate(hf_config.timm_model_ids): + timm_vision_backbone = timm.create_model( + timm_model_id, + pretrained=True, + num_classes=0, + img_size=hf_config.image_sizes[idx], + act_layer=hf_config.timm_override_act_layers[idx], + ) + + # Get Per-Backbone Image Processing + data_cfg = timm.data.resolve_model_data_config(timm_vision_backbone) + input_sizes.append((3, hf_config.image_sizes[idx], hf_config.image_sizes[idx])) + interpolations.append(data_cfg["interpolation"]) + means.append(data_cfg["mean"]) + stds.append(data_cfg["std"]) + + # Patch `LayerScale` because of HF annoying `fix_key` overwrite... + for module in timm_vision_backbone.modules(): + if isinstance(module, LayerScale): + ls_apply_patch(module) + + # Create PrismaticImageProcessor (`transformers.ImageProcessingMixin`) + hf_image_processor = PrismaticImageProcessor( + use_fused_vision_backbone=hf_config.use_fused_vision_backbone, + image_resize_strategy=hf_config.image_resize_strategy, + input_sizes=input_sizes, + interpolations=interpolations, + means=means, + stds=stds, + ) + + # Create top-level PrismaticProcessor (`transformers.ProcessorMixin` =>> enables registry w/ AutoProcessor) + print("[*] Creating PrismaticProcessor Instance from Tokenizer and PrismaticImageProcessor") + hf_processor = PrismaticProcessor(image_processor=hf_image_processor, tokenizer=tokenizer) + + # Load Prismatic Model State Dictionary (in preparation for conversion) + print("[*] Loading Prismatic VLM State Dictionary from Checkpoint") + model_state_dict = torch.load(checkpoint_pt, map_location="cpu")["model"] + assert ("downsampler" not in model_state_dict) or (len(model_state_dict["downsampler"]) == 0), "Downsampler?" + assert all([k in model_state_dict for k in ["vision_backbone", "projector", "llm_backbone"]]), "Missing keys!" + + # Convert + print("[*] Running Conversion") + converted_state_dict = remap_state_dicts_for_hf( + model_state_dict["vision_backbone"], + model_state_dict["projector"], + model_state_dict["llm_backbone"], + use_fused_vision_backbone=hf_config.use_fused_vision_backbone, + ) + + # Create PrismaticForConditionalGeneration =>> Note that we can't initialize on `meta` device because TIMM + print("[*] Building (Randomly Initialized) Model =>> OpenVLAForActionPrediction") + hf_model = OpenVLAForActionPrediction(hf_config) + hf_model.load_state_dict(converted_state_dict, strict=True, assign=True) + + # Cast Model to BF16 before Saving + hf_model.to(torch.bfloat16) + + # Save Pretrained Versions to Local Path + print("[*] Saving Model & Processor to Local Path") + hf_model.save_pretrained(cfg.output_hf_model_local_path, max_shard_size="7GB") + hf_image_processor.save_pretrained(cfg.output_hf_model_local_path) + hf_processor.save_pretrained(cfg.output_hf_model_local_path) + + # Copy `dataset_statistics.json` File to Converted Checkpoint Directory + output_dataset_statistics_json = cfg.output_hf_model_local_path / "dataset_statistics.json" + shutil.copyfile(dataset_statistics_json, output_dataset_statistics_json) + + print(f"[*] Saving Complete! Saved converted checkpoint to: {cfg.output_hf_model_local_path}") + + ##################################################################################### + # Optional: Push Model to Hugging Face Hub + ##################################################################################### + + # # Register AutoClasses + # OpenVLAConfig.register_for_auto_class() + # PrismaticImageProcessor.register_for_auto_class("AutoImageProcessor") + # PrismaticProcessor.register_for_auto_class("AutoProcessor") + # OpenVLAForActionPrediction.register_for_auto_class("AutoModelForVision2Seq") + + # # Push to HF Hub + # print("[*] Pushing Model & Processor to HF Hub") + # hf_config.push_to_hub(cfg.output_hf_model_hub_path) + # hf_model.push_to_hub(cfg.output_hf_model_hub_path, max_shard_size="7GB") + # hf_image_processor.push_to_hub(cfg.output_hf_model_hub_path) + # hf_processor.push_to_hub(cfg.output_hf_model_hub_path) + + +if __name__ == "__main__": + convert_openvla_weights_to_hf() diff --git a/vla-scripts/extern/verify_openvla.py b/vla-scripts/extern/verify_openvla.py new file mode 100644 index 0000000000000000000000000000000000000000..d76fafb8592871e201d7c9c536b573561ee86919 --- /dev/null +++ b/vla-scripts/extern/verify_openvla.py @@ -0,0 +1,89 @@ +""" +verify_openvla.py + +Given an HF-exported OpenVLA model, attempt to load via AutoClasses, and verify forward() and predict_action(). +""" + +import time + +import numpy as np +import torch +from PIL import Image +from transformers import AutoModelForVision2Seq, AutoProcessor + +# === Verification Arguments +MODEL_PATH = "openvla/openvla-7b" +SYSTEM_PROMPT = ( + "A chat between a curious user and an artificial intelligence assistant. " + "The assistant gives helpful, detailed, and polite answers to the user's questions." +) +INSTRUCTION = "put spoon on towel" + + +def get_openvla_prompt(instruction: str) -> str: + if "v01" in MODEL_PATH: + return f"{SYSTEM_PROMPT} USER: What action should the robot take to {instruction.lower()}? ASSISTANT:" + else: + return f"In: What action should the robot take to {instruction.lower()}?\nOut:" + + +@torch.inference_mode() +def verify_openvla() -> None: + print(f"[*] Verifying OpenVLAForActionPrediction using Model `{MODEL_PATH}`") + device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") + + # Load Processor & VLA + print("[*] Instantiating Processor and Pretrained OpenVLA") + processor = AutoProcessor.from_pretrained(MODEL_PATH, trust_remote_code=True) + + # === BFLOAT16 + FLASH-ATTN MODE === + print("[*] Loading in BF16 with Flash-Attention Enabled") + vla = AutoModelForVision2Seq.from_pretrained( + MODEL_PATH, + attn_implementation="flash_attention_2", + torch_dtype=torch.bfloat16, + low_cpu_mem_usage=True, + trust_remote_code=True, + ).to(device) + + # === 8-BIT QUANTIZATION MODE (`pip install bitsandbytes`) :: [~9GB of VRAM Passive || 10GB of VRAM Active] === + # print("[*] Loading in 8-Bit Quantization Mode") + # vla = AutoModelForVision2Seq.from_pretrained( + # MODEL_PATH, + # attn_implementation="flash_attention_2", + # torch_dtype=torch.float16, + # quantization_config=BitsAndBytesConfig(load_in_8bit=True), + # low_cpu_mem_usage=True, + # trust_remote_code=True, + # ) + + # === 4-BIT QUANTIZATION MODE (`pip install bitsandbytes`) :: [~6GB of VRAM Passive || 7GB of VRAM Active] === + # print("[*] Loading in 4-Bit Quantization Mode") + # vla = AutoModelForVision2Seq.from_pretrained( + # MODEL_PATH, + # attn_implementation="flash_attention_2", + # torch_dtype=torch.float16, + # quantization_config=BitsAndBytesConfig(load_in_4bit=True), + # low_cpu_mem_usage=True, + # trust_remote_code=True, + # ) + + print("[*] Iterating with Randomly Generated Images") + for _ in range(100): + prompt = get_openvla_prompt(INSTRUCTION) + image = Image.fromarray(np.asarray(np.random.rand(256, 256, 3) * 255, dtype=np.uint8)) + + # === BFLOAT16 MODE === + inputs = processor(prompt, image).to(device, dtype=torch.bfloat16) + + # === 8-BIT/4-BIT QUANTIZATION MODE === + # inputs = processor(prompt, image).to(device, dtype=torch.float16) + + # Run OpenVLA Inference + start_time = time.time() + action = vla.predict_action(**inputs, unnorm_key="bridge_orig", do_sample=False) + print(f"\t=>> Time: {time.time() - start_time:.4f} || Action: {action}") + + +if __name__ == "__main__": + verify_openvla() diff --git a/vla-scripts/finetune.py b/vla-scripts/finetune.py new file mode 100644 index 0000000000000000000000000000000000000000..940e4850648c3af04de9ae474722b97113acddc3 --- /dev/null +++ b/vla-scripts/finetune.py @@ -0,0 +1,1142 @@ +""" +finetune.py + +Fine-tunes OpenVLA via LoRA. +""" + +import os +import time +from collections import deque +from dataclasses import dataclass +from pathlib import Path +from typing import Dict, Optional, Tuple, Type + +import draccus +import torch +import torch.distributed as dist +import torch.nn as nn +import tqdm +from accelerate import PartialState +from huggingface_hub import HfApi, snapshot_download +from peft import LoraConfig, PeftModel, get_peft_model +from torch.nn.parallel import DistributedDataParallel as DDP +from torch.optim import AdamW +from torch.optim.lr_scheduler import MultiStepLR +from torch.utils.data import DataLoader +from transformers import AutoConfig, AutoImageProcessor, AutoModelForVision2Seq, AutoProcessor +from transformers.modeling_outputs import CausalLMOutputWithPast + +import wandb + +from experiments.robot.openvla_utils import ( + check_model_logic_mismatch, + model_is_on_hf_hub, + update_auto_map, +) + +from prismatic.extern.hf.configuration_prismatic import OpenVLAConfig +from prismatic.extern.hf.modeling_prismatic import OpenVLAForActionPrediction +from prismatic.extern.hf.processing_prismatic import PrismaticImageProcessor, PrismaticProcessor +from prismatic.models.action_heads import DiffusionActionHead, L1RegressionActionHead +from prismatic.models.backbones.llm.prompting import PurePromptBuilder +from prismatic.models.film_vit_wrapper import FiLMedPrismaticVisionBackbone +from prismatic.models.projectors import ( + NoisyActionProjector, + ProprioProjector, +) +from prismatic.training.train_utils import ( + compute_actions_l1_loss, + compute_token_accuracy, + get_current_action_mask, + get_next_actions_mask, +) +from prismatic.util.data_utils import PaddedCollatorForActionPrediction +from prismatic.vla.action_tokenizer import ActionTokenizer +from prismatic.vla.constants import ( + ACTION_DIM, + ACTION_PROPRIO_NORMALIZATION_TYPE, + NUM_ACTIONS_CHUNK, + PROPRIO_DIM, +) +from prismatic.vla.datasets import RLDSBatchTransform, RLDSDataset +from prismatic.vla.datasets.rlds.utils.data_utils import save_dataset_statistics + +# Sane Defaults +os.environ["TOKENIZERS_PARALLELISM"] = "false" + + +@dataclass +class FinetuneConfig: + # fmt: off + vla_path: str = "openvla/openvla-7b" # Path to OpenVLA model (on HuggingFace Hub or stored locally) + + # Dataset + data_root_dir: Path = Path("datasets/rlds") # Directory containing RLDS datasets + dataset_name: str = "aloha_scoop_x_into_bowl" # Name of fine-tuning dataset (e.g., `aloha_scoop_x_into_bowl`) + run_root_dir: Path = Path("runs") # Path to directory to store logs & checkpoints + shuffle_buffer_size: int = 100_000 # Dataloader shuffle buffer size (can reduce if OOM errors occur) + + # Algorithm and architecture + use_l1_regression: bool = True # If True, trains continuous action head with L1 regression objective + use_diffusion: bool = False # If True, trains continuous action head with diffusion modeling objective (DDIM) + num_diffusion_steps_train: int = 50 # (When `diffusion==True`) Number of diffusion steps used for training + use_film: bool = False # If True, uses FiLM to infuse language inputs into visual features + num_images_in_input: int = 1 # Number of images in the VLA input (default: 1) + use_proprio: bool = False # If True, includes robot proprioceptive state in input + + # Training configuration + batch_size: int = 8 # Batch size per device (total batch size = batch_size * num GPUs) + learning_rate: float = 5e-4 # Learning rate + lr_warmup_steps: int = 0 # Number of steps to warm up learning rate (from 10% to 100%) + num_steps_before_decay: int = 100_000 # Number of steps before LR decays by 10x + grad_accumulation_steps: int = 1 # Number of gradient accumulation steps + max_steps: int = 200_000 # Max number of training steps + use_val_set: bool = False # If True, uses validation set and log validation metrics + val_freq: int = 10_000 # (When `use_val_set==True`) Validation set logging frequency in steps + val_time_limit: int = 180 # (When `use_val_set==True`) Time limit for computing validation metrics + save_freq: int = 10_000 # Checkpoint saving frequency in steps + save_latest_checkpoint_only: bool = False # If True, saves only 1 checkpoint, overwriting latest checkpoint + # (If False, saves all checkpoints) + resume: bool = False # If True, resumes from checkpoint + resume_step: Optional[int] = None # (When `resume==True`) Step number that we are resuming from + image_aug: bool = True # If True, trains with image augmentations (HIGHLY RECOMMENDED) + diffusion_sample_freq: int = 50 # (When `use_diffusion==True`) Frequency for sampling in steps + + # LoRA + use_lora: bool = True # If True, uses LoRA fine-tuning + lora_rank: int = 32 # Rank of LoRA weight matrix + lora_dropout: float = 0.0 # Dropout applied to LoRA weights + merge_lora_during_training: bool = True # If True, merges LoRA weights and saves result during training + # Note: Merging can be very slow on some machines. If so, set to + # False and merge final checkpoint offline! + + # Logging + wandb_entity: str = "your-wandb-entity" # Name of WandB entity + wandb_project: str = "your-wandb-project" # Name of WandB project + run_id_note: Optional[str] = None # Extra note to add to end of run ID for logging + run_id_override: Optional[str] = None # Optional string to override the run ID with + wandb_log_freq: int = 10 # WandB logging frequency in steps + + # fmt: on + + +def remove_ddp_in_checkpoint(state_dict) -> dict: + """ + Removes the 'module.' prefix from parameter names in a PyTorch model state dictionary that was saved using + DistributedDataParallel (DDP). + + When a model is trained using PyTorch's DistributedDataParallel, the saved state dictionary contains parameters + prefixed with 'module.'. This function removes these prefixes to make the state dictionary compatible when + loading into models that are not yet wrapped in DDP. + + Args: + state_dict (dict): PyTorch model state dictionary. + + Returns: + dict: A new state dictionary with the same contents but with 'module.' prefixes removed from parameter names. + Parameters without the 'module.' prefix remain unchanged. + """ + new_state_dict = {} + for k, v in state_dict.items(): + if k[:7] == "module.": + new_state_dict[k[7:]] = v + else: + new_state_dict[k] = v + return new_state_dict + + +def get_run_id(cfg) -> str: + """ + Generates or retrieves an identifier string for an experiment run. + + Args: + cfg (FinetuneConfig): Training configuration. + + Returns: + str: Experiment run ID. + """ + if cfg.run_id_override is not None: + # Override the run ID with the user-provided ID + run_id = cfg.run_id_override + elif cfg.resume: + # Override run ID with the previous resumed run's ID + run_id = cfg.vla_path.split("/")[-1] + # Remove the "--XXX_chkpt" suffix from the run ID if it exists + if "chkpt" in run_id.split("--")[-1]: + run_id = "--".join(run_id.split("--")[:-1]) + else: + run_id = ( + f"{cfg.vla_path.split('/')[-1]}+{cfg.dataset_name}" + f"+b{cfg.batch_size * cfg.grad_accumulation_steps}" + f"+lr-{cfg.learning_rate}" + ) + if cfg.use_lora: + run_id += f"+lora-r{cfg.lora_rank}+dropout-{cfg.lora_dropout}" + if cfg.image_aug: + run_id += "--image_aug" + if cfg.run_id_note is not None: + run_id += f"--{cfg.run_id_note}" + return run_id + + +def load_checkpoint(module_name: str, path: str, step: int, device: str = "cpu") -> dict: + """ + Loads a checkpoint for a given module. + + Args: + module_name (str): Name of model component to load checkpoint for. + path (str): Path to checkpoint directory. + step (int): Gradient step number of saved checkpoint. + device (str): String specifying how to remap storage locations (default = "cpu"). + + Returns: + dict: PyTorch model state dictionary. + """ + checkpoint_path = os.path.join(path, f"{module_name}--{step}_checkpoint.pt") + print(f"Loading checkpoint: {checkpoint_path}") + state_dict = torch.load(checkpoint_path, weights_only=True, map_location=device) + return remove_ddp_in_checkpoint(state_dict) + + +def wrap_ddp(module: nn.Module, device_id: int, find_unused: bool = False) -> DDP: + """ + Wrap a module with DistributedDataParallel. + + Args: + module (nn.Module): PyTorch module. + device_id (str): Device ID. + find_unused (bool): Whether to detect parameters without gradients in distributed training. + + Returns: + DistributedDataParallel: PyTorch module wrapped with DDP. + """ + return DDP(module, device_ids=[device_id], find_unused_parameters=find_unused, gradient_as_bucket_view=True) + + +def count_parameters(module: nn.Module, name: str) -> None: + """ + Counts and prints the number of trainable parameters in a module. + + Args: + module (nn.Module): PyTorch module. + module_name (str): Name of model component. + + Returns: + None. + """ + num_params = sum(p.numel() for p in module.parameters() if p.requires_grad) + print(f"# trainable params in {name}: {num_params}") + + +def init_module( + module_class: Type[nn.Module], + module_name: str, + cfg: FinetuneConfig, + device_id: int, + module_args: dict, + to_bf16: bool = False, + find_unused_params: bool = False, +) -> DDP: + """ + Initializes a module, optionally loads checkpoint, moves to device, and wraps with DDP. + + Args: + module_class (Type[nn.Module]): Class of PyTorch module to initialize. + module_name (str): Name of model component to load checkpoint for. + cfg (FinetuneConfig): Training configuration. + device_id (str): Device ID. + module_args (dict): Args for initializing the module. + to_bf16 (bool): Whether to convert to torch.bfloat16 data type. + find_unused_params (bool): Whether to detect parameters without gradients in distributed training. + + Returns: + DistributedDataParallel: PyTorch module wrapped with DDP. + """ + module = module_class(**module_args) + count_parameters(module, module_name) + + if cfg.resume: + state_dict = load_checkpoint(module_name, cfg.vla_path, cfg.resume_step) + module.load_state_dict(state_dict) + + if to_bf16: + module = module.to(torch.bfloat16) + module = module.to(device_id) + + return wrap_ddp(module, device_id, find_unused_params) + + +def run_forward_pass( + vla, + action_head, + noisy_action_projector, + proprio_projector, + batch, + action_tokenizer, + device_id, + use_l1_regression, + use_diffusion, + use_proprio, + use_film, + num_patches, + compute_diffusion_l1=False, + num_diffusion_steps_train=None, +) -> Tuple[torch.Tensor, Dict[str, float]]: + """ + Compute model forward pass and metrics for both training and validation. + + Args: + vla (OpenVLAForActionPrediction): Vision-language-action policy. + action_head (nn.Module): Action head module. + noisy_action_projector (nn.Module): Noisy action projector module (only used for diffusion). + proprio_projector (nn.Module): Proprioceptive state projector module. + batch (dict): Input batch. + action_tokenizer (ActionTokenizer): Action tokenizer. + device_id (str): Device ID. + use_l1_regression (bool): Whether to use L1 regression. + use_diffusion (bool): Whether to use diffusion. + use_proprio (bool): Whether to use proprioceptive state as input. + use_film (bool): Whether to use FiLM for better language following. + num_patches (int): Number of vision patches. + compute_diffusion_l1 (bool): Whether to sample actions and compute L1 loss for diffusion (do this once every + diffusion_sample_freq steps during training; do it every batch for validation) + num_diffusion_steps_train (int): Number of diffusion steps for training (only used for diffusion). + + Returns: + tuple: (loss, metrics_dict) + loss: The loss tensor with gradient for backpropagation. + metrics_dict: Dictionary of computed metrics (detached values for logging). + """ + metrics = {} + + # Get ground-truth action labels + ground_truth_actions = batch["actions"].to(device_id).to(torch.bfloat16) + + # [Only for diffusion] Sample noisy actions used as input for noise predictor network + if use_diffusion: + noisy_dict = action_head.module.sample_noisy_actions(ground_truth_actions) + noise, noisy_actions, diffusion_timestep_embeddings = ( + noisy_dict["noise"], + noisy_dict["noisy_actions"], + noisy_dict["diffusion_timestep_embeddings"], + ) + else: + noise, noisy_actions, diffusion_timestep_embeddings = None, None, None + + # VLA forward pass + with torch.autocast("cuda", dtype=torch.bfloat16): + output: CausalLMOutputWithPast = vla( + input_ids=batch["input_ids"].to(device_id), + attention_mask=batch["attention_mask"].to(device_id), + pixel_values=batch["pixel_values"].to(torch.bfloat16).to(device_id), + labels=batch["labels"], + output_hidden_states=True, + proprio=batch["proprio"] if use_proprio else None, + proprio_projector=proprio_projector if use_proprio else None, + noisy_actions=noisy_actions if use_diffusion else None, + noisy_action_projector=noisy_action_projector if use_diffusion else None, + diffusion_timestep_embeddings=diffusion_timestep_embeddings if use_diffusion else None, + use_film=use_film, + ) + + # Get action masks needed for logging + ground_truth_token_ids = batch["labels"][:, 1:].to(device_id) + current_action_mask = get_current_action_mask(ground_truth_token_ids) + next_actions_mask = get_next_actions_mask(ground_truth_token_ids) + + # Compute metrics for discrete action representation (next-token prediction) + if not (use_l1_regression or use_diffusion): + loss = output.loss + predicted_token_ids = output.logits[:, num_patches:-1].argmax(dim=2) + curr_action_accuracy = compute_token_accuracy( + predicted_token_ids, ground_truth_token_ids, mask=current_action_mask + ) + curr_action_l1_loss = compute_actions_l1_loss( + action_tokenizer, predicted_token_ids, ground_truth_token_ids, mask=current_action_mask + ) + next_actions_accuracy = compute_token_accuracy( + predicted_token_ids, ground_truth_token_ids, mask=next_actions_mask + ) + next_actions_l1_loss = compute_actions_l1_loss( + action_tokenizer, predicted_token_ids, ground_truth_token_ids, mask=next_actions_mask + ) + metrics.update( + { + "loss_value": loss.item(), # Detached value for logging + "curr_action_accuracy": curr_action_accuracy.item(), + "curr_action_l1_loss": curr_action_l1_loss.item(), + "next_actions_accuracy": next_actions_accuracy.item(), + "next_actions_l1_loss": next_actions_l1_loss.item(), + } + ) + # Compute metrics for continuous action representations (L1 regression | diffusion) + else: + # Get last layer hidden states + last_hidden_states = output.hidden_states[-1] # (B, seq_len, D) + # Get hidden states for text portion of prompt+response (after the vision patches) + text_hidden_states = last_hidden_states[:, num_patches:-1] + # Get hidden states for action portion of response + batch_size = batch["input_ids"].shape[0] + actions_hidden_states = ( + text_hidden_states[current_action_mask | next_actions_mask] + .reshape(batch_size, NUM_ACTIONS_CHUNK * ACTION_DIM, -1) + .to(torch.bfloat16) + ) # (B, act_chunk_len, D) + + if use_l1_regression: + # Predict action + predicted_actions = action_head.module.predict_action(actions_hidden_states) + # Get full L1 loss + loss = torch.nn.L1Loss()(ground_truth_actions, predicted_actions) + + if use_diffusion: + # Predict noise + noise_pred = action_head.module.predict_noise(actions_hidden_states) + # Get diffusion noise prediction MSE loss + noise_pred = noise_pred.reshape(noise.shape) + loss = nn.functional.mse_loss(noise_pred, noise, reduction="mean") + + # Only sample actions and compute L1 losses if specified + if compute_diffusion_l1: + with torch.no_grad(): + predicted_actions = run_diffusion_sampling( + vla=vla, + action_head=action_head, + noisy_action_projector=noisy_action_projector, + proprio_projector=proprio_projector, + batch=batch, + batch_size=batch_size, + num_patches=num_patches, + actions_shape=ground_truth_actions.shape, + device_id=device_id, + current_action_mask=current_action_mask, + next_actions_mask=next_actions_mask, + use_proprio=use_proprio, + use_film=use_film, + ) + + metrics.update( + { + "loss_value": loss.item(), # Detached value for logging + } + ) + + # Get detailed L1 losses for logging + should_log_l1_loss = not use_diffusion or (use_diffusion and compute_diffusion_l1) + if should_log_l1_loss: + ground_truth_curr_action = ground_truth_actions[:, 0] + predicted_curr_action = predicted_actions[:, 0] + ground_truth_next_actions = ground_truth_actions[:, 1:] + predicted_next_actions = predicted_actions[:, 1:] + curr_action_l1_loss = torch.nn.L1Loss()(ground_truth_curr_action, predicted_curr_action) + next_actions_l1_loss = torch.nn.L1Loss()(ground_truth_next_actions, predicted_next_actions) + metrics.update( + { + "curr_action_l1_loss": curr_action_l1_loss.item(), + "next_actions_l1_loss": next_actions_l1_loss.item(), + } + ) + + # Return both the loss tensor (with gradients) and the metrics dictionary (with detached values) + return loss, metrics + + +def run_diffusion_sampling( + vla, + action_head, + noisy_action_projector, + proprio_projector, + batch, + batch_size, + num_patches, + actions_shape, + device_id, + current_action_mask, + next_actions_mask, + use_proprio, + use_film, +) -> torch.Tensor: + """ + Run diffusion sampling (reverse diffusion) to generate actions. + + Args: + vla (OpenVLAForActionPrediction): Vision-language-action policy. + action_head (nn.Module): Action head module. + noisy_action_projector (nn.Module): Noisy action projector module (only used for diffusion). + proprio_projector (nn.Module): Proprioceptive state projector module. + batch (dict): Input batch. + batch_size (int): Batch size. + num_patches (int): Number of vision patches. + actions_shape (tuple): Shape of ground-truth actions. + device_id (str): Device ID. + current_action_mask (torch.Tensor): Mask for current action. + next_actions_mask (torch.Tensor): Mask for next actions. + use_proprio (bool): Whether to use proprioceptive state as input. + use_film (bool): Whether to use FiLM for better language following. + + Returns: + torch.Tensor: Predicted actions. + """ + # Sample random noisy action, used as the starting point for reverse diffusion + noise = torch.randn( + size=(batch_size, NUM_ACTIONS_CHUNK, ACTION_DIM), + device=device_id, + dtype=torch.bfloat16, + ) # (B, chunk_len, action_dim) + + # Set diffusion timestep values + action_head.module.noise_scheduler.set_timesteps(action_head.module.num_diffusion_steps_train) + + # Reverse diffusion: Iteratively denoise to generate action, conditioned on observation + curr_noisy_actions = noise + for t in action_head.module.noise_scheduler.timesteps: + # Get diffusion model's noise prediction (conditioned on VLA latent embedding, current noisy action embedding, + # and diffusion timestep embedding) + timesteps = torch.Tensor([t]).repeat(batch_size).to(device_id) + diffusion_timestep_embeddings = ( + action_head.module.time_encoder(timesteps).to(curr_noisy_actions.dtype).to(curr_noisy_actions.device) + ) # (B, llm_dim) + diffusion_timestep_embeddings = diffusion_timestep_embeddings.unsqueeze(1) # (B, 1, llm_dim) + + with torch.autocast("cuda", dtype=torch.bfloat16): + output = vla( + input_ids=batch["input_ids"].to(device_id), + attention_mask=batch["attention_mask"].to(device_id), + pixel_values=batch["pixel_values"].to(torch.bfloat16).to(device_id), + labels=batch["labels"], + output_hidden_states=True, + proprio=batch["proprio"] if use_proprio else None, + proprio_projector=proprio_projector if use_proprio else None, + noisy_actions=curr_noisy_actions, + noisy_action_projector=noisy_action_projector, + diffusion_timestep_embeddings=diffusion_timestep_embeddings, + use_film=use_film, + ) + # Get last layer hidden states + last_hidden_states = output.hidden_states[-1] # (B, seq_len, D) + # Get hidden states for text portion of prompt+response (after the vision patches) + text_hidden_states = last_hidden_states[:, num_patches:-1] + # Get hidden states for action portion of response + actions_hidden_states = text_hidden_states[current_action_mask | next_actions_mask].reshape( + batch_size, NUM_ACTIONS_CHUNK * ACTION_DIM, -1 + ) # (B, act_chunk_len, D) + actions_hidden_states = actions_hidden_states.to(torch.bfloat16) + # Predict noise + noise_pred = action_head.module.predict_noise(actions_hidden_states) + + # Compute the action at the previous diffusion timestep: x_t -> x_{t-1} + curr_noisy_actions = action_head.module.noise_scheduler.step(noise_pred, t, curr_noisy_actions).prev_sample + + return curr_noisy_actions.reshape(actions_shape) + + +def compute_smoothened_metrics(metrics_deques) -> dict: + """ + Compute smoothened metrics from recent deques. + + Args: + metrics_deques (dict): Dictionary of deques containing recent metrics. + + Returns: + dict: Dictionary of smoothened metrics. + """ + smoothened_metrics = {} + for name, deque in metrics_deques.items(): + if deque and len(deque) > 0: + smoothened_metrics[name] = sum(deque) / len(deque) + return smoothened_metrics + + +def log_metrics_to_wandb(metrics, prefix, step, wandb_entity) -> None: + """ + Log metrics to Weights & Biases. + + Args: + metrics (dict): Dictionary of metrics to log + prefix (str): Prefix for metric names + step (int): Training step + wandb_entity (str): W&B entity instance + + Returns: + None. + """ + log_dict = {} + for name, value in metrics.items(): + # Map loss_value to Loss for better readability in W&B + if name == "loss_value": + log_dict[f"{prefix}/Loss"] = value + # Keep other metrics as is + else: + log_dict[f"{prefix}/{name.replace('_', ' ').title()}"] = value + wandb_entity.log(log_dict, step=step) + + +def save_training_checkpoint( + cfg, + run_dir, + log_step, + vla, + processor, + proprio_projector, + noisy_action_projector, + action_head, + train_dataset, + distributed_state, +) -> None: + """ + Save all training checkpoints including model components, LoRA adapter, and dataset statistics. + + Args: + cfg (FinetuneConfig): Training configuration. + run_dir (Path): Experiment run directory path. + log_step (int): Current logging step. + vla (OpenVLAForActionPrediction): Vision-language-action policy. + processor (PrismaticProcessor): OpenVLA inputs processor. + proprio_projector (nn.Module): Proprioceptive state projector module. + noisy_action_projector (nn.Module): Noisy action projector module (only used for diffusion). + action_head (nn.Module): Action head module. + train_dataset (RLDSDataset): Training dataset. + distributed_state (PartialState): Distributed training state. + + Returns: + None. + """ + # Determine checkpoint paths and naming + if cfg.save_latest_checkpoint_only: + checkpoint_dir = run_dir + checkpoint_name_suffix = "latest_checkpoint.pt" + else: + checkpoint_dir = Path(str(run_dir) + f"--{log_step}_chkpt") + checkpoint_name_suffix = f"{log_step}_checkpoint.pt" + + adapter_dir = checkpoint_dir / "lora_adapter" + + # Create directories and save dataset statistics (main process only) + if distributed_state.is_main_process: + os.makedirs(checkpoint_dir, exist_ok=True) + os.makedirs(adapter_dir, exist_ok=True) + save_dataset_statistics(train_dataset.dataset_statistics, checkpoint_dir) + print(f"Saving Model Checkpoint for Step {log_step}") + + # Wait for directories to be created + dist.barrier() + + # Save model components (main process only) + if distributed_state.is_main_process: + # Save processor and LoRA adapter + processor.save_pretrained(checkpoint_dir) + vla.module.save_pretrained(adapter_dir) + + # Save other components + if cfg.use_proprio and proprio_projector is not None: + torch.save(proprio_projector.state_dict(), checkpoint_dir / f"proprio_projector--{checkpoint_name_suffix}") + + if cfg.use_diffusion and noisy_action_projector is not None: + torch.save( + noisy_action_projector.state_dict(), checkpoint_dir / f"noisy_action_projector--{checkpoint_name_suffix}" + ) + + if (cfg.use_l1_regression or cfg.use_diffusion) and action_head is not None: + torch.save(action_head.state_dict(), checkpoint_dir / f"action_head--{checkpoint_name_suffix}") + + if cfg.use_film: + # To be safe, just save the entire vision backbone (not just FiLM components) + torch.save( + vla.module.vision_backbone.state_dict(), checkpoint_dir / f"vision_backbone--{checkpoint_name_suffix}" + ) + + # Wait for model components to be saved + dist.barrier() + + # Merge LoRA weights into base model and save resulting model checkpoint + # Note: Can be very slow on some devices; if so, we recommend merging offline + if cfg.use_lora and cfg.merge_lora_during_training: + base_vla = AutoModelForVision2Seq.from_pretrained( + cfg.vla_path, torch_dtype=torch.bfloat16, low_cpu_mem_usage=True, trust_remote_code=True + ) + merged_vla = PeftModel.from_pretrained(base_vla, adapter_dir) + merged_vla = merged_vla.merge_and_unload() + + if distributed_state.is_main_process: + merged_vla.save_pretrained(checkpoint_dir) + print(f"Saved merged model for Step {log_step} at: {checkpoint_dir}") + + # Wait for merged model to be saved + dist.barrier() + + +def run_validation( + vla, + action_head, + noisy_action_projector, + proprio_projector, + val_dataloader, + action_tokenizer, + device_id, + cfg, + num_patches, + log_step, + distributed_state, + val_time_limit, +) -> None: + """ + Compute validation set metrics for logging. + + Args: + vla (OpenVLAForActionPrediction): Vision-language-action policy. + action_head (nn.Module): Action head module. + noisy_action_projector (nn.Module): Noisy action projector module (only used for diffusion). + proprio_projector (nn.Module): Proprioceptive state projector module. + val_dataloader (DataLoader): Validation data loader. + action_tokenizer (ActionTokenizer): Action tokenizer. + device_id (str): Device ID. + cfg (FinetuneConfig): Training configuration. + num_patches (int): Number of vision patches. + log_step (int): Current logging step. + distributed_state (PartialState): Distributed training state. + val_time_limit (int): Time limit for computing validation metrics. + + Returns: + None. + """ + val_start_time = time.time() + vla.eval() + val_batches_count = 0 + + # List to store validation metrics + all_val_metrics = [] + + with torch.no_grad(): + for batch in val_dataloader: + # Always compute L1 loss for validation, even for diffusion + _, metrics = run_forward_pass( + vla=vla, + action_head=action_head, + noisy_action_projector=noisy_action_projector, + proprio_projector=proprio_projector, + batch=batch, + action_tokenizer=action_tokenizer, + device_id=device_id, + use_l1_regression=cfg.use_l1_regression, + use_diffusion=cfg.use_diffusion, + use_proprio=cfg.use_proprio, + use_film=cfg.use_film, + num_patches=num_patches, + compute_diffusion_l1=True, + num_diffusion_steps_train=cfg.num_diffusion_steps_train if cfg.use_diffusion else None, + ) + + # Add the loss value to the metrics + metrics["loss"] = metrics["loss_value"] + all_val_metrics.append(metrics) + val_batches_count += 1 + + # Cut testing on validation set short if it exceeds time limit + if time.time() - val_start_time > val_time_limit: + break + + # Compute average validation metrics + avg_val_metrics = {} + for metric_name in all_val_metrics[0].keys(): + values = [metrics[metric_name] for metrics in all_val_metrics if metric_name in metrics] + if values: + avg_val_metrics[metric_name] = sum(values) / len(values) + + # Add batch count to metrics + avg_val_metrics["val_batches_count"] = val_batches_count + + # Log validation metrics to W&B + if distributed_state.is_main_process: + log_metrics_to_wandb(avg_val_metrics, "VLA Val", log_step, wandb) + + +@draccus.wrap() +def finetune(cfg: FinetuneConfig) -> None: + """ + Fine-tunes base VLA on demonstration dataset via LoRA. + + Allows toggling different action representations (discrete vs. continuous), different learning objectives + (next-token prediction vs. L1 regression vs. diffusion), FiLM. Also allows for additional model inputs, + such as additional camera images and robot proprioceptive state. Assumes parallel action generation with + action chunking. + + Args: + cfg (FinetuneConfig): Training configuration. + + Returns: + None. + """ + assert cfg.use_lora, "Only LoRA fine-tuning is supported. Please set --use_lora=True!" + assert not (cfg.use_l1_regression and cfg.use_diffusion), ( + "Cannot do both L1 regression and diffusion. Please pick one of them!" + ) + + # Trim trailing forward slash ('/') in VLA path if it exists + cfg.vla_path = cfg.vla_path.rstrip("/") + print(f"Fine-tuning OpenVLA Model `{cfg.vla_path}` on `{cfg.dataset_name}`") + + # Get experiment run ID + run_id = get_run_id(cfg) + + # Create experiment run directory + run_dir = cfg.run_root_dir / run_id + os.makedirs(run_dir, exist_ok=True) + + # GPU setup + distributed_state = PartialState() + device_id = distributed_state.local_process_index + torch.cuda.set_device(device_id) + torch.cuda.empty_cache() + + # Initialize wandb logging + if distributed_state.is_main_process: + wandb.init(entity=cfg.wandb_entity, project=cfg.wandb_project, name=f"ft+{run_id}") + + # Print detected constants + print( + "Detected constants:\n" + f"\tNUM_ACTIONS_CHUNK: {NUM_ACTIONS_CHUNK}\n" + f"\tACTION_DIM: {ACTION_DIM}\n" + f"\tPROPRIO_DIM: {PROPRIO_DIM}\n" + f"\tACTION_PROPRIO_NORMALIZATION_TYPE: {ACTION_PROPRIO_NORMALIZATION_TYPE}" + ) + + # Two options: + # (1) Base model is on Hugging Face Hub + # - Then download it and record the path to the download directory + # (2) Base model is stored locally + # - Then register model config in HF Auto Classes + # In both cases, we want to check whether any changes have been made to + # the `modeling_prismatic.py` file in this codebase; if so, we will copy + # the file to the downloaded or locally stored checkpoint directory so + # that the user's changes to the VLA class logic go into effect + if model_is_on_hf_hub(cfg.vla_path): + # Download model directly from Hugging Face Hub + vla_download_path = snapshot_download(repo_id=cfg.vla_path) + # Overwrite VLA path + cfg.vla_path = vla_download_path + else: + # Register OpenVLA model to HF Auto Classes (not needed if the model is on HF Hub) + AutoConfig.register("openvla", OpenVLAConfig) + AutoImageProcessor.register(OpenVLAConfig, PrismaticImageProcessor) + AutoProcessor.register(OpenVLAConfig, PrismaticProcessor) + AutoModelForVision2Seq.register(OpenVLAConfig, OpenVLAForActionPrediction) + + # Update config.json and sync model files + if distributed_state.is_main_process: + update_auto_map(cfg.vla_path) + check_model_logic_mismatch(cfg.vla_path) + + # Wait for model files to be synced + dist.barrier() + + # Load processor and VLA + processor = AutoProcessor.from_pretrained(cfg.vla_path, trust_remote_code=True) + vla = AutoModelForVision2Seq.from_pretrained( + cfg.vla_path, + torch_dtype=torch.bfloat16, + low_cpu_mem_usage=True, + trust_remote_code=True, + ).to(device_id) + + # Set number of images in VLA input + vla.vision_backbone.set_num_images_in_input(cfg.num_images_in_input) + + # LoRA setup + if cfg.use_lora: + lora_config = LoraConfig( + r=cfg.lora_rank, + lora_alpha=min(cfg.lora_rank, 16), + lora_dropout=cfg.lora_dropout, + target_modules="all-linear", + init_lora_weights="gaussian", + ) + vla = get_peft_model(vla, lora_config) + vla.print_trainable_parameters() + + # FiLM setup + if cfg.use_film: + count_parameters(vla.vision_backbone, "vla.vision_backbone (original)") + # Wrap vision backbone with FiLM wrapper + # Important: For this, must specify `vla.model.vision_backbone` instead of just `vla.vision_backbone`, since the + # latter would cause the new wrapped backbone to be saved as a new attribute of `vla` instead of overwriting the + # original one (due to the LoRA wrapper) + vla.model.vision_backbone = FiLMedPrismaticVisionBackbone( + vision_backbone=vla.model.vision_backbone, + llm_dim=vla.llm_dim, + ) + count_parameters(vla.vision_backbone, "vla.vision_backbone (post-wrap)") + if cfg.resume: + state_dict = load_checkpoint("vision_backbone", cfg.vla_path, cfg.resume_step) + vla.model.vision_backbone.load_state_dict(state_dict) + vla.model.vision_backbone = vla.model.vision_backbone.to(device_id) + + # Wrap VLA with DDP + vla = wrap_ddp(vla, device_id, find_unused=True) + + # If applicable, instantiate proprio projector + if cfg.use_proprio: + proprio_projector = init_module( + ProprioProjector, + "proprio_projector", + cfg, + device_id, + {"llm_dim": vla.module.llm_dim, "proprio_dim": PROPRIO_DIM}, + ) + + # If applicable, instantiate continuous action head for L1 regression + if cfg.use_l1_regression: + action_head = init_module( + L1RegressionActionHead, + "action_head", + cfg, + device_id, + {"input_dim": vla.module.llm_dim, "hidden_dim": vla.module.llm_dim, "action_dim": ACTION_DIM}, + to_bf16=True, + ) + + # If applicable, instantiate diffusion action head and noisy action projector + if cfg.use_diffusion: + action_head = init_module( + DiffusionActionHead, + "action_head", + cfg, + device_id, + { + "input_dim": vla.module.llm_dim, + "hidden_dim": vla.module.llm_dim, + "action_dim": ACTION_DIM, + "num_diffusion_steps_train": cfg.num_diffusion_steps_train, + }, + to_bf16=True, + ) + noisy_action_projector = init_module( + NoisyActionProjector, "noisy_action_projector", cfg, device_id, {"llm_dim": vla.module.llm_dim} + ) + + # Get number of vision patches + NUM_PATCHES = vla.module.vision_backbone.get_num_patches() * vla.module.vision_backbone.get_num_images_in_input() + # If we have proprio inputs, a single proprio embedding is appended to the end of the vision patch embeddings + if cfg.use_proprio: + NUM_PATCHES += 1 + # For diffusion, a single diffusion timestep embedding is appended to the end of the vision patch embeddings + if cfg.use_diffusion: + NUM_PATCHES += 1 + + # Instantiate optimizer + trainable_params = [param for param in vla.parameters() if param.requires_grad] + if cfg.use_l1_regression or cfg.use_diffusion: + trainable_params += [param for param in action_head.parameters() if param.requires_grad] + if cfg.use_diffusion: + trainable_params += [param for param in noisy_action_projector.parameters() if param.requires_grad] + if cfg.use_proprio: + trainable_params += [param for param in proprio_projector.parameters() if param.requires_grad] + print(f"# total trainable params: {sum(p.numel() for p in trainable_params)}") + optimizer = AdamW(trainable_params, lr=cfg.learning_rate) + + # Record original learning rate + original_lr = optimizer.param_groups[0]["lr"] + + # Create learning rate scheduler + scheduler = MultiStepLR( + optimizer, + milestones=[cfg.num_steps_before_decay], # Number of steps after which LR will change + gamma=0.1, # Multiplicative factor of learning rate decay + ) + + # Create Action Tokenizer + action_tokenizer = ActionTokenizer(processor.tokenizer) + + # Load Fine-tuning Dataset =>> note that we use an RLDS-formatted dataset following Open X-Embodiment by default. + # =>> If you want to use a non-RLDS dataset (e.g., a standard PyTorch Dataset) see the following commented block. + # =>> Note that our training code does not loop over epochs because the RLDS loader does this implicitly; if using + # your own Dataset, make sure to add the appropriate logic to the training loop! + # + # --- + # from prismatic.vla.datasets import DummyDataset + # + # train_dataset = DummyDataset( + # action_tokenizer, + # processor.tokenizer, + # image_transform=processor.image_processor.apply_transform, + # prompt_builder_fn=PurePromptBuilder, + # ) + # --- + + # We assume that the model takes as input one third-person camera image and 1 or 2 optional wrist camera image(s) + use_wrist_image = cfg.num_images_in_input > 1 + + # Create training and optional validation datasets + batch_transform = RLDSBatchTransform( + action_tokenizer, + processor.tokenizer, + image_transform=processor.image_processor.apply_transform, + prompt_builder_fn=PurePromptBuilder, + use_wrist_image=use_wrist_image, + use_proprio=cfg.use_proprio, + ) + train_dataset = RLDSDataset( + cfg.data_root_dir, + cfg.dataset_name, + batch_transform, + resize_resolution=tuple(vla.module.config.image_sizes), + shuffle_buffer_size=cfg.shuffle_buffer_size, + image_aug=cfg.image_aug, + ) + if cfg.use_val_set: + val_dataset = RLDSDataset( + cfg.data_root_dir, + cfg.dataset_name, + batch_transform, + resize_resolution=tuple(vla.module.config.image_sizes), + shuffle_buffer_size=cfg.shuffle_buffer_size // 10, + image_aug=cfg.image_aug, + train=False, + ) + + # [Important] Save dataset statistics so that we can unnormalize actions during inference + if distributed_state.is_main_process: + save_dataset_statistics(train_dataset.dataset_statistics, run_dir) + + # Create collator and dataloader + collator = PaddedCollatorForActionPrediction( + processor.tokenizer.model_max_length, processor.tokenizer.pad_token_id, padding_side="right" + ) + dataloader = DataLoader( + train_dataset, + batch_size=cfg.batch_size, + sampler=None, + collate_fn=collator, + num_workers=0, # Important: Set to 0 if using RLDS, which uses its own parallelism + ) + if cfg.use_val_set: + val_batch_size = cfg.batch_size + val_dataloader = DataLoader( + val_dataset, + batch_size=val_batch_size, + sampler=None, + collate_fn=collator, + num_workers=0, # Important: Set to 0 if using RLDS, which uses its own parallelism + ) + + # Deque to store recent train metrics (used for computing smoothened metrics for gradient accumulation) + recent_metrics = { + "loss_value": deque(maxlen=cfg.grad_accumulation_steps), + "curr_action_accuracy": deque(maxlen=cfg.grad_accumulation_steps), + "curr_action_l1_loss": deque(maxlen=cfg.grad_accumulation_steps), + "next_actions_accuracy": deque(maxlen=cfg.grad_accumulation_steps), + "next_actions_l1_loss": deque(maxlen=cfg.grad_accumulation_steps), + } + + # Start training + with tqdm.tqdm(total=cfg.max_steps, leave=False) as progress: + vla.train() + optimizer.zero_grad() + for batch_idx, batch in enumerate(dataloader): + # Compute training metrics and loss + compute_diffusion_l1 = cfg.use_diffusion and batch_idx % cfg.diffusion_sample_freq == 0 + loss, metrics = run_forward_pass( + vla=vla, + action_head=action_head, + noisy_action_projector=noisy_action_projector if cfg.use_diffusion else None, + proprio_projector=proprio_projector if cfg.use_proprio else None, + batch=batch, + action_tokenizer=action_tokenizer, + device_id=device_id, + use_l1_regression=cfg.use_l1_regression, + use_diffusion=cfg.use_diffusion, + use_proprio=cfg.use_proprio, + use_film=cfg.use_film, + num_patches=NUM_PATCHES, + compute_diffusion_l1=compute_diffusion_l1, + num_diffusion_steps_train=cfg.num_diffusion_steps_train if cfg.use_diffusion else None, + ) + + # Normalize loss to account for gradient accumulation + normalized_loss = loss / cfg.grad_accumulation_steps + + # Backward pass + normalized_loss.backward() + + # Store recent train metrics + for metric_name, value in metrics.items(): + if metric_name in recent_metrics: + recent_metrics[metric_name].append(value) + + # Compute gradient step index + gradient_step_idx = batch_idx // cfg.grad_accumulation_steps + + # Compute smoothened train metrics + smoothened_metrics = compute_smoothened_metrics(recent_metrics) + + # Push Metrics to W&B (every wandb_log_freq gradient steps) + log_step = gradient_step_idx if not cfg.resume else cfg.resume_step + gradient_step_idx + if distributed_state.is_main_process and log_step % cfg.wandb_log_freq == 0: + log_metrics_to_wandb(smoothened_metrics, "VLA Train", log_step, wandb) + + # [If applicable] Linearly warm up learning rate from 10% to 100% of original + if cfg.lr_warmup_steps > 0: + lr_progress = min((gradient_step_idx + 1) / cfg.lr_warmup_steps, 1.0) # Cap at 1.0 + current_lr = original_lr * (0.1 + 0.9 * lr_progress) + for param_group in optimizer.param_groups: + param_group["lr"] = current_lr + + if distributed_state.is_main_process and gradient_step_idx % cfg.wandb_log_freq == 0: + # Log the learning rate + # Make sure to do this AFTER any learning rate modifications (e.g., warmup/decay) + wandb.log( + { + "VLA Train/Learning Rate": scheduler.get_last_lr()[0], + }, + step=log_step, + ) + + # Optimizer and LR scheduler step + if (batch_idx + 1) % cfg.grad_accumulation_steps == 0: + optimizer.step() + scheduler.step() + optimizer.zero_grad() + progress.update() + + # Save model checkpoint: either keep latest checkpoint only or all checkpoints + if gradient_step_idx > 0 and log_step % cfg.save_freq == 0: + save_training_checkpoint( + cfg=cfg, + run_dir=run_dir, + log_step=log_step, + vla=vla, + processor=processor, + proprio_projector=proprio_projector if cfg.use_proprio else None, + noisy_action_projector=noisy_action_projector if cfg.use_diffusion else None, + action_head=action_head if (cfg.use_l1_regression or cfg.use_diffusion) else None, + train_dataset=train_dataset, + distributed_state=distributed_state, + ) + + # Test model on validation set + if cfg.use_val_set and log_step > 0 and log_step % cfg.val_freq == 0: + run_validation( + vla=vla, + action_head=action_head, + noisy_action_projector=noisy_action_projector if cfg.use_diffusion else None, + proprio_projector=proprio_projector if cfg.use_proprio else None, + val_dataloader=val_dataloader, + action_tokenizer=action_tokenizer, + device_id=device_id, + cfg=cfg, + num_patches=NUM_PATCHES, + log_step=log_step, + distributed_state=distributed_state, + val_time_limit=cfg.val_time_limit, + ) + # Set model back to training mode after validation + vla.train() + + # Stop training when max_steps is reached + if log_step == cfg.max_steps: + print(f"Max step {cfg.max_steps} reached! Stopping training...") + break + + +if __name__ == "__main__": + finetune() diff --git a/vla-scripts/merge_lora_weights_and_save.py b/vla-scripts/merge_lora_weights_and_save.py new file mode 100644 index 0000000000000000000000000000000000000000..8c38c10e95853d900a70b35f3112337be49b579e --- /dev/null +++ b/vla-scripts/merge_lora_weights_and_save.py @@ -0,0 +1,73 @@ +""" +Loads a checkpoint that only has a LoRA adapter (no merged model) and merges the adapter +into the base OpenVLA model. Saves the final checkpoint in the same directory. + +Make sure to specify the correct base checkpoint when running this script. For example, +- if you fine-tuned the default OpenVLA-7B model without modifications, then `--base_checkpoint=="openvla/openvla-7b"` +- if you fine-tuned a different model or resumed fine-tuning from a different checkpoint, then specify that base checkpoint +- if you fine-tuned the default OpenVLA-7B model with modifications to `modeling_prismatic.py` (OpenVLA class definition), + then the base checkpoint path should point to the checkpoint containing the modifications + +Usage: + python vla-scripts/merge_lora_weights_and_save.py \ + --base_checkpoint openvla/openvla-7b \ + --lora_finetuned_checkpoint_dir /PATH/TO/CHECKPOINT/DIR/ +""" + +import os +import time +from dataclasses import dataclass +from pathlib import Path +from typing import Union + +import draccus +import torch +from peft import PeftModel +from transformers import AutoConfig, AutoImageProcessor, AutoModelForVision2Seq, AutoProcessor + +from prismatic.extern.hf.configuration_prismatic import OpenVLAConfig +from prismatic.extern.hf.modeling_prismatic import OpenVLAForActionPrediction +from prismatic.extern.hf.processing_prismatic import PrismaticImageProcessor, PrismaticProcessor + + +@dataclass +class ConvertConfig: + # fmt: off + + base_checkpoint: Union[str, Path] = "" # Base model checkpoint path/dir (either openvla/openvla-7b or whichever model you fine-tuned / resumed training from) + lora_finetuned_checkpoint_dir: Union[str, Path] = "" # Checkpoint directory containing the LoRA adapter + + # fmt: on + + +@draccus.wrap() +def main(cfg: ConvertConfig) -> None: + # Register OpenVLA model to HF Auto Classes (not needed if the model is on HF Hub) + AutoConfig.register("openvla", OpenVLAConfig) + AutoImageProcessor.register(OpenVLAConfig, PrismaticImageProcessor) + AutoProcessor.register(OpenVLAConfig, PrismaticProcessor) + AutoModelForVision2Seq.register(OpenVLAConfig, OpenVLAForActionPrediction) + + # Load Model using HF AutoClasses + print(f"Loading base model: {cfg.base_checkpoint}") + vla = AutoModelForVision2Seq.from_pretrained( + cfg.base_checkpoint, + torch_dtype=torch.bfloat16, + low_cpu_mem_usage=True, + trust_remote_code=True, + ) + + # Load LoRA weights and merge into base model, then save final checkpoint + print("Merging LoRA weights into base model...") + start_time = time.time() + merged_vla = PeftModel.from_pretrained(vla, os.path.join(cfg.lora_finetuned_checkpoint_dir, "lora_adapter")).to( + "cuda" + ) + merged_vla = merged_vla.merge_and_unload() + merged_vla.save_pretrained(cfg.lora_finetuned_checkpoint_dir) + print(f"\nMerging complete! Time elapsed (sec): {time.time() - start_time}") + print(f"\nSaved merged model checkpoint at:\n{cfg.lora_finetuned_checkpoint_dir}") + + +if __name__ == "__main__": + main()