diff --git a/external/peract_bimanual/.gitignore b/external/peract_bimanual/.gitignore new file mode 100644 index 0000000000000000000000000000000000000000..68bc17f9ff2104a9d7b6777058bb4c343ca72609 --- /dev/null +++ b/external/peract_bimanual/.gitignore @@ -0,0 +1,160 @@ +# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[cod] +*$py.class + +# C extensions +*.so + +# Distribution / packaging +.Python +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +share/python-wheels/ +*.egg-info/ +.installed.cfg +*.egg +MANIFEST + +# PyInstaller +# Usually these files are written by a python script from a template +# before PyInstaller builds the exe, so as to inject date/other infos into it. +*.manifest +*.spec + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt + +# Unit test / coverage reports +htmlcov/ +.tox/ +.nox/ +.coverage +.coverage.* +.cache +nosetests.xml +coverage.xml +*.cover +*.py,cover +.hypothesis/ +.pytest_cache/ +cover/ + +# Translations +*.mo +*.pot + +# Django stuff: +*.log +local_settings.py +db.sqlite3 +db.sqlite3-journal + +# Flask stuff: +instance/ +.webassets-cache + +# Scrapy stuff: +.scrapy + +# Sphinx documentation +docs/_build/ + +# PyBuilder +.pybuilder/ +target/ + +# Jupyter Notebook +.ipynb_checkpoints + +# IPython +profile_default/ +ipython_config.py + +# pyenv +# For a library or package, you might want to ignore these files since the code is +# intended to run in multiple environments; otherwise, check them in: +# .python-version + +# pipenv +# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. +# However, in case of collaboration, if having platform-specific dependencies or dependencies +# having no cross-platform support, pipenv may install dependencies that don't work, or not +# install all needed dependencies. +#Pipfile.lock + +# poetry +# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. +# This is especially recommended for binary packages to ensure reproducibility, and is more +# commonly ignored for libraries. +# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control +#poetry.lock + +# pdm +# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. +#pdm.lock +# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it +# in version control. +# https://pdm.fming.dev/#use-with-ide +.pdm.toml + +# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm +__pypackages__/ + +# Celery stuff +celerybeat-schedule +celerybeat.pid + +# SageMath parsed files +*.sage.py + +# Environments +.env +.venv +env/ +venv/ +ENV/ +env.bak/ +venv.bak/ + +# Spyder project settings +.spyderproject +.spyproject + +# Rope project settings +.ropeproject + +# mkdocs documentation +/site + +# mypy +.mypy_cache/ +.dmypy.json +dmypy.json + +# Pyre type checker +.pyre/ + +# pytype static type analyzer +.pytype/ + +# Cython debug symbols +cython_debug/ + +# PyCharm +# JetBrains specific template is maintained in a separate JetBrains.gitignore that can +# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore +# and can be added to the global gitignore or merged into this file. For a more nuclear +# option (not recommended) you can uncomment the following to ignore the entire idea folder. +#.idea/ diff --git a/external/peract_bimanual/ARM_LICENSE b/external/peract_bimanual/ARM_LICENSE new file mode 100644 index 0000000000000000000000000000000000000000..5c596260a5f028d308fb28c2f6d0680b4e02c9df --- /dev/null +++ b/external/peract_bimanual/ARM_LICENSE @@ -0,0 +1,196 @@ +Q-attention: Enabling Efficient Learning for Vision-based Robotic Manipulation + +LICENCE AGREEMENT + +WE (Imperial College of Science, Technology and Medicine, (“Imperial College London”)) +ARE WILLING TO LICENSE THIS SOFTWARE TO YOU (a licensee “You”) ONLY ON THE +CONDITION THAT YOU ACCEPT ALL OF THE TERMS CONTAINED IN THE +FOLLOWING AGREEMENT. PLEASE READ THE AGREEMENT CAREFULLY BEFORE +DOWNLOADING THE SOFTWARE. BY EXERCISING THE OPTION TO DOWNLOAD +THE SOFTWARE YOU AGREE TO BE BOUND BY THE TERMS OF THE AGREEMENT. +SOFTWARE LICENCE AGREEMENT (EXCLUDING BSD COMPONENTS) + +1.This Agreement pertains to a worldwide, non-exclusive, temporary, fully paid-up, royalty +free, non-transferable, non-sub- licensable licence (the “Licence”) to use the Q-attention +source code, including any modification, part or derivative (the “Software”). +Ownership and Licence. Your rights to use and download the Software onto your computer, +and all other copies that You are authorised to make, are specified in this Agreement. +However, we (or our licensors) retain all rights, including but not limited to all copyright and +other intellectual property rights anywhere in the world, in the Software not expressly +granted to You in this Agreement. + +2. Permitted use of the Licence: + +(a) You may download and install the Software onto one computer or server for use in +accordance with Clause 2(b) of this Agreement provided that You ensure that the Software is +not accessible by other users unless they have themselves accepted the terms of this licence +agreement. + +(b) You may use the Software solely for non-commercial, internal or academic research +purposes and only in accordance with the terms of this Agreement. You may not use the +Software for commercial purposes, including but not limited to (1) integration of all or part of +the source code or the Software into a product for sale or licence by or on behalf of You to +third parties or (2) use of the Software or any derivative of it for research to develop software +products for sale or licence to a third party or (3) use of the Software or any derivative of it +for research to develop non-software products for sale or licence to a third party, or (4) use of +the Software to provide any service to an external organisation for which payment is +received. + +Should You wish to use the Software for commercial purposes, You shall +email researchcontracts.engineering@imperial.ac.uk . + +(c) Right to Copy. You may copy the Software for back-up and archival purposes, provided +that each copy is kept in your possession and provided You reproduce our copyright notice +(set out in Schedule 1) on each copy. + +(d) Transfer and sub-licensing. You may not rent, lend, or lease the Software and You may +not transmit, transfer or sub-license this licence to use the Software or any of your rights or +obligations under this Agreement to another party. + +(e) Identity of Licensee. The licence granted herein is personal to You. You shall not permit +any third party to access, modify or otherwise use the Software nor shall You access modify +or otherwise use the Software on behalf of any third party. If You wish to obtain a licence for +mutiple users or a site licence for the Software please contact us +at researchcontracts.engineering@imperial.ac.uk . + +(f) Publications and presentations. You may make public, results or data obtained from, +dependent on or arising from research carried out using the Software, provided that any such +presentation or publication identifies the Software as the source of the results or the data, +including the Copyright Notice given in each element of the Software, and stating that the +Software has been made available for use by You under licence from Imperial College London +and You provide a copy of any such publication to Imperial College London. + +3. Prohibited Uses. You may not, without written permission from us +at researchcontracts.engineering@imperial.ac.uk : + +(a) Use, copy, modify, merge, or transfer copies of the Software or any documentation +provided by us which relates to the Software except as provided in this Agreement; + +(b) Use any back-up or archival copies of the Software (or allow anyone else to use such +copies) for any purpose other than to replace the original copy in the event it is destroyed or +becomes defective; or + +(c) Disassemble, decompile or "unlock", reverse translate, or in any manner decode the +Software for any reason. + +4. Warranty Disclaimer + +(a) Disclaimer. The Software has been developed for research purposes only. You +acknowledge that we are providing the Software to You under this licence agreement free of +charge and on condition that the disclaimer set out below shall apply. We do not represent or +warrant that the Software as to: (i) the quality, accuracy or reliability of the Software; (ii) the +suitability of the Software for any particular use or for use under any specific conditions; and +(iii) whether use of the Software will infringe third-party rights. +You acknowledge that You have reviewed and evaluated the Software to determine that it +meets your needs and that You assume all responsibility and liability for determining the +suitability of the Software as fit for your particular purposes and requirements. Subject to +Clause 4(b), we exclude and expressly disclaim all express and implied representations, +warranties, conditions and terms not stated herein (including the implied conditions or +warranties of satisfactory quality, merchantable quality, merchantability and fitness for +purpose). + +(b) Savings. Some jurisdictions may imply warranties, conditions or terms or impose +obligations upon us which cannot, in whole or in part, be excluded, restricted or modified or +otherwise do not allow the exclusion of implied warranties, conditions or terms, in which +case the above warranty disclaimer and exclusion will only apply to You to the extent +permitted in the relevant jurisdiction and does not in any event exclude any implied +warranties, conditions or terms which may not under applicable law be excluded. + +(c) Imperial College London disclaims all responsibility for the use which is made of the +Software and any liability for the outcomes arising from using the Software. + +5. Limitation of Liability + +(a) You acknowledge that we are providing the Software to You under this licence agreement +free of charge and on condition that the limitation of liability set out below shall apply. +Accordingly, subject to Clause 5(b), we exclude all liability whether in contract, tort, +negligence or otherwise, in respect of the Software and/or any related documentation +provided to You by us including, but not limited to, liability for loss or corruption of data, +loss of contracts, loss of income, loss of profits, loss of cover and any consequential or indirect +loss or damage of any kind arising out of or in connection with this licence agreement, +however caused. This exclusion shall apply even if we have been advised of the possibility of +such loss or damage. + +(b) You agree to indemnify Imperial College London and hold it harmless from and against +any and all claims, damages and liabilities asserted by third parties (including claims for +negligence) which arise directly or indirectly from the use of the Software or any derivative +of it or the sale of any products based on the Software. You undertake to make no liability +claim against any employee, student, agent or appointee of Imperial College London, in +connection with this Licence or the Software. + +(c) Nothing in this Agreement shall have the effect of excluding or limiting our statutory +liability. + +(d) Some jurisdictions do not allow these limitations or exclusions either wholly or in part, +and, to that extent, they may not apply to you. Nothing in this licence agreement will affect +your statutory rights or other relevant statutory provisions which cannot be excluded, +restricted or modified, and its terms and conditions must be read and construed subject to any +such statutory rights and/or provisions. + +6. Confidentiality. You agree not to disclose any confidential information provided to You by +us pursuant to this Agreement to any third party without our prior written consent. The +obligations in this Clause 6 shall survive the termination of this Agreement for any reason. + +7. Termination. + +(a) We may terminate this licence agreement and your right to use the Software at any time +with immediate effect upon written notice to You. + +(b) This licence agreement and your right to use the Software automatically terminate if You: +(i) fail to comply with any provisions of this Agreement; or +(ii) destroy the copies of the Software in your possession, or voluntarily return the Software +to us. + +(c) Upon termination You will destroy all copies of the Software. + +(d) Otherwise, the restrictions on your rights to use the Software will expire 10 (ten) years +after first use of the Software under this licence agreement. + +8. Miscellaneous Provisions. + +(a) This Agreement will be governed by and construed in accordance with the substantive +laws of England and Wales whose courts shall have exclusive jurisdiction over all disputes +which may arise between us. + +(b) This is the entire agreement between us relating to the Software, and supersedes any prior +purchase order, communications, advertising or representations concerning the Software. + +(c) No change or modification of this Agreement will be valid unless it is in writing, and is +signed by us. + +(d) The unenforceability or invalidity of any part of this Agreement will not affect the +enforceability or validity of the remaining parts. + +BSD Elements of the Software + +For BSD elements of the Software, the following terms shall apply: + +Copyright as indicated in the header of the individual element of the Software. + +All rights reserved. + +Redistribution and use in source and binary forms, with or without modification, are +permitted provided that the following conditions are met: + +1. Redistributions of source code must retain the above copyright notice, this list of +conditions and the following disclaimer. + +2. Redistributions in binary form must reproduce the above copyright notice, this list of +conditions and the following disclaimer in the documentation and/or other materials +provided with the distribution. + +3. Neither the name of the copyright holder nor the names of its contributors may be used to +endorse or promote products derived from this software without specific prior written +permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON +ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. diff --git a/external/peract_bimanual/Dockerfile b/external/peract_bimanual/Dockerfile new file mode 100644 index 0000000000000000000000000000000000000000..6cce278b6f85a0a0461b695ebc64a6f61bdc63cf --- /dev/null +++ b/external/peract_bimanual/Dockerfile @@ -0,0 +1,68 @@ +# Use the NVIDIA base image for CUDA +FROM nvcr.io/nvidia/cuda:12.3.2-cudnn9-devel-ubuntu20.04 + +# Set environment variables +ENV COPPELIASIM_ROOT=${HOME}/code/coppelia_sim +ENV LD_LIBRARY_PATH=$LD_LIBRARY_PATH:$COPPELIASIM_ROOT +ENV QT_QPA_PLATFORM_PLUGIN_PATH=$COPPELIASIM_ROOT +ENV DEBIAN_FRONTEND=noninteractive +ENV TZ=America/Los_Angeles +ENV CONDA_ALWAYS_YES=true +ENV FORCE_CUDA=1 +ENV TORCH_CUDA_ARCH_LIST="5.0;5.2;5.3;6.0;6.1;6.2;7.0;7.2;7.5;8.0;8.6;8.7;8.9;9.0+PTX" + +# Create necessary directories +RUN mkdir -p ${HOME}/code + +# Install dependencies and essential tools +RUN apt-get update && apt-get install -y \ + tzdata sudo curl git vim htop tar bzip2 pigz rsync less mlocate \ + build-essential gdb ca-certificates stress sysstat itop \ + xauth xvfb mesa-utils mesa-utils-extra x11-apps \ + xorg xserver-xorg-core libxv1 x11-xserver-utils libxcb-randr0-dev \ + libxrender-dev libxkbcommon-dev libxkbcommon-x11-0 libavcodec-dev \ + libavformat-dev libswscale-dev '^libxcb.*-dev' libx11-xcb-dev \ + libglu1-mesa-dev libxrender-dev libxi-dev libxkbcommon-dev \ + libxkbcommon-x11-dev libegl1-mesa libarchive-dev libarchive13 \ + && rm -rf /var/lib/apt/lists/* + +# Install VirtualGL +RUN TEMP_DIR=$(mktemp -d -p /) && cd $TEMP_DIR && \ + curl -L -o virtualgl.deb https://sourceforge.net/projects/virtualgl/files/3.1/virtualgl_3.1_amd64.deb/download && \ + dpkg -i virtualgl.deb && \ + /opt/VirtualGL/bin/vglserver_config +glx +egl +s +f +t && \ + rm -rf $TEMP_DIR + +RUN mkdir ${HOME}/.ssh && chmod -R 700 ${HOME}/.ssh + +RUN ssh-keyscan github.com >> ${HOME}/.ssh/known_hosts + +RUN curl -L -O https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh +RUN bash Miniconda3-latest-Linux-x86_64.sh -b -p /opt/conda +RUN export PATH=/opt/conda/bin:${PATH} + +# Install code and dependencies + +WORKDIR ${HOME}/code + +RUN eval "$(/opt/conda/bin/conda shell.bash hook)" && conda init bash +RUN eval "$(/opt/conda/bin/conda shell.bash hook)" && conda install mamba -c conda-forge +#RUN conda config --set auto_activate_base false + + +RUN git clone https://github.com/markusgrotz/peract_bimanual.git ${HOME}/code/peract_bimanual + + +RUN eval "$(/opt/conda/bin/conda shell.bash hook)" && ${HOME}/code/peract_bimanual/scripts/install_dependencies.sh + + +# Activate the environment by default +RUN echo ". /opt/conda/etc/profile.d/conda.sh" >> ~/.bashrc && \ + echo "conda activate rlbench" >> ~/.bashrc + + +WORKDIR /root/code/peract_bimanual + +# Default command +CMD ["/bin/bash"] + diff --git a/external/peract_bimanual/INSTALLATION.md b/external/peract_bimanual/INSTALLATION.md new file mode 100644 index 0000000000000000000000000000000000000000..54ef61dc31fa18f262235c1195712bdf94bd5d8c --- /dev/null +++ b/external/peract_bimanual/INSTALLATION.md @@ -0,0 +1,87 @@ +# INSTALLATION + +To install the dependencies execute the `scripts/install_dependencies.sh` + +```bash +scripts/install_conda.sh # Skip this step if you already have conda installed. +scripts/install_dependencies.sh +``` + +Please see the [README](README.md) for a quick start instruction. + + +Alternatively, you can follow the detailed instructions to setup the software from scratch + +#### 2. PyRep and Coppelia Simulator + +Follow instructions from my [PyRep fork](https://github.com/markusgrotz/PyRep); reproduced here for convenience: + +PyRep requires version **4.1** of CoppeliaSim. Download: +- [Ubuntu 20.04](https://www.coppeliarobotics.com/files/V4_1_0/CoppeliaSim_Edu_V4_1_0_Ubuntu20_04.tar.xz) + +Once you have downloaded CoppeliaSim, you can pull PyRep from git: + +```bash +cd +git clone https://github.com/markusgrotz/PyRep.git +cd PyRep +``` + +Add the following to your *~/.bashrc* file: (__NOTE__: the 'EDIT ME' in the first line) + +```bash +export COPPELIASIM_ROOT=/PATH/TO/COPPELIASIM/INSTALL/DIR +export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:$COPPELIASIM_ROOT +export QT_QPA_PLATFORM_PLUGIN_PATH=$COPPELIASIM_ROOT +``` + +Remember to source your bashrc (`source ~/.bashrc`) or +zshrc (`source ~/.zshrc`) after this. + +**Warning**: CoppeliaSim might cause conflicts with ROS workspaces. + +Finally install the python library: + +```bash +pip install -e . +``` + +You should be good to go! +You could try running one of the examples in the *examples/* folder. + +#### 3. RLBench + +PerAct uses my [RLBench fork](https://github.com/markusgrotz/RLBench/tree/peract). + +```bash +cd +git clone https://github.com/markusgrotz/RLBench.git + +cd RLBench +pip install -e . +``` + +For [running in headless mode](https://github.com/MohitShridhar/RLBench/tree/peract#running-headless), tasks setups, and other issues, please refer to the [official repo](https://github.com/stepjam/RLBench). + +#### 4. YARR + +PerAct uses my [YARR fork](https://github.com/markusgrotz/YARR/). + +```bash +cd +git clone https://github.com/markusgrotz/YARR.git + +cd YARR +pip install -e . +``` + + + +#### RVT baseline + +pip install git+https://github.com/NVlabs/RVT.git +pip install -e . + + + + diff --git a/external/peract_bimanual/agents/__init__.py b/external/peract_bimanual/agents/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/external/peract_bimanual/agents/act_bc_lang/__init__.py b/external/peract_bimanual/agents/act_bc_lang/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..c952cc7dab9c82b060211fbca21861ea3153e20c --- /dev/null +++ b/external/peract_bimanual/agents/act_bc_lang/__init__.py @@ -0,0 +1 @@ +import agents.act_bc_lang.launch_utils diff --git a/external/peract_bimanual/agents/act_bc_lang/act_bc_lang_agent.py b/external/peract_bimanual/agents/act_bc_lang/act_bc_lang_agent.py new file mode 100644 index 0000000000000000000000000000000000000000..8582ef64361bb00faef9cf8c7e6458f986c0720c --- /dev/null +++ b/external/peract_bimanual/agents/act_bc_lang/act_bc_lang_agent.py @@ -0,0 +1,381 @@ +import copy +import logging +from functools import lru_cache +import pickle +import os +from typing import List +import numpy as np + +import torch +import torch.nn as nn +import torch.nn.functional as F +from yarr.agents.agent import Agent, Summary, ActResult, ScalarSummary, HistogramSummary + +from helpers import utils +from helpers.utils import stack_on_channel + +from helpers.clip.core.clip import build_model, load_clip + +NAME = "ActBCLangAgent" + + +class ActBCLangAgent(Agent): + def __init__( + self, + actor_network: nn.Module, + camera_names: List[str], + lr: float = 0.01, + weight_decay: float = 1e-5, + grad_clip: float = 20.0, + episode_length: int = 400, + train_demo_path=None, + task_name=None, + ): + self._camera_names = camera_names + self._actor = actor_network + self._lr = lr + self._weight_decay = weight_decay + self._grad_clip = grad_clip + self._episode_length = episode_length + self.train_demo_path = train_demo_path + self.task_name = task_name + + def build(self, training: bool, device: torch.device = None): + if device is None: + device = torch.device("cpu") + self._actor = self._actor.to(device).train(training) + self._actor_optimizer = self._actor.configure_optimizers() + + self._device = device + + def reset(self): + super(ActBCLangAgent, self).reset() + + self._timestep = 0 + # .. input_dim = input_dim * 2 for bimanual + self._all_time_actions = torch.zeros( + [ + self._episode_length, + self._episode_length + self._actor.model.num_queries, + self._actor.model.input_dim, + ] + ).to(self._device) + self._all_actions = None + + def _grad_step(self, loss, opt, model_params=None, clip=None): + opt.zero_grad() + loss.backward() + if clip is not None and model_params is not None: + nn.utils.clip_grad_value_(model_params, clip) + opt.step() + + @lru_cache() + def train_stats(self): + right_joint_positions = [] + left_joint_positions = [] + + right_gripper_positions = [] + left_gripper_positions = [] + + episodes_dir = ( + f"{self.train_demo_path}/{self.task_name}/all_variations/episodes/" + ) + + for episode in os.listdir(episodes_dir): + with open( + os.path.join(episodes_dir, episode, "low_dim_obs.pkl"), "br" + ) as f: + d = pickle.load(f) + + for o in d: + right_joint_positions.append(o.right.joint_positions) + left_joint_positions.append(o.left.joint_positions) + + right_gripper_positions.append([o.right.gripper_joint_positions[0]]) + left_gripper_positions.append([o.left.gripper_joint_positions[0]]) + + right_joint_positions = np.asarray(right_joint_positions, dtype=np.float32) + left_joint_positions = np.asarray(left_joint_positions, dtype=np.float32) + + right_gripper_positions = np.asarray(right_gripper_positions, dtype=np.float32) + left_gripper_positions = np.asarray(left_gripper_positions, dtype=np.float32) + + stats = { + "right_joints_mean": right_joint_positions.mean(axis=0), + "right_joints_std": right_joint_positions.std(axis=0), + "left_joints_mean": left_joint_positions.mean(axis=0), + "left_joints_std": left_joint_positions.std(axis=0), + "right_gripper_mean": right_gripper_positions.mean(axis=0), + "right_gripper_std": right_gripper_positions.std(axis=0), + "left_gripper_mean": left_gripper_positions.mean(axis=0), + "left_gripper_std": left_gripper_positions.std(axis=0), + } + + return {k: torch.from_numpy(v).to(self._device) for k, v in stats.items()} + + def normalize_z(self, data, mean, std): + return (data - mean) / std + + def unnormalize_z(self, data, mean, std): + return data * std + mean + + def preprocess_qpos(self, observation: dict): + stats = self.train_stats() + + right_qrev = self.normalize_z( + observation["right_joint_positions"][:, 0], + stats["right_joints_mean"], + stats["right_joints_std"], + ) + right_qgripper = self.normalize_z( + observation["right_gripper_joint_positions"][:, 0], + stats["right_gripper_mean"], + stats["right_gripper_std"], + ) + left_qrev = self.normalize_z( + observation["left_joint_positions"][:, 0], + stats["left_joints_mean"], + stats["left_joints_std"], + ) + left_qgripper = self.normalize_z( + observation["left_gripper_joint_positions"][:, 0], + stats["left_gripper_mean"], + stats["left_gripper_std"], + ) + qpos = torch.cat( + [ + right_qrev, + right_qgripper[:, 0].unsqueeze(-1), + left_qrev, + left_qgripper[:, 0].unsqueeze(-1), + ], + dim=-1, + ) + + return qpos + + def preprocess_action(self, replay_sample: dict): + stats = self.train_stats() + + right_qrev = self.normalize_z( + replay_sample["right_prev_joint_positions"][:, 0], + stats["right_joints_mean"], + stats["right_joints_std"], + ) + right_qgripper = self.normalize_z( + replay_sample["right_prev_gripper_joint_positions"][:, 0], + stats["right_gripper_mean"], + stats["right_gripper_std"], + ) + left_qrev = self.normalize_z( + replay_sample["left_prev_joint_positions"][:, 0], + stats["left_joints_mean"], + stats["left_joints_std"], + ) + left_qgripper = self.normalize_z( + replay_sample["left_prev_gripper_joint_positions"][:, 0], + stats["left_gripper_mean"], + stats["left_gripper_std"], + ) + qpos = torch.cat( + [ + right_qrev, + right_qgripper[:, 0].unsqueeze(-1), + left_qrev, + left_qgripper[:, 0].unsqueeze(-1), + ], + dim=-1, + ) + + right_action_rev = self.normalize_z( + replay_sample["right_next_joint_positions"], + stats["right_joints_mean"], + stats["right_joints_std"], + ) + right_action_gripper = self.normalize_z( + replay_sample["right_next_gripper_joint_positions"], + stats["right_gripper_mean"], + stats["right_gripper_std"], + ) + left_action_rev = self.normalize_z( + replay_sample["left_next_joint_positions"], + stats["left_joints_mean"], + stats["left_joints_std"], + ) + left_action_gripper = self.normalize_z( + replay_sample["left_next_gripper_joint_positions"], + stats["left_gripper_mean"], + stats["left_gripper_std"], + ) + action_seq = torch.cat( + [ + right_action_rev, + right_action_gripper[:, :, 0].unsqueeze(-1), + left_action_rev, + left_action_gripper[:, :, 0].unsqueeze(-1), + ], + dim=-1, + ) + + return qpos, action_seq + + def preprocess_images(self, replay_sample: dict): + stacked_rgb = [] + stacked_point_cloud = [] + + for camera in self._camera_names: + rgb = replay_sample["%s_rgb" % camera] + rgb = rgb if rgb.dim() == 4 else rgb[:, 0] + stacked_rgb.append(rgb) + + point_cloud = replay_sample["%s_point_cloud" % camera] + point_cloud = point_cloud if point_cloud.dim() == 4 else point_cloud[:, 0] + stacked_point_cloud.append(point_cloud) + + stacked_rgb = torch.stack(stacked_rgb, dim=1) + stacked_point_cloud = torch.stack(stacked_point_cloud, dim=1) + + return stacked_rgb, stacked_point_cloud + + def update(self, step: int, replay_sample: dict) -> dict: + lang_goal_emb = replay_sample["lang_goal_emb"] # TODO use language + robot_state = replay_sample["low_dim_state"] + + # preprocess input + qpos, action_seq = self.preprocess_action(replay_sample) + stacked_rgb, stacked_point_cloud = self.preprocess_images(replay_sample) + is_pad = replay_sample["is_pad"].bool() + + # forward pass + loss_dict = self._actor(qpos, stacked_rgb, action_seq, is_pad) + + # gradient step + loss = loss_dict["total_losses"] + loss.backward() + self._actor_optimizer.step() + self._actor_optimizer.zero_grad() + + self._summaries = { + "loss": loss_dict["total_losses"], + "l1": loss_dict["l1"], + "right_l1": loss_dict["right_l1"], + "left_l1": loss_dict["left_l1"], + "kl": loss_dict["kl"], + } + + return loss_dict + + def _normalize_quat(self, x): + return x / x.square().sum(dim=1).sqrt().unsqueeze(-1) + + def _normalize_revolute_joints(self, x): + # normalize joint angles + # input ranges from -pi to pi + # out ranges from 0 to 1 + return (x + np.pi) / (2 * np.pi) + + def _unnormalize_revolute_joints(self, x): + # map input with range 0 to 1 to -pi to pi + x = (x - 0.5) * 2.0 * np.pi + x = torch.clamp(x, -np.pi, np.pi) + return x + + def _normalize_gripper_joints(self, x): + gripper_min = 0 + gripper_max = 0.04 + # normalize gripper joint angles between 0 and 1, the input ranges from 0 to 0.04 + return (x - gripper_min) / (gripper_max - gripper_min) + + def _unnormalize_gripper_joints(self, x): + gripper_min = 0 + gripper_max = 0.04 + + x = x * (gripper_max - gripper_min) + gripper_min + x = torch.clamp(x, gripper_min, gripper_max) + return torch.unsqueeze(x, dim=0) + + def act(self, step: int, observation: dict, deterministic=False) -> ActResult: + # lang_goal_tokens = observation.get('lang_goal_tokens', None).long() + # with torch.no_grad(): + # lang_goal_tokens = lang_goal_tokens.to(device=self._device) + # lang_goal_emb, _ = self._clip_rn50.encode_text_with_embeddings(lang_goal_tokens[0]) + # lang_goal_emb = lang_goal_emb.to(device=self._device) + + action_horizon = self._actor.model.num_queries + query_freq = 1 + + stats = self.train_stats() + + if self._timestep % query_freq == 0: + with torch.no_grad(): + # preprocess input + qpos = self.preprocess_qpos(observation) + stacked_rgb, stacked_point_cloud = self.preprocess_images(observation) + + # forward pass + self._all_actions = self._actor( + qpos, stacked_rgb, actions=None, is_pad=None + ) + + # temporal aggregation + t = self._timestep + + self._all_time_actions[[t], t : t + action_horizon] = self._all_actions + actions_for_curr_step = self._all_time_actions[:, t] + actions_populated = torch.all(actions_for_curr_step != 0, axis=1) + actions_for_curr_step = actions_for_curr_step[actions_populated] + k = 0.01 + exp_weights = np.exp(-k * np.arange(len(actions_for_curr_step))) + exp_weights = exp_weights / exp_weights.sum() + exp_weights = torch.from_numpy(exp_weights).to(self._device).unsqueeze(dim=1) + raw_action = (actions_for_curr_step * exp_weights).sum(dim=0, keepdim=True) + raw_action = raw_action[0] + + right_a_rev = self.unnormalize_z( + raw_action[0:7], stats["right_joints_mean"], stats["right_joints_std"] + ) + right_a_gripper = self.unnormalize_z( + raw_action[7], stats["right_gripper_mean"], stats["right_gripper_std"] + ) + + left_a_rev = self.unnormalize_z( + raw_action[8:15], stats["left_joints_mean"], stats["left_joints_std"] + ) + left_a_gripper = self.unnormalize_z( + raw_action[15], stats["left_gripper_mean"], stats["left_gripper_std"] + ) + + raw_action = torch.cat( + [right_a_rev, right_a_gripper, left_a_rev, left_a_gripper], dim=-1 + ) + + self._timestep += 1 + + return ActResult(raw_action.detach().cpu().numpy()) + + def update_summaries(self) -> List[Summary]: + summaries = [] + for n, v in self._summaries.items(): + summaries.append(ScalarSummary("%s/%s" % (NAME, n), v)) + + # for tag, param in self._actor.named_parameters(): + # summaries.append( + # + # summaries.append( + # HistogramSummary('%s/weight/%s' % (NAME, tag), param.data)) + + return summaries + + def act_summaries(self) -> List[Summary]: + return [] + + def load_weights(self, savedir: str): + self._actor.load_state_dict( + torch.load( + os.path.join(savedir, "bc_actor.pt"), map_location=torch.device("cpu") + ) + ) + print("Loaded weights from %s" % savedir) + + def save_weights(self, savedir: str): + torch.save(self._actor.state_dict(), os.path.join(savedir, "bc_actor.pt")) diff --git a/external/peract_bimanual/agents/act_bc_lang/act_policy.py b/external/peract_bimanual/agents/act_bc_lang/act_policy.py new file mode 100644 index 0000000000000000000000000000000000000000..ee4527415a3285cddb89b2f327ac680ab4120c5d --- /dev/null +++ b/external/peract_bimanual/agents/act_bc_lang/act_policy.py @@ -0,0 +1,135 @@ +import torch +import torch.nn as nn +from torch.nn import functional as F +import torchvision.transforms as transforms + +from agents.act_bc_lang.detr.build import ( + build_ACT_model_and_optimizer, + build_CNNMLP_model_and_optimizer, +) + + +class ACTPolicy(nn.Module): + def __init__(self, args): + super().__init__() + model, optimizer = build_ACT_model_and_optimizer(args) + self.model = model # CVAE decoder + self.optimizer = optimizer + self.kl_weight = args.kl_weight + print(f"KL Weight {self.kl_weight}") + + def forward(self, qpos, image, actions=None, is_pad=None): + env_state = None + + if actions is not None: # training time + actions = actions[:, : self.model.num_queries] + is_pad = is_pad[:, : self.model.num_queries] + + a_hat, is_pad_hat, (mu, logvar) = self.model( + qpos, image, env_state, actions, is_pad + ) + total_kld, dim_wise_kld, mean_kld = kl_divergence(mu, logvar) + loss_dict = dict() + + right_actions_joints, right_a_hat_joints = ( + actions[:, :, 0:8], + a_hat[:, :, 0:8], + ) + right_actions_gripper, right_a_hat_gripper = ( + actions[:, :, 7], + a_hat[:, :, 7], + ) + left_actions_joints, left_a_hat_joints = ( + actions[:, :, 8:16], + a_hat[:, :, 8:16], + ) + left_actions_gripper, left_a_hat_gripper = ( + actions[:, :, 15], + a_hat[:, :, 15], + ) + + # use L1 loss for joints + right_l1_loss = F.l1_loss( + right_a_hat_joints, right_actions_joints, reduction="none" + ) + right_l1 = (right_l1_loss * ~is_pad.unsqueeze(-1)).mean() + + left_l1_loss = F.l1_loss( + left_a_hat_joints, left_actions_joints, reduction="none" + ) + left_l1 = (left_l1_loss * ~is_pad.unsqueeze(-1)).mean() + + l1 = right_l1 + left_l1 + + right_gripper_l1_loss = F.l1_loss( + right_a_hat_gripper, right_actions_gripper, reduction="none" + ) + right_gripper_l1_loss = (right_gripper_l1_loss * ~is_pad).mean() + + left_gripper_l1_loss = F.l1_loss( + left_a_hat_gripper, left_actions_gripper, reduction="none" + ) + left_gripper_l1_loss = (left_gripper_l1_loss * ~is_pad).mean() + + gripper_l1 = right_gripper_l1_loss + left_gripper_l1_loss + loss_dict["right_l1"] = right_l1 + loss_dict["left_l1"] = left_l1 + + loss_dict["l1"] = l1 + loss_dict["gripper_l1"] = gripper_l1 + + loss_dict["kl"] = total_kld[0] + loss_dict["total_losses"] = ( + loss_dict["l1"] + loss_dict["kl"] * self.kl_weight + ) + return loss_dict + else: # inference time + a_hat, _, (_, _) = self.model( + qpos, image, env_state + ) # no action, sample from prior + return a_hat + + def configure_optimizers(self): + return self.optimizer + + +class CNNMLPPolicy(nn.Module): + def __init__(self, args): + super().__init__() + model, optimizer = build_CNNMLP_model_and_optimizer(args) + self.model = model # decoder + self.optimizer = optimizer + + def forward(self, qpos, image, actions=None, is_pad=None): + env_state = None # TODO + + if actions is not None: # training time + actions = actions[:, 0] + a_hat = self.model(qpos, image, env_state, actions) + mse = F.mse_loss(actions, a_hat) + loss_dict = dict() + loss_dict["mse"] = mse + loss_dict["loss"] = loss_dict["mse"] + return loss_dict + else: # inference time + a_hat = self.model(qpos, image, env_state) # no action, sample from prior + return a_hat + + def configure_optimizers(self): + return self.optimizer + + +def kl_divergence(mu, logvar): + batch_size = mu.size(0) + assert batch_size != 0 + if mu.data.ndimension() == 4: + mu = mu.view(mu.size(0), mu.size(1)) + if logvar.data.ndimension() == 4: + logvar = logvar.view(logvar.size(0), logvar.size(1)) + + klds = -0.5 * (1 + logvar - mu.pow(2) - logvar.exp()) + total_kld = klds.sum(1).mean(0, True) + dimension_wise_kld = klds.mean(0) + mean_kld = klds.mean(1).mean(0, True) + + return total_kld, dimension_wise_kld, mean_kld diff --git a/external/peract_bimanual/agents/act_bc_lang/detr/__init__.py b/external/peract_bimanual/agents/act_bc_lang/detr/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/external/peract_bimanual/agents/act_bc_lang/detr/build.py b/external/peract_bimanual/agents/act_bc_lang/detr/build.py new file mode 100644 index 0000000000000000000000000000000000000000..cf55cf1c65b75524992181462a7be7588441241f --- /dev/null +++ b/external/peract_bimanual/agents/act_bc_lang/detr/build.py @@ -0,0 +1,41 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved +import argparse +from pathlib import Path + +import numpy as np +import torch +from .models import build_ACT_model, build_CNNMLP_model + + + +def build_ACT_model_and_optimizer(args): + model = build_ACT_model(args) + + param_dicts = [ + {"params": [p for n, p in model.named_parameters() if "backbone" not in n and p.requires_grad]}, + { + "params": [p for n, p in model.named_parameters() if "backbone" in n and p.requires_grad], + "lr": args.lr_backbone, + }, + ] + optimizer = torch.optim.AdamW(param_dicts, lr=args.lr, + weight_decay=args.weight_decay) + + return model, optimizer + + +def build_CNNMLP_model_and_optimizer(args): + model = build_CNNMLP_model(args) + + param_dicts = [ + {"params": [p for n, p in model.named_parameters() if "backbone" not in n and p.requires_grad]}, + { + "params": [p for n, p in model.named_parameters() if "backbone" in n and p.requires_grad], + "lr": args.lr_backbone, + }, + ] + optimizer = torch.optim.AdamW(param_dicts, lr=args.lr, + weight_decay=args.weight_decay) + + return model, optimizer + diff --git a/external/peract_bimanual/agents/act_bc_lang/detr/util/__init__.py b/external/peract_bimanual/agents/act_bc_lang/detr/util/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..168f9979a4623806934b0ff1102ac166704e7dec --- /dev/null +++ b/external/peract_bimanual/agents/act_bc_lang/detr/util/__init__.py @@ -0,0 +1 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved diff --git a/external/peract_bimanual/agents/act_bc_lang/launch_utils.py b/external/peract_bimanual/agents/act_bc_lang/launch_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..cabe83734a5e542ecb38ea925d130a8814b333fb --- /dev/null +++ b/external/peract_bimanual/agents/act_bc_lang/launch_utils.py @@ -0,0 +1,456 @@ +# Adapted from ARM +# Source: https://github.com/stepjam/ARM +# License: https://github.com/stepjam/ARM/LICENSE + +import logging +from typing import List + +import numpy as np +from omegaconf import DictConfig +from rlbench.backend.observation import Observation +from rlbench.observation_config import ObservationConfig +import rlbench.utils as rlbench_utils +from rlbench.demo import Demo +from yarr.replay_buffer.prioritized_replay_buffer import ( + PrioritizedReplayBuffer, + ObservationElement, +) +from yarr.replay_buffer.replay_buffer import ReplayElement, ReplayBuffer +from yarr.replay_buffer.uniform_replay_buffer import UniformReplayBuffer +from yarr.replay_buffer.task_uniform_replay_buffer import TaskUniformReplayBuffer + +from helpers import utils +from helpers import observation_utils +from agents.act_bc_lang.act_bc_lang_agent import ActBCLangAgent +from helpers.custom_rlbench_env import CustomRLBenchEnv +from helpers.preprocess_agent import PreprocessAgent +from agents.act_bc_lang.act_policy import ACTPolicy, CNNMLPPolicy + +import torch +from torch.multiprocessing import Process, Value, Manager +from helpers.clip.core.clip import build_model, load_clip, tokenize + +LOW_DIM_SIZE = 8 + + +def create_replay( + batch_size: int, + timesteps: int, + prioritisation: bool, + task_uniform: bool, + save_dir: str, + cameras: list, + image_size=[128, 128], + replay_size=3e5, + prev_action_horizon: int = 1, + next_action_horizon: int = 1, +): + lang_feat_dim = 1024 + + # low_dim_state + observation_elements = [] + observation_elements.append( + ObservationElement("low_dim_state", (LOW_DIM_SIZE,), np.float32) + ) + + # action sequences + action_seq_sizes = { + "right_prev_joint_positions": 7, + "right_prev_gripper_joint_positions": 2, + "right_prev_gripper_poses": 7, + "right_next_joint_positions": 7, + "right_next_gripper_joint_positions": 2, + "right_next_gripper_poses": 7, + "left_prev_joint_positions": 7, + "left_prev_gripper_joint_positions": 2, + "left_prev_gripper_poses": 7, + "left_next_joint_positions": 7, + "left_next_gripper_joint_positions": 2, + "left_next_gripper_poses": 7, + } + + for seq_name, seq_size in action_seq_sizes.items(): + horizon = prev_action_horizon if "prev" in seq_name else next_action_horizon + observation_elements.append( + ObservationElement( + seq_name, + ( + horizon, + seq_size, + ), + np.float32, + ) + ) + + # action is_pad + observation_elements.append( + ObservationElement("is_pad", (next_action_horizon,), np.int32) + ) + + # rgb, depth, point cloud, intrinsics, extrinsics + for cname in cameras: + observation_elements.append( + ObservationElement( + "%s_rgb" % cname, + ( + 3, + *image_size, + ), + np.float32, + ) + ) + observation_elements.append( + ObservationElement("%s_point_cloud" % cname, (3, *image_size), np.float32) + ) # see pyrep/objects/vision_sensor.py on how pointclouds are extracted from depth frames + observation_elements.append( + ObservationElement( + "%s_camera_extrinsics" % cname, + ( + 4, + 4, + ), + np.float32, + ) + ) + observation_elements.append( + ObservationElement( + "%s_camera_intrinsics" % cname, + ( + 3, + 3, + ), + np.float32, + ) + ) + + observation_elements.extend( + [ + ReplayElement("lang_goal_emb", (lang_feat_dim,), np.float32), + ReplayElement("task", (), str), + ReplayElement( + "lang_goal", (1,), object + ), # language goal string for debugging and visualization + ] + ) + + extra_replay_elements = [ + ReplayElement("demo", (), bool), + ] + + replay_buffer = TaskUniformReplayBuffer( + save_dir=save_dir, + batch_size=batch_size, + timesteps=timesteps, + replay_capacity=int(replay_size), + action_shape=(8 * 2,), + action_dtype=np.float32, + reward_shape=(), + reward_dtype=np.float32, + update_horizon=1, + observation_elements=observation_elements, + extra_replay_elements=extra_replay_elements, + ) + return replay_buffer + + +def _get_action(obs_tp1: Observation): + quat = utils.normalize_quaternion(obs_tp1.gripper_pose[3:]) + if quat[-1] < 0: + quat = -quat + return np.concatenate( + [obs_tp1.gripper_pose[:3], quat, [float(obs_tp1.gripper_open)]] + ) + + +def _get_action_seq( + demo: Demo, + timestep: int, + prev_action_horizon: int, + next_action_horizon: int, + robot_name: str, +): + action_seq = { + "right_prev_joint_positions": [], + "right_prev_gripper_joint_positions": [], + "right_prev_gripper_poses": [], + "left_prev_joint_positions": [], + "left_prev_gripper_joint_positions": [], + "left_prev_gripper_poses": [], + "right_next_joint_positions": [], + "right_next_gripper_joint_positions": [], + "right_next_gripper_poses": [], + "left_next_joint_positions": [], + "left_next_gripper_joint_positions": [], + "left_next_gripper_poses": [], + "is_pad": [], + } + + for prev_t in list(reversed(range(prev_action_horizon))): + t = timestep - prev_t + t = max(0, t) + obs = demo[t] + + action_seq["right_prev_joint_positions"].append(obs.right.joint_positions) + action_seq["right_prev_gripper_joint_positions"].append( + obs.right.gripper_joint_positions + ) + action_seq["right_prev_gripper_poses"].append(obs.right.gripper_pose) + action_seq["left_prev_joint_positions"].append(obs.left.joint_positions) + action_seq["left_prev_gripper_joint_positions"].append( + obs.left.gripper_joint_positions + ) + action_seq["left_prev_gripper_poses"].append(obs.left.gripper_pose) + + action_seq["is_pad"] = np.zeros(next_action_horizon) + for idx, next_t in enumerate(range(0, next_action_horizon)): + t = timestep + next_t + t = min(t, len(demo) - 1) + obs = demo[t] + + if timestep + next_t > len(demo) - 1: + action_seq["is_pad"][idx] = 1 + + action_seq["right_next_joint_positions"].append(obs.right.joint_positions) + action_seq["right_next_gripper_joint_positions"].append( + obs.right.gripper_joint_positions + ) + action_seq["right_next_gripper_poses"].append(obs.right.gripper_pose) + action_seq["left_next_joint_positions"].append(obs.left.joint_positions) + action_seq["left_next_gripper_joint_positions"].append( + obs.left.gripper_joint_positions + ) + action_seq["left_next_gripper_poses"].append(obs.left.gripper_pose) + + # convert to numpy arrays + return {k: np.array(v) for k, v in action_seq.items()} + + +def _add_keypoints_to_replay( + step: int, + cfg: DictConfig, + task: str, + replay: ReplayBuffer, + inital_obs: Observation, + demo: Demo, + description: str = "", + clip_model=None, + device="cpu", +): + cameras = cfg.rlbench.cameras + robot_name = cfg.method.robot_name + + prev_action = None + obs = inital_obs + all_actions = [] + k = step + k_tp1 = min(k + 1, len(demo) - 1) + obs_tp1 = demo[k_tp1] + + if obs_tp1.is_bimanual and robot_name == "bimanual": + right_action = _get_action(obs_tp1.right) + left_action = _get_action(obs_tp1.left) + action = np.append(right_action, left_action) + elif robot_name == "unimanual": + action = _get_action(obs_tp1) + elif obs_tp1.is_bimanual and robot_name == "right": + action = _get_action(obs_tp1.right) + elif obs_tp1.is_bimanual and robot_name == "left": + action = _get_action(obs_tp1.left) + else: + logging.error("Invalid robot name %s", cfg.method.robot_name) + raise Exception("Invalid robot name.") + + all_actions.append(action) + + terminal = k == len(demo) - 1 + reward = float(terminal) if terminal else 0 + + obs_dict = observation_utils.extract_obs( + obs, + t=k, + prev_action=prev_action, + cameras=cameras, + episode_length=cfg.rlbench.episode_length, + robot_name=robot_name, + ) + + if obs_tp1.is_bimanual and robot_name == "bimanual": + obs_dict["low_dim_state"] = np.concatenate( + [obs_dict["right_low_dim_state"], obs_dict["left_low_dim_state"]] + ) + del obs_dict["right_low_dim_state"] + del obs_dict["left_low_dim_state"] + del obs_dict["right_ignore_collisions"] + del obs_dict["left_ignore_collisions"] + else: + del obs_dict["ignore_collisions"] + + tokens = tokenize([description]).numpy() + token_tensor = torch.from_numpy(tokens).to(device) + lang_feats, lang_embs = clip_model.encode_text_with_embeddings(token_tensor) + obs_dict["lang_goal_emb"] = lang_feats[0].float().detach().cpu().numpy() + + final_obs = { + "task": task, + "lang_goal": np.array([description], dtype=object), + } + + action_seq = _get_action_seq( + demo, + step, + cfg.method.prev_action_horizon, + cfg.method.next_action_horizon, + robot_name, + ) + obs_dict.update(action_seq) + + prev_action = np.copy(action) + others = {"demo": True} + others.update(final_obs) + others.update(obs_dict) + timeout = False + replay.add(action, reward, terminal, timeout, **others) + + return all_actions + + +def fill_replay( + cfg: DictConfig, + obs_config: ObservationConfig, + rank: int, + replay: ReplayBuffer, + task: str, + num_demos: int, + demo_augmentation: bool, + demo_augmentation_every_n: int, + cameras: List[str], + clip_model=None, + device="cpu", +): + if clip_model is None: + model, _ = load_clip("RN50", jit=False, device=device) + clip_model = build_model(model.state_dict()) + clip_model.to(device) + del model + + logging.debug("Filling %s replay ..." % task) + all_actions = [] + for d_idx in range(num_demos): + # load demo from disk + demo = rlbench_utils.get_stored_demos( + amount=1, + image_paths=False, + dataset_root=cfg.rlbench.demo_path, + variation_number=-1, + task_name=task, + obs_config=obs_config, + random_selection=False, + from_episode_number=d_idx, + )[0] + + descs = demo._observations[0].misc["descriptions"] + + if rank == 0: + logging.info(f"Loading Demo({d_idx})") + + for i in range(len(demo) - 1): + obs = demo[i] + desc = descs[0] + + # stopped = np.allclose(obs.joint_velocities, 0, atol=0.1) + # if stopped: + # continue + + all_actions.extend( + _add_keypoints_to_replay( + i, + cfg, + task, + replay, + obs, + demo, + description=desc, + clip_model=clip_model, + device=device, + ) + ) + logging.debug("Replay filled with demos.") + return all_actions + + +def fill_multi_task_replay( + cfg: DictConfig, + obs_config: ObservationConfig, + rank: int, + replay: ReplayBuffer, + tasks: List[str], + num_demos: int, + demo_augmentation: bool, + demo_augmentation_every_n: int, + cameras: List[str], + clip_model=None, +): + manager = Manager() + store = manager.dict() + + # create a MP dict for storing indicies + # TODO(mohit): this shouldn't be initialized here + del replay._task_idxs + task_idxs = manager.dict() + replay._task_idxs = task_idxs + replay._create_storage(store) + replay.add_count = Value("i", 0) + + # fill replay buffer in parallel across tasks + max_parallel_processes = cfg.replay.max_parallel_processes + processes = [] + n = np.arange(len(tasks)) + split_n = utils.split_list(n, max_parallel_processes) + for split in split_n: + for e_idx, task_idx in enumerate(split): + task = tasks[int(task_idx)] + model_device = torch.device( + "cuda:%s" % (e_idx % torch.cuda.device_count()) + if torch.cuda.is_available() + else "cpu" + ) + p = Process( + target=fill_replay, + args=( + cfg, + obs_config, + rank, + replay, + task, + num_demos, + demo_augmentation, + demo_augmentation_every_n, + cameras, + clip_model, + model_device, + ), + ) + p.start() + processes.append(p) + + for p in processes: + p.join() + + logging.debug("Replay filled with multi demos.") + + +def create_agent(cfg: DictConfig): + actor_net = ACTPolicy(cfg.method) + + bc_agent = ActBCLangAgent( + actor_network=actor_net, + camera_names=cfg.rlbench.cameras, + lr=cfg.method.lr, + weight_decay=cfg.method.weight_decay, + grad_clip=cfg.method.grad_clip, + episode_length=cfg.rlbench.episode_length, + train_demo_path=cfg.method.train_demo_path, + task_name=cfg.rlbench.tasks[0], + ) + + return PreprocessAgent(pose_agent=bc_agent, norm_type="imagenet") diff --git a/external/peract_bimanual/agents/agent_factory.py b/external/peract_bimanual/agents/agent_factory.py new file mode 100644 index 0000000000000000000000000000000000000000..6b65245679a4a40a40fa9ce206e934bda29286e9 --- /dev/null +++ b/external/peract_bimanual/agents/agent_factory.py @@ -0,0 +1,111 @@ +import os +import logging + +from omegaconf import DictConfig + + +from yarr.agents.agent import BimanualAgent +from yarr.agents.agent import LeaderFollowerAgent +from yarr.agents.agent import Agent + + +supported_agents = { + "leader_follower": ("PERACT_BC", "RVT"), + "independent": ("PERACT_BC", "RVT"), + "bimanual": ("BIMANUAL_PERACT", "ACT_BC_LANG"), + "unimanual": (), +} + + +def create_agent(cfg: DictConfig) -> Agent: + method_name = cfg.method.name + agent_type = cfg.method.agent_type + + logging.info("Using method %s with type %s", method_name, agent_type) + + assert method_name in supported_agents[agent_type] + + agent_fn = agent_fn_by_name(method_name) + + if agent_type == "leader_follower": + checkpoint_name_prefix = cfg.framework.checkpoint_name_prefix + cfg.method.robot_name = "right" + cfg.framework.checkpoint_name_prefix = ( + f"{checkpoint_name_prefix}_{method_name.lower()}_leader" + ) + leader_agent = agent_fn(cfg) + + cfg.method.robot_name = "left" + cfg.framework.checkpoint_name_prefix = ( + f"{checkpoint_name_prefix}_{method_name.lower()}_follower" + ) + cfg.method.low_dim_size = ( + cfg.method.low_dim_size + 8 + ) # also add the action size + follower_agent = agent_fn(cfg) + + cfg.method.robot_name = "bimanual" + + return LeaderFollowerAgent(leader_agent, follower_agent) + + elif agent_type == "independent": + checkpoint_name_prefix = cfg.framework.checkpoint_name_prefix + cfg.method.robot_name = "right" + cfg.framework.checkpoint_name_prefix = ( + f"{checkpoint_name_prefix}_{method_name.lower()}_right" + ) + right_agent = agent_fn(cfg) + + cfg.method.robot_name = "left" + cfg.framework.checkpoint_name_prefix = ( + f"{checkpoint_name_prefix}_{method_name.lower()}_left" + ) + left_agent = agent_fn(cfg) + + cfg.method.robot_name = "bimanual" + + return BimanualAgent(right_agent, left_agent) + elif agent_type == "bimanual" or agent_type == "unimanual": + return agent_fn(cfg) + else: + raise Exception("invalid agent type") + + +def agent_fn_by_name(method_name: str) -> Agent: + if method_name == "ARM": + from agents import arm + + raise NotImplementedError("ARM not yet supported for eval.py") + elif method_name == "BC_LANG": + from agents.baselines import bc_lang + + return bc_lang.launch_utils.create_agent + elif method_name == "VIT_BC_LANG": + from agents.baselines import vit_bc_lang + + return vit_bc_lang.launch_utils.create_agent + elif method_name == "C2FARM_LINGUNET_BC": + from agents import c2farm_lingunet_bc + + return c2farm_lingunet_bc.launch_utils.create_agent + elif method_name.startswith("PERACT_BC"): + from agents import peract_bc + + return peract_bc.launch_utils.create_agent + elif method_name.startswith("BIMANUAL_PERACT"): + from agents import bimanual_peract + + return bimanual_peract.launch_utils.create_agent + elif method_name.startswith("RVT"): + from agents import rvt + + return rvt.launch_utils.create_agent + elif method_name.startswith("ACT_BC_LANG"): + from agents import act_bc_lang + + return act_bc_lang.launch_utils.create_agent + elif method_name == "PERACT_RL": + raise NotImplementedError("PERACT_RL not yet supported for eval.py") + + else: + raise ValueError("Method %s does not exists." % method_name) diff --git a/external/peract_bimanual/agents/arm/launch_utils.py b/external/peract_bimanual/agents/arm/launch_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..bae7c33a67e023a07a77e11ca1690621bf050f1c --- /dev/null +++ b/external/peract_bimanual/agents/arm/launch_utils.py @@ -0,0 +1,441 @@ +import copy +import logging +from typing import List + +import numpy as np +import torch +import torch.nn as nn +from rlbench.backend.observation import Observation +from rlbench.demo import Demo +from yarr.replay_buffer.prioritized_replay_buffer import ( + PrioritizedReplayBuffer, + ObservationElement, +) +from yarr.replay_buffer.replay_buffer import ReplayElement, ReplayBuffer +from yarr.replay_buffer.uniform_replay_buffer import UniformReplayBuffer + +from helpers import demo_loading_utils, utils +from helpers.custom_rlbench_env import CustomRLBenchEnv +from helpers.network_utils import ( + SiameseNet, + DenseBlock, + Conv2DBlock, + Conv2DUpsampleBlock, +) +from helpers.preprocess_agent import PreprocessAgent +from agents.arm.next_best_pose_agent import NextBestPoseAgent +from agents.arm.qattention_agent import QAttentionAgent + +REWARD_SCALE = 100.0 + + +def create_replay( + batch_size: int, + timesteps: int, + prioritisation: bool, + save_dir: str, + cameras: list, + env: CustomRLBenchEnv, +): + observation_elements = env.observation_elements + for cname in cameras: + observation_elements.extend( + [ + ObservationElement("%s_pixel_coord" % cname, (2,), np.int32), + ] + ) + + replay_class = UniformReplayBuffer + if prioritisation: + replay_class = PrioritizedReplayBuffer + replay_buffer = replay_class( + save_dir=save_dir, + batch_size=batch_size, + timesteps=timesteps, + replay_capacity=int(1e5), + action_shape=(8,), + action_dtype=np.float32, + reward_shape=(), + reward_dtype=np.float32, + update_horizon=1, + observation_elements=observation_elements, + extra_replay_elements=[ReplayElement("demo", (), np.bool)], + ) + return replay_buffer + + +def _point_to_pixel_index( + point: np.ndarray, extrinsics: np.ndarray, intrinsics: np.ndarray +): + point = np.array([point[0], point[1], point[2], 1]) + world_to_cam = np.linalg.inv(extrinsics) + point_in_cam_frame = world_to_cam.dot(point) + px, py, pz = point_in_cam_frame[:3] + px = 2 * intrinsics[0, 2] - int(-intrinsics[0, 0] * (px / pz) + intrinsics[0, 2]) + py = 2 * intrinsics[1, 2] - int(-intrinsics[1, 1] * (py / pz) + intrinsics[1, 2]) + return px, py + + +def _get_action(obs_tp1: Observation): + quat = utils.normalize_quaternion(obs_tp1.gripper_pose[3:]) + if quat[-1] < 0: + quat = -quat + return np.concatenate( + [obs_tp1.gripper_pose[:3], quat, [float(obs_tp1.gripper_open)]] + ) + + +def _add_keypoints_to_replay( + replay: ReplayBuffer, + inital_obs: Observation, + demo: Demo, + env: CustomRLBenchEnv, + episode_keypoints: List[int], + cameras: List[str], +): + prev_action = None + obs = inital_obs + all_actions = [] + for k, keypoint in enumerate(episode_keypoints): + obs_tp1 = demo[keypoint] + action = _get_action(obs_tp1) + all_actions.append(action) + terminal = k == len(episode_keypoints) - 1 + reward = float(terminal) * REWARD_SCALE if terminal else 0 + obs_dict = env.extract_obs(obs, t=k, prev_action=prev_action) + prev_action = np.copy(action) + others = {"demo": True} + final_obs = {} + for name in cameras: + px, py = _point_to_pixel_index( + obs_tp1.gripper_pose[:3], + obs_tp1.misc["%s_camera_extrinsics" % name], + obs_tp1.misc["%s_camera_intrinsics" % name], + ) + final_obs["%s_pixel_coord" % name] = [py, px] + others.update(final_obs) + others.update(obs_dict) + timeout = False + replay.add(action, reward, terminal, timeout, **others) + obs = obs_tp1 # Set the next obs + # Final step + obs_dict_tp1 = env.extract_obs(obs_tp1, t=k + 1, prev_action=prev_action) + obs_dict_tp1.update(final_obs) + replay.add_final(**obs_dict_tp1) + return all_actions + + +def fill_replay( + replay: ReplayBuffer, + task: str, + env: CustomRLBenchEnv, + num_demos: int, + demo_augmentation: bool, + demo_augmentation_every_n: int, + cameras: List[str], +): + logging.info("Filling replay with demos...") + all_actions = [] + for d_idx in range(num_demos): + demo = env.env.get_demos( + task, + 1, + variation_number=0, + random_selection=False, + from_episode_number=d_idx, + )[0] + episode_keypoints = demo_loading_utils.keypoint_discovery(demo) + + for i in range(len(demo) - 1): + if not demo_augmentation and i > 0: + break + if i % demo_augmentation_every_n != 0: + continue + obs = demo[i] + # If our starting point is past one of the keypoints, then remove it + while len(episode_keypoints) > 0 and i >= episode_keypoints[0]: + episode_keypoints = episode_keypoints[1:] + if len(episode_keypoints) == 0: + break + all_actions.extend( + _add_keypoints_to_replay( + replay, obs, demo, env, episode_keypoints, cameras + ) + ) + logging.info("Replay filled with demos.") + return all_actions + + +class SharedNet(nn.Module): + def __init__(self, activation: str, norm: str = None): + super(SharedNet, self).__init__() + self._activation = activation + self._norm = norm + + def build(self): + self._rgb_pre = nn.Sequential( + Conv2DBlock(3, 32, 3, 1, activation=self._activation, norm=self._norm), + ) + self._pcd_pre = nn.Sequential( + Conv2DBlock(3, 32, 3, 1, activation=self._activation, norm=self._norm), + ) + + def forward(self, observations): + x_rgb, x_pcd = self._rgb_pre(observations[0]), self._pcd_pre(observations[1]) + x = torch.cat([x_rgb, x_pcd], dim=1) + return x + + +class ActorNet(nn.Module): + def __init__(self, activation: str, low_dim_size: int, norm: str = None): + super(ActorNet, self).__init__() + self._activation = activation + self._low_dim_size = low_dim_size + self._norm = norm + + def build(self): + self._convs = nn.Sequential( + Conv2DBlock( + 64 + self._low_dim_size, + 64, + 1, + 1, + activation=self._activation, + norm=self._norm, + ), + Conv2DBlock(64, 64, 3, 1, activation=self._activation, norm=self._norm), + ) + self._fcs = nn.Sequential( + DenseBlock(64, 64, activation=self._activation), + DenseBlock(64, 64, activation=self._activation), + DenseBlock(64, 8 * 2), + ) + self._maxp = nn.AdaptiveMaxPool2d(1) + + def forward(self, observation_feats, low_dim_ins): + low_dim_feats = low_dim_ins + _, _, h, w = observation_feats.shape + low_dim_feats = low_dim_feats.unsqueeze(-1).unsqueeze(-1).repeat(1, 1, h, w) + x = torch.cat([observation_feats, low_dim_feats], dim=1) + x = self._convs(x) + x = self._maxp(x).squeeze(-1).squeeze(-1) + x = self._fcs(x) + return x + + +class CriticNet(nn.Module): + def __init__( + self, activation: str, low_dim_size: int, norm: str = None, q_conf: bool = True + ): + super(CriticNet, self).__init__() + self._activation = activation + self._low_dim_size = low_dim_size + self._norm = norm + self._q_conf = q_conf + + def build(self): + self._convs = nn.Sequential( + Conv2DBlock( + 64 + self._low_dim_size, 128, 3, 1, self._norm, self._activation + ), + Conv2DBlock(128, 128, 3, 1, self._norm, self._activation), + Conv2DBlock(128, 128, 3, 1, self._norm, self._activation), + Conv2DBlock(128, 128, 3, 1, self._norm, self._activation), + ) + if self._q_conf: + self._final_conv = Conv2DBlock(128, 2, 3, 1) + else: + self._maxp = nn.AdaptiveMaxPool2d(1) + self._fcs = nn.Sequential( + DenseBlock(128, 64, activation=self._activation), + DenseBlock(64, 1), + ) + + def forward(self, observation_feats, low_dim_ins): + low_dim_feats = low_dim_ins + _, _, h, w = observation_feats.shape + low_dim_feats = low_dim_feats.unsqueeze(-1).unsqueeze(-1).repeat(1, 1, h, w) + x = torch.cat([observation_feats, low_dim_feats], dim=1) + x = self._convs(x) + if self._q_conf: + x = self._final_conv(x) + x[:, 1] = torch.sigmoid(x[:, 1]) + else: + x = self._maxp(x).squeeze(-1).squeeze(-1) + x = self._fcs(x) + return x + + +class Qattention2DNet(nn.Module): + def __init__( + self, + siamese_net: SiameseNet, + filters: List[int], + kernel_sizes: List[int], + strides: List[int], + low_dim_state_len: int, + norm: str = None, + activation: str = "relu", + output_channels: int = 1, + skip_connections: bool = True, + ): + super(Qattention2DNet, self).__init__() + self._siamese_net = copy.deepcopy(siamese_net) + self._input_channels = self._siamese_net.output_channels + low_dim_state_len + self._filters = filters + self._kernel_sizes = kernel_sizes + self._strides = strides + self._norm = norm + self._activation = activation + self._output_channels = output_channels + self._skip_connections = skip_connections + self._build_calls = 0 + + def build(self): + self._build_calls += 1 + if self._build_calls != 1: + raise RuntimeError("Build needs to be called once.") + self._siamese_net.build() + self._down = [] + ch = self._input_channels + for filt, ksize, stride in zip( + self._filters, self._kernel_sizes, self._strides + ): + conv_block = Conv2DBlock( + ch, + filt, + ksize, + stride, + self._norm, + self._activation, + padding_mode="replicate", + ) + ch = filt + self._down.append(conv_block) + self._down = nn.ModuleList(self._down) + + reverse_conv_data = list(zip(self._filters, self._kernel_sizes, self._strides)) + reverse_conv_data.reverse() + + self._up = [] + for i, (filt, ksize, stride) in enumerate(reverse_conv_data): + if i > 0 and self._skip_connections: + ch += reverse_conv_data[-i - 1][0] + convt_block = Conv2DUpsampleBlock( + ch, filt, ksize, stride, self._norm, self._activation + ) + ch = filt + self._up.append(convt_block) + self._up = nn.ModuleList(self._up) + + self._final_conv = Conv2DBlock( + ch, self._output_channels, 3, 1, padding_mode="replicate" + ) + + def forward(self, observations, low_dim_ins): + x = self._siamese_net(observations) + _, _, h, w = x.shape + if low_dim_ins is not None: + low_dim_latents = low_dim_ins.unsqueeze(-1).unsqueeze(-1).repeat(1, 1, h, w) + x = torch.cat([x, low_dim_latents], dim=1) + self.ups = [] + self.downs = [] + layers_for_skip = [] + for l in self._down: + x = l(x) + layers_for_skip.append(x) + self.downs.append(x) + self.latent = x + layers_for_skip.reverse() + for i, l in enumerate(self._up): + if i > 0 and self._skip_connections: + # Skip connections. Skip the first up layer. + x = torch.cat([layers_for_skip[i], x], 1) + x = l(x) + self.ups.append(x) + x = self._final_conv(x) + return x + + +def create_agent( + camera_name: str, + activation: str, + q_conf: bool, + action_min_max, + alpha, + alpha_lr, + alpha_auto_tune, + critic_lr, + actor_lr, + next_best_pose_critic_weight_decay, + next_best_pose_actor_weight_decay, + crop_shape, + next_best_pose_tau, + next_best_pose_critic_grad_clip, + next_best_pose_actor_grad_clip, + qattention_tau, + qattention_lr, + qattention_weight_decay, + qattention_lambda_qreg, + low_dim_state_len, + qattention_grad_clip, +): + siamese_net = SiameseNet( + input_channels=[3, 3], + filters=[8], + kernel_sizes=[5], + strides=[1], + activation=activation, + norm=None, + ) + qattention_net = Qattention2DNet( + siamese_net=siamese_net, + filters=[16, 16], + kernel_sizes=[5, 5], + strides=[2, 2], + output_channels=1, + norm=None, + activation=activation, + skip_connections=True, + low_dim_state_len=0, + ) + + qattention_agent = QAttentionAgent( + pixel_unet=qattention_net, + tau=qattention_tau, + camera_name=camera_name, + lr=qattention_lr, + weight_decay=qattention_weight_decay, + lambda_qreg=qattention_lambda_qreg, + include_low_dim_state=False, + grad_clip=qattention_grad_clip, + ) + + shared_net = SharedNet(activation, norm="layer") + critic_net = CriticNet( + activation, low_dim_state_len + 8, norm="layer", q_conf=q_conf + ) + actor_net = ActorNet(activation, low_dim_state_len) + + next_best_pose_agent = NextBestPoseAgent( + qattention_agent=qattention_agent, + shared_network=shared_net, + critic_network=critic_net, + actor_network=actor_net, + action_min_max=action_min_max, + camera_name=camera_name, + alpha=alpha, + alpha_lr=alpha_lr, + alpha_auto_tune=alpha_auto_tune, + critic_lr=critic_lr, + actor_lr=actor_lr, + critic_weight_decay=next_best_pose_critic_weight_decay, + actor_weight_decay=next_best_pose_actor_weight_decay, + crop_shape=crop_shape, + critic_tau=next_best_pose_tau, + critic_grad_clip=next_best_pose_critic_grad_clip, + actor_grad_clip=next_best_pose_actor_grad_clip, + q_conf=q_conf, + ) + + return PreprocessAgent(pose_agent=next_best_pose_agent) diff --git a/external/peract_bimanual/agents/arm/next_best_pose_agent.py b/external/peract_bimanual/agents/arm/next_best_pose_agent.py new file mode 100644 index 0000000000000000000000000000000000000000..e2f9517bcd4032ed351e142a855bf9eef812a581 --- /dev/null +++ b/external/peract_bimanual/agents/arm/next_best_pose_agent.py @@ -0,0 +1,526 @@ +import copy +import logging +import os +from typing import List + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +from yarr.agents.agent import ( + Agent, + Summary, + ActResult, + ScalarSummary, + ImageSummary, + HistogramSummary, +) + +from helpers import utils +from helpers.utils import stack_on_channel +from agents.arm.qattention_agent import QAttentionAgent + +NAME = "NextBestPoseAgent" +LOG_STD_MAX = 4 +LOG_STD_MIN = -40 +REPLAY_ALPHA = 0.7 +REPLAY_BETA = 0.5 + + +class QFunction(nn.Module): + def __init__(self, critic: nn.Module, shared: nn.Module, q_conf: bool): + super(QFunction, self).__init__() + self._q_conf = q_conf + self._q1 = copy.deepcopy(critic) + self._q2 = copy.deepcopy(critic) + self.shared = copy.deepcopy(shared) + self._q1.build() + self._q2.build() + self.shared.build() + + def forward(self, observations, robot_state, action): + obs_feats = self.shared(observations) + combined = torch.cat([robot_state, action.float()], dim=1) + q1 = self._q1(obs_feats, combined) + q2 = self._q2(obs_feats, combined) + if self._q_conf: + b = q1.shape[0] + q1 = q1.view(b, 2, -1) + q2 = q2.view(b, 2, -1) + q1v, q1c = q1[:, 0], q1[:, 1] + q1_best = q1v.gather(1, q1c.argmax(dim=1).unsqueeze(-1)) + q2v, q2c = q2[:, 0], q2[:, 1] + q2_best = q2v.gather(1, q2c.argmax(dim=1).unsqueeze(-1)) + return q1, q2, q1_best, q2_best + else: + q1, q2 = q1.unsqueeze(1), q2.unsqueeze(1) + return q1, q2, q1, q2 + + +class Actor(nn.Module): + def __init__(self, actor_network: nn.Module, action_min_max: torch.tensor): + super(Actor, self).__init__() + self._action_min_max = action_min_max + self._actor_network = copy.deepcopy(actor_network) + self._actor_network.build() + + def _rescale_actions(self, x): + return ( + 0.5 * (x + 1.0) * (self._action_min_max[1] - self._action_min_max[0]) + + self._action_min_max[0] + ) + + def _normalize(self, x): + return x / x.square().sum(dim=1).sqrt().unsqueeze(-1) + + def _gaussian_logprob(self, noise, log_std): + residual = (-0.5 * noise.pow(2) - log_std).sum(-1, keepdim=True) + return residual - 0.5 * np.log(2 * np.pi) * noise.size(-1) + + def forward(self, observations, robot_state): + mu_and_logstd = self._actor_network(observations, robot_state) + mu, log_std = torch.split(mu_and_logstd, 8, dim=1) + log_std = torch.clamp(log_std, LOG_STD_MIN, LOG_STD_MAX) + + std = log_std.exp() + noise = torch.randn_like(mu) + pi = mu + noise * std + log_pi = self._gaussian_logprob(noise, log_std) + mu = torch.tanh(mu) + pi = torch.tanh(pi) + log_pi -= torch.log(F.relu(1 - pi.pow(2)) + 1e-6).sum(-1, keepdim=True) + + pi = self._rescale_actions(pi) + mu = self._rescale_actions(mu) + + pi = torch.cat([pi[:, :3], self._normalize(pi[:, 3:7]), pi[:, 7:]], dim=-1) + mu = torch.cat([mu[:, :3], self._normalize(mu[:, 3:7]), mu[:, 7:]], dim=-1) + return mu, pi, log_pi, log_std + + +class NextBestPoseAgent(Agent): + def __init__( + self, + qattention_agent: QAttentionAgent, + shared_network: nn.Module, + critic_network: nn.Module, + actor_network: nn.Module, + action_min_max: tuple, + camera_name: str, + alpha: float = 0.2, + alpha_auto_tune: bool = True, + alpha_lr: float = 0.001, + critic_lr: float = 0.01, + actor_lr: float = 0.01, + critic_weight_decay: float = 1e-5, + actor_weight_decay: float = 1e-5, + crop_shape: tuple = (16, 16), + critic_tau: float = 0.005, + critic_grad_clip: float = 20.0, + actor_grad_clip: float = 20.0, + gamma: float = 0.99, + nstep: int = 1, + q_conf: bool = True, + ): + self._qattention_agent = qattention_agent + self._alpha = alpha + self._alpha_auto_tune = alpha_auto_tune + self._crop_shape = crop_shape + self._critic_tau = critic_tau + self._critic_grad_clip = critic_grad_clip + self._actor_grad_clip = actor_grad_clip + self._camera_name = camera_name + self._gamma = gamma + self._nstep = nstep + self._target_entropy = -8 + self._shared_network = shared_network + self._critic_network = critic_network + self._actor_network = actor_network + self._action_min_max = action_min_max + self._critic_lr = critic_lr + self._actor_lr = actor_lr + self._alpha_lr = alpha_lr + self._critic_weight_decay = critic_weight_decay + self._actor_weight_decay = actor_weight_decay + self._q_conf = q_conf + self._crop_augmentation = False + + def build(self, training: bool, device: torch.device = None): + if device is None: + device = torch.device("cpu") + self._qattention_agent.build(training, device) + action_min_max = torch.tensor(self._action_min_max).to(device) + self._actor = ( + Actor(self._actor_network, action_min_max).to(device).train(training) + ) + + self._action_min_max_t = torch.tensor(self._action_min_max).to(device) + + grid_for_crop = ( + torch.arange(0, self._crop_shape[0], device=device) + .unsqueeze(0) + .repeat(self._crop_shape[0], 1) + .unsqueeze(-1) + ) + self._grid_for_crop = torch.cat( + [grid_for_crop.transpose(1, 0), grid_for_crop], dim=2 + ).unsqueeze(0) + self._q = ( + QFunction(self._critic_network, self._shared_network, self._q_conf) + .to(device) + .train(training) + ) + if training: + self._q_target = ( + QFunction(self._critic_network, self._shared_network, self._q_conf) + .to(device) + .train(False) + ) + utils.soft_updates(self._q, self._q_target, 1.0) + + self._crop_shape_t = torch.tensor( + [list(self._crop_shape)], dtype=torch.int32, device=device + ) + + # Freeze target critic. + for p in self._q_target.parameters(): + p.requires_grad = False + + self._log_alpha = 0 + if self._alpha_auto_tune: + self._log_alpha = torch.tensor( + (np.log(self._alpha)), + dtype=torch.float, + requires_grad=True, + device=device, + ) + if training: + self._alpha_optimizer = torch.optim.Adam( + [self._log_alpha], lr=self._alpha_lr + ) + + self._critic_optimizer = torch.optim.Adam( + self._q.parameters(), + lr=self._critic_lr, + weight_decay=self._critic_weight_decay, + ) + self._actor_optimizer = torch.optim.Adam( + self._actor.parameters(), + lr=self._actor_lr, + weight_decay=self._actor_weight_decay, + ) + + logging.info( + "# NBP Critic Params: %d" + % sum(p.numel() for p in self._q.parameters() if p.requires_grad) + ) + logging.info( + "# NBP Actor Params: %d" + % sum(p.numel() for p in self._actor.parameters() if p.requires_grad) + ) + else: + for p in self._actor.parameters(): + p.requires_grad = False + + self._device = device + + @property + def alpha(self): + return self._log_alpha.exp() if self._alpha_auto_tune else self._alpha + + def _extract_crop(self, pixel_action, observation): + # Pixel action will now be (B, 2) + observation = stack_on_channel(observation) + h = observation.shape[-1] + top_left_corner = torch.clamp( + pixel_action - self._crop_shape[0] // 2, 0, h - self._crop_shape[1] + ) + grid = self._grid_for_crop + top_left_corner.unsqueeze(1).unsqueeze(1) + grid = ((grid / float(h)) * 2.0) - 1.0 + grid = torch.cat((grid[:, :, :, 1:2], grid[:, :, :, 0:1]), dim=-1) + crop = F.grid_sample(observation, grid, mode="nearest", align_corners=True) + return crop + + def _preprocess_inputs(self, replay_sample, pixel_action, pixel_action_tp1): + observations = [ + self._extract_crop( + pixel_action, replay_sample["%s_rgb" % self._camera_name] + ), + self._extract_crop( + pixel_action, replay_sample["%s_point_cloud" % self._camera_name] + ), + ] + tp1_observations = [ + self._extract_crop( + pixel_action_tp1, replay_sample["%s_rgb_tp1" % self._camera_name] + ), + self._extract_crop( + pixel_action_tp1, + replay_sample["%s_point_cloud_tp1" % self._camera_name], + ), + ] + return observations, tp1_observations + + def _clip_action(self, a): + return torch.min( + torch.max(a, self._action_min_max_t[0:1]), self._action_min_max_t[1:2] + ) + + def _update_critic(self, replay_sample: dict) -> None: + action = replay_sample["action"] + reward = replay_sample["reward"] + + robot_state = stack_on_channel(replay_sample["low_dim_state"][:, -1:]) + robot_state_tp1 = stack_on_channel(replay_sample["low_dim_state_tp1"][:, -1:]) + + # Get last of time stack and first of plan stack + pixel_action = replay_sample["%s_pixel_coord" % self._camera_name][:, -1] + pixel_action_tp1 = replay_sample["%s_pixel_coord_tp1" % self._camera_name][ + :, -1 + ] + + if self._crop_augmentation: + shifted = ( + (torch.rand_like(pixel_action.float()) * self._crop_shape_t).int() + - self._crop_shape_t // 2 + ) * replay_sample["demo"].int().unsqueeze(1) + pixel_action += shifted + pixel_action_tp1 += shifted + + # Don't want timeouts to be classed as terminals + terminal = replay_sample["terminal"].float() - replay_sample["timeout"].float() + + observations, tp1_observations = self._preprocess_inputs( + replay_sample, pixel_action, pixel_action_tp1 + ) + + q1, q2, _, _ = self._q(observations, robot_state, action) + + with torch.no_grad(): + obs_feats = self._q.shared(tp1_observations) + _, pi_tp1, logp_pi_tp1, _ = self._actor(obs_feats, robot_state_tp1) + + q1_pi_tp1_targ, q2_pi_tp1_targ, _, _ = self._q_target( + tp1_observations, robot_state_tp1, pi_tp1 + ) + + min_q_pi_targ = torch.min(q1_pi_tp1_targ[:, 0], q2_pi_tp1_targ[:, 0]) + next_value = min_q_pi_targ - self.alpha * logp_pi_tp1 + q_backup = ( + reward.unsqueeze(-1) + + (self._gamma**self._nstep) + * (1.0 - terminal.unsqueeze(-1)) + * next_value + ) + + loss_weights = utils.loss_weights(replay_sample, REPLAY_BETA) + + self._critic_summaries = {} + if self._q_conf: + w = 1.0 + q1_delta = ( + F.smooth_l1_loss(q1[:, 0], q_backup, reduction="none") * q1[:, 1] + - w * q1[:, 1].log() + ) + q2_delta = ( + F.smooth_l1_loss(q2[:, 0], q_backup, reduction="none") * q2[:, 1] + - w * q2[:, 1].log() + ) + self._critic_summaries = { + "q_conf_loss": -(w * q1[:, 1].log()).mean(), + "q_conf_mean": q1[:, 1].mean(), + } + else: + q1_delta = F.smooth_l1_loss(q1[:, 0], q_backup, reduction="none") + q2_delta = F.smooth_l1_loss(q2[:, 0], q_backup, reduction="none") + + q1_delta, q2_delta = q1_delta.mean(1), q2_delta.mean(1) + q1_bellman_loss = (q1_delta * loss_weights).mean() + q2_bellman_loss = (q2_delta * loss_weights).mean() + + critic_loss = q1_bellman_loss + q2_bellman_loss + + self._critic_summaries.update( + { + "q1_bellman_loss": q1_bellman_loss, + "q2_bellman_loss": q2_bellman_loss, + "q1_mean": q1[:, 0].mean().item(), + "q2_mean": q2[:, 0].mean().item(), + "alpha": self.alpha, + } + ) + self._crop_summary = observations + self._crop_summary_tp1 = tp1_observations + + new_pri = torch.sqrt((q1_delta + q2_delta) / 2.0 + 1e-10) + self._new_priority = (new_pri / torch.max(new_pri)).detach() + self._grad_step( + critic_loss, + self._critic_optimizer, + self._q.parameters(), + self._critic_grad_clip, + ) + + def _update_actor(self, replay_sample: dict) -> None: + robot_state = stack_on_channel(replay_sample["low_dim_state"][:, -1:]) + pixel_action = replay_sample["%s_pixel_coord" % self._camera_name][:, -1] + + if self._crop_augmentation: + shifted = ( + (torch.rand_like(pixel_action.float()) * self._crop_shape_t).int() + - self._crop_shape_t // 2 + ) * replay_sample["demo"].int().unsqueeze(1) + pixel_action += shifted + + # Crop the observations + observations = [ + self._extract_crop( + pixel_action, replay_sample["%s_rgb" % self._camera_name] + ), + self._extract_crop( + pixel_action, replay_sample["%s_point_cloud" % self._camera_name] + ), + ] + + with torch.no_grad(): + obs_feats = self._q.shared(observations) + + mu, pi, self._logp_pi, log_scale_diag = self._actor(obs_feats, robot_state) + + _, _, q1_pi, q2_pi = self._q(observations, robot_state, pi) + + min_q_pi = torch.min(q1_pi, q2_pi)[:, 0] + pi_loss = self.alpha * self._logp_pi - min_q_pi + + loss_weights = utils.loss_weights(replay_sample, REPLAY_BETA) + pi_loss = (pi_loss * loss_weights).mean() + + self._actor_summaries = { + "pi/loss": pi_loss, + "pi/q1_pi_mean": q1_pi.mean(), + "pi/q2_pi_mean": q2_pi.mean(), + "pi/mu": mu.mean(), + "pi/pi": pi.mean(), + "pi/log_pi": self._logp_pi.mean(), + "pi/log_scale_diag": log_scale_diag.mean(), + } + self._grad_step( + pi_loss, + self._actor_optimizer, + self._actor.parameters(), + self._actor_grad_clip, + ) + + def _update_alpha(self): + alpha_loss = -( + self.alpha * (self._logp_pi + self._target_entropy).detach() + ).mean() + self._grad_step(alpha_loss, self._alpha_optimizer) + + def _grad_step(self, loss, opt, model_params=None, clip=None): + opt.zero_grad() + loss.backward() + if clip is not None and model_params is not None: + nn.utils.clip_grad_value_(model_params, clip) + opt.step() + + def update(self, step: int, replay_sample: dict) -> dict: + info = self._qattention_agent.update(step, replay_sample) + + self._update_critic(replay_sample) + + # Freeze critic so you don't waste computational effort + # computing gradients for them during the policy learning step. + for p in self._q.parameters(): + p.requires_grad = False + + self._update_actor(replay_sample) + if self._alpha_auto_tune: + self._update_alpha() + + # UnFreeze critic. + for p in self._q.parameters(): + p.requires_grad = True + + utils.soft_updates(self._q, self._q_target, self._critic_tau) + pixel_agent_priority = info["priority"] + return { + "priority": ((self._new_priority + pixel_agent_priority) / 2.0) + ** REPLAY_ALPHA + } + + def act(self, step: int, observation: dict, deterministic=False) -> ActResult: + with torch.no_grad(): + act_res = self._qattention_agent.act(step, observation, deterministic) + observations = [ + self._extract_crop( + act_res.action.unsqueeze(0), + observation["%s_rgb" % self._camera_name], + ), + self._extract_crop( + act_res.action.unsqueeze(0), + observation["%s_point_cloud" % self._camera_name], + ), + ] + self._act_crop_summaries = observations + robot_state = stack_on_channel(observation["low_dim_state"][:, -1:]) + obs_feats = self._q.shared(observations) + mu, pi, _, _ = self._actor(obs_feats, robot_state) + act_res.action = (mu if deterministic else pi)[0] + act_res.info.update({"rgb_crop": observations[0]}) + return act_res + + def update_summaries(self) -> List[Summary]: + summaries = [ + ImageSummary("%s/crops/rgb" % NAME, (self._crop_summary[0] + 1.0) / 2.0), + ImageSummary("%s/crops/point_cloud" % NAME, self._crop_summary[1]), + ImageSummary( + "%s/crops_tp1/rgb" % NAME, (self._crop_summary_tp1[0] + 1.0) / 2.0 + ), + ImageSummary("%s/crops_tp1/point_cloud" % NAME, self._crop_summary_tp1[1]), + ] + + for n, v in list(self._critic_summaries.items()) + list( + self._actor_summaries.items() + ): + summaries.append(ScalarSummary("%s/%s" % (NAME, n), v)) + + for tag, param in list(self._q.named_parameters()) + list( + self._actor.named_parameters() + ): + summaries.append( + HistogramSummary("%s/gradient/%s" % (NAME, tag), param.grad) + ) + summaries.append(HistogramSummary("%s/weight/%s" % (NAME, tag), param.data)) + + pixel_summaries = self._qattention_agent.update_summaries() + return pixel_summaries + summaries + + def act_summaries(self) -> List[Summary]: + summaries = [ + ImageSummary( + "%s/crops/act/rgb" % NAME, (self._act_crop_summaries[0] + 1.0) / 2.0 + ), + ImageSummary( + "%s/crops/act/point_cloud" % NAME, self._act_crop_summaries[1] + ), + ] + return summaries + self._qattention_agent.act_summaries() + + def load_weights(self, savedir: str): + self._qattention_agent.load_weights(savedir) + self._actor.load_state_dict( + torch.load( + os.path.join(savedir, "pose_actor.pt"), map_location=torch.device("cpu") + ) + ) + self._q.load_state_dict( + torch.load( + os.path.join(savedir, "pose_q.pt"), map_location=torch.device("cpu") + ) + ) + + def save_weights(self, savedir: str): + self._qattention_agent.save_weights(savedir) + torch.save(self._actor.state_dict(), os.path.join(savedir, "pose_actor.pt")) + torch.save(self._q.state_dict(), os.path.join(savedir, "pose_q.pt")) diff --git a/external/peract_bimanual/agents/arm/qattention_agent.py b/external/peract_bimanual/agents/arm/qattention_agent.py new file mode 100644 index 0000000000000000000000000000000000000000..f0deb576907a2cfab654ec0c18a0c7f2f0c5c21e --- /dev/null +++ b/external/peract_bimanual/agents/arm/qattention_agent.py @@ -0,0 +1,247 @@ +import copy +import logging +import os +from typing import List + +import PIL +import torch +import torch.nn as nn +import torch.nn.functional as F +from torchvision import transforms + +from yarr.agents.agent import ( + Agent, + ActResult, + ScalarSummary, + HistogramSummary, + ImageSummary, + Summary, +) + +from helpers import utils +from helpers.utils import stack_on_channel + +NAME = "QAttentionAgent" +REPLAY_BETA = 1.0 + + +class QFunction(nn.Module): + def __init__(self, unet: nn.Module): + super(QFunction, self).__init__() + self._qnet = copy.deepcopy(unet) + self._qnet2 = copy.deepcopy(unet) + self._qnet.build() + self._qnet2.build() + + def _argmax_2d(self, tensor): + t_shape = tensor.shape + m = tensor.view(t_shape[0], -1).argmax(1).view(-1, 1) + indices = torch.cat((m // t_shape[-1], m % t_shape[-1]), dim=1) + return indices + + def forward(self, x, robot_state): + q = self._qnet(x, robot_state)[:, 0] + q2 = self._qnet2(x, robot_state)[:, 0] + coords = self._argmax_2d(torch.min(q, q2)) + return q, q2, coords + + +class QAttentionAgent(Agent): + def __init__( + self, + pixel_unet: nn.Module, + camera_name: str, + tau: float = 0.005, + gamma: float = 0.99, + nstep: int = 1, + lr: float = 0.0001, + weight_decay: float = 1e-5, + lambda_qreg: float = 1e-6, + grad_clip: float = 20.0, + include_low_dim_state: bool = False, + ): + self._pixel_unet = pixel_unet + self._camera_name = camera_name + self._tau = tau + self._gamma = gamma + self._nstep = nstep + self._lr = lr + self._weight_decay = weight_decay + self._lambda_qreg = lambda_qreg + self._grad_clip = grad_clip + self._include_low_dim_state = include_low_dim_state + + def build(self, training: bool, device: torch.device = None): + if device is None: + device = torch.device("cpu") + self._q = QFunction(self._pixel_unet).to(device).train(training) + self._q_target = None + if training: + self._q_target = QFunction(self._pixel_unet).to(device).train(False) + for p in self._q_target.parameters(): + p.requires_grad = False + utils.soft_updates(self._q, self._q_target, 1.0) + self._optimizer = torch.optim.Adam( + self._q.parameters(), lr=self._lr, weight_decay=self._weight_decay + ) + logging.info( + "# Q-attention Params: %d" + % sum(p.numel() for p in self._q.parameters() if p.requires_grad) + ) + else: + for p in self._q.parameters(): + p.requires_grad = False + self._device = device + + def _get_q_from_pixel_coord(self, q, coord): + b, h, w = q.shape + flat_indicies = (coord[:, 0] * w + coord[:, 1])[:, None].long() + return q.view(b, h * w).gather(1, flat_indicies) + + def _preprocess_inputs(self, replay_sample): + observations = [ + stack_on_channel(replay_sample["%s_rgb" % self._camera_name]), + stack_on_channel(replay_sample["%s_point_cloud" % self._camera_name]), + ] + tp1_observations = [ + stack_on_channel(replay_sample["%s_rgb_tp1" % self._camera_name]), + stack_on_channel(replay_sample["%s_point_cloud_tp1" % self._camera_name]), + ] + return observations, tp1_observations + + def update(self, step: int, replay_sample: dict) -> dict: + pixel_action = replay_sample["%s_pixel_coord" % self._camera_name][:, -1].int() + reward = replay_sample["reward"] + reward = torch.where(reward > 0, reward, torch.zeros_like(reward)) + + robot_state = robot_state_tp1 = None + if self._include_low_dim_state: + robot_state = stack_on_channel(replay_sample["low_dim_state"]) + robot_state_tp1 = stack_on_channel(replay_sample["low_dim_state_tp1"]) + + # Don't want timeouts to be classed as terminals + terminal = replay_sample["terminal"].float() - replay_sample["timeout"].float() + + obs, obs_tp1 = self._preprocess_inputs(replay_sample) + q, q2, coords = self._q(obs, robot_state) + + with torch.no_grad(): + # (B, h, w) + _, _, coords_tp1 = self._q(obs_tp1, robot_state_tp1) + q_tp1_targ, q2_tp1_targ, _ = self._q_target(obs_tp1, robot_state_tp1) + q_tp1_targ = torch.min(q_tp1_targ, q2_tp1_targ) + q_tp1_targ = self._get_q_from_pixel_coord(q_tp1_targ, coords_tp1) + target = ( + reward.unsqueeze(1) + + (self._gamma**self._nstep) + * (1 - terminal.unsqueeze(1)) + * q_tp1_targ + ) + target = torch.clamp(target, 0.0, 100.0) + + q_pred = self._get_q_from_pixel_coord(q, pixel_action) + delta = F.smooth_l1_loss(q_pred, target, reduction="none").mean(1) + + delta += F.smooth_l1_loss( + self._get_q_from_pixel_coord(q2, pixel_action), target, reduction="none" + ).mean(1) + q_reg = ( + (0.5 * torch.sum(q**2)) + (0.5 * torch.sum(q2**2)) + ) * self._lambda_qreg + + loss_weights = utils.loss_weights(replay_sample, REPLAY_BETA) + total_loss = ((delta) * loss_weights).mean() + q_reg + new_priority = ((delta) + 1e-10).sqrt() + new_priority /= new_priority.max() + + self._summaries = { + "losses/bellman": delta.mean(), + "losses/qreg": q_reg.mean(), + "q/mean": q.mean(), + "q/action_q": q_pred.mean(), + } + self._qvalues = q[:1] + self._rgb_observation = replay_sample["front_rgb"][0, -1] + self._optimizer.zero_grad() + total_loss.backward() + if self._grad_clip is not None: + nn.utils.clip_grad_value_(self._q.parameters(), self._grad_clip) + self._optimizer.step() + utils.soft_updates(self._q, self._q_target, self._tau) + + return { + "priority": new_priority, + } + + def act(self, step: int, observation: dict, deterministic=False) -> ActResult: + with torch.no_grad(): + observations = [ + stack_on_channel(observation["%s_rgb" % self._camera_name]), + stack_on_channel(observation["%s_point_cloud" % self._camera_name]), + ] + robot_state = None + if self._include_low_dim_state: + robot_state = stack_on_channel(observation["low_dim_state"]) + # Coords are stored as (y, x) + q, q2, coords = self._q(observations, robot_state) + self._act_qvalues = torch.min(q, q2)[:1] + self._rgb_observation = observation["front_rgb"][0, -1] + return ActResult( + coords[0], + observation_elements={ + "%s_pixel_coord" % self._camera_name: coords[0], + }, + info={"q_values": self._act_qvalues}, + ) + + @staticmethod + def generate_heatmap(q_values, rgb_obs): + norm_q = torch.clamp(q_values / 100.0, 0, 1) + heatmap = torch.cat( + [norm_q, torch.zeros_like(norm_q), torch.zeros_like(norm_q)] + ) + img = transforms.functional.to_pil_image(rgb_obs) + h_img = transforms.functional.to_pil_image(heatmap).convert("RGB") + ret = PIL.Image.blend(img, h_img, 0.75) + return transforms.ToTensor()(ret).unsqueeze_(0) + + def update_summaries(self) -> List[Summary]: + summaries = [ + ImageSummary( + "%s/Q" % NAME, + QAttentionAgent.generate_heatmap( + self._qvalues.cpu(), ((self._rgb_observation + 1) / 2.0).cpu() + ), + ) + ] + for n, v in self._summaries.items(): + summaries.append(ScalarSummary("%s/%s" % (NAME, n), v)) + + for tag, param in self._q.named_parameters(): + assert not torch.isnan(param.grad.abs() <= 1.0).all() + summaries.append( + HistogramSummary("%s/gradient/%s" % (NAME, tag), param.grad) + ) + summaries.append(HistogramSummary("%s/weight/%s" % (NAME, tag), param.data)) + return summaries + + def act_summaries(self) -> List[Summary]: + return [ + ImageSummary( + "%s/Q_act" % NAME, + QAttentionAgent.generate_heatmap( + self._act_qvalues.cpu(), ((self._rgb_observation + 1) / 2.0).cpu() + ), + ) + ] + + def load_weights(self, savedir: str): + self._q.load_state_dict( + torch.load( + os.path.join(savedir, "pixel_agent_q.pt"), + map_location=torch.device("cpu"), + ) + ) + + def save_weights(self, savedir: str): + torch.save(self._q.state_dict(), os.path.join(savedir, "pixel_agent_q.pt")) diff --git a/external/peract_bimanual/agents/baselines/__init__.py b/external/peract_bimanual/agents/baselines/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/external/peract_bimanual/agents/baselines/bc_lang/__init__.py b/external/peract_bimanual/agents/baselines/bc_lang/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..84da091caedb4a3f3ff7b0c0e96b1020f988baa9 --- /dev/null +++ b/external/peract_bimanual/agents/baselines/bc_lang/__init__.py @@ -0,0 +1 @@ +import agents.baselines.bc_lang.launch_utils diff --git a/external/peract_bimanual/agents/baselines/bc_lang/bc_lang_agent.py b/external/peract_bimanual/agents/baselines/bc_lang/bc_lang_agent.py new file mode 100644 index 0000000000000000000000000000000000000000..bb83a7ea58946f187cd1236e5917a3675de169c5 --- /dev/null +++ b/external/peract_bimanual/agents/baselines/bc_lang/bc_lang_agent.py @@ -0,0 +1,148 @@ +import copy +import logging +import os +from typing import List + +import torch +import torch.nn as nn +import torch.nn.functional as F +from yarr.agents.agent import Agent, Summary, ActResult, ScalarSummary, HistogramSummary + +from helpers import utils +from helpers.utils import stack_on_channel + +from helpers.clip.core.clip import build_model, load_clip + +NAME = "BCLangAgent" +REPLAY_ALPHA = 0.7 +REPLAY_BETA = 1.0 + + +class Actor(nn.Module): + def __init__(self, actor_network: nn.Module): + super(Actor, self).__init__() + self._actor_network = copy.deepcopy(actor_network) + self._actor_network.build() + + def forward(self, observations, robot_state, lang_goal_emb): + mu = self._actor_network(observations, robot_state, lang_goal_emb) + return mu + + +class BCLangAgent(Agent): + def __init__( + self, + actor_network: nn.Module, + camera_name: str, + lr: float = 0.01, + weight_decay: float = 1e-5, + grad_clip: float = 20.0, + ): + self._camera_name = camera_name + self._actor_network = actor_network + self._lr = lr + self._weight_decay = weight_decay + self._grad_clip = grad_clip + + def build(self, training: bool, device: torch.device = None): + if device is None: + device = torch.device("cpu") + self._actor = Actor(self._actor_network).to(device).train(training) + if training: + self._actor_optimizer = torch.optim.Adam( + self._actor.parameters(), lr=self._lr, weight_decay=self._weight_decay + ) + logging.info( + "# Actor Params: %d" + % sum(p.numel() for p in self._actor.parameters() if p.requires_grad) + ) + else: + for p in self._actor.parameters(): + p.requires_grad = False + + model, _ = load_clip("RN50", jit=False) + self._clip_rn50 = build_model(model.state_dict()) + self._clip_rn50 = self._clip_rn50.float().to(device) + self._clip_rn50.eval() + del model + + self._device = device + + def _grad_step(self, loss, opt, model_params=None, clip=None): + opt.zero_grad() + loss.backward() + if clip is not None and model_params is not None: + nn.utils.clip_grad_value_(model_params, clip) + opt.step() + + def update(self, step: int, replay_sample: dict) -> dict: + lang_goal_emb = replay_sample["lang_goal_emb"] + robot_state = replay_sample["low_dim_state"] + observations = [ + replay_sample["%s_rgb" % self._camera_name], + replay_sample["%s_point_cloud" % self._camera_name], + ] + mu = self._actor(observations, robot_state, lang_goal_emb) + loss_weights = utils.loss_weights(replay_sample, REPLAY_BETA) + delta = F.mse_loss(mu, replay_sample["action"], reduction="none").mean(1) + loss = (delta * loss_weights).mean() + self._grad_step( + loss, self._actor_optimizer, self._actor.parameters(), self._grad_clip + ) + self._summaries = { + "pi/loss": loss, + "pi/mu": mu.mean(), + } + return {"total_losses": loss} + + def _normalize_quat(self, x): + return x / x.square().sum(dim=1).sqrt().unsqueeze(-1) + + def act(self, step: int, observation: dict, deterministic=False) -> ActResult: + lang_goal_tokens = observation.get("lang_goal_tokens", None).long() + + with torch.no_grad(): + lang_goal_tokens = lang_goal_tokens.to(device=self._device) + lang_goal_emb, _ = self._clip_rn50.encode_text_with_embeddings( + lang_goal_tokens[0] + ) + lang_goal_emb = lang_goal_emb.to(device=self._device) + + observations = [ + observation["%s_rgb" % self._camera_name][0].to(self._device), + observation["%s_point_cloud" % self._camera_name][0].to(self._device), + ] + robot_state = observation["low_dim_state"][0].to(self._device) + + mu = self._actor(observations, robot_state, lang_goal_emb) + mu = torch.cat([mu[:, :3], self._normalize_quat(mu[:, 3:7]), mu[:, 7:]], dim=-1) + ignore_collisions = torch.Tensor([1.0]).to(mu.device) + mu0 = torch.cat([mu[0], ignore_collisions]) + return ActResult(mu0.detach().cpu()) + + def update_summaries(self) -> List[Summary]: + summaries = [] + for n, v in self._summaries.items(): + summaries.append(ScalarSummary("%s/%s" % (NAME, n), v)) + + for tag, param in self._actor.named_parameters(): + summaries.append( + HistogramSummary("%s/gradient/%s" % (NAME, tag), param.grad) + ) + summaries.append(HistogramSummary("%s/weight/%s" % (NAME, tag), param.data)) + + return summaries + + def act_summaries(self) -> List[Summary]: + return [] + + def load_weights(self, savedir: str): + self._actor.load_state_dict( + torch.load( + os.path.join(savedir, "bc_actor.pt"), map_location=torch.device("cpu") + ) + ) + print("Loaded weights from %s" % savedir) + + def save_weights(self, savedir: str): + torch.save(self._actor.state_dict(), os.path.join(savedir, "bc_actor.pt")) diff --git a/external/peract_bimanual/agents/baselines/bc_lang/launch_utils.py b/external/peract_bimanual/agents/baselines/bc_lang/launch_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..f52b412f49d1fb0bd2f8f3ed39e5f9e7ce1211e7 --- /dev/null +++ b/external/peract_bimanual/agents/baselines/bc_lang/launch_utils.py @@ -0,0 +1,368 @@ +# Adapted from ARM +# Source: https://github.com/stepjam/ARM +# License: https://github.com/stepjam/ARM/LICENSE + +import logging +from typing import List + +import numpy as np +from omegaconf import DictConfig +from rlbench.backend.observation import Observation +from rlbench.observation_config import ObservationConfig +import rlbench.utils as rlbench_utils +from rlbench.demo import Demo +from yarr.replay_buffer.prioritized_replay_buffer import ( + PrioritizedReplayBuffer, + ObservationElement, +) +from yarr.replay_buffer.replay_buffer import ReplayElement, ReplayBuffer +from yarr.replay_buffer.uniform_replay_buffer import UniformReplayBuffer +from yarr.replay_buffer.task_uniform_replay_buffer import TaskUniformReplayBuffer + +from helpers import demo_loading_utils, utils +from helpers import observation_utils +from agents.baselines.bc_lang.bc_lang_agent import BCLangAgent +from helpers.custom_rlbench_env import CustomRLBenchEnv +from helpers.network_utils import SiameseNet, CNNLangAndFcsNet +from helpers.preprocess_agent import PreprocessAgent + +import torch +from torch.multiprocessing import Process, Value, Manager +from helpers.clip.core.clip import build_model, load_clip, tokenize + +LOW_DIM_SIZE = 4 + + +def create_replay( + batch_size: int, + timesteps: int, + prioritisation: bool, + task_uniform: bool, + save_dir: str, + cameras: list, + image_size=[128, 128], + replay_size=3e5, +): + lang_feat_dim = 1024 + + # low_dim_state + observation_elements = [] + observation_elements.append( + ObservationElement("low_dim_state", (LOW_DIM_SIZE,), np.float32) + ) + + # rgb, depth, point cloud, intrinsics, extrinsics + for cname in cameras: + observation_elements.append( + ObservationElement( + "%s_rgb" % cname, + ( + 3, + *image_size, + ), + np.float32, + ) + ) + observation_elements.append( + ObservationElement("%s_point_cloud" % cname, (3, *image_size), np.float32) + ) # see pyrep/objects/vision_sensor.py on how pointclouds are extracted from depth frames + observation_elements.append( + ObservationElement( + "%s_camera_extrinsics" % cname, + ( + 4, + 4, + ), + np.float32, + ) + ) + observation_elements.append( + ObservationElement( + "%s_camera_intrinsics" % cname, + ( + 3, + 3, + ), + np.float32, + ) + ) + + observation_elements.extend( + [ + ReplayElement("lang_goal_emb", (lang_feat_dim,), np.float32), + ReplayElement("task", (), str), + ReplayElement( + "lang_goal", (1,), object + ), # language goal string for debugging and visualization + ] + ) + + extra_replay_elements = [ + ReplayElement("demo", (), np.bool), + ] + + replay_buffer = TaskUniformReplayBuffer( + save_dir=save_dir, + batch_size=batch_size, + timesteps=timesteps, + replay_capacity=int(replay_size), + action_shape=(8,), + action_dtype=np.float32, + reward_shape=(), + reward_dtype=np.float32, + update_horizon=1, + observation_elements=observation_elements, + extra_replay_elements=extra_replay_elements, + ) + return replay_buffer + + +def _get_action(obs_tp1: Observation): + quat = utils.normalize_quaternion(obs_tp1.gripper_pose[3:]) + if quat[-1] < 0: + quat = -quat + return np.concatenate( + [obs_tp1.gripper_pose[:3], quat, [float(obs_tp1.gripper_open)]] + ) + + +def _add_keypoints_to_replay( + cfg: DictConfig, + task: str, + replay: ReplayBuffer, + inital_obs: Observation, + demo: Demo, + episode_keypoints: List[int], + cameras: List[str], + description: str = "", + clip_model=None, + device="cpu", +): + prev_action = None + obs = inital_obs + all_actions = [] + for k, keypoint in enumerate(episode_keypoints): + obs_tp1 = demo[keypoint] + action = _get_action(obs_tp1) + all_actions.append(action) + terminal = k == len(episode_keypoints) - 1 + reward = float(terminal) if terminal else 0 + + obs_dict = observation_utils.extract_obs( + obs, + t=k, + prev_action=prev_action, + cameras=cameras, + episode_length=cfg.rlbench.episode_length, + robot_name=cfg.method.robot_name, + ) + del obs_dict["ignore_collisions"] + tokens = tokenize([description]).numpy() + token_tensor = torch.from_numpy(tokens).to(device) + lang_feats, lang_embs = clip_model.encode_text_with_embeddings(token_tensor) + obs_dict["lang_goal_emb"] = lang_feats[0].float().detach().cpu().numpy() + + final_obs = { + "task": task, + "lang_goal": np.array([description], dtype=object), + } + + prev_action = np.copy(action) + others = {"demo": True} + others.update(final_obs) + others.update(obs_dict) + timeout = False + replay.add(action, reward, terminal, timeout, **others) + obs = obs_tp1 # Set the next obs + # Final step + obs_dict_tp1 = observation_utils.extract_obs( + obs_tp1, + t=k + 1, + prev_action=prev_action, + cameras=cameras, + episode_length=cfg.rlbench.episode_length, + robot_name=cfg.method.robot_name, + ) + obs_dict_tp1["lang_goal_emb"] = lang_feats[0].float().detach().cpu().numpy() + # del obs_dict_tp1['lang_goal_tokens'] + del obs_dict_tp1["ignore_collisions"] + # obs_dict_tp1['task'] = task + obs_dict_tp1.update(final_obs) + replay.add_final(**obs_dict_tp1) + return all_actions + + +def fill_replay( + cfg: DictConfig, + obs_config: ObservationConfig, + rank: int, + replay: ReplayBuffer, + task: str, + num_demos: int, + demo_augmentation: bool, + demo_augmentation_every_n: int, + cameras: List[str], + clip_model=None, + device="cpu", +): + if clip_model is None: + model, _ = load_clip("RN50", jit=False, device=device) + clip_model = build_model(model.state_dict()) + clip_model.to(device) + del model + + logging.debug("Filling %s replay ..." % task) + all_actions = [] + for d_idx in range(num_demos): + # load demo from disk + demo = rlbench_utils.get_stored_demos( + amount=1, + image_paths=False, + dataset_root=cfg.rlbench.demo_path, + variation_number=-1, + task_name=task, + obs_config=obs_config, + random_selection=False, + from_episode_number=d_idx, + )[0] + + descs = demo._observations[0].misc["descriptions"] + + # extract keypoints (a.k.a keyframes) + episode_keypoints = demo_loading_utils.keypoint_discovery(demo) + + if rank == 0: + logging.info( + f"Loading Demo({d_idx}) - found {len(episode_keypoints)} keypoints - {task}" + ) + + for i in range(len(demo) - 1): + if not demo_augmentation and i > 0: + break + if i % demo_augmentation_every_n != 0: + continue + + obs = demo[i] + desc = descs[0] + # if our starting point is past one of the keypoints, then remove it + while len(episode_keypoints) > 0 and i >= episode_keypoints[0]: + episode_keypoints = episode_keypoints[1:] + if len(episode_keypoints) == 0: + break + all_actions.extend( + _add_keypoints_to_replay( + cfg, + task, + replay, + obs, + demo, + episode_keypoints, + cameras, + description=desc, + clip_model=clip_model, + device=device, + ) + ) + logging.debug("Replay filled with demos.") + return all_actions + + +def fill_multi_task_replay( + cfg: DictConfig, + obs_config: ObservationConfig, + rank: int, + replay: ReplayBuffer, + tasks: List[str], + num_demos: int, + demo_augmentation: bool, + demo_augmentation_every_n: int, + cameras: List[str], + clip_model=None, +): + manager = Manager() + store = manager.dict() + + # create a MP dict for storing indicies + # TODO(mohit): this shouldn't be initialized here + del replay._task_idxs + task_idxs = manager.dict() + replay._task_idxs = task_idxs + replay._create_storage(store) + replay.add_count = Value("i", 0) + + # fill replay buffer in parallel across tasks + max_parallel_processes = cfg.replay.max_parallel_processes + processes = [] + n = np.arange(len(tasks)) + split_n = utils.split_list(n, max_parallel_processes) + for split in split_n: + for e_idx, task_idx in enumerate(split): + task = tasks[int(task_idx)] + model_device = torch.device( + "cuda:%s" % (e_idx % torch.cuda.device_count()) + if torch.cuda.is_available() + else "cpu" + ) + p = Process( + target=fill_replay, + args=( + cfg, + obs_config, + rank, + replay, + task, + num_demos, + demo_augmentation, + demo_augmentation_every_n, + cameras, + clip_model, + model_device, + ), + ) + p.start() + processes.append(p) + + for p in processes: + p.join() + + logging.debug("Replay filled with multi demos.") + + +def create_agent(cfg: DictConfig): + camera_name = cfg.rlbench.cameras + activation = cfg.method.activation + lr = cfg.method.lr + weight_decay = cfg.method.weight_decay + image_resolution = cfg.rlbench.camera_resolution + grad_clip = cfg.method.grad_clip + + siamese_net = SiameseNet( + input_channels=[3, 3], + filters=[16], + kernel_sizes=[5], + strides=[1], + activation=activation, + norm=None, + ) + + actor_net = CNNLangAndFcsNet( + siamese_net=siamese_net, + input_resolution=image_resolution, + filters=[32, 64, 64], + kernel_sizes=[3, 3, 3], + strides=[2, 2, 2], + norm=None, + activation=activation, + fc_layers=[128, 64, 3 + 4 + 1], + low_dim_state_len=LOW_DIM_SIZE, + ) + + bc_agent = BCLangAgent( + actor_network=actor_net, + camera_name=camera_name, + lr=lr, + weight_decay=weight_decay, + grad_clip=grad_clip, + ) + + return PreprocessAgent(pose_agent=bc_agent) diff --git a/external/peract_bimanual/agents/baselines/vit_bc_lang/__init__.py b/external/peract_bimanual/agents/baselines/vit_bc_lang/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..0d787af6e96872bcd4bba4030b1a9cc5ed623413 --- /dev/null +++ b/external/peract_bimanual/agents/baselines/vit_bc_lang/__init__.py @@ -0,0 +1 @@ +import agents.baselines.vit_bc_lang.launch_utils diff --git a/external/peract_bimanual/agents/baselines/vit_bc_lang/launch_utils.py b/external/peract_bimanual/agents/baselines/vit_bc_lang/launch_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..1089c80d2deede1a95f47dc71691643ac0fe96fe --- /dev/null +++ b/external/peract_bimanual/agents/baselines/vit_bc_lang/launch_utils.py @@ -0,0 +1,372 @@ +# Adapted from ARM +# Source: https://github.com/stepjam/ARM +# License: https://github.com/stepjam/ARM/LICENSE + +import logging +from typing import List + +import numpy as np +from omegaconf import DictConfig +from rlbench.backend.observation import Observation +from rlbench.observation_config import ObservationConfig +import rlbench.utils as rlbench_utils +from rlbench.demo import Demo +from yarr.replay_buffer.prioritized_replay_buffer import ( + PrioritizedReplayBuffer, + ObservationElement, +) +from yarr.replay_buffer.replay_buffer import ReplayElement, ReplayBuffer +from yarr.replay_buffer.uniform_replay_buffer import UniformReplayBuffer +from yarr.replay_buffer.task_uniform_replay_buffer import TaskUniformReplayBuffer + +from helpers import demo_loading_utils, utils +from helpers import observation_utils +from agents.baselines.vit_bc_lang.vit_bc_lang_agent import ViTBCLangAgent +from helpers.custom_rlbench_env import CustomRLBenchEnv +from helpers.network_utils import ViTLangAndFcsNet, ViT +from helpers.preprocess_agent import PreprocessAgent + +import torch +from torch.multiprocessing import Process, Value, Manager +from helpers.clip.core.clip import build_model, load_clip, tokenize + +LOW_DIM_SIZE = 4 + + +def create_replay( + batch_size: int, + timesteps: int, + prioritisation: bool, + task_uniform: bool, + save_dir: str, + cameras: list, + image_size=[128, 128], + replay_size=3e5, +): + lang_feat_dim = 1024 + + # low_dim_state + observation_elements = [] + observation_elements.append( + ObservationElement("low_dim_state", (LOW_DIM_SIZE,), np.float32) + ) + + # rgb, depth, point cloud, intrinsics, extrinsics + for cname in cameras: + observation_elements.append( + ObservationElement( + "%s_rgb" % cname, + ( + 3, + *image_size, + ), + np.float32, + ) + ) + observation_elements.append( + ObservationElement("%s_point_cloud" % cname, (3, *image_size), np.float32) + ) # see pyrep/objects/vision_sensor.py on how pointclouds are extracted from depth frames + observation_elements.append( + ObservationElement( + "%s_camera_extrinsics" % cname, + ( + 4, + 4, + ), + np.float32, + ) + ) + observation_elements.append( + ObservationElement( + "%s_camera_intrinsics" % cname, + ( + 3, + 3, + ), + np.float32, + ) + ) + + observation_elements.extend( + [ + ReplayElement("lang_goal_emb", (lang_feat_dim,), np.float32), + ReplayElement("task", (), str), + ReplayElement( + "lang_goal", (1,), object + ), # language goal string for debugging and visualization + ] + ) + + extra_replay_elements = [ + ReplayElement("demo", (), np.bool), + ] + + replay_buffer = TaskUniformReplayBuffer( + save_dir=save_dir, + batch_size=batch_size, + timesteps=timesteps, + replay_capacity=int(replay_size), + action_shape=(8,), + action_dtype=np.float32, + reward_shape=(), + reward_dtype=np.float32, + update_horizon=1, + observation_elements=observation_elements, + extra_replay_elements=extra_replay_elements, + ) + return replay_buffer + + +def _get_action(obs_tp1: Observation): + quat = utils.normalize_quaternion(obs_tp1.gripper_pose[3:]) + if quat[-1] < 0: + quat = -quat + return np.concatenate( + [obs_tp1.gripper_pose[:3], quat, [float(obs_tp1.gripper_open)]] + ) + + +def _add_keypoints_to_replay( + cfg: DictConfig, + task: str, + replay: ReplayBuffer, + inital_obs: Observation, + demo: Demo, + episode_keypoints: List[int], + cameras: List[str], + description: str = "", + clip_model=None, + device="cpu", +): + prev_action = None + obs = inital_obs + all_actions = [] + for k, keypoint in enumerate(episode_keypoints): + obs_tp1 = demo[keypoint] + action = _get_action(obs_tp1) + all_actions.append(action) + terminal = k == len(episode_keypoints) - 1 + reward = float(terminal) if terminal else 0 + + obs_dict = observation_utils.extract_obs( + obs, + t=k, + prev_action=prev_action, + cameras=cameras, + episode_length=cfg.rlbench.episode_length, + robot_name=cfg.method.robot_name, + ) + del obs_dict["ignore_collisions"] + tokens = tokenize([description]).numpy() + token_tensor = torch.from_numpy(tokens).to(device) + lang_feats, lang_embs = clip_model.encode_text_with_embeddings(token_tensor) + obs_dict["lang_goal_emb"] = lang_feats[0].float().detach().cpu().numpy() + + final_obs = { + "task": task, + "lang_goal": np.array([description], dtype=object), + } + + prev_action = np.copy(action) + others = {"demo": True} + others.update(final_obs) + others.update(obs_dict) + timeout = False + replay.add(action, reward, terminal, timeout, **others) + obs = obs_tp1 # Set the next obs + # Final step + obs_dict_tp1 = observation_utils.extract_obs( + obs_tp1, + t=k + 1, + prev_action=prev_action, + cameras=cameras, + episode_length=cfg.rlbench.episode_length, + robot_name=cfg.method.robot_name, + ) + obs_dict_tp1["lang_goal_emb"] = lang_feats[0].float().detach().cpu().numpy() + # del obs_dict_tp1['lang_goal_tokens'] + del obs_dict_tp1["ignore_collisions"] + # obs_dict_tp1['task'] = task + obs_dict_tp1.update(final_obs) + replay.add_final(**obs_dict_tp1) + return all_actions + + +def fill_replay( + cfg: DictConfig, + obs_config: ObservationConfig, + rank: int, + replay: ReplayBuffer, + task: str, + num_demos: int, + demo_augmentation: bool, + demo_augmentation_every_n: int, + cameras: List[str], + clip_model=None, + device="cpu", +): + if clip_model is None: + model, _ = load_clip("RN50", jit=False, device=device) + clip_model = build_model(model.state_dict()) + clip_model.to(device) + del model + + logging.debug("Filling %s replay ..." % task) + all_actions = [] + for d_idx in range(num_demos): + # load demo from disk + demo = rlbench_utils.get_stored_demos( + amount=1, + image_paths=False, + dataset_root=cfg.rlbench.demo_path, + variation_number=-1, + task_name=task, + obs_config=obs_config, + random_selection=False, + from_episode_number=d_idx, + )[0] + + descs = demo._observations[0].misc["descriptions"] + + # extract keypoints (a.k.a keyframes) + episode_keypoints = demo_loading_utils.keypoint_discovery(demo) + + if rank == 0: + logging.info( + f"Loading Demo({d_idx}) - found {len(episode_keypoints)} keypoints - {task}" + ) + + for i in range(len(demo) - 1): + if not demo_augmentation and i > 0: + break + if i % demo_augmentation_every_n != 0: + continue + + obs = demo[i] + desc = descs[0] + # if our starting point is past one of the keypoints, then remove it + while len(episode_keypoints) > 0 and i >= episode_keypoints[0]: + episode_keypoints = episode_keypoints[1:] + if len(episode_keypoints) == 0: + break + all_actions.extend( + _add_keypoints_to_replay( + cfg, + task, + replay, + obs, + demo, + episode_keypoints, + cameras, + description=desc, + clip_model=clip_model, + device=device, + ) + ) + logging.debug("Replay filled with demos.") + return all_actions + + +def fill_multi_task_replay( + cfg: DictConfig, + obs_config: ObservationConfig, + rank: int, + replay: ReplayBuffer, + tasks: List[str], + num_demos: int, + demo_augmentation: bool, + demo_augmentation_every_n: int, + cameras: List[str], + clip_model=None, +): + manager = Manager() + store = manager.dict() + + # create a MP dict for storing indicies + # TODO(mohit): this shouldn't be initialized here + del replay._task_idxs + task_idxs = manager.dict() + replay._task_idxs = task_idxs + replay._create_storage(store) + replay.add_count = Value("i", 0) + + # fill replay buffer in parallel across tasks + max_parallel_processes = cfg.replay.max_parallel_processes + processes = [] + n = np.arange(len(tasks)) + split_n = utils.split_list(n, max_parallel_processes) + for split in split_n: + for e_idx, task_idx in enumerate(split): + task = tasks[int(task_idx)] + model_device = torch.device( + "cuda:%s" % (e_idx % torch.cuda.device_count()) + if torch.cuda.is_available() + else "cpu" + ) + p = Process( + target=fill_replay, + args=( + cfg, + obs_config, + rank, + replay, + task, + num_demos, + demo_augmentation, + demo_augmentation_every_n, + cameras, + clip_model, + model_device, + ), + ) + p.start() + processes.append(p) + + for p in processes: + p.join() + + logging.debug("Replay filled with multi demos.") + + +def create_agent(cfg: DictConfig): + camera_name = cfg.rlbench.cameras + activation = cfg.method.activation + lr = cfg.method.lr + weight_decay = cfg.method.weight_decay + image_resolution = cfg.rlbench.camera_resolution + grad_clip = cfg.method.grad_clip + + vit = ViT( + image_size=128, + patch_size=8, + num_classes=16, + dim=64, + depth=6, + heads=8, + mlp_dim=64, + dropout=0.1, + emb_dropout=0.1, + channels=6, + ) + + actor_net = ViTLangAndFcsNet( + vit=vit, + input_resolution=image_resolution, + filters=[64, 96, 128], + kernel_sizes=[1, 1, 1], + strides=[1, 1, 1], + norm=None, + activation=activation, + fc_layers=[128, 64, 3 + 4 + 1], + low_dim_state_len=LOW_DIM_SIZE, + ) + + bc_agent = ViTBCLangAgent( + actor_network=actor_net, + camera_name=camera_name, + lr=lr, + weight_decay=weight_decay, + grad_clip=grad_clip, + ) + + return PreprocessAgent(pose_agent=bc_agent) diff --git a/external/peract_bimanual/agents/baselines/vit_bc_lang/vit_bc_lang_agent.py b/external/peract_bimanual/agents/baselines/vit_bc_lang/vit_bc_lang_agent.py new file mode 100644 index 0000000000000000000000000000000000000000..ba86e49001ef8289d3c1b3c9e13a86e126a95880 --- /dev/null +++ b/external/peract_bimanual/agents/baselines/vit_bc_lang/vit_bc_lang_agent.py @@ -0,0 +1,148 @@ +import copy +import logging +import os +from typing import List + +import torch +import torch.nn as nn +import torch.nn.functional as F +from yarr.agents.agent import Agent, Summary, ActResult, ScalarSummary, HistogramSummary + +from helpers import utils +from helpers.utils import stack_on_channel + +from helpers.clip.core.clip import build_model, load_clip + +NAME = "ViTBCLangAgent" +REPLAY_ALPHA = 0.7 +REPLAY_BETA = 1.0 + + +class Actor(nn.Module): + def __init__(self, actor_network: nn.Module): + super(Actor, self).__init__() + self._actor_network = copy.deepcopy(actor_network) + self._actor_network.build() + + def forward(self, observations, robot_state, lang_goal_emb): + mu = self._actor_network(observations, robot_state, lang_goal_emb) + return mu + + +class ViTBCLangAgent(Agent): + def __init__( + self, + actor_network: nn.Module, + camera_name: str, + lr: float = 0.01, + weight_decay: float = 1e-5, + grad_clip: float = 20.0, + ): + self._camera_name = camera_name + self._actor_network = actor_network + self._lr = lr + self._weight_decay = weight_decay + self._grad_clip = grad_clip + + def build(self, training: bool, device: torch.device = None): + if device is None: + device = torch.device("cpu") + self._actor = Actor(self._actor_network).to(device).train(training) + if training: + self._actor_optimizer = torch.optim.Adam( + self._actor.parameters(), lr=self._lr, weight_decay=self._weight_decay + ) + logging.info( + "# Actor Params: %d" + % sum(p.numel() for p in self._actor.parameters() if p.requires_grad) + ) + else: + for p in self._actor.parameters(): + p.requires_grad = False + + model, _ = load_clip("RN50", jit=False) + self._clip_rn50 = build_model(model.state_dict()) + self._clip_rn50 = self._clip_rn50.float().to(device) + self._clip_rn50.eval() + del model + + self._device = device + + def _grad_step(self, loss, opt, model_params=None, clip=None): + opt.zero_grad() + loss.backward() + if clip is not None and model_params is not None: + nn.utils.clip_grad_value_(model_params, clip) + opt.step() + + def update(self, step: int, replay_sample: dict) -> dict: + lang_goal_emb = replay_sample["lang_goal_emb"] + robot_state = replay_sample["low_dim_state"] + observations = [ + replay_sample["%s_rgb" % self._camera_name], + replay_sample["%s_point_cloud" % self._camera_name], + ] + mu = self._actor(observations, robot_state, lang_goal_emb) + loss_weights = utils.loss_weights(replay_sample, REPLAY_BETA) + delta = F.mse_loss(mu, replay_sample["action"], reduction="none").mean(1) + loss = (delta * loss_weights).mean() + self._grad_step( + loss, self._actor_optimizer, self._actor.parameters(), self._grad_clip + ) + self._summaries = { + "pi/loss": loss, + "pi/mu": mu.mean(), + } + return {"total_losses": loss} + + def _normalize_quat(self, x): + return x / x.square().sum(dim=1).sqrt().unsqueeze(-1) + + def act(self, step: int, observation: dict, deterministic=False) -> ActResult: + lang_goal_tokens = observation.get("lang_goal_tokens", None).long() + + with torch.no_grad(): + lang_goal_tokens = lang_goal_tokens.to(device=self._device) + lang_goal_emb, _ = self._clip_rn50.encode_text_with_embeddings( + lang_goal_tokens[0] + ) + lang_goal_emb = lang_goal_emb.to(device=self._device) + + observations = [ + observation["%s_rgb" % self._camera_name][0].to(self._device), + observation["%s_point_cloud" % self._camera_name][0].to(self._device), + ] + robot_state = observation["low_dim_state"][0].to(self._device) + + mu = self._actor(observations, robot_state, lang_goal_emb) + mu = torch.cat([mu[:, :3], self._normalize_quat(mu[:, 3:7]), mu[:, 7:]], dim=-1) + ignore_collisions = torch.Tensor([1.0]).to(mu.device) + mu0 = torch.cat([mu[0], ignore_collisions]) + return ActResult(mu0.detach().cpu()) + + def update_summaries(self) -> List[Summary]: + summaries = [] + for n, v in self._summaries.items(): + summaries.append(ScalarSummary("%s/%s" % (NAME, n), v)) + + for tag, param in self._actor.named_parameters(): + summaries.append( + HistogramSummary("%s/gradient/%s" % (NAME, tag), param.grad) + ) + summaries.append(HistogramSummary("%s/weight/%s" % (NAME, tag), param.data)) + + return summaries + + def act_summaries(self) -> List[Summary]: + return [] + + def load_weights(self, savedir: str): + self._actor.load_state_dict( + torch.load( + os.path.join(savedir, "bc_actor.pt"), map_location=torch.device("cpu") + ) + ) + print("Loaded weights from %s" % savedir) + + def save_weights(self, savedir: str): + torch.save(self._actor.state_dict(), os.path.join(savedir, "bc_actor.pt")) diff --git a/external/peract_bimanual/agents/bimanual_peract/__init__.py b/external/peract_bimanual/agents/bimanual_peract/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..5ae0978b0b3074a5458e507cd478b24e2fb20fa0 --- /dev/null +++ b/external/peract_bimanual/agents/bimanual_peract/__init__.py @@ -0,0 +1 @@ +import agents.bimanual_peract.launch_utils diff --git a/external/peract_bimanual/agents/bimanual_peract/launch_utils.py b/external/peract_bimanual/agents/bimanual_peract/launch_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..caa67f252dd585e15dd994156508c0dfb93e3d55 --- /dev/null +++ b/external/peract_bimanual/agents/bimanual_peract/launch_utils.py @@ -0,0 +1,93 @@ +# Adapted from ARM +# Source: https://github.com/stepjam/ARM +# License: https://github.com/stepjam/ARM/LICENSE + + +from helpers.preprocess_agent import PreprocessAgent + +from agents.bimanual_peract.perceiver_lang_io import PerceiverVoxelLangEncoder +from agents.bimanual_peract.qattention_peract_bc_agent import QAttentionPerActBCAgent +from agents.bimanual_peract.qattention_stack_agent import QAttentionStackAgent + +from omegaconf import DictConfig + + +def create_agent(cfg: DictConfig): + depth_0bounds = cfg.rlbench.scene_bounds + cam_resolution = cfg.rlbench.camera_resolution + + num_rotation_classes = int(360.0 // cfg.method.rotation_resolution) + qattention_agents = [] + for depth, vox_size in enumerate(cfg.method.voxel_sizes): + last = depth == len(cfg.method.voxel_sizes) - 1 + perceiver_encoder = PerceiverVoxelLangEncoder( + depth=cfg.method.transformer_depth, + iterations=cfg.method.transformer_iterations, + voxel_size=vox_size, + initial_dim=3 + 3 + 1 + 3, + low_dim_size=cfg.method.low_dim_size, + layer=depth, + num_rotation_classes=num_rotation_classes if last else 0, + num_grip_classes=2 if last else 0, + num_collision_classes=2 if last else 0, + input_axis=3, + num_latents=cfg.method.num_latents, + latent_dim=cfg.method.latent_dim, + cross_heads=cfg.method.cross_heads, + latent_heads=cfg.method.latent_heads, + cross_dim_head=cfg.method.cross_dim_head, + latent_dim_head=cfg.method.latent_dim_head, + weight_tie_layers=False, + activation=cfg.method.activation, + pos_encoding_with_lang=cfg.method.pos_encoding_with_lang, + input_dropout=cfg.method.input_dropout, + attn_dropout=cfg.method.attn_dropout, + decoder_dropout=cfg.method.decoder_dropout, + lang_fusion_type=cfg.method.lang_fusion_type, + voxel_patch_size=cfg.method.voxel_patch_size, + voxel_patch_stride=cfg.method.voxel_patch_stride, + no_skip_connection=cfg.method.no_skip_connection, + no_perceiver=cfg.method.no_perceiver, + no_language=cfg.method.no_language, + final_dim=cfg.method.final_dim, + ) + + qattention_agent = QAttentionPerActBCAgent( + layer=depth, + coordinate_bounds=depth_0bounds, + perceiver_encoder=perceiver_encoder, + camera_names=cfg.rlbench.cameras, + voxel_size=vox_size, + bounds_offset=cfg.method.bounds_offset[depth - 1] if depth > 0 else None, + image_crop_size=cfg.method.image_crop_size, + lr=cfg.method.lr, + training_iterations=cfg.framework.training_iterations, + lr_scheduler=cfg.method.lr_scheduler, + num_warmup_steps=cfg.method.num_warmup_steps, + trans_loss_weight=cfg.method.trans_loss_weight, + rot_loss_weight=cfg.method.rot_loss_weight, + grip_loss_weight=cfg.method.grip_loss_weight, + collision_loss_weight=cfg.method.collision_loss_weight, + include_low_dim_state=True, + image_resolution=cam_resolution, + batch_size=cfg.replay.batch_size, + voxel_feature_size=3, + lambda_weight_l2=cfg.method.lambda_weight_l2, + num_rotation_classes=num_rotation_classes, + rotation_resolution=cfg.method.rotation_resolution, + transform_augmentation=cfg.method.transform_augmentation.apply_se3, + transform_augmentation_xyz=cfg.method.transform_augmentation.aug_xyz, + transform_augmentation_rpy=cfg.method.transform_augmentation.aug_rpy, + transform_augmentation_rot_resolution=cfg.method.transform_augmentation.aug_rot_resolution, + optimizer_type=cfg.method.optimizer, + num_devices=cfg.ddp.num_devices, + ) + qattention_agents.append(qattention_agent) + + rotation_agent = QAttentionStackAgent( + qattention_agents=qattention_agents, + rotation_resolution=cfg.method.rotation_resolution, + camera_names=cfg.rlbench.cameras, + ) + preprocess_agent = PreprocessAgent(pose_agent=rotation_agent) + return preprocess_agent diff --git a/external/peract_bimanual/agents/bimanual_peract/perceiver_lang_io.py b/external/peract_bimanual/agents/bimanual_peract/perceiver_lang_io.py new file mode 100644 index 0000000000000000000000000000000000000000..773441d3b68500716577a9db5cf095928bf3e13c --- /dev/null +++ b/external/peract_bimanual/agents/bimanual_peract/perceiver_lang_io.py @@ -0,0 +1,549 @@ +# Perceiver IO implementation adpated for manipulation +# Source: https://github.com/lucidrains/perceiver-pytorch +# License: https://github.com/lucidrains/perceiver-pytorch/blob/main/LICENSE + +import torch +from torch import nn + +from einops import rearrange +from einops import repeat + +from perceiver_pytorch.perceiver_pytorch import cache_fn +from perceiver_pytorch.perceiver_pytorch import PreNorm, FeedForward, Attention + +from helpers.network_utils import ( + DenseBlock, + SpatialSoftmax3D, + Conv3DBlock, + Conv3DUpsampleBlock, +) + + +# PerceiverIO adapted for 6-DoF manipulation +class PerceiverVoxelLangEncoder(nn.Module): + def __init__( + self, + depth, # number of self-attention layers + iterations, # number cross-attention iterations (PerceiverIO uses just 1) + voxel_size, # N voxels per side (size: N*N*N) + initial_dim, # 10 dimensions - dimension of the input sequence to be encoded + low_dim_size, # 4 dimensions - proprioception: {gripper_open, left_finger, right_finger, timestep} + layer=0, + num_rotation_classes=72, # 5 degree increments (5*72=360) for each of the 3-axis + num_grip_classes=2, # open or not open + num_collision_classes=2, # collisions allowed or not allowed + input_axis=3, # 3D tensors have 3 axes + num_latents=512, # number of latent vectors + im_channels=64, # intermediate channel size + latent_dim=512, # dimensions of latent vectors + cross_heads=1, # number of cross-attention heads + latent_heads=8, # number of latent heads + cross_dim_head=64, + latent_dim_head=64, + activation="relu", + weight_tie_layers=False, + pos_encoding_with_lang=True, + input_dropout=0.1, + attn_dropout=0.1, + decoder_dropout=0.0, + lang_fusion_type="seq", + voxel_patch_size=9, + voxel_patch_stride=8, + no_skip_connection=False, + no_perceiver=False, + no_language=False, + final_dim=64, + ): + super().__init__() + self.depth = depth + self.layer = layer + self.init_dim = int(initial_dim) + self.iterations = iterations + self.input_axis = input_axis + self.voxel_size = voxel_size + self.low_dim_size = low_dim_size + self.im_channels = im_channels + self.pos_encoding_with_lang = pos_encoding_with_lang + self.lang_fusion_type = lang_fusion_type + self.voxel_patch_size = voxel_patch_size + self.voxel_patch_stride = voxel_patch_stride + self.num_rotation_classes = num_rotation_classes + self.num_grip_classes = num_grip_classes + self.num_collision_classes = num_collision_classes + self.final_dim = final_dim + self.input_dropout = input_dropout + self.attn_dropout = attn_dropout + self.decoder_dropout = decoder_dropout + self.no_skip_connection = no_skip_connection + self.no_perceiver = no_perceiver + self.no_language = no_language + + # patchified input dimensions + spatial_size = voxel_size // self.voxel_patch_stride # 100/5 = 20 + + # 64 voxel features + 64 proprio features (+ 64 lang goal features if concattenated) + self.input_dim_before_seq = ( + self.im_channels * 3 + if self.lang_fusion_type == "concat" + else self.im_channels * 2 + ) + + # CLIP language feature dimensions + lang_feat_dim, lang_emb_dim, lang_max_seq_len = 1024, 512, 77 + + # learnable positional encoding + if self.pos_encoding_with_lang: + self.pos_encoding = nn.Parameter( + torch.randn( + 1, lang_max_seq_len + spatial_size**3, self.input_dim_before_seq + ) + ) + else: + # assert self.lang_fusion_type == 'concat', 'Only concat is supported for pos encoding without lang.' + self.pos_encoding = nn.Parameter( + torch.randn( + 1, + spatial_size, + spatial_size, + spatial_size, + self.input_dim_before_seq, + ) + ) + + # voxel input preprocessing 1x1 conv encoder + self.input_preprocess = Conv3DBlock( + self.init_dim, + self.im_channels, + kernel_sizes=1, + strides=1, + norm=None, + activation=activation, + ) + + # patchify conv + self.patchify = Conv3DBlock( + self.input_preprocess.out_channels, + self.im_channels, + kernel_sizes=self.voxel_patch_size, + strides=self.voxel_patch_stride, + norm=None, + activation=activation, + ) + + # language preprocess + if self.lang_fusion_type == "concat": + self.lang_preprocess = nn.Linear(lang_feat_dim, self.im_channels) + elif self.lang_fusion_type == "seq": + self.lang_preprocess = nn.Linear(lang_emb_dim, self.im_channels * 2) + + # proprioception + if self.low_dim_size > 0: + self.proprio_preprocess = DenseBlock( + self.low_dim_size, + self.im_channels, + norm=None, + activation=activation, + ) + + # pooling functions + self.local_maxp = nn.MaxPool3d(3, 2, padding=1) + self.global_maxp = nn.AdaptiveMaxPool3d(1) + + # 1st 3D softmax + self.ss0 = SpatialSoftmax3D( + self.voxel_size, self.voxel_size, self.voxel_size, self.im_channels + ) + flat_size = self.im_channels * 4 + + # latent vectors (that are randomly initialized) + self.latents = nn.Parameter(torch.randn(num_latents, latent_dim)) + + # encoder cross attention + self.cross_attend_blocks = nn.ModuleList( + [ + PreNorm( + latent_dim, + Attention( + latent_dim, + self.input_dim_before_seq, + heads=cross_heads, + dim_head=cross_dim_head, + dropout=input_dropout, + ), + context_dim=self.input_dim_before_seq, + ), + PreNorm(latent_dim, FeedForward(latent_dim)), + PreNorm(latent_dim, FeedForward(latent_dim)), + ] + ) + + get_latent_attn = lambda: PreNorm( + latent_dim, + Attention( + latent_dim, + heads=latent_heads, + dim_head=latent_dim_head, + dropout=attn_dropout, + ), + ) + get_latent_ff = lambda: PreNorm(latent_dim, FeedForward(latent_dim)) + get_latent_attn, get_latent_ff = map(cache_fn, (get_latent_attn, get_latent_ff)) + + # self attention layers + self.layers = nn.ModuleList([]) + cache_args = {"_cache": weight_tie_layers} + + for i in range(depth): + self.layers.append( + nn.ModuleList( + [ + get_latent_attn(**cache_args), + get_latent_ff(**cache_args), + get_latent_attn(**cache_args), + get_latent_ff(**cache_args), + ] + ) + ) + + self.combined_latent_attn = get_latent_attn(**cache_args) + self.combined_latent_ff = get_latent_ff(**cache_args) + + # decoder cross attention + self.decoder_cross_attn_right = PreNorm( + self.input_dim_before_seq, + Attention( + self.input_dim_before_seq, + latent_dim, + heads=cross_heads, + dim_head=cross_dim_head, + dropout=decoder_dropout, + ), + context_dim=latent_dim, + ) + + self.decoder_cross_attn_left = PreNorm( + self.input_dim_before_seq, + Attention( + self.input_dim_before_seq, + latent_dim, + heads=cross_heads, + dim_head=cross_dim_head, + dropout=decoder_dropout, + ), + context_dim=latent_dim, + ) + + # upsample conv + self.up0 = Conv3DUpsampleBlock( + self.input_dim_before_seq, + self.final_dim, + kernel_sizes=self.voxel_patch_size, + strides=self.voxel_patch_stride, + norm=None, + activation=activation, + ) + + # 2nd 3D softmax + self.ss1 = SpatialSoftmax3D( + spatial_size, spatial_size, spatial_size, self.input_dim_before_seq + ) + + flat_size += self.input_dim_before_seq * 4 + + # final 3D softmax + self.final = Conv3DBlock( + self.im_channels + if (self.no_perceiver or self.no_skip_connection) + else self.im_channels * 2, + self.im_channels, + kernel_sizes=3, + strides=1, + norm=None, + activation=activation, + ) + + self.right_trans_decoder = Conv3DBlock( + self.final_dim, + 1, + kernel_sizes=3, + strides=1, + norm=None, + activation=None, + ) + + self.left_trans_decoder = Conv3DBlock( + self.final_dim, + 1, + kernel_sizes=3, + strides=1, + norm=None, + activation=None, + ) + + # rotation, gripper, and collision MLP layers + if self.num_rotation_classes > 0: + self.ss_final = SpatialSoftmax3D( + self.voxel_size, self.voxel_size, self.voxel_size, self.im_channels + ) + + flat_size += self.im_channels * 4 + + self.right_dense0 = DenseBlock(flat_size, 256, None, activation) + self.right_dense1 = DenseBlock(256, self.final_dim, None, activation) + + self.left_dense0 = DenseBlock(flat_size, 256, None, activation) + self.left_dense1 = DenseBlock(256, self.final_dim, None, activation) + + self.right_rot_grip_collision_ff = DenseBlock( + self.final_dim, + self.num_rotation_classes * 3 + + self.num_grip_classes + + self.num_collision_classes, + None, + None, + ) + + self.left_rot_grip_collision_ff = DenseBlock( + self.final_dim, + self.num_rotation_classes * 3 + + self.num_grip_classes + + self.num_collision_classes, + None, + None, + ) + + def encode_text(self, x): + with torch.no_grad(): + text_feat, text_emb = self._clip_rn50.encode_text_with_embeddings(x) + + text_feat = text_feat.detach() + text_emb = text_emb.detach() + text_mask = torch.where(x == 0, x, 1) # [1, max_token_len] + return text_feat, text_emb + + def forward( + self, + ins, + proprio, + lang_goal_emb, + lang_token_embs, + prev_layer_voxel_grid, + bounds, + prev_layer_bounds, + mask=None, + ): + # preprocess input + d0 = self.input_preprocess(ins) # [B,10,100,100,100] -> [B,64,100,100,100] + + # aggregated features from 1st softmax and maxpool for MLP decoders + feats = [self.ss0(d0.contiguous()), self.global_maxp(d0).view(ins.shape[0], -1)] + + # patchify input (5x5x5 patches) + ins = self.patchify(d0) # [B,64,100,100,100] -> [B,64,20,20,20] + + b, c, d, h, w, device = *ins.shape, ins.device + axis = [d, h, w] + assert ( + len(axis) == self.input_axis + ), "input must have the same number of axis as input_axis" + + # concat proprio + if self.low_dim_size > 0: + p = self.proprio_preprocess(proprio) # [B,4] -> [B,64] + p = p.unsqueeze(-1).unsqueeze(-1).unsqueeze(-1).repeat(1, 1, d, h, w) + ins = torch.cat([ins, p], dim=1) # [B,128,20,20,20] + + # language ablation + if self.no_language: + lang_goal_emb = torch.zeros_like(lang_goal_emb) + lang_token_embs = torch.zeros_like(lang_token_embs) + + # option 1: tile and concat lang goal to input + if self.lang_fusion_type == "concat": + lang_emb = lang_goal_emb + lang_emb = lang_emb.to(dtype=ins.dtype) + l = self.lang_preprocess(lang_emb) + l = l.unsqueeze(-1).unsqueeze(-1).unsqueeze(-1).repeat(1, 1, d, h, w) + ins = torch.cat([ins, l], dim=1) + + # channel last + ins = rearrange(ins, "b d ... -> b ... d") # [B,20,20,20,128] + + # add pos encoding to grid + if not self.pos_encoding_with_lang: + ins = ins + self.pos_encoding + + ######################## NOTE ############################# + # NOTE: If you add positional encodings ^here the lang embs + # won't have positional encodings. I accidently forgot + # to turn this off for all the experiments in the paper. + # So I guess those models were using language embs + # as a bag of words :( But it doesn't matter much for + # RLBench tasks since we don't test for novel instructions + # at test time anyway. The recommend way is to add + # positional encodings to the final input sequence + # fed into the Perceiver Transformer, as done below + # (and also in the Colab tutorial). + ########################################################### + + # concat to channels of and flatten axis + queries_orig_shape = ins.shape + + # rearrange input to be channel last + ins = rearrange(ins, "b ... d -> b (...) d") # [B,8000,128] + ins_wo_prev_layers = ins + + # option 2: add lang token embs as a sequence + if self.lang_fusion_type == "seq": + l = self.lang_preprocess(lang_token_embs) # [B,77,1024] -> [B,77,128] + ins = torch.cat((l, ins), dim=1) # [B,8077,128] + + # add pos encoding to language + flattened grid (the recommended way) + if self.pos_encoding_with_lang: + ins = ins + self.pos_encoding + + # batchify latents + x = repeat(self.latents, "n d -> b n d", b=b) + + cross_attn, cross_ff_right, cross_ff_left = self.cross_attend_blocks + + for it in range(self.iterations): + # encoder cross attention + x = cross_attn(x, context=ins, mask=mask) + x + + # x.size() = [1, num_latents, latent_dim] + x_right, x_left = x.chunk(2, dim=1) + + x_right = cross_ff_right(x_right) + x_right + x_left = cross_ff_left(x_left) + x_left + + # self-attention layers + for ( + self_attn_right, + self_ff_right, + self_attn_left, + self_ff_left, + ) in self.layers: + x_right = self_attn_right(x_right) + x_right + x_right = self_ff_right(x_right) + x_right + + x_left = self_attn_left(x_left) + x_left + x_left = self_ff_left(x_left) + x_left + + x = torch.concat([x_right, x_left], dim=1) + x = self.combined_latent_attn(x) + x + x = self.combined_latent_ff(x) + x + + x_right, x_left = x.chunk(2, dim=1) + + # decoder cross attention + latents_right = self.decoder_cross_attn_right(ins, context=x_right) + latents_left = self.decoder_cross_attn_left(ins, context=x_left) + + # crop out the language part of the output sequence + if self.lang_fusion_type == "seq": + latents_right = latents_right[:, l.shape[1] :] + latents_left = latents_left[:, l.shape[1] :] + + # reshape back to voxel grid + latents_right = latents_right.view( + b, *queries_orig_shape[1:-1], latents_right.shape[-1] + ) # [B,20,20,20,64] + latents_right = rearrange( + latents_right, "b ... d -> b d ..." + ) # [B,64,20,20,20] + + # reshape back to voxel grid + latents_left = latents_left.view( + b, *queries_orig_shape[1:-1], latents_left.shape[-1] + ) # [B,20,20,20,64] + latents_left = rearrange(latents_left, "b ... d -> b d ...") # [B,64,20,20,20] + + # aggregated features from 2nd softmax and maxpool for MLP decoders + + feats_right = feats.copy() + feats_left = feats + + feats_right.extend( + [ + self.ss1(latents_right.contiguous()), + self.global_maxp(latents_right).view(b, -1), + ] + ) + feats_left.extend( + [ + self.ss1(latents_left.contiguous()), + self.global_maxp(latents_left).view(b, -1), + ] + ) + + # upsample + u0_right = self.up0(latents_right) + u0_left = self.up0(latents_left) + + # ablations + if self.no_skip_connection: + u_right = self.final(u0_right) + u_left = self.final(u0_left) + elif self.no_perceiver: + u_right = self.final(d0) + u_left = self.final(d0) + else: + u_right = self.final(torch.cat([d0, u0_right], dim=1)) + u_left = self.final(torch.cat([d0, u0_left], dim=1)) + + # translation decoder + right_trans = self.right_trans_decoder(u_right) + left_trans = self.left_trans_decoder(u_left) + + # rotation, gripper, and collision MLPs + rot_and_grip_out = None + if self.num_rotation_classes > 0: + feats_right.extend( + [ + self.ss_final(u_right.contiguous()), + self.global_maxp(u_right).view(b, -1), + ] + ) + + right_dense0 = self.right_dense0(torch.cat(feats_right, dim=1)) + right_dense1 = self.right_dense1(right_dense0) # [B,72*3+2+2] + + right_rot_and_grip_collision_out = self.right_rot_grip_collision_ff( + right_dense1 + ) + right_rot_and_grip_out = right_rot_and_grip_collision_out[ + :, : -self.num_collision_classes + ] + right_collision_out = right_rot_and_grip_collision_out[ + :, -self.num_collision_classes : + ] + + feats_left.extend( + [ + self.ss_final(u_left.contiguous()), + self.global_maxp(u_left).view(b, -1), + ] + ) + + left_dense0 = self.left_dense0(torch.cat(feats_left, dim=1)) + left_dense1 = self.left_dense1(left_dense0) # [B,72*3+2+2] + + left_rot_and_grip_collision_out = self.left_rot_grip_collision_ff( + left_dense1 + ) + left_rot_and_grip_out = left_rot_and_grip_collision_out[ + :, : -self.num_collision_classes + ] + left_collision_out = left_rot_and_grip_collision_out[ + :, -self.num_collision_classes : + ] + + return ( + right_trans, + right_rot_and_grip_out, + right_collision_out, + left_trans, + left_rot_and_grip_out, + left_collision_out, + ) diff --git a/external/peract_bimanual/agents/bimanual_peract/qattention_peract_bc_agent.py b/external/peract_bimanual/agents/bimanual_peract/qattention_peract_bc_agent.py new file mode 100644 index 0000000000000000000000000000000000000000..0de930fc42932232b4c20a1a9bd145709abaf5a8 --- /dev/null +++ b/external/peract_bimanual/agents/bimanual_peract/qattention_peract_bc_agent.py @@ -0,0 +1,1063 @@ +import copy +import logging +import os +from typing import List + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +from torchvision import transforms +from pytorch3d import transforms as torch3d_tf +from yarr.agents.agent import ( + Agent, + ActResult, + ScalarSummary, + HistogramSummary, + ImageSummary, + Summary, +) + +from helpers import utils +from helpers.utils import visualise_voxel, stack_on_channel +from voxel.voxel_grid import VoxelGrid +from voxel.augmentation import apply_se3_augmentation +from einops import rearrange +from helpers.clip.core.clip import build_model, load_clip + +import transformers +from helpers.optim.lamb import Lamb + +from torch.nn.parallel import DistributedDataParallel as DDP + +NAME = "QAttentionAgent" + + +class QFunction(nn.Module): + def __init__( + self, + perceiver_encoder: nn.Module, + voxelizer: VoxelGrid, + bounds_offset: float, + rotation_resolution: float, + device, + training, + ): + super(QFunction, self).__init__() + self._rotation_resolution = rotation_resolution + self._voxelizer = voxelizer + self._bounds_offset = bounds_offset + self._qnet = perceiver_encoder.to(device) + + # distributed training + if training: + self._qnet = DDP(self._qnet, device_ids=[device]) + + def _argmax_3d(self, tensor_orig): + b, c, d, h, w = tensor_orig.shape # c will be one + idxs = tensor_orig.view(b, c, -1).argmax(-1) + indices = torch.cat([((idxs // h) // d), (idxs // h) % w, idxs % w], 1) + return indices + + def choose_highest_action(self, q_trans, q_rot_grip, q_collision): + coords = self._argmax_3d(q_trans) + rot_and_grip_indicies = None + ignore_collision = None + if q_rot_grip is not None: + q_rot = torch.stack( + torch.split( + q_rot_grip[:, :-2], int(360 // self._rotation_resolution), dim=1 + ), + dim=1, + ) + rot_and_grip_indicies = torch.cat( + [ + q_rot[:, 0:1].argmax(-1), + q_rot[:, 1:2].argmax(-1), + q_rot[:, 2:3].argmax(-1), + q_rot_grip[:, -2:].argmax(-1, keepdim=True), + ], + -1, + ) + ignore_collision = q_collision[:, -2:].argmax(-1, keepdim=True) + return coords, rot_and_grip_indicies, ignore_collision + + def forward( + self, + rgb_pcd, + proprio, + pcd, + lang_goal_emb, + lang_token_embs, + bounds=None, + prev_bounds=None, + prev_layer_voxel_grid=None, + ): + # rgb_pcd will be list of list (list of [rgb, pcd]) + b = rgb_pcd[0][0].shape[0] + pcd_flat = torch.cat([p.permute(0, 2, 3, 1).reshape(b, -1, 3) for p in pcd], 1) + + # flatten RGBs and Pointclouds + rgb = [rp[0] for rp in rgb_pcd] + feat_size = rgb[0].shape[1] + flat_imag_features = torch.cat( + [p.permute(0, 2, 3, 1).reshape(b, -1, feat_size) for p in rgb], 1 + ) + + # construct voxel grid + voxel_grid = self._voxelizer.coords_to_bounding_voxel_grid( + pcd_flat, coord_features=flat_imag_features, coord_bounds=bounds + ) + + # swap to channels fist + voxel_grid = voxel_grid.permute(0, 4, 1, 2, 3).detach() + + # batch bounds if necessary + if bounds.shape[0] != b: + bounds = bounds.repeat(b, 1) + + # forward pass + split_pred = self._qnet( + voxel_grid, + proprio, + lang_goal_emb, + lang_token_embs, + prev_layer_voxel_grid, + bounds, + prev_bounds, + ) + + return split_pred, voxel_grid + + +class QAttentionPerActBCAgent(Agent): + def __init__( + self, + layer: int, + coordinate_bounds: list, + perceiver_encoder: nn.Module, + camera_names: list, + batch_size: int, + voxel_size: int, + bounds_offset: float, + voxel_feature_size: int, + image_crop_size: int, + num_rotation_classes: int, + rotation_resolution: float, + lr: float = 0.0001, + lr_scheduler: bool = False, + training_iterations: int = 100000, + num_warmup_steps: int = 20000, + trans_loss_weight: float = 1.0, + rot_loss_weight: float = 1.0, + grip_loss_weight: float = 1.0, + collision_loss_weight: float = 1.0, + include_low_dim_state: bool = False, + image_resolution: list = None, + lambda_weight_l2: float = 0.0, + transform_augmentation: bool = True, + transform_augmentation_xyz: list = [0.0, 0.0, 0.0], + transform_augmentation_rpy: list = [0.0, 0.0, 180.0], + transform_augmentation_rot_resolution: int = 5, + optimizer_type: str = "adam", + num_devices: int = 1, + ): + self._layer = layer + self._coordinate_bounds = coordinate_bounds + self._perceiver_encoder = perceiver_encoder + self._voxel_feature_size = voxel_feature_size + self._bounds_offset = bounds_offset + self._image_crop_size = image_crop_size + self._lr = lr + self._lr_scheduler = lr_scheduler + self._training_iterations = training_iterations + self._num_warmup_steps = num_warmup_steps + self._trans_loss_weight = trans_loss_weight + self._rot_loss_weight = rot_loss_weight + self._grip_loss_weight = grip_loss_weight + self._collision_loss_weight = collision_loss_weight + self._include_low_dim_state = include_low_dim_state + self._image_resolution = image_resolution or [128, 128] + self._voxel_size = voxel_size + self._camera_names = camera_names + self._num_cameras = len(camera_names) + self._batch_size = batch_size + self._lambda_weight_l2 = lambda_weight_l2 + self._transform_augmentation = transform_augmentation + self._transform_augmentation_xyz = torch.from_numpy( + np.array(transform_augmentation_xyz) + ) + self._transform_augmentation_rpy = transform_augmentation_rpy + self._transform_augmentation_rot_resolution = ( + transform_augmentation_rot_resolution + ) + self._optimizer_type = optimizer_type + self._num_devices = num_devices + self._num_rotation_classes = num_rotation_classes + self._rotation_resolution = rotation_resolution + + self._cross_entropy_loss = nn.CrossEntropyLoss(reduction="none") + self._name = NAME + "_layer" + str(self._layer) + + def build(self, training: bool, device: torch.device = None): + self._training = training + + if device is None: + device = torch.device("cpu") + + self._device = device + + self._voxelizer = VoxelGrid( + coord_bounds=self._coordinate_bounds, + voxel_size=self._voxel_size, + device=device, + batch_size=self._batch_size if training else 1, + feature_size=self._voxel_feature_size, + max_num_coords=np.prod(self._image_resolution) * self._num_cameras, + ) + + self._q = ( + QFunction( + self._perceiver_encoder, + self._voxelizer, + self._bounds_offset, + self._rotation_resolution, + device, + training, + ) + .to(device) + .train(training) + ) + + grid_for_crop = ( + torch.arange(0, self._image_crop_size, device=device) + .unsqueeze(0) + .repeat(self._image_crop_size, 1) + .unsqueeze(-1) + ) + self._grid_for_crop = torch.cat( + [grid_for_crop.transpose(1, 0), grid_for_crop], dim=2 + ).unsqueeze(0) + + self._coordinate_bounds = torch.tensor( + self._coordinate_bounds, device=device + ).unsqueeze(0) + + if self._training: + # optimizer + if self._optimizer_type == "lamb": + self._optimizer = Lamb( + self._q.parameters(), + lr=self._lr, + weight_decay=self._lambda_weight_l2, + betas=(0.9, 0.999), + adam=False, + ) + elif self._optimizer_type == "adam": + self._optimizer = torch.optim.Adam( + self._q.parameters(), + lr=self._lr, + weight_decay=self._lambda_weight_l2, + ) + else: + raise Exception("Unknown optimizer type") + + # learning rate scheduler + if self._lr_scheduler: + self._scheduler = ( + transformers.get_cosine_with_hard_restarts_schedule_with_warmup( + self._optimizer, + num_warmup_steps=self._num_warmup_steps, + num_training_steps=self._training_iterations, + num_cycles=self._training_iterations // 10000, + ) + ) + + # one-hot zero tensors + self._action_trans_one_hot_zeros = torch.zeros( + ( + self._batch_size, + 1, + self._voxel_size, + self._voxel_size, + self._voxel_size, + ), + dtype=int, + device=device, + ) + self._action_rot_x_one_hot_zeros = torch.zeros( + (self._batch_size, self._num_rotation_classes), dtype=int, device=device + ) + self._action_rot_y_one_hot_zeros = torch.zeros( + (self._batch_size, self._num_rotation_classes), dtype=int, device=device + ) + self._action_rot_z_one_hot_zeros = torch.zeros( + (self._batch_size, self._num_rotation_classes), dtype=int, device=device + ) + self._action_grip_one_hot_zeros = torch.zeros( + (self._batch_size, 2), dtype=int, device=device + ) + self._action_ignore_collisions_one_hot_zeros = torch.zeros( + (self._batch_size, 2), dtype=int, device=device + ) + + # print total params + logging.info( + "# Q Params: %d" + % sum( + p.numel() + for name, p in self._q.named_parameters() + if p.requires_grad and "clip" not in name + ) + ) + else: + for param in self._q.parameters(): + param.requires_grad = False + + # load CLIP for encoding language goals during evaluation + model, _ = load_clip("RN50", jit=False) + self._clip_rn50 = build_model(model.state_dict()) + self._clip_rn50 = self._clip_rn50.float().to(device) + self._clip_rn50.eval() + del model + + self._voxelizer.to(device) + self._q.to(device) + + def _extract_crop(self, pixel_action, observation): + # Pixel action will now be (B, 2) + # observation = stack_on_channel(observation) + h = observation.shape[-1] + top_left_corner = torch.clamp( + pixel_action - self._image_crop_size // 2, 0, h - self._image_crop_size + ) + grid = self._grid_for_crop + top_left_corner.unsqueeze(1) + grid = ((grid / float(h)) * 2.0) - 1.0 # between -1 and 1 + # Used for cropping the images across a batch + # swap fro y x, to x, y + grid = torch.cat((grid[:, :, :, 1:2], grid[:, :, :, 0:1]), dim=-1) + crop = F.grid_sample(observation, grid, mode="nearest", align_corners=True) + return crop + + def _preprocess_inputs(self, replay_sample): + obs = [] + pcds = [] + self._crop_summary = [] + for n in self._camera_names: + rgb = replay_sample["%s_rgb" % n] + pcd = replay_sample["%s_point_cloud" % n] + + obs.append([rgb, pcd]) + pcds.append(pcd) + return obs, pcds + + def _act_preprocess_inputs(self, observation): + obs, pcds = [], [] + for n in self._camera_names: + rgb = observation["%s_rgb" % n] + pcd = observation["%s_point_cloud" % n] + + obs.append([rgb, pcd]) + pcds.append(pcd) + return obs, pcds + + def _get_value_from_voxel_index(self, q, voxel_idx): + b, c, d, h, w = q.shape + q_trans_flat = q.view(b, c, d * h * w) + flat_indicies = ( + voxel_idx[:, 0] * d * h + voxel_idx[:, 1] * h + voxel_idx[:, 2] + )[:, None].int() + highest_idxs = flat_indicies.unsqueeze(-1).repeat(1, c, 1) + chosen_voxel_values = q_trans_flat.gather(2, highest_idxs)[ + ..., 0 + ] # (B, trans + rot + grip) + return chosen_voxel_values + + def _get_value_from_rot_and_grip(self, rot_grip_q, rot_and_grip_idx): + q_rot = torch.stack( + torch.split( + rot_grip_q[:, :-2], int(360 // self._rotation_resolution), dim=1 + ), + dim=1, + ) # B, 3, 72 + q_grip = rot_grip_q[:, -2:] + rot_and_grip_values = torch.cat( + [ + q_rot[:, 0].gather(1, rot_and_grip_idx[:, 0:1]), + q_rot[:, 1].gather(1, rot_and_grip_idx[:, 1:2]), + q_rot[:, 2].gather(1, rot_and_grip_idx[:, 2:3]), + q_grip.gather(1, rot_and_grip_idx[:, 3:4]), + ], + -1, + ) + return rot_and_grip_values + + def _celoss(self, pred, labels): + return self._cross_entropy_loss(pred, labels.argmax(-1)) + + def _softmax_q_trans(self, q): + q_shape = q.shape + return F.softmax(q.reshape(q_shape[0], -1), dim=1).reshape(q_shape) + + def _softmax_q_rot_grip(self, q_rot_grip): + q_rot_x_flat = q_rot_grip[ + :, 0 * self._num_rotation_classes : 1 * self._num_rotation_classes + ] + q_rot_y_flat = q_rot_grip[ + :, 1 * self._num_rotation_classes : 2 * self._num_rotation_classes + ] + q_rot_z_flat = q_rot_grip[ + :, 2 * self._num_rotation_classes : 3 * self._num_rotation_classes + ] + q_grip_flat = q_rot_grip[:, 3 * self._num_rotation_classes :] + + q_rot_x_flat_softmax = F.softmax(q_rot_x_flat, dim=1) + q_rot_y_flat_softmax = F.softmax(q_rot_y_flat, dim=1) + q_rot_z_flat_softmax = F.softmax(q_rot_z_flat, dim=1) + q_grip_flat_softmax = F.softmax(q_grip_flat, dim=1) + + return torch.cat( + [ + q_rot_x_flat_softmax, + q_rot_y_flat_softmax, + q_rot_z_flat_softmax, + q_grip_flat_softmax, + ], + dim=1, + ) + + def _softmax_ignore_collision(self, q_collision): + q_collision_softmax = F.softmax(q_collision, dim=1) + return q_collision_softmax + + def update(self, step: int, replay_sample: dict) -> dict: + right_action_trans = replay_sample["right_trans_action_indicies"][ + :, self._layer * 3 : self._layer * 3 + 3 + ].int() + right_action_rot_grip = replay_sample["right_rot_grip_action_indicies"].int() + right_action_gripper_pose = replay_sample["right_gripper_pose"] + right_action_ignore_collisions = replay_sample["right_ignore_collisions"].int() + + left_action_trans = replay_sample["left_trans_action_indicies"][ + :, self._layer * 3 : self._layer * 3 + 3 + ].int() + left_action_rot_grip = replay_sample["left_rot_grip_action_indicies"].int() + left_action_gripper_pose = replay_sample["left_gripper_pose"] + left_action_ignore_collisions = replay_sample["left_ignore_collisions"].int() + + lang_goal_emb = replay_sample["lang_goal_emb"].float() + lang_token_embs = replay_sample["lang_token_embs"].float() + prev_layer_voxel_grid = replay_sample.get("prev_layer_voxel_grid", None) + prev_layer_bounds = replay_sample.get("prev_layer_bounds", None) + device = self._device + + bounds = self._coordinate_bounds.to(device) + if self._layer > 0: + right_cp = replay_sample[ + "right_attention_coordinate_layer_%d" % (self._layer - 1) + ] + + left_cp = replay_sample[ + "left_attention_coordinate_layer_%d" % (self._layer - 1) + ] + + right_bounds = torch.cat( + [right_cp - self._bounds_offset, right_cp + self._bounds_offset], dim=1 + ) + left_bounds = torch.cat( + [left_cp - self._bounds_offset, left_cp + self._bounds_offset], dim=1 + ) + + else: + right_bounds = bounds + left_bounds = bounds + + right_proprio = None + left_proprio = None + if self._include_low_dim_state: + right_proprio = replay_sample["right_low_dim_state"] + left_proprio = replay_sample["left_low_dim_state"] + + # ..TODO:: + # Can we add the coordinates of both robots? + # + + obs, pcd = self._preprocess_inputs(replay_sample) + + # batch size + bs = pcd[0].shape[0] + + # We can move the point cloud w.r.t to the other robot's cooridinate system + # similar to apply_se3_augmentation + # + + # SE(3) augmentation of point clouds and actions + if self._transform_augmentation: + from voxel import augmentation + + ( + right_action_trans, + right_action_rot_grip, + left_action_trans, + left_action_rot_grip, + pcd, + ) = augmentation.bimanual_apply_se3_augmentation( + pcd, + right_action_gripper_pose, + right_action_trans, + right_action_rot_grip, + left_action_gripper_pose, + left_action_trans, + left_action_rot_grip, + bounds, + self._layer, + self._transform_augmentation_xyz, + self._transform_augmentation_rpy, + self._transform_augmentation_rot_resolution, + self._voxel_size, + self._rotation_resolution, + self._device, + ) + else: + right_action_trans = right_action_trans.int() + left_action_trans = left_action_trans.int() + + proprio = torch.cat((right_proprio, left_proprio), dim=1) + + right_action = ( + right_action_trans, + right_action_rot_grip, + right_action_ignore_collisions, + ) + left_action = ( + left_action_trans, + left_action_rot_grip, + left_action_ignore_collisions, + ) + # forward pass + q, voxel_grid = self._q( + obs, + proprio, + pcd, + lang_goal_emb, + lang_token_embs, + bounds, + prev_layer_bounds, + prev_layer_voxel_grid, + ) + + ( + right_q_trans, + right_q_rot_grip, + right_q_collision, + left_q_trans, + left_q_rot_grip, + left_q_collision, + ) = q + + # argmax to choose best action + ( + right_coords, + right_rot_and_grip_indicies, + right_ignore_collision_indicies, + ) = self._q.choose_highest_action( + right_q_trans, right_q_rot_grip, right_q_collision + ) + + ( + left_coords, + left_rot_and_grip_indicies, + left_ignore_collision_indicies, + ) = self._q.choose_highest_action( + left_q_trans, left_q_rot_grip, left_q_collision + ) + + ( + right_q_trans_loss, + right_q_rot_loss, + right_q_grip_loss, + right_q_collision_loss, + ) = (0.0, 0.0, 0.0, 0.0) + left_q_trans_loss, left_q_rot_loss, left_q_grip_loss, left_q_collision_loss = ( + 0.0, + 0.0, + 0.0, + 0.0, + ) + + # translation one-hot + right_action_trans_one_hot = self._action_trans_one_hot_zeros.clone().detach() + left_action_trans_one_hot = self._action_trans_one_hot_zeros.clone().detach() + for b in range(bs): + right_gt_coord = right_action_trans[b, :].int() + right_action_trans_one_hot[ + b, :, right_gt_coord[0], right_gt_coord[1], right_gt_coord[2] + ] = 1 + left_gt_coord = left_action_trans[b, :].int() + left_action_trans_one_hot[ + b, :, left_gt_coord[0], left_gt_coord[1], left_gt_coord[2] + ] = 1 + + # translation loss + right_q_trans_flat = right_q_trans.view(bs, -1) + right_action_trans_one_hot_flat = right_action_trans_one_hot.view(bs, -1) + right_q_trans_loss = self._celoss( + right_q_trans_flat, right_action_trans_one_hot_flat + ) + left_q_trans_flat = left_q_trans.view(bs, -1) + left_action_trans_one_hot_flat = left_action_trans_one_hot.view(bs, -1) + left_q_trans_loss = self._celoss( + left_q_trans_flat, left_action_trans_one_hot_flat + ) + + q_trans_loss = right_q_trans_loss + left_q_trans_loss + + with_rot_and_grip = ( + len(right_rot_and_grip_indicies) > 0 and len(left_rot_and_grip_indicies) > 0 + ) + if with_rot_and_grip: + # rotation, gripper, and collision one-hots + right_action_rot_x_one_hot = self._action_rot_x_one_hot_zeros.clone() + right_action_rot_y_one_hot = self._action_rot_y_one_hot_zeros.clone() + right_action_rot_z_one_hot = self._action_rot_z_one_hot_zeros.clone() + right_action_grip_one_hot = self._action_grip_one_hot_zeros.clone() + right_action_ignore_collisions_one_hot = ( + self._action_ignore_collisions_one_hot_zeros.clone() + ) + + left_action_rot_x_one_hot = self._action_rot_x_one_hot_zeros.clone() + left_action_rot_y_one_hot = self._action_rot_y_one_hot_zeros.clone() + left_action_rot_z_one_hot = self._action_rot_z_one_hot_zeros.clone() + left_action_grip_one_hot = self._action_grip_one_hot_zeros.clone() + left_action_ignore_collisions_one_hot = ( + self._action_ignore_collisions_one_hot_zeros.clone() + ) + + for b in range(bs): + right_gt_rot_grip = right_action_rot_grip[b, :].int() + right_action_rot_x_one_hot[b, right_gt_rot_grip[0]] = 1 + right_action_rot_y_one_hot[b, right_gt_rot_grip[1]] = 1 + right_action_rot_z_one_hot[b, right_gt_rot_grip[2]] = 1 + right_action_grip_one_hot[b, right_gt_rot_grip[3]] = 1 + + right_gt_ignore_collisions = right_action_ignore_collisions[b, :].int() + right_action_ignore_collisions_one_hot[ + b, right_gt_ignore_collisions[0] + ] = 1 + + left_gt_rot_grip = left_action_rot_grip[b, :].int() + left_action_rot_x_one_hot[b, left_gt_rot_grip[0]] = 1 + left_action_rot_y_one_hot[b, left_gt_rot_grip[1]] = 1 + left_action_rot_z_one_hot[b, left_gt_rot_grip[2]] = 1 + left_action_grip_one_hot[b, left_gt_rot_grip[3]] = 1 + + left_gt_ignore_collisions = left_action_ignore_collisions[b, :].int() + left_action_ignore_collisions_one_hot[ + b, left_gt_ignore_collisions[0] + ] = 1 + + # flatten predictions + right_q_rot_x_flat = right_q_rot_grip[ + :, 0 * self._num_rotation_classes : 1 * self._num_rotation_classes + ] + right_q_rot_y_flat = right_q_rot_grip[ + :, 1 * self._num_rotation_classes : 2 * self._num_rotation_classes + ] + right_q_rot_z_flat = right_q_rot_grip[ + :, 2 * self._num_rotation_classes : 3 * self._num_rotation_classes + ] + right_q_grip_flat = right_q_rot_grip[:, 3 * self._num_rotation_classes :] + right_q_ignore_collisions_flat = right_q_collision + + left_q_rot_x_flat = left_q_rot_grip[ + :, 0 * self._num_rotation_classes : 1 * self._num_rotation_classes + ] + left_q_rot_y_flat = left_q_rot_grip[ + :, 1 * self._num_rotation_classes : 2 * self._num_rotation_classes + ] + left_q_rot_z_flat = left_q_rot_grip[ + :, 2 * self._num_rotation_classes : 3 * self._num_rotation_classes + ] + left_q_grip_flat = left_q_rot_grip[:, 3 * self._num_rotation_classes :] + left_q_ignore_collisions_flat = left_q_collision + + # rotation loss + right_q_rot_loss += self._celoss( + right_q_rot_x_flat, right_action_rot_x_one_hot + ) + right_q_rot_loss += self._celoss( + right_q_rot_y_flat, right_action_rot_y_one_hot + ) + right_q_rot_loss += self._celoss( + right_q_rot_z_flat, right_action_rot_z_one_hot + ) + + left_q_rot_loss += self._celoss( + left_q_rot_x_flat, left_action_rot_x_one_hot + ) + left_q_rot_loss += self._celoss( + left_q_rot_y_flat, left_action_rot_y_one_hot + ) + left_q_rot_loss += self._celoss( + left_q_rot_z_flat, left_action_rot_z_one_hot + ) + + # gripper loss + right_q_grip_loss += self._celoss( + right_q_grip_flat, right_action_grip_one_hot + ) + left_q_grip_loss += self._celoss(left_q_grip_flat, left_action_grip_one_hot) + + # collision loss + right_q_collision_loss += self._celoss( + right_q_ignore_collisions_flat, right_action_ignore_collisions_one_hot + ) + left_q_collision_loss += self._celoss( + left_q_ignore_collisions_flat, left_action_ignore_collisions_one_hot + ) + + q_trans_loss = right_q_trans_loss + left_q_trans_loss + q_rot_loss = right_q_rot_loss + left_q_rot_loss + q_grip_loss = right_q_grip_loss + left_q_grip_loss + q_collision_loss = right_q_collision_loss + left_q_collision_loss + + combined_losses = ( + (q_trans_loss * self._trans_loss_weight) + + (q_rot_loss * self._rot_loss_weight) + + (q_grip_loss * self._grip_loss_weight) + + (q_collision_loss * self._collision_loss_weight) + ) + total_loss = combined_losses.mean() + + self._optimizer.zero_grad() + total_loss.backward() + self._optimizer.step() + + self._summaries = { + "losses/total_loss": total_loss, + "losses/trans_loss": q_trans_loss.mean(), + "losses/rot_loss": q_rot_loss.mean() if with_rot_and_grip else 0.0, + "losses/grip_loss": q_grip_loss.mean() if with_rot_and_grip else 0.0, + "losses/right/trans_loss": q_trans_loss.mean(), + "losses/right/rot_loss": q_rot_loss.mean() if with_rot_and_grip else 0.0, + "losses/right/grip_loss": q_grip_loss.mean() if with_rot_and_grip else 0.0, + "losses/right/collision_loss": q_collision_loss.mean() + if with_rot_and_grip + else 0.0, + "losses/left/trans_loss": q_trans_loss.mean(), + "losses/left/rot_loss": q_rot_loss.mean() if with_rot_and_grip else 0.0, + "losses/left/grip_loss": q_grip_loss.mean() if with_rot_and_grip else 0.0, + "losses/left/collision_loss": q_collision_loss.mean() + if with_rot_and_grip + else 0.0, + "losses/collision_loss": q_collision_loss.mean() + if with_rot_and_grip + else 0.0, + } + + if self._lr_scheduler: + self._scheduler.step() + self._summaries["learning_rate"] = self._scheduler.get_last_lr()[0] + + self._vis_voxel_grid = voxel_grid[0] + self._right_vis_translation_qvalue = self._softmax_q_trans(right_q_trans[0]) + self._right_vis_max_coordinate = right_coords[0] + self._right_vis_gt_coordinate = right_action_trans[0] + + self._left_vis_translation_qvalue = self._softmax_q_trans(left_q_trans[0]) + self._left_vis_max_coordinate = left_coords[0] + self._left_vis_gt_coordinate = left_action_trans[0] + + # Note: PerAct doesn't use multi-layer voxel grids like C2FARM + # stack prev_layer_voxel_grid(s) from previous layers into a list + if prev_layer_voxel_grid is None: + prev_layer_voxel_grid = [voxel_grid] + else: + prev_layer_voxel_grid = prev_layer_voxel_grid + [voxel_grid] + + # stack prev_layer_bound(s) from previous layers into a list + if prev_layer_bounds is None: + prev_layer_bounds = [self._coordinate_bounds.repeat(bs, 1)] + else: + prev_layer_bounds = prev_layer_bounds + [bounds] + + return { + "total_loss": total_loss, + "prev_layer_voxel_grid": prev_layer_voxel_grid, + "prev_layer_bounds": prev_layer_bounds, + } + + def act(self, step: int, observation: dict, deterministic=False) -> ActResult: + deterministic = True + bounds = self._coordinate_bounds + prev_layer_voxel_grid = observation.get("prev_layer_voxel_grid", None) + prev_layer_bounds = observation.get("prev_layer_bounds", None) + lang_goal_tokens = observation.get("lang_goal_tokens", None).long() + + # extract CLIP language embs + with torch.no_grad(): + lang_goal_tokens = lang_goal_tokens.to(device=self._device) + ( + lang_goal_emb, + lang_token_embs, + ) = self._clip_rn50.encode_text_with_embeddings(lang_goal_tokens[0]) + + # voxelization resolution + res = (bounds[:, 3:] - bounds[:, :3]) / self._voxel_size + max_rot_index = int(360 // self._rotation_resolution) + right_proprio = None + left_proprio = None + + if self._include_low_dim_state: + right_proprio = observation["right_low_dim_state"] + left_proprio = observation["left_low_dim_state"] + right_proprio = right_proprio[0].to(self._device) + left_proprio = left_proprio[0].to(self._device) + + obs, pcd = self._act_preprocess_inputs(observation) + + # correct batch size and device + obs = [[o[0][0].to(self._device), o[1][0].to(self._device)] for o in obs] + + pcd = [p[0].to(self._device) for p in pcd] + lang_goal_emb = lang_goal_emb.to(self._device) + lang_token_embs = lang_token_embs.to(self._device) + bounds = torch.as_tensor(bounds, device=self._device) + prev_layer_voxel_grid = ( + prev_layer_voxel_grid.to(self._device) + if prev_layer_voxel_grid is not None + else None + ) + prev_layer_bounds = ( + prev_layer_bounds.to(self._device) + if prev_layer_bounds is not None + else None + ) + + proprio = torch.cat((right_proprio, left_proprio), dim=1) + + # inference + ( + right_q_trans, + right_q_rot_grip, + right_q_ignore_collisions, + left_q_trans, + left_q_rot_grip, + left_q_ignore_collisions, + ), vox_grid = self._q( + obs, + proprio, + pcd, + lang_goal_emb, + lang_token_embs, + bounds, + prev_layer_bounds, + prev_layer_voxel_grid, + ) + + # softmax Q predictions + right_q_trans = self._softmax_q_trans(right_q_trans) + left_q_trans = self._softmax_q_trans(left_q_trans) + + if right_q_rot_grip is not None: + right_q_rot_grip = self._softmax_q_rot_grip(right_q_rot_grip) + + if left_q_rot_grip is not None: + left_q_rot_grip = self._softmax_q_rot_grip(left_q_rot_grip) + + if right_q_ignore_collisions is not None: + right_q_ignore_collisions = self._softmax_ignore_collision( + right_q_ignore_collisions + ) + + if left_q_ignore_collisions is not None: + left_q_ignore_collisions = self._softmax_ignore_collision( + left_q_ignore_collisions + ) + + # argmax Q predictions + ( + right_coords, + right_rot_and_grip_indicies, + right_ignore_collisions, + ) = self._q.choose_highest_action( + right_q_trans, right_q_rot_grip, right_q_ignore_collisions + ) + ( + left_coords, + left_rot_and_grip_indicies, + left_ignore_collisions, + ) = self._q.choose_highest_action( + left_q_trans, left_q_rot_grip, left_q_ignore_collisions + ) + + if right_q_rot_grip is not None: + right_rot_grip_action = right_rot_and_grip_indicies + if right_q_ignore_collisions is not None: + right_ignore_collisions_action = right_ignore_collisions.int() + + if left_q_rot_grip is not None: + left_rot_grip_action = left_rot_and_grip_indicies + if left_q_ignore_collisions is not None: + left_ignore_collisions_action = left_ignore_collisions.int() + + right_coords = right_coords.int() + left_coords = left_coords.int() + + right_attention_coordinate = bounds[:, :3] + res * right_coords + res / 2 + left_attention_coordinate = bounds[:, :3] + res * left_coords + res / 2 + + # stack prev_layer_voxel_grid(s) into a list + # NOTE: PerAct doesn't used multi-layer voxel grids like C2FARM + if prev_layer_voxel_grid is None: + prev_layer_voxel_grid = [vox_grid] + else: + prev_layer_voxel_grid = prev_layer_voxel_grid + [vox_grid] + + if prev_layer_bounds is None: + prev_layer_bounds = [bounds] + else: + prev_layer_bounds = prev_layer_bounds + [bounds] + + observation_elements = { + "right_attention_coordinate": right_attention_coordinate, + "left_attention_coordinate": left_attention_coordinate, + "prev_layer_voxel_grid": prev_layer_voxel_grid, + "prev_layer_bounds": prev_layer_bounds, + } + info = { + "voxel_grid_depth%d" % self._layer: vox_grid, + "right_q_depth%d" % self._layer: right_q_trans, + "right_voxel_idx_depth%d" % self._layer: right_coords, + "left_q_depth%d" % self._layer: left_q_trans, + "left_voxel_idx_depth%d" % self._layer: left_coords, + } + self._act_voxel_grid = vox_grid[0] + self._right_act_max_coordinate = right_coords[0] + self._right_act_qvalues = right_q_trans[0].detach() + self._left_act_max_coordinate = left_coords[0] + self._left_act_qvalues = left_q_trans[0].detach() + + action = ( + right_coords, + right_rot_grip_action, + right_ignore_collisions, + left_coords, + left_rot_grip_action, + left_ignore_collisions, + ) + + return ActResult(action, observation_elements=observation_elements, info=info) + + def update_summaries(self) -> List[Summary]: + voxel_grid = self._vis_voxel_grid.detach().cpu().numpy() + summaries = [] + summaries.append( + ImageSummary( + "%s/right_update_qattention" % self._name, + transforms.ToTensor()( + visualise_voxel( + voxel_grid, + self._right_vis_translation_qvalue.detach().cpu().numpy(), + self._right_vis_max_coordinate.detach().cpu().numpy(), + self._right_vis_gt_coordinate.detach().cpu().numpy(), + ) + ), + ) + ) + summaries.append( + ImageSummary( + "%s/left_update_qattention" % self._name, + transforms.ToTensor()( + visualise_voxel( + voxel_grid, + self._left_vis_translation_qvalue.detach().cpu().numpy(), + self._left_vis_max_coordinate.detach().cpu().numpy(), + self._left_vis_gt_coordinate.detach().cpu().numpy(), + ) + ), + ) + ) + for n, v in self._summaries.items(): + summaries.append(ScalarSummary("%s/%s" % (self._name, n), v)) + + for name, crop in self._crop_summary: + crops = (torch.cat(torch.split(crop, 3, dim=1), dim=3) + 1.0) / 2.0 + summaries.extend([ImageSummary("%s/crops/%s" % (self._name, name), crops)]) + + for tag, param in self._q.named_parameters(): + # assert not torch.isnan(param.grad.abs() <= 1.0).all() + summaries.append( + HistogramSummary("%s/gradient/%s" % (self._name, tag), param.grad) + ) + summaries.append( + HistogramSummary("%s/weight/%s" % (self._name, tag), param.data) + ) + + return summaries + + def act_summaries(self) -> List[Summary]: + voxel_grid = self._act_voxel_grid.cpu().numpy() + right_q_attention = self._right_act_qvalues.cpu().numpy() + right_highlight_coordinate = self._right_act_max_coordinate.cpu().numpy() + right_visualization = visualise_voxel( + voxel_grid, right_q_attention, right_highlight_coordinate + ) + + left_q_attention = self._left_act_qvalues.cpu().numpy() + left_highlight_coordinate = self._left_act_max_coordinate.cpu().numpy() + left_visualization = visualise_voxel( + voxel_grid, left_q_attention, left_highlight_coordinate + ) + + return [ + ImageSummary( + f"{self._name}/right_act_Qattention", + transforms.ToTensor()(right_visualization), + ), + ImageSummary( + f"{self._name}/left_act_Qattention", + transforms.ToTensor()(left_visualization), + ), + ] + + def load_weights(self, savedir: str): + device = ( + self._device + if not self._training + else torch.device("cuda:%d" % self._device) + ) + weight_file = os.path.join(savedir, "%s.pt" % self._name) + state_dict = torch.load(weight_file, map_location=device) + + # load only keys that are in the current model + merged_state_dict = self._q.state_dict() + for k, v in state_dict.items(): + if not self._training: + k = k.replace("_qnet.module", "_qnet") + if k in merged_state_dict: + merged_state_dict[k] = v + else: + if "_voxelizer" not in k: + logging.warning("key %s not found in checkpoint" % k) + if not self._training: + # reshape voxelizer weights + b = merged_state_dict["_voxelizer._ones_max_coords"].shape[0] + merged_state_dict["_voxelizer._ones_max_coords"] = merged_state_dict[ + "_voxelizer._ones_max_coords" + ][0:1] + flat_shape = merged_state_dict["_voxelizer._flat_output"].shape[0] + merged_state_dict["_voxelizer._flat_output"] = merged_state_dict[ + "_voxelizer._flat_output" + ][0 : flat_shape // b] + merged_state_dict["_voxelizer._tiled_batch_indices"] = merged_state_dict[ + "_voxelizer._tiled_batch_indices" + ][0:1] + merged_state_dict["_voxelizer._index_grid"] = merged_state_dict[ + "_voxelizer._index_grid" + ][0:1] + self._q.load_state_dict(merged_state_dict) + print("loaded weights from %s" % weight_file) + + def save_weights(self, savedir: str): + torch.save(self._q.state_dict(), os.path.join(savedir, "%s.pt" % self._name)) diff --git a/external/peract_bimanual/agents/bimanual_peract/qattention_stack_agent.py b/external/peract_bimanual/agents/bimanual_peract/qattention_stack_agent.py new file mode 100644 index 0000000000000000000000000000000000000000..4c135bc41c53bc2a6b09cd2a7c0d1ca6e972b09f --- /dev/null +++ b/external/peract_bimanual/agents/bimanual_peract/qattention_stack_agent.py @@ -0,0 +1,202 @@ +from typing import List + +import torch +from yarr.agents.agent import Agent, ActResult, Summary + +import numpy as np + +from helpers import utils +from agents.bimanual_peract.qattention_peract_bc_agent import QAttentionPerActBCAgent + +NAME = "QAttentionStackAgent" + + +class QAttentionStackAgent(Agent): + def __init__( + self, + qattention_agents: List[QAttentionPerActBCAgent], + rotation_resolution: float, + camera_names: List[str], + rotation_prediction_depth: int = 0, + ): + super(QAttentionStackAgent, self).__init__() + self._qattention_agents = qattention_agents + self._rotation_resolution = rotation_resolution + self._camera_names = camera_names + self._rotation_prediction_depth = rotation_prediction_depth + + def build(self, training: bool, device=None) -> None: + self._device = device + if self._device is None: + self._device = torch.device("cpu") + for qa in self._qattention_agents: + qa.build(training, device) + + def update(self, step: int, replay_sample: dict) -> dict: + priorities = 0 + total_losses = 0.0 + for qa in self._qattention_agents: + update_dict = qa.update(step, replay_sample) + replay_sample.update(update_dict) + total_losses += update_dict["total_loss"] + return { + "total_losses": total_losses, + } + + def act(self, step: int, observation: dict, deterministic=False) -> ActResult: + observation_elements = {} + ( + right_translation_results, + right_rot_grip_results, + right_ignore_collisions_results, + ) = ([], [], []) + ( + left_translation_results, + left_rot_grip_results, + left_ignore_collisions_results, + ) = ([], [], []) + + infos = {} + for depth, qagent in enumerate(self._qattention_agents): + act_results = qagent.act(step, observation, deterministic) + right_attention_coordinate = ( + act_results.observation_elements["right_attention_coordinate"] + .cpu() + .numpy() + ) + left_attention_coordinate = ( + act_results.observation_elements["left_attention_coordinate"] + .cpu() + .numpy() + ) + observation_elements[ + "right_attention_coordinate_layer_%d" % depth + ] = right_attention_coordinate[0] + observation_elements[ + "left_attention_coordinate_layer_%d" % depth + ] = left_attention_coordinate[0] + + ( + right_translation_idxs, + right_rot_grip_idxs, + right_ignore_collisions_idxs, + left_translation_idxs, + left_rot_grip_idxs, + left_ignore_collisions_idxs, + ) = act_results.action + + right_translation_results.append(right_translation_idxs) + if right_rot_grip_idxs is not None: + right_rot_grip_results.append(right_rot_grip_idxs) + if right_ignore_collisions_idxs is not None: + right_ignore_collisions_results.append(right_ignore_collisions_idxs) + + left_translation_results.append(left_translation_idxs) + if left_rot_grip_idxs is not None: + left_rot_grip_results.append(left_rot_grip_idxs) + if left_ignore_collisions_idxs is not None: + left_ignore_collisions_results.append(left_ignore_collisions_idxs) + + observation[ + "right_attention_coordinate" + ] = act_results.observation_elements["right_attention_coordinate"] + observation["left_attention_coordinate"] = act_results.observation_elements[ + "left_attention_coordinate" + ] + + observation["prev_layer_voxel_grid"] = act_results.observation_elements[ + "prev_layer_voxel_grid" + ] + observation["prev_layer_bounds"] = act_results.observation_elements[ + "prev_layer_bounds" + ] + + for n in self._camera_names: + extrinsics = observation["%s_camera_extrinsics" % n][0, 0].cpu().numpy() + intrinsics = observation["%s_camera_intrinsics" % n][0, 0].cpu().numpy() + px, py = utils.point_to_pixel_index( + right_attention_coordinate[0], extrinsics, intrinsics + ) + pc_t = torch.tensor( + [[[py, px]]], dtype=torch.float32, device=self._device + ) + observation[f"right_{n}_pixel_coord"] = pc_t + observation_elements[f"right_{n}_pixel_coord"] = [py, px] + + px, py = utils.point_to_pixel_index( + left_attention_coordinate[0], extrinsics, intrinsics + ) + pc_t = torch.tensor( + [[[py, px]]], dtype=torch.float32, device=self._device + ) + observation[f"left_{n}_pixel_coord"] = pc_t + observation_elements[f"left_{n}_pixel_coord"] = [py, px] + infos.update(act_results.info) + + right_rgai = torch.cat(right_rot_grip_results, 1)[0].cpu().numpy() + # ..todo:: utils.correct_rotation_instability does nothing so we can ignore it + # right_rgai = utils.correct_rotation_instability(right_rgai, self._rotation_resolution) + right_ignore_collisions = ( + torch.cat(right_ignore_collisions_results, 1)[0].cpu().numpy() + ) + right_trans_action_indicies = ( + torch.cat(right_translation_results, 1)[0].cpu().numpy() + ) + + observation_elements[ + "right_trans_action_indicies" + ] = right_trans_action_indicies[:3] + observation_elements["right_rot_grip_action_indicies"] = right_rgai[:4] + + left_rgai = torch.cat(left_rot_grip_results, 1)[0].cpu().numpy() + left_ignore_collisions = ( + torch.cat(left_ignore_collisions_results, 1)[0].cpu().numpy() + ) + left_trans_action_indicies = ( + torch.cat(left_translation_results, 1)[0].cpu().numpy() + ) + + observation_elements["left_trans_action_indicies"] = left_trans_action_indicies[ + 3: + ] + observation_elements["left_rot_grip_action_indicies"] = left_rgai[4:] + + continuous_action = np.concatenate( + [ + right_attention_coordinate[0], + utils.discrete_euler_to_quaternion( + right_rgai[-4:-1], self._rotation_resolution + ), + right_rgai[-1:], + right_ignore_collisions, + left_attention_coordinate[0], + utils.discrete_euler_to_quaternion( + left_rgai[-4:-1], self._rotation_resolution + ), + left_rgai[-1:], + left_ignore_collisions, + ] + ) + return ActResult( + continuous_action, observation_elements=observation_elements, info=infos + ) + + def update_summaries(self) -> List[Summary]: + summaries = [] + for qa in self._qattention_agents: + summaries.extend(qa.update_summaries()) + return summaries + + def act_summaries(self) -> List[Summary]: + s = [] + for qa in self._qattention_agents: + s.extend(qa.act_summaries()) + return s + + def load_weights(self, savedir: str): + for qa in self._qattention_agents: + qa.load_weights(savedir) + + def save_weights(self, savedir: str): + for qa in self._qattention_agents: + qa.save_weights(savedir) diff --git a/external/peract_bimanual/agents/c2farm_lingunet_bc/__init__.py b/external/peract_bimanual/agents/c2farm_lingunet_bc/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..587f0562db2b3e6ae84274230698dd73f27ced7e --- /dev/null +++ b/external/peract_bimanual/agents/c2farm_lingunet_bc/__init__.py @@ -0,0 +1 @@ +import agents.c2farm_lingunet_bc.launch_utils diff --git a/external/peract_bimanual/agents/c2farm_lingunet_bc/launch_utils.py b/external/peract_bimanual/agents/c2farm_lingunet_bc/launch_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..01595131ed4101b27fbac08b240934c612dbf378 --- /dev/null +++ b/external/peract_bimanual/agents/c2farm_lingunet_bc/launch_utils.py @@ -0,0 +1,519 @@ +# Adapted from ARM +# Source: https://github.com/stepjam/ARM +# License: https://github.com/stepjam/ARM/LICENSE + +import logging +from typing import List + +import numpy as np +from omegaconf import DictConfig +from rlbench.backend.observation import Observation +from rlbench.observation_config import ObservationConfig +import rlbench.utils as rlbench_utils +from rlbench.demo import Demo +from yarr.replay_buffer.prioritized_replay_buffer import ObservationElement +from yarr.replay_buffer.replay_buffer import ReplayElement, ReplayBuffer +from yarr.replay_buffer.uniform_replay_buffer import UniformReplayBuffer +from yarr.replay_buffer.task_uniform_replay_buffer import TaskUniformReplayBuffer + +from helpers import demo_loading_utils, utils +from helpers import observation_utils +from helpers.preprocess_agent import PreprocessAgent +from helpers.clip.core.clip import tokenize +from agents.c2farm_lingunet_bc.networks import QattentionLingU3DNet +from agents.c2farm_lingunet_bc.qattention_lingunet_bc_agent import ( + QAttentionLingUNetBCAgent, +) +from agents.c2farm_lingunet_bc.qattention_stack_agent import QAttentionStackAgent + +import torch +from torch.multiprocessing import Process, Value, Manager +from helpers.clip.core.clip import build_model, load_clip, tokenize +from omegaconf import DictConfig + +REWARD_SCALE = 100.0 +LOW_DIM_SIZE = 4 + + +def create_replay( + batch_size: int, + timesteps: int, + prioritisation: bool, + task_uniform: bool, + save_dir: str, + cameras: list, + voxel_sizes, + image_size=[128, 128], + replay_size=3e5, +): + trans_indicies_size = 3 * len(voxel_sizes) + rot_and_grip_indicies_size = 3 + 1 + gripper_pose_size = 7 + ignore_collisions_size = 1 + max_token_seq_len = 77 + lang_feat_dim = 1024 + lang_emb_dim = 512 + + # low_dim_state + observation_elements = [] + observation_elements.append( + ObservationElement("low_dim_state", (LOW_DIM_SIZE,), np.float32) + ) + + # rgb, depth, point cloud, intrinsics, extrinsics + for cname in cameras: + observation_elements.append( + ObservationElement( + "%s_rgb" % cname, + ( + 3, + *image_size, + ), + np.float32, + ) + ) + observation_elements.append( + ObservationElement("%s_point_cloud" % cname, (3, *image_size), np.float32) + ) # see pyrep/objects/vision_sensor.py on how pointclouds are extracted from depth frames + observation_elements.append( + ObservationElement( + "%s_camera_extrinsics" % cname, + ( + 4, + 4, + ), + np.float32, + ) + ) + observation_elements.append( + ObservationElement( + "%s_camera_intrinsics" % cname, + ( + 3, + 3, + ), + np.float32, + ) + ) + observation_elements.append( + ObservationElement("%s_pixel_coord" % cname, (2,), np.int32) + ) + + # discretized translation, discretized rotation, discrete ignore collision, 6-DoF gripper pose, and pre-trained language embeddings + observation_elements.extend( + [ + ReplayElement("trans_action_indicies", (trans_indicies_size,), np.int32), + ReplayElement( + "rot_grip_action_indicies", (rot_and_grip_indicies_size,), np.int32 + ), + ReplayElement("ignore_collisions", (ignore_collisions_size,), np.int32), + ReplayElement("gripper_pose", (gripper_pose_size,), np.float32), + ReplayElement("lang_goal_emb", (lang_feat_dim,), np.float32), + ReplayElement( + "lang_token_embs", + ( + max_token_seq_len, + lang_emb_dim, + ), + np.float32, + ), # extracted from CLIP's language encoder + ReplayElement("task", (), str), + ReplayElement( + "lang_goal", (1,), object + ), # language goal string for debugging and visualization + ] + ) + + for depth in range(len(voxel_sizes)): + observation_elements.append( + ReplayElement("attention_coordinate_layer_%d" % depth, (3,), np.float32) + ) + + extra_replay_elements = [ + ReplayElement("demo", (), np.bool), + ] + + replay_buffer = TaskUniformReplayBuffer( + save_dir=save_dir, + batch_size=batch_size, + timesteps=timesteps, + replay_capacity=int(replay_size), + action_shape=(8,), + action_dtype=np.float32, + reward_shape=(), + reward_dtype=np.float32, + update_horizon=1, + observation_elements=observation_elements, + extra_replay_elements=extra_replay_elements, + ) + return replay_buffer + + +def _get_action( + obs_tp1: Observation, + obs_tm1: Observation, + rlbench_scene_bounds: List[float], # metric 3D bounds of the scene + voxel_sizes: List[int], + bounds_offset: List[float], + rotation_resolution: int, + crop_augmentation: bool, +): + quat = utils.normalize_quaternion(obs_tp1.gripper_pose[3:]) + if quat[-1] < 0: + quat = -quat + disc_rot = utils.quaternion_to_discrete_euler(quat, rotation_resolution) + disc_rot = utils.correct_rotation_instability(disc_rot, rotation_resolution) + + attention_coordinate = obs_tp1.gripper_pose[:3] + trans_indicies, attention_coordinates = [], [] + bounds = np.array(rlbench_scene_bounds) + ignore_collisions = int(obs_tm1.ignore_collisions) + for depth, vox_size in enumerate( + voxel_sizes + ): # only single voxelization-level is used in PerAct + if depth > 0: + if crop_augmentation: + shift = bounds_offset[depth - 1] * 0.75 + attention_coordinate += np.random.uniform(-shift, shift, size=(3,)) + bounds = np.concatenate( + [ + attention_coordinate - bounds_offset[depth - 1], + attention_coordinate + bounds_offset[depth - 1], + ] + ) + index = utils.point_to_voxel_index(obs_tp1.gripper_pose[:3], vox_size, bounds) + trans_indicies.extend(index.tolist()) + res = (bounds[3:] - bounds[:3]) / vox_size + attention_coordinate = bounds[:3] + res * index + attention_coordinates.append(attention_coordinate) + + rot_and_grip_indicies = disc_rot.tolist() + grip = float(obs_tp1.gripper_open) + rot_and_grip_indicies.extend([int(obs_tp1.gripper_open)]) + return ( + trans_indicies, + rot_and_grip_indicies, + ignore_collisions, + np.concatenate([obs_tp1.gripper_pose, np.array([grip])]), + attention_coordinates, + ) + + +def _add_keypoints_to_replay( + cfg: DictConfig, + task: str, + replay: ReplayBuffer, + inital_obs: Observation, + demo: Demo, + episode_keypoints: List[int], + cameras: List[str], + rlbench_scene_bounds: List[float], + voxel_sizes: List[int], + bounds_offset: List[float], + rotation_resolution: int, + crop_augmentation: bool, + description: str = "", + clip_model=None, + device="cpu", +): + prev_action = None + obs = inital_obs + for k, keypoint in enumerate(episode_keypoints): + obs_tp1 = demo[keypoint] + obs_tm1 = demo[max(0, keypoint - 1)] + ( + trans_indicies, + rot_grip_indicies, + ignore_collisions, + action, + attention_coordinates, + ) = _get_action( + obs_tp1, + obs_tm1, + rlbench_scene_bounds, + voxel_sizes, + bounds_offset, + rotation_resolution, + crop_augmentation, + ) + + terminal = k == len(episode_keypoints) - 1 + reward = float(terminal) * REWARD_SCALE if terminal else 0 + + obs_dict = observation_utils.extract_obs( + obs, + t=k, + prev_action=prev_action, + cameras=cameras, + episode_length=cfg.rlbench.episode_length, + robot_name=cfg.method.robot_name, + ) + tokens = tokenize([description]).numpy() + token_tensor = torch.from_numpy(tokens).to(device) + sentence_emb, token_embs = clip_model.encode_text_with_embeddings(token_tensor) + obs_dict["lang_goal_emb"] = sentence_emb[0].float().detach().cpu().numpy() + obs_dict["lang_token_embs"] = token_embs[0].float().detach().cpu().numpy() + + prev_action = np.copy(action) + + others = {"demo": True} + final_obs = { + "trans_action_indicies": trans_indicies, + "rot_grip_action_indicies": rot_grip_indicies, + "gripper_pose": obs_tp1.gripper_pose, + "task": task, + "lang_goal": np.array([description], dtype=object), + } + + for depth in range(len(voxel_sizes)): + final_obs["attention_coordinate_layer_%d" % depth] = attention_coordinates[ + depth + ] + for name in cameras: + px, py = utils.point_to_pixel_index( + obs_tp1.gripper_pose[:3], + obs_tp1.misc["%s_camera_extrinsics" % name], + obs_tp1.misc["%s_camera_intrinsics" % name], + ) + final_obs["%s_pixel_coord" % name] = [py, px] + + others.update(final_obs) + others.update(obs_dict) + + timeout = False + replay.add(action, reward, terminal, timeout, **others) + obs = obs_tp1 + + # final step + obs_dict_tp1 = observation_utils.extract_obs( + obs_tp1, + t=k + 1, + prev_action=prev_action, + cameras=cameras, + episode_length=cfg.rlbench.episode_length, + robot_name=cfg.method.robot_name, + ) + obs_dict_tp1["lang_goal_emb"] = sentence_emb[0].float().detach().cpu().numpy() + obs_dict_tp1["lang_token_embs"] = token_embs[0].float().detach().cpu().numpy() + + obs_dict_tp1.pop("wrist_world_to_cam", None) + obs_dict_tp1.update(final_obs) + replay.add_final(**obs_dict_tp1) + + +def fill_replay( + cfg: DictConfig, + obs_config: ObservationConfig, + rank: int, + replay: ReplayBuffer, + task: str, + num_demos: int, + demo_augmentation: bool, + demo_augmentation_every_n: int, + cameras: List[str], + rlbench_scene_bounds: List[float], # AKA: DEPTH0_BOUNDS + voxel_sizes: List[int], + bounds_offset: List[float], + rotation_resolution: int, + crop_augmentation: bool, + clip_model=None, + device="cpu", + keypoint_method="heuristic", +): + if clip_model is None: + model, _ = load_clip("RN50", jit=False, device=device) + clip_model = build_model(model.state_dict()) + clip_model.to(device) + del model + + logging.debug("Filling %s replay ..." % task) + for d_idx in range(num_demos): + # load demo from disk + demo = rlbench_utils.get_stored_demos( + amount=1, + image_paths=False, + dataset_root=cfg.rlbench.demo_path, + variation_number=-1, + task_name=task, + obs_config=obs_config, + random_selection=False, + from_episode_number=d_idx, + )[0] + + descs = demo._observations[0].misc["descriptions"] + + # extract keypoints (a.k.a keyframes) + episode_keypoints = demo_loading_utils.keypoint_discovery( + demo, method=keypoint_method + ) + + if rank == 0: + logging.info( + f"Loading Demo({d_idx}) - found {len(episode_keypoints)} keypoints - {task}" + ) + + for i in range(len(demo) - 1): + if not demo_augmentation and i > 0: + break + if i % demo_augmentation_every_n != 0: + continue + + obs = demo[i] + desc = descs[0] + # if our starting point is past one of the keypoints, then remove it + while len(episode_keypoints) > 0 and i >= episode_keypoints[0]: + episode_keypoints = episode_keypoints[1:] + if len(episode_keypoints) == 0: + break + _add_keypoints_to_replay( + cfg, + task, + replay, + obs, + demo, + episode_keypoints, + cameras, + rlbench_scene_bounds, + voxel_sizes, + bounds_offset, + rotation_resolution, + crop_augmentation, + description=desc, + clip_model=clip_model, + device=device, + ) + logging.debug("Replay %s filled with demos." % task) + + +def fill_multi_task_replay( + cfg: DictConfig, + obs_config: ObservationConfig, + rank: int, + replay: ReplayBuffer, + tasks: List[str], + num_demos: int, + demo_augmentation: bool, + demo_augmentation_every_n: int, + cameras: List[str], + rlbench_scene_bounds: List[float], + voxel_sizes: List[int], + bounds_offset: List[float], + rotation_resolution: int, + crop_augmentation: bool, + clip_model=None, + keypoint_method="heuristic", +): + manager = Manager() + store = manager.dict() + + # create a MP dict for storing indicies + # TODO(mohit): this shouldn't be initialized here + del replay._task_idxs + task_idxs = manager.dict() + replay._task_idxs = task_idxs + replay._create_storage(store) + replay.add_count = Value("i", 0) + + # fill replay buffer in parallel across tasks + max_parallel_processes = cfg.replay.max_parallel_processes + processes = [] + n = np.arange(len(tasks)) + split_n = utils.split_list(n, max_parallel_processes) + for split in split_n: + for e_idx, task_idx in enumerate(split): + task = tasks[int(task_idx)] + model_device = torch.device( + "cuda:%s" % (e_idx % torch.cuda.device_count()) + if torch.cuda.is_available() + else "cpu" + ) + p = Process( + target=fill_replay, + args=( + cfg, + obs_config, + rank, + replay, + task, + num_demos, + demo_augmentation, + demo_augmentation_every_n, + cameras, + rlbench_scene_bounds, + voxel_sizes, + bounds_offset, + rotation_resolution, + crop_augmentation, + clip_model, + model_device, + keypoint_method, + ), + ) + p.start() + processes.append(p) + + for p in processes: + p.join() + + +def create_agent(cfg: DictConfig): + LATENT_SIZE = 64 + depth_0bounds = cfg.rlbench.scene_bounds + cam_resolution = cfg.rlbench.camera_resolution + + num_rotation_classes = int(360.0 // cfg.method.rotation_resolution) + qattention_agents = [] + for depth, vox_size in enumerate(cfg.method.voxel_sizes): + last = depth == len(cfg.method.voxel_sizes) - 1 + unet3d = QattentionLingU3DNet( + in_channels=3 + 3 + 1 + 3, + out_channels=1, + voxel_size=vox_size, + out_dense=((num_rotation_classes * 3) + 4) if last else 0, + kernels=LATENT_SIZE, + norm=None if "None" in cfg.method.norm else cfg.method.norm, + dense_feats=128, + activation=cfg.method.activation, + low_dim_size=4, + include_prev_layer=cfg.method.include_prev_layer and depth > 0, + depth=depth, + ) + + qattention_agent = QAttentionLingUNetBCAgent( + layer=depth, + coordinate_bounds=depth_0bounds, + unet3d=unet3d, + camera_names=cfg.rlbench.cameras, + batch_size=cfg.replay.batch_size, + voxel_size=vox_size, + bounds_offset=cfg.method.bounds_offset[depth - 1] if depth > 0 else None, + voxel_feature_size=3, + image_crop_size=cfg.method.image_crop_size, + lr=cfg.method.lr, + training_iterations=cfg.framework.training_iterations, + lr_scheduler=cfg.method.lr_scheduler, + num_warmup_steps=cfg.method.num_warmup_steps, + trans_loss_weight=cfg.method.trans_loss_weight, + rot_loss_weight=cfg.method.rot_loss_weight, + grip_loss_weight=cfg.method.grip_loss_weight, + collision_loss_weight=cfg.method.collision_loss_weight, + include_low_dim_state=True, + image_resolution=cam_resolution, + lambda_weight_l2=cfg.method.lambda_weight_l2, + num_rotation_classes=num_rotation_classes, + rotation_resolution=cfg.method.rotation_resolution, + transform_augmentation=cfg.method.transform_augmentation.apply_se3, + transform_augmentation_xyz=cfg.method.transform_augmentation.aug_xyz, + transform_augmentation_rpy=cfg.method.transform_augmentation.aug_rpy, + transform_augmentation_rot_resolution=cfg.method.transform_augmentation.aug_rot_resolution, + num_devices=cfg.ddp.num_devices, + ) + qattention_agents.append(qattention_agent) + + rotation_agent = QAttentionStackAgent( + qattention_agents=qattention_agents, + rotation_resolution=cfg.method.rotation_resolution, + camera_names=cfg.rlbench.cameras, + ) + preprocess_agent = PreprocessAgent(pose_agent=rotation_agent) + return preprocess_agent diff --git a/external/peract_bimanual/agents/c2farm_lingunet_bc/networks.py b/external/peract_bimanual/agents/c2farm_lingunet_bc/networks.py new file mode 100644 index 0000000000000000000000000000000000000000..dd286f9d74ac85e4975259a72781e1f49bf4010e --- /dev/null +++ b/external/peract_bimanual/agents/c2farm_lingunet_bc/networks.py @@ -0,0 +1,301 @@ +import torch +import torch.nn as nn + +from helpers.network_utils import ( + Conv3DInceptionBlock, + DenseBlock, + SpatialSoftmax3D, + Conv3DInceptionBlockUpsampleBlock, + Conv3DBlock, +) + + +class QattentionLingU3DNet(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + out_dense: int, + voxel_size: int, + low_dim_size: int, + kernels: int, + norm: str = None, + activation: str = "relu", + dense_feats: int = 32, + include_prev_layer=False, + depth=0, + lingunet_dropout=0.0, + ): + super(QattentionLingU3DNet, self).__init__() + self._in_channels = in_channels + self._out_channels = out_channels + self._norm = norm + self._activation = activation + self._kernels = kernels + self._low_dim_size = low_dim_size + self._build_calls = 0 + self._voxel_size = voxel_size + self._dense_feats = dense_feats + self._out_dense = out_dense + self._include_prev_layer = include_prev_layer + self._depth = depth + + self._lingunet_dropout = lingunet_dropout + self._clip_lang_feat_dim = 1024 + + if self._voxel_size < 16: + raise Exception( + "Voxel size for C2FARM_LINGUNET_BC should be at least 16 or higher" + ) + + def build(self): + use_residual = False + self._build_calls += 1 + if self._build_calls != 1: + raise RuntimeError("Build needs to be called once.") + + spatial_size = self._voxel_size + self._input_preprocess = Conv3DInceptionBlock( + self._in_channels, + self._kernels, + norm=self._norm, + activation=self._activation, + ) + + d0_ins = self._input_preprocess.out_channels + if self._include_prev_layer: + PREV_VOXEL_CHANNELS = 0 + d0_ins += self._input_preprocess.out_channels * self._depth + + if self._low_dim_size > 0: + self._proprio_preprocess = DenseBlock( + self._low_dim_size, self._kernels, None, self._activation + ) + d0_ins += self._kernels + + self._down0 = Conv3DInceptionBlock( + d0_ins, + self._kernels, + norm=self._norm, + activation=self._activation, + residual=use_residual, + ) + self._ss0 = SpatialSoftmax3D( + spatial_size, spatial_size, spatial_size, self._down0.out_channels + ) + spatial_size //= 2 + self._down1 = Conv3DInceptionBlock( + self._down0.out_channels, + self._kernels * 2, + norm=self._norm, + activation=self._activation, + residual=use_residual, + ) + self._ss1 = SpatialSoftmax3D( + spatial_size, spatial_size, spatial_size, self._down1.out_channels + ) + spatial_size //= 2 + + flat_size = self._down0.out_channels * 4 + self._down1.out_channels * 4 + + k1 = self._down1.out_channels + if self._voxel_size > 8: + k1 += self._kernels + self._down2 = Conv3DInceptionBlock( + self._down1.out_channels, + self._kernels * 4, + norm=self._norm, + activation=self._activation, + residual=use_residual, + ) + self._lang_proj2 = DenseBlock( + self._clip_lang_feat_dim, self._down2.out_channels, None, None + ) + self._dropout2 = nn.Dropout(self._lingunet_dropout) + flat_size += self._down2.out_channels * 4 + self._ss2 = SpatialSoftmax3D( + spatial_size, spatial_size, spatial_size, self._down2.out_channels + ) + spatial_size //= 2 + k2 = self._down2.out_channels + if self._voxel_size > 16: + k2 *= 2 + self._down3 = Conv3DInceptionBlock( + self._down2.out_channels, + self._kernels, + norm=self._norm, + activation=self._activation, + residual=use_residual, + ) + self._lang_proj3 = DenseBlock( + self._clip_lang_feat_dim, self._down3.out_channels, None, None + ) + self._dropout3 = nn.Dropout(self._lingunet_dropout) + flat_size += self._down3.out_channels * 4 + self._ss3 = SpatialSoftmax3D( + spatial_size, spatial_size, spatial_size, self._down3.out_channels + ) + self._up3 = Conv3DInceptionBlockUpsampleBlock( + self._kernels, + self._kernels * 4, + 2, + norm=self._norm, + activation=self._activation, + residual=use_residual, + ) + self._up2 = Conv3DInceptionBlockUpsampleBlock( + k2, + self._kernels, + 2, + norm=self._norm, + activation=self._activation, + residual=use_residual, + ) + + self._up1 = Conv3DInceptionBlockUpsampleBlock( + k1, + self._kernels, + 2, + norm=self._norm, + activation=self._activation, + residual=use_residual, + ) + + self._global_maxp = nn.AdaptiveMaxPool3d(1) + self._local_maxp = nn.MaxPool3d(3, 2, padding=1) + self._final = Conv3DBlock( + self._kernels * 2, + self._kernels, + kernel_sizes=3, + strides=1, + norm=self._norm, + activation=self._activation, + ) + self._final2 = Conv3DBlock( + self._kernels, + self._out_channels, + kernel_sizes=3, + strides=1, + norm=None, + activation=None, + ) + + self._ss_final = SpatialSoftmax3D( + self._voxel_size, self._voxel_size, self._voxel_size, self._kernels + ) + flat_size += self._kernels * 4 + + if self._out_dense > 0: + self._dense0 = DenseBlock( + flat_size, self._dense_feats, None, self._activation + ) + self._dense1 = DenseBlock( + self._dense_feats, self._dense_feats, None, self._activation + ) + self._dense2 = DenseBlock(self._dense_feats, self._out_dense, None, None) + + def _proj_feature(self, x, spatial_size, proj_fn): + x = proj_fn(x) + x = x.unsqueeze(2).unsqueeze(3).unsqueeze(4) + x = x.repeat(1, 1, spatial_size, spatial_size, spatial_size) + return x + + def forward( + self, + ins, + proprio, + lang_goal_embs, + lang_token_embs, + bounds, + prev_bounds, + prev_layer_voxel_grid, + ): + b, _, d, h, w = ins.shape + x = self._input_preprocess(ins) + + if self._include_prev_layer: + for voxel_grid in prev_layer_voxel_grid: + y = self._input_preprocess(voxel_grid) + x = torch.cat([x, y], dim=1) + + if self._low_dim_size > 0: + p = self._proprio_preprocess(proprio) + p = p.unsqueeze(-1).unsqueeze(-1).unsqueeze(-1).repeat(1, 1, d, h, w) + x = torch.cat([x, p], dim=1) + + l_feat = lang_goal_embs + l_feat = l_feat.to(dtype=x.dtype) + + d0 = self._down0(x) + # l0 = self._proj_feature(l_feat, d0.shape[-1], self._lang_proj0) + # d0 = self._dropout0(d0 * l0) + ss0 = self._ss0(d0) + maxp0 = self._global_maxp(d0).view(b, -1) + + d1 = u = self._down1(self._local_maxp(d0)) + # l1 = self._proj_feature(l_feat, d1.shape[-1], self._lang_proj1) + # d1 = self._dropout1(d1 * l1) + ss1 = self._ss1(d1) + maxp1 = self._global_maxp(d1).view(b, -1) + + feats = [ss0, maxp0, ss1, maxp1] + + if self._voxel_size > 8: + d2 = u = self._down2(self._local_maxp(d1)) + l2 = self._proj_feature(l_feat, d2.shape[-1], self._lang_proj2) + d2 = self._dropout2(d2 * l2) + feats.extend([self._ss2(d2), self._global_maxp(d2).view(b, -1)]) + if self._voxel_size > 16: + d3 = self._down3(self._local_maxp(d2)) + l3 = self._proj_feature(l_feat, d3.shape[-1], self._lang_proj3) + d3 = self._dropout3(d3 * l3) + feats.extend([self._ss3(d3), self._global_maxp(d3).view(b, -1)]) + u3 = self._up3(d3) + u = torch.cat([d2, u3], dim=1) + u2 = self._up2(u) + u = torch.cat([d1, u2], dim=1) + + u1 = self._up1(u) + f1 = self._final(torch.cat([d0, u1], dim=1)) + trans = self._final2(f1) + + feats.extend([self._ss_final(f1), self._global_maxp(f1).view(b, -1)]) + + self.latent_dict = { + "d0": d0.mean(-1).mean(-1).mean(-1), + "d1": d1.mean(-1).mean(-1).mean(-1), + "u1": u1.mean(-1).mean(-1).mean(-1), + "trans_out": trans, + } + + rot_and_grip_out, collision_out = None, None + if self._out_dense > 0: + dense0 = self._dense0(torch.cat(feats, 1)) + dense1 = self._dense1(dense0) + rot_and_grip_collision_out = self._dense2(dense1) + rot_and_grip_out = rot_and_grip_collision_out[:, :-2] + collision_out = rot_and_grip_collision_out[:, -2:] + self.latent_dict.update( + { + "dense0": dense0, + "dense1": dense1, + "dense2": rot_and_grip_collision_out, + } + ) + + if self._voxel_size > 8: + self.latent_dict.update( + { + "d2": d2.mean(-1).mean(-1).mean(-1), + "u2": u2.mean(-1).mean(-1).mean(-1), + } + ) + if self._voxel_size > 16: + self.latent_dict.update( + { + "d3": d3.mean(-1).mean(-1).mean(-1), + "u3": u3.mean(-1).mean(-1).mean(-1), + } + ) + + return trans, rot_and_grip_out, collision_out diff --git a/external/peract_bimanual/agents/c2farm_lingunet_bc/qattention_lingunet_bc_agent.py b/external/peract_bimanual/agents/c2farm_lingunet_bc/qattention_lingunet_bc_agent.py new file mode 100644 index 0000000000000000000000000000000000000000..ef466a2f660f6794773c6c93888166269e523cdf --- /dev/null +++ b/external/peract_bimanual/agents/c2farm_lingunet_bc/qattention_lingunet_bc_agent.py @@ -0,0 +1,790 @@ +import copy +import logging +import os +from typing import List + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +from torchvision import transforms +from pytorch3d import transforms as torch3d_tf +from yarr.agents.agent import ( + Agent, + ActResult, + ScalarSummary, + HistogramSummary, + ImageSummary, + Summary, +) + +from helpers import utils +from helpers.utils import visualise_voxel, stack_on_channel +from voxel.voxel_grid import VoxelGrid +from voxel.augmentation import apply_se3_augmentation +from einops import rearrange +from helpers.clip.core.clip import build_model, load_clip + +import transformers +from torch.nn.parallel import DistributedDataParallel as DDP + +NAME = "QAttentionAgent" + + +class QFunction(nn.Module): + def __init__( + self, + unet_3d: nn.Module, + voxelizer: VoxelGrid, + bounds_offset: float, + rotation_resolution: float, + device, + training, + ): + super(QFunction, self).__init__() + self._rotation_resolution = rotation_resolution + self._voxelizer = voxelizer + self._bounds_offset = bounds_offset + self._qnet = unet_3d.to(device) + + # distributed training + if training: + self._qnet = DDP(self._qnet, device_ids=[device]) + + def _argmax_3d(self, tensor_orig): + b, c, d, h, w = tensor_orig.shape # c will be one + idxs = tensor_orig.view(b, c, -1).argmax(-1) + indices = torch.cat([((idxs // h) // d), (idxs // h) % w, idxs % w], 1) + return indices + + def choose_highest_action(self, q_trans, q_rot_grip, q_collision): + coords = self._argmax_3d(q_trans) + rot_and_grip_indicies = None + ignore_collision = None + if q_rot_grip is not None: + q_rot = torch.stack( + torch.split( + q_rot_grip[:, :-2], int(360 // self._rotation_resolution), dim=1 + ), + dim=1, + ) + rot_and_grip_indicies = torch.cat( + [ + q_rot[:, 0:1].argmax(-1), + q_rot[:, 1:2].argmax(-1), + q_rot[:, 2:3].argmax(-1), + q_rot_grip[:, -2:].argmax(-1, keepdim=True), + ], + -1, + ) + ignore_collision = q_collision[:, -2:].argmax(-1, keepdim=True) + return coords, rot_and_grip_indicies, ignore_collision + + def forward( + self, + rgb_pcd, + proprio, + pcd, + lang_goal_emb, + lang_token_embs, + bounds=None, + prev_bounds=None, + prev_layer_voxel_grid=None, + ): + # rgb_pcd will be list of list (list of [rgb, pcd]) + b = rgb_pcd[0][0].shape[0] + pcd_flat = torch.cat([p.permute(0, 2, 3, 1).reshape(b, -1, 3) for p in pcd], 1) + + # flatten RGBs and Pointclouds + rgb = [rp[0] for rp in rgb_pcd] + feat_size = rgb[0].shape[1] + flat_imag_features = torch.cat( + [p.permute(0, 2, 3, 1).reshape(b, -1, feat_size) for p in rgb], 1 + ) + + # construct voxel grid + voxel_grid = self._voxelizer.coords_to_bounding_voxel_grid( + pcd_flat, coord_features=flat_imag_features, coord_bounds=bounds + ) + + # swap to channels fist + voxel_grid = voxel_grid.permute(0, 4, 1, 2, 3).detach() + + # batch bounds if necessary + if bounds.shape[0] != b: + bounds = bounds.repeat(b, 1) + + # forward pass + q_trans, q_rot_and_grip, q_ignore_collisions = self._qnet( + voxel_grid, + proprio, + lang_goal_emb, + lang_token_embs, + prev_layer_voxel_grid, + bounds, + prev_bounds, + ) + + return q_trans, q_rot_and_grip, q_ignore_collisions, voxel_grid + + +class QAttentionLingUNetBCAgent(Agent): + def __init__( + self, + layer: int, + coordinate_bounds: list, + unet3d: nn.Module, + camera_names: list, + batch_size: int, + voxel_size: int, + bounds_offset: float, + voxel_feature_size: int, + image_crop_size: int, + num_rotation_classes: int, + rotation_resolution: float, + lr: float = 0.0001, + lr_scheduler: bool = False, + training_iterations: int = 100000, + num_warmup_steps: int = 20000, + trans_loss_weight: float = 1.0, + rot_loss_weight: float = 1.0, + grip_loss_weight: float = 1.0, + collision_loss_weight: float = 1.0, + include_low_dim_state: bool = False, + image_resolution: list = None, + lambda_weight_l2: float = 0.0, + transform_augmentation: bool = True, + transform_augmentation_xyz: list = [0.0, 0.0, 0.0], + transform_augmentation_rpy: list = [0.0, 0.0, 180.0], + transform_augmentation_rot_resolution: int = 5, + num_devices: int = 1, + ): + self._layer = layer + self._coordinate_bounds = coordinate_bounds + self._unet3d = unet3d + self._voxel_feature_size = voxel_feature_size + self._bounds_offset = bounds_offset + self._image_crop_size = image_crop_size + self._lr = lr + self._lr_scheduler = lr_scheduler + self._training_iterations = training_iterations + self._num_warmup_steps = num_warmup_steps + self._trans_loss_weight = trans_loss_weight + self._rot_loss_weight = rot_loss_weight + self._grip_loss_weight = grip_loss_weight + self._collision_loss_weight = collision_loss_weight + self._include_low_dim_state = include_low_dim_state + self._image_resolution = image_resolution or [128, 128] + self._voxel_size = voxel_size + self._camera_names = camera_names + self._num_cameras = len(camera_names) + self._batch_size = batch_size + self._lambda_weight_l2 = lambda_weight_l2 + self._transform_augmentation = transform_augmentation + self._transform_augmentation_xyz = torch.from_numpy( + np.array(transform_augmentation_xyz) + ) + self._transform_augmentation_rpy = transform_augmentation_rpy + self._transform_augmentation_rot_resolution = ( + transform_augmentation_rot_resolution + ) + self._num_devices = num_devices + self._num_rotation_classes = num_rotation_classes + self._rotation_resolution = rotation_resolution + + self._cross_entropy_loss = nn.CrossEntropyLoss(reduction="none") + self._name = NAME + "_layer" + str(self._layer) + + def build(self, training: bool, device: torch.device = None): + self._training = training + self._device = device + + if device is None: + device = torch.device("cpu") + + self._voxelizer = VoxelGrid( + coord_bounds=self._coordinate_bounds, + voxel_size=self._voxel_size, + device=device, + batch_size=self._batch_size if training else 1, + feature_size=self._voxel_feature_size, + max_num_coords=np.prod(self._image_resolution) * self._num_cameras, + ) + + self._unet3d.build() + + self._q = ( + QFunction( + self._unet3d, + self._voxelizer, + self._bounds_offset, + self._rotation_resolution, + device, + training, + ) + .to(device) + .train(training) + ) + + grid_for_crop = ( + torch.arange(0, self._image_crop_size, device=device) + .unsqueeze(0) + .repeat(self._image_crop_size, 1) + .unsqueeze(-1) + ) + self._grid_for_crop = torch.cat( + [grid_for_crop.transpose(1, 0), grid_for_crop], dim=2 + ).unsqueeze(0) + + self._coordinate_bounds = torch.tensor( + self._coordinate_bounds, device=device + ).unsqueeze(0) + + if self._training: + # optimizer + self._optimizer = torch.optim.Adam( + self._q.parameters(), + lr=self._lr, + weight_decay=self._lambda_weight_l2, + ) + + # learning rate scheduler + if self._lr_scheduler: + self._scheduler = ( + transformers.get_cosine_with_hard_restarts_schedule_with_warmup( + self._optimizer, + num_warmup_steps=self._num_warmup_steps, + num_training_steps=self._training_iterations, + num_cycles=self._training_iterations // 10000, + ) + ) + + # one-hot zero tensors + self._action_trans_one_hot_zeros = torch.zeros( + ( + self._batch_size, + 1, + self._voxel_size, + self._voxel_size, + self._voxel_size, + ), + dtype=int, + device=device, + ) + self._action_rot_x_one_hot_zeros = torch.zeros( + (self._batch_size, self._num_rotation_classes), dtype=int, device=device + ) + self._action_rot_y_one_hot_zeros = torch.zeros( + (self._batch_size, self._num_rotation_classes), dtype=int, device=device + ) + self._action_rot_z_one_hot_zeros = torch.zeros( + (self._batch_size, self._num_rotation_classes), dtype=int, device=device + ) + self._action_grip_one_hot_zeros = torch.zeros( + (self._batch_size, 2), dtype=int, device=device + ) + self._action_ignore_collisions_one_hot_zeros = torch.zeros( + (self._batch_size, 2), dtype=int, device=device + ) + + # print total params + logging.info( + "# Q Params: %d" + % sum( + p.numel() + for name, p in self._q.named_parameters() + if p.requires_grad and "clip" not in name + ) + ) + else: + for param in self._q.parameters(): + param.requires_grad = False + + # load CLIP for encoding language goals during evaluation + model, _ = load_clip("RN50", jit=False) + self._clip_rn50 = build_model(model.state_dict()) + self._clip_rn50 = self._clip_rn50.float().to(device) + self._clip_rn50.eval() + del model + + self._voxelizer.to(device) + self._q.to(device) + + def _extract_crop(self, pixel_action, observation): + # Pixel action will now be (B, 2) + # observation = stack_on_channel(observation) + h = observation.shape[-1] + top_left_corner = torch.clamp( + pixel_action - self._image_crop_size // 2, 0, h - self._image_crop_size + ) + grid = self._grid_for_crop + top_left_corner.unsqueeze(1).unsqueeze(1) + grid = ((grid / float(h)) * 2.0) - 1.0 # between -1 and 1 + # Used for cropping the images across a batch + # swap fro y x, to x, y + grid = torch.cat((grid[:, :, :, 1:2], grid[:, :, :, 0:1]), dim=-1) + crop = F.grid_sample(observation, grid, mode="nearest", align_corners=True) + return crop + + def _preprocess_inputs(self, replay_sample): + obs, pcds = [], [] + self._crop_summary = [] + for n in self._camera_names: + if self._layer > 0: + pc_t = replay_sample["%s_pixel_coord" % n] + rgb = self._extract_crop(pc_t, replay_sample["%s_rgb" % n]) + pcd = self._extract_crop(pc_t, replay_sample["%s_point_cloud" % n]) + self._crop_summary.append((n, rgb)) + else: + rgb = replay_sample["%s_rgb" % n] + pcd = replay_sample["%s_point_cloud" % n] + + obs.append([rgb, pcd]) + pcds.append(pcd) + return obs, pcds + + def _act_preprocess_inputs(self, observation): + obs, pcds = [], [] + for n in self._camera_names: + if self._layer > 0: + pc_t = observation["%s_pixel_coord" % n][0] + rgb = self._extract_crop(pc_t, observation["%s_rgb" % n][0]) + pcd = self._extract_crop(pc_t, observation["%s_point_cloud" % n][0]) + else: + rgb = observation["%s_rgb" % n][0] + pcd = observation["%s_point_cloud" % n][0] + + obs.append([rgb, pcd]) + pcds.append(pcd) + return obs, pcds + + def _get_value_from_voxel_index(self, q, voxel_idx): + b, c, d, h, w = q.shape + q_trans_flat = q.view(b, c, d * h * w) + flat_indicies = ( + voxel_idx[:, 0] * d * h + voxel_idx[:, 1] * h + voxel_idx[:, 2] + )[:, None].int() + highest_idxs = flat_indicies.unsqueeze(-1).repeat(1, c, 1) + chosen_voxel_values = q_trans_flat.gather(2, highest_idxs)[ + ..., 0 + ] # (B, trans + rot + grip) + return chosen_voxel_values + + def _get_value_from_rot_and_grip(self, rot_grip_q, rot_and_grip_idx): + q_rot = torch.stack( + torch.split( + rot_grip_q[:, :-2], int(360 // self._rotation_resolution), dim=1 + ), + dim=1, + ) # B, 3, 72 + q_grip = rot_grip_q[:, -2:] + rot_and_grip_values = torch.cat( + [ + q_rot[:, 0].gather(1, rot_and_grip_idx[:, 0:1]), + q_rot[:, 1].gather(1, rot_and_grip_idx[:, 1:2]), + q_rot[:, 2].gather(1, rot_and_grip_idx[:, 2:3]), + q_grip.gather(1, rot_and_grip_idx[:, 3:4]), + ], + -1, + ) + return rot_and_grip_values + + def _celoss(self, pred, labels): + return self._cross_entropy_loss(pred, labels.argmax(-1)) + + def _softmax_q_trans(self, q): + q_shape = q.shape + return F.softmax(q.reshape(q_shape[0], -1), dim=1).reshape(q_shape) + + def _softmax_q_rot_grip(self, q_rot_grip): + q_rot_x_flat = q_rot_grip[ + :, 0 * self._num_rotation_classes : 1 * self._num_rotation_classes + ] + q_rot_y_flat = q_rot_grip[ + :, 1 * self._num_rotation_classes : 2 * self._num_rotation_classes + ] + q_rot_z_flat = q_rot_grip[ + :, 2 * self._num_rotation_classes : 3 * self._num_rotation_classes + ] + q_grip_flat = q_rot_grip[:, 3 * self._num_rotation_classes :] + + q_rot_x_flat_softmax = F.softmax(q_rot_x_flat, dim=1) + q_rot_y_flat_softmax = F.softmax(q_rot_y_flat, dim=1) + q_rot_z_flat_softmax = F.softmax(q_rot_z_flat, dim=1) + q_grip_flat_softmax = F.softmax(q_grip_flat, dim=1) + + return torch.cat( + [ + q_rot_x_flat_softmax, + q_rot_y_flat_softmax, + q_rot_z_flat_softmax, + q_grip_flat_softmax, + ], + dim=1, + ) + + def _softmax_ignore_collision(self, q_collision): + q_collision_softmax = F.softmax(q_collision, dim=1) + return q_collision_softmax + + def update(self, step: int, replay_sample: dict) -> dict: + action_trans = replay_sample["trans_action_indicies"][ + :, self._layer * 3 : self._layer * 3 + 3 + ].int() + action_rot_grip = replay_sample["rot_grip_action_indicies"].int() + action_gripper_pose = replay_sample["gripper_pose"] + action_ignore_collisions = replay_sample["ignore_collisions"].int() + lang_goal_emb = replay_sample["lang_goal_emb"].float() + lang_token_embs = replay_sample["lang_token_embs"].float() + prev_layer_voxel_grid = replay_sample.get("prev_layer_voxel_grid", None) + prev_layer_bounds = replay_sample.get("prev_layer_bounds", None) + device = self._device + + bounds = bounds_tp1 = self._coordinate_bounds + if self._layer > 0: + cp = replay_sample["attention_coordinate_layer_%d" % (self._layer - 1)] + bounds = torch.cat( + [cp - self._bounds_offset, cp + self._bounds_offset], dim=1 + ) + + proprio = None + if self._include_low_dim_state: + proprio = replay_sample["low_dim_state"] + + obs, pcd = self._preprocess_inputs(replay_sample) + + # batch size + bs = pcd[0].shape[0] + + # SE(3) augmentation of point clouds and actions + if self._transform_augmentation: + action_trans, action_rot_grip, pcd = apply_se3_augmentation( + pcd, + action_gripper_pose, + action_trans, + action_rot_grip, + bounds, + self._layer, + self._transform_augmentation_xyz, + self._transform_augmentation_rpy, + self._transform_augmentation_rot_resolution, + self._voxel_size, + self._rotation_resolution, + self._device, + ) + + # forward pass + q_trans, q_rot_grip, q_collision, voxel_grid = self._q( + obs, + proprio, + pcd, + lang_goal_emb, + lang_token_embs, + bounds, + prev_layer_bounds, + prev_layer_voxel_grid, + ) + + # argmax to choose best action + ( + coords, + rot_and_grip_indicies, + ignore_collision_indicies, + ) = self._q.choose_highest_action(q_trans, q_rot_grip, q_collision) + + q_trans_loss, q_rot_loss, q_grip_loss, q_collision_loss = 0.0, 0.0, 0.0, 0.0 + + # translation one-hot + action_trans_one_hot = self._action_trans_one_hot_zeros.clone() + for b in range(bs): + gt_coord = action_trans[b, :].int() + action_trans_one_hot[b, :, gt_coord[0], gt_coord[1], gt_coord[2]] = 1 + + # translation loss + q_trans_flat = q_trans.view(bs, -1) + action_trans_one_hot_flat = action_trans_one_hot.view(bs, -1) + q_trans_loss = self._celoss(q_trans_flat, action_trans_one_hot_flat) + + with_rot_and_grip = rot_and_grip_indicies is not None + if with_rot_and_grip: + # rotation, gripper, and collision one-hots + action_rot_x_one_hot = self._action_rot_x_one_hot_zeros.clone() + action_rot_y_one_hot = self._action_rot_y_one_hot_zeros.clone() + action_rot_z_one_hot = self._action_rot_z_one_hot_zeros.clone() + action_grip_one_hot = self._action_grip_one_hot_zeros.clone() + action_ignore_collisions_one_hot = ( + self._action_ignore_collisions_one_hot_zeros.clone() + ) + + for b in range(bs): + gt_rot_grip = action_rot_grip[b, :].int() + action_rot_x_one_hot[b, gt_rot_grip[0]] = 1 + action_rot_y_one_hot[b, gt_rot_grip[1]] = 1 + action_rot_z_one_hot[b, gt_rot_grip[2]] = 1 + action_grip_one_hot[b, gt_rot_grip[3]] = 1 + + gt_ignore_collisions = action_ignore_collisions[b, :].int() + action_ignore_collisions_one_hot[b, gt_ignore_collisions[0]] = 1 + + # flatten predictions + q_rot_x_flat = q_rot_grip[ + :, 0 * self._num_rotation_classes : 1 * self._num_rotation_classes + ] + q_rot_y_flat = q_rot_grip[ + :, 1 * self._num_rotation_classes : 2 * self._num_rotation_classes + ] + q_rot_z_flat = q_rot_grip[ + :, 2 * self._num_rotation_classes : 3 * self._num_rotation_classes + ] + q_grip_flat = q_rot_grip[:, 3 * self._num_rotation_classes :] + q_ignore_collisions_flat = q_collision + + # rotation loss + q_rot_loss += self._celoss(q_rot_x_flat, action_rot_x_one_hot) + q_rot_loss += self._celoss(q_rot_y_flat, action_rot_y_one_hot) + q_rot_loss += self._celoss(q_rot_z_flat, action_rot_z_one_hot) + + # gripper loss + q_grip_loss += self._celoss(q_grip_flat, action_grip_one_hot) + + # collision loss + q_collision_loss += self._celoss( + q_ignore_collisions_flat, action_ignore_collisions_one_hot + ) + + combined_losses = ( + (q_trans_loss * self._trans_loss_weight) + + (q_rot_loss * self._rot_loss_weight) + + (q_grip_loss * self._grip_loss_weight) + + (q_collision_loss * self._collision_loss_weight) + ) + total_loss = combined_losses.mean() + + self._optimizer.zero_grad() + total_loss.backward() + self._optimizer.step() + + self._summaries = { + "losses/total_loss": total_loss, + "losses/trans_loss": q_trans_loss.mean(), + "losses/rot_loss": q_rot_loss.mean() if with_rot_and_grip else 0.0, + "losses/grip_loss": q_grip_loss.mean() if with_rot_and_grip else 0.0, + "losses/collision_loss": q_collision_loss.mean() + if with_rot_and_grip + else 0.0, + } + + if self._lr_scheduler: + self._scheduler.step() + self._summaries["learning_rate"] = self._scheduler.get_last_lr()[0] + + self._vis_voxel_grid = voxel_grid[0] + self._vis_translation_qvalue = self._softmax_q_trans(q_trans[0]) + self._vis_max_coordinate = coords[0] + self._vis_gt_coordinate = action_trans[0] + + # Note: PerAct doesn't use multi-layer voxel grids like C2FARM + # stack prev_layer_voxel_grid(s) from previous layers into a list + if prev_layer_voxel_grid is None: + prev_layer_voxel_grid = [voxel_grid] + else: + prev_layer_voxel_grid = prev_layer_voxel_grid + [voxel_grid] + + # stack prev_layer_bound(s) from previous layers into a list + if prev_layer_bounds is None: + prev_layer_bounds = [self._coordinate_bounds.repeat(bs, 1)] + else: + prev_layer_bounds = prev_layer_bounds + [bounds] + + return { + "total_loss": total_loss, + "prev_layer_voxel_grid": prev_layer_voxel_grid, + "prev_layer_bounds": prev_layer_bounds, + } + + def act(self, step: int, observation: dict, deterministic=False) -> ActResult: + deterministic = True + bounds = self._coordinate_bounds + prev_layer_voxel_grid = observation.get("prev_layer_voxel_grid", None) + prev_layer_bounds = observation.get("prev_layer_bounds", None) + lang_goal_tokens = observation.get("lang_goal_tokens", None).long() + + # extract CLIP language embs + with torch.no_grad(): + lang_goal_tokens = lang_goal_tokens.to(device=self._device) + ( + lang_goal_emb, + lang_token_embs, + ) = self._clip_rn50.encode_text_with_embeddings(lang_goal_tokens[0]) + + if self._layer > 0: + cp = observation["attention_coordinate"] + bounds = torch.cat( + [cp - self._bounds_offset, cp + self._bounds_offset], dim=1 + ) + + # voxelization resolution + res = (bounds[:, 3:] - bounds[:, :3]) / self._voxel_size + max_rot_index = int(360 // self._rotation_resolution) + proprio = None + + if self._include_low_dim_state: + proprio = observation["low_dim_state"] + + obs, pcd = self._act_preprocess_inputs(observation) + + # correct batch size and device + obs = [[o[0].to(self._device), o[1].to(self._device)] for o in obs] + proprio = proprio[0].to(self._device) + pcd = [p.to(self._device) for p in pcd] + lang_goal_emb = lang_goal_emb.to(self._device) + lang_token_embs = lang_token_embs.to(self._device) + bounds = torch.as_tensor(bounds, device=self._device) + if prev_layer_voxel_grid is not None: + prev_layer_voxel_grid = [ + pvg.to(self._device) for pvg in prev_layer_voxel_grid + ] + if prev_layer_bounds is not None: + prev_layer_bounds = [pb.to(self._device) for pb in prev_layer_bounds] + + # inference + q_trans, q_rot_grip, q_ignore_collisions, vox_grid = self._q( + obs, + proprio, + pcd, + lang_goal_emb, + lang_token_embs, + bounds, + prev_layer_bounds, + prev_layer_voxel_grid, + ) + + # softmax Q predictions + q_trans = self._softmax_q_trans(q_trans) + q_rot_grip = ( + self._softmax_q_rot_grip(q_rot_grip) if q_rot_grip is not None else None + ) + q_ignore_collisions = ( + self._softmax_ignore_collision(q_ignore_collisions) + if q_ignore_collisions is not None + else None + ) + + # argmax Q predictions + ( + coords, + rot_and_grip_indicies, + ignore_collisions, + ) = self._q.choose_highest_action(q_trans, q_rot_grip, q_ignore_collisions) + + rot_grip_action = rot_and_grip_indicies if q_rot_grip is not None else None + ignore_collisions_action = ( + ignore_collisions.int() if ignore_collisions is not None else None + ) + + coords = coords.int() + attention_coordinate = bounds[:, :3] + res * coords + res / 2 + + # stack prev_layer_voxel_grid(s) into a list + # NOTE: PerAct doesn't used multi-layer voxel grids like C2FARM + if prev_layer_voxel_grid is None: + prev_layer_voxel_grid = [vox_grid] + else: + prev_layer_voxel_grid = prev_layer_voxel_grid + [vox_grid] + + if prev_layer_bounds is None: + prev_layer_bounds = [bounds] + else: + prev_layer_bounds = prev_layer_bounds + [bounds] + + observation_elements = { + "attention_coordinate": attention_coordinate, + "prev_layer_voxel_grid": prev_layer_voxel_grid, + "prev_layer_bounds": prev_layer_bounds, + } + info = { + "voxel_grid_depth%d" % self._layer: vox_grid, + "q_depth%d" % self._layer: q_trans, + "voxel_idx_depth%d" % self._layer: coords, + } + self._act_voxel_grid = vox_grid[0] + self._act_max_coordinate = coords[0] + self._act_qvalues = q_trans[0].detach() + return ActResult( + (coords, rot_grip_action, ignore_collisions_action), + observation_elements=observation_elements, + info=info, + ) + + def update_summaries(self) -> List[Summary]: + summaries = [ + ImageSummary( + "%s/update_qattention" % self._name, + transforms.ToTensor()( + visualise_voxel( + self._vis_voxel_grid.detach().cpu().numpy(), + self._vis_translation_qvalue.detach().cpu().numpy(), + self._vis_max_coordinate.detach().cpu().numpy(), + self._vis_gt_coordinate.detach().cpu().numpy(), + ) + ), + ) + ] + + for n, v in self._summaries.items(): + summaries.append(ScalarSummary("%s/%s" % (self._name, n), v)) + + for name, crop in self._crop_summary: + crops = (torch.cat(torch.split(crop, 3, dim=1), dim=3) + 1.0) / 2.0 + summaries.extend([ImageSummary("%s/crops/%s" % (self._name, name), crops)]) + + for tag, param in self._q.named_parameters(): + # assert not torch.isnan(param.grad.abs() <= 1.0).all() + summaries.append( + HistogramSummary("%s/gradient/%s" % (self._name, tag), param.grad) + ) + summaries.append( + HistogramSummary("%s/weight/%s" % (self._name, tag), param.data) + ) + + return summaries + + def act_summaries(self) -> List[Summary]: + return [ + ImageSummary( + "%s/act_Qattention" % self._name, + transforms.ToTensor()( + visualise_voxel( + self._act_voxel_grid.cpu().numpy(), + self._act_qvalues.cpu().numpy(), + self._act_max_coordinate.cpu().numpy(), + ) + ), + ) + ] + + def load_weights(self, savedir: str): + device = ( + self._device + if not self._training + else torch.device("cuda:%d" % self._device) + ) + state_dict = torch.load( + os.path.join(savedir, "%s.pt" % self._name), map_location=device + ) + + # load only keys that are in the current model + merged_state_dict = self._q.state_dict() + for k, v in state_dict.items(): + if "_voxelizer" not in k: + if not self._training: + k = k.replace("_qnet.module", "_qnet") + + if k in merged_state_dict: + merged_state_dict[k] = v + else: + logging.warning("key %s not found in checkpoint" % k) + self._q.load_state_dict(merged_state_dict) + print("loaded weights from %s" % savedir) + + def save_weights(self, savedir: str): + torch.save(self._q.state_dict(), os.path.join(savedir, "%s.pt" % self._name)) diff --git a/external/peract_bimanual/agents/c2farm_lingunet_bc/qattention_stack_agent.py b/external/peract_bimanual/agents/c2farm_lingunet_bc/qattention_stack_agent.py new file mode 100644 index 0000000000000000000000000000000000000000..ade2d22545da516c935e84fdd47e4a39d55a1619 --- /dev/null +++ b/external/peract_bimanual/agents/c2farm_lingunet_bc/qattention_stack_agent.py @@ -0,0 +1,136 @@ +from typing import List + +import torch +from yarr.agents.agent import Agent, ActResult, Summary + +import numpy as np + +from helpers import utils +from agents.c2farm_lingunet_bc.qattention_lingunet_bc_agent import ( + QAttentionLingUNetBCAgent, +) + +from scipy.spatial.transform import Rotation + +NAME = "QAttentionStackAgent" + + +class QAttentionStackAgent(Agent): + def __init__( + self, + qattention_agents: List[QAttentionLingUNetBCAgent], + rotation_resolution: float, + camera_names: List[str], + rotation_prediction_depth: int = 0, + ): + super(QAttentionStackAgent, self).__init__() + self._qattention_agents = qattention_agents + self._rotation_resolution = rotation_resolution + self._camera_names = camera_names + self._rotation_prediction_depth = rotation_prediction_depth + + def build(self, training: bool, device=None) -> None: + self._device = device + if self._device is None: + self._device = torch.device("cpu") + for qa in self._qattention_agents: + qa.build(training, device) + + def update(self, step: int, replay_sample: dict) -> dict: + priorities = 0 + total_losses = 0.0 + for qa in self._qattention_agents: + update_dict = qa.update(step, replay_sample) + replay_sample.update(update_dict) + total_losses += update_dict["total_loss"] + return { + "total_losses": total_losses, + } + + def act(self, step: int, observation: dict, deterministic=False) -> ActResult: + observation_elements = {} + translation_results, rot_grip_results, ignore_collisions_results = [], [], [] + infos = {} + for depth, qagent in enumerate(self._qattention_agents): + act_results = qagent.act(step, observation, deterministic) + attention_coordinate = ( + act_results.observation_elements["attention_coordinate"].cpu().numpy() + ) + observation_elements[ + "attention_coordinate_layer_%d" % depth + ] = attention_coordinate[0] + + translation_idxs, rot_grip_idxs, ignore_collisions_idxs = act_results.action + translation_results.append(translation_idxs) + if rot_grip_idxs is not None: + rot_grip_results.append(rot_grip_idxs) + if ignore_collisions_idxs is not None: + ignore_collisions_results.append(ignore_collisions_idxs) + + observation["attention_coordinate"] = act_results.observation_elements[ + "attention_coordinate" + ] + observation["prev_layer_voxel_grid"] = act_results.observation_elements[ + "prev_layer_voxel_grid" + ] + observation["prev_layer_bounds"] = act_results.observation_elements[ + "prev_layer_bounds" + ] + + for n in self._camera_names: + px, py = utils.point_to_pixel_index( + attention_coordinate[0], + observation["%s_camera_extrinsics" % n][0, 0].cpu().numpy(), + observation["%s_camera_intrinsics" % n][0, 0].cpu().numpy(), + ) + pc_t = torch.tensor( + [[[py, px]]], dtype=torch.float32, device=self._device + ) + observation["%s_pixel_coord" % n] = pc_t + observation_elements["%s_pixel_coord" % n] = [py, px] + + infos.update(act_results.info) + + rgai = torch.cat(rot_grip_results, 1)[0].cpu().numpy() + ignore_collisions = float( + torch.cat(ignore_collisions_results, 1)[0].cpu().numpy() + ) + observation_elements["trans_action_indicies"] = ( + torch.cat(translation_results, 1)[0].cpu().numpy() + ) + observation_elements["rot_grip_action_indicies"] = rgai + continuous_action = np.concatenate( + [ + act_results.observation_elements["attention_coordinate"] + .cpu() + .numpy()[0], + utils.discrete_euler_to_quaternion( + rgai[-4:-1], self._rotation_resolution + ), + rgai[-1:], + [ignore_collisions], + ] + ) + return ActResult( + continuous_action, observation_elements=observation_elements, info=infos + ) + + def update_summaries(self) -> List[Summary]: + summaries = [] + for qa in self._qattention_agents: + summaries.extend(qa.update_summaries()) + return summaries + + def act_summaries(self) -> List[Summary]: + s = [] + for qa in self._qattention_agents: + s.extend(qa.act_summaries()) + return s + + def load_weights(self, savedir: str): + for qa in self._qattention_agents: + qa.load_weights(savedir) + + def save_weights(self, savedir: str): + for qa in self._qattention_agents: + qa.save_weights(savedir) diff --git a/external/peract_bimanual/agents/peract_bc/__init__.py b/external/peract_bimanual/agents/peract_bc/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..ed6d0332ccb4c707428ab36ae39009e07e6fd298 --- /dev/null +++ b/external/peract_bimanual/agents/peract_bc/__init__.py @@ -0,0 +1 @@ +import agents.peract_bc.launch_utils diff --git a/external/peract_bimanual/agents/peract_bc/launch_utils.py b/external/peract_bimanual/agents/peract_bc/launch_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..e486650b80e76d7eb365216869fbf35e445d55bc --- /dev/null +++ b/external/peract_bimanual/agents/peract_bc/launch_utils.py @@ -0,0 +1,94 @@ +# Adapted from ARM +# Source: https://github.com/stepjam/ARM +# License: https://github.com/stepjam/ARM/LICENSE + + +from helpers.preprocess_agent import PreprocessAgent +from agents.peract_bc.perceiver_lang_io import PerceiverVoxelLangEncoder +from agents.peract_bc.qattention_peract_bc_agent import QAttentionPerActBCAgent +from agents.peract_bc.qattention_stack_agent import QAttentionStackAgent + +from omegaconf import DictConfig + + +def create_agent(cfg: DictConfig): + LATENT_SIZE = 64 + depth_0bounds = cfg.rlbench.scene_bounds + cam_resolution = cfg.rlbench.camera_resolution + + num_rotation_classes = int(360.0 // cfg.method.rotation_resolution) + qattention_agents = [] + for depth, vox_size in enumerate(cfg.method.voxel_sizes): + last = depth == len(cfg.method.voxel_sizes) - 1 + perceiver_encoder = PerceiverVoxelLangEncoder( + depth=cfg.method.transformer_depth, + iterations=cfg.method.transformer_iterations, + voxel_size=vox_size, + initial_dim=3 + 3 + 1 + 3, + low_dim_size=cfg.method.low_dim_size, + layer=depth, + num_rotation_classes=num_rotation_classes if last else 0, + num_grip_classes=2 if last else 0, + num_collision_classes=2 if last else 0, + input_axis=3, + num_latents=cfg.method.num_latents, + latent_dim=cfg.method.latent_dim, + cross_heads=cfg.method.cross_heads, + latent_heads=cfg.method.latent_heads, + cross_dim_head=cfg.method.cross_dim_head, + latent_dim_head=cfg.method.latent_dim_head, + weight_tie_layers=False, + activation=cfg.method.activation, + pos_encoding_with_lang=cfg.method.pos_encoding_with_lang, + input_dropout=cfg.method.input_dropout, + attn_dropout=cfg.method.attn_dropout, + decoder_dropout=cfg.method.decoder_dropout, + lang_fusion_type=cfg.method.lang_fusion_type, + voxel_patch_size=cfg.method.voxel_patch_size, + voxel_patch_stride=cfg.method.voxel_patch_stride, + no_skip_connection=cfg.method.no_skip_connection, + no_perceiver=cfg.method.no_perceiver, + no_language=cfg.method.no_language, + final_dim=cfg.method.final_dim, + ) + + qattention_agent = QAttentionPerActBCAgent( + layer=depth, + coordinate_bounds=depth_0bounds, + perceiver_encoder=perceiver_encoder, + camera_names=cfg.rlbench.cameras, + voxel_size=vox_size, + bounds_offset=cfg.method.bounds_offset[depth - 1] if depth > 0 else None, + image_crop_size=cfg.method.image_crop_size, + lr=cfg.method.lr, + training_iterations=cfg.framework.training_iterations, + lr_scheduler=cfg.method.lr_scheduler, + num_warmup_steps=cfg.method.num_warmup_steps, + trans_loss_weight=cfg.method.trans_loss_weight, + rot_loss_weight=cfg.method.rot_loss_weight, + grip_loss_weight=cfg.method.grip_loss_weight, + collision_loss_weight=cfg.method.collision_loss_weight, + include_low_dim_state=True, + image_resolution=cam_resolution, + batch_size=cfg.replay.batch_size, + voxel_feature_size=3, + lambda_weight_l2=cfg.method.lambda_weight_l2, + num_rotation_classes=num_rotation_classes, + rotation_resolution=cfg.method.rotation_resolution, + transform_augmentation=cfg.method.transform_augmentation.apply_se3, + transform_augmentation_xyz=cfg.method.transform_augmentation.aug_xyz, + transform_augmentation_rpy=cfg.method.transform_augmentation.aug_rpy, + transform_augmentation_rot_resolution=cfg.method.transform_augmentation.aug_rot_resolution, + optimizer_type=cfg.method.optimizer, + num_devices=cfg.ddp.num_devices, + checkpoint_name_prefix=cfg.framework.checkpoint_name_prefix, + ) + qattention_agents.append(qattention_agent) + + rotation_agent = QAttentionStackAgent( + qattention_agents=qattention_agents, + rotation_resolution=cfg.method.rotation_resolution, + camera_names=cfg.rlbench.cameras, + ) + preprocess_agent = PreprocessAgent(pose_agent=rotation_agent) + return preprocess_agent diff --git a/external/peract_bimanual/agents/peract_bc/perceiver_lang_io.py b/external/peract_bimanual/agents/peract_bc/perceiver_lang_io.py new file mode 100644 index 0000000000000000000000000000000000000000..5d944e5e4935dcfe9bd4ec8510fbd59ab975c7d7 --- /dev/null +++ b/external/peract_bimanual/agents/peract_bc/perceiver_lang_io.py @@ -0,0 +1,426 @@ +# Perceiver IO implementation adpated for manipulation +# Source: https://github.com/lucidrains/perceiver-pytorch +# License: https://github.com/lucidrains/perceiver-pytorch/blob/main/LICENSE + +import torch +from torch import nn + +from einops import rearrange +from einops import repeat + +from perceiver_pytorch.perceiver_pytorch import cache_fn +from perceiver_pytorch.perceiver_pytorch import PreNorm, FeedForward, Attention + +from helpers.network_utils import ( + DenseBlock, + SpatialSoftmax3D, + Conv3DBlock, + Conv3DUpsampleBlock, +) + + +# PerceiverIO adapted for 6-DoF manipulation +class PerceiverVoxelLangEncoder(nn.Module): + def __init__( + self, + depth, # number of self-attention layers + iterations, # number cross-attention iterations (PerceiverIO uses just 1) + voxel_size, # N voxels per side (size: N*N*N) + initial_dim, # 10 dimensions - dimension of the input sequence to be encoded + low_dim_size, # 4 dimensions - proprioception: {gripper_open, left_finger, right_finger, timestep} + layer=0, + num_rotation_classes=72, # 5 degree increments (5*72=360) for each of the 3-axis + num_grip_classes=2, # open or not open + num_collision_classes=2, # collisions allowed or not allowed + input_axis=3, # 3D tensors have 3 axes + num_latents=512, # number of latent vectors + im_channels=64, # intermediate channel size + latent_dim=512, # dimensions of latent vectors + cross_heads=1, # number of cross-attention heads + latent_heads=8, # number of latent heads + cross_dim_head=64, + latent_dim_head=64, + activation="relu", + weight_tie_layers=False, + pos_encoding_with_lang=True, + input_dropout=0.1, + attn_dropout=0.1, + decoder_dropout=0.0, + lang_fusion_type="seq", + voxel_patch_size=9, + voxel_patch_stride=8, + no_skip_connection=False, + no_perceiver=False, + no_language=False, + final_dim=64, + ): + super().__init__() + self.depth = depth + self.layer = layer + self.init_dim = int(initial_dim) + self.iterations = iterations + self.input_axis = input_axis + self.voxel_size = voxel_size + self.low_dim_size = low_dim_size + self.im_channels = im_channels + self.pos_encoding_with_lang = pos_encoding_with_lang + self.lang_fusion_type = lang_fusion_type + self.voxel_patch_size = voxel_patch_size + self.voxel_patch_stride = voxel_patch_stride + self.num_rotation_classes = num_rotation_classes + self.num_grip_classes = num_grip_classes + self.num_collision_classes = num_collision_classes + self.final_dim = final_dim + self.input_dropout = input_dropout + self.attn_dropout = attn_dropout + self.decoder_dropout = decoder_dropout + self.no_skip_connection = no_skip_connection + self.no_perceiver = no_perceiver + self.no_language = no_language + + # patchified input dimensions + spatial_size = voxel_size // self.voxel_patch_stride # 100/5 = 20 + + # 64 voxel features + 64 proprio features (+ 64 lang goal features if concattenated) + self.input_dim_before_seq = ( + self.im_channels * 3 + if self.lang_fusion_type == "concat" + else self.im_channels * 2 + ) + + # CLIP language feature dimensions + lang_feat_dim, lang_emb_dim, lang_max_seq_len = 1024, 512, 77 + + # learnable positional encoding + if self.pos_encoding_with_lang: + self.pos_encoding = nn.Parameter( + torch.randn( + 1, lang_max_seq_len + spatial_size**3, self.input_dim_before_seq + ) + ) + else: + # assert self.lang_fusion_type == 'concat', 'Only concat is supported for pos encoding without lang.' + self.pos_encoding = nn.Parameter( + torch.randn( + 1, + spatial_size, + spatial_size, + spatial_size, + self.input_dim_before_seq, + ) + ) + + # voxel input preprocessing 1x1 conv encoder + self.input_preprocess = Conv3DBlock( + self.init_dim, + self.im_channels, + kernel_sizes=1, + strides=1, + norm=None, + activation=activation, + ) + + # patchify conv + self.patchify = Conv3DBlock( + self.input_preprocess.out_channels, + self.im_channels, + kernel_sizes=self.voxel_patch_size, + strides=self.voxel_patch_stride, + norm=None, + activation=activation, + ) + + # language preprocess + if self.lang_fusion_type == "concat": + self.lang_preprocess = nn.Linear(lang_feat_dim, self.im_channels) + elif self.lang_fusion_type == "seq": + self.lang_preprocess = nn.Linear(lang_emb_dim, self.im_channels * 2) + + # proprioception + if self.low_dim_size > 0: + self.proprio_preprocess = DenseBlock( + self.low_dim_size, + self.im_channels, + norm=None, + activation=activation, + ) + + # pooling functions + self.local_maxp = nn.MaxPool3d(3, 2, padding=1) + self.global_maxp = nn.AdaptiveMaxPool3d(1) + + # 1st 3D softmax + self.ss0 = SpatialSoftmax3D( + self.voxel_size, self.voxel_size, self.voxel_size, self.im_channels + ) + flat_size = self.im_channels * 4 + + # latent vectors (that are randomly initialized) + self.latents = nn.Parameter(torch.randn(num_latents, latent_dim)) + + # encoder cross attention + self.cross_attend_blocks = nn.ModuleList( + [ + PreNorm( + latent_dim, + Attention( + latent_dim, + self.input_dim_before_seq, + heads=cross_heads, + dim_head=cross_dim_head, + dropout=input_dropout, + ), + context_dim=self.input_dim_before_seq, + ), + PreNorm(latent_dim, FeedForward(latent_dim)), + ] + ) + + get_latent_attn = lambda: PreNorm( + latent_dim, + Attention( + latent_dim, + heads=latent_heads, + dim_head=latent_dim_head, + dropout=attn_dropout, + ), + ) + get_latent_ff = lambda: PreNorm(latent_dim, FeedForward(latent_dim)) + get_latent_attn, get_latent_ff = map(cache_fn, (get_latent_attn, get_latent_ff)) + + # self attention layers + self.layers = nn.ModuleList([]) + cache_args = {"_cache": weight_tie_layers} + + for i in range(depth): + self.layers.append( + nn.ModuleList( + [get_latent_attn(**cache_args), get_latent_ff(**cache_args)] + ) + ) + + # decoder cross attention + self.decoder_cross_attn = PreNorm( + self.input_dim_before_seq, + Attention( + self.input_dim_before_seq, + latent_dim, + heads=cross_heads, + dim_head=cross_dim_head, + dropout=decoder_dropout, + ), + context_dim=latent_dim, + ) + + # upsample conv + self.up0 = Conv3DUpsampleBlock( + self.input_dim_before_seq, + self.final_dim, + kernel_sizes=self.voxel_patch_size, + strides=self.voxel_patch_stride, + norm=None, + activation=activation, + ) + + # 2nd 3D softmax + self.ss1 = SpatialSoftmax3D( + spatial_size, spatial_size, spatial_size, self.input_dim_before_seq + ) + + flat_size += self.input_dim_before_seq * 4 + + # final 3D softmax + self.final = Conv3DBlock( + self.im_channels + if (self.no_perceiver or self.no_skip_connection) + else self.im_channels * 2, + self.im_channels, + kernel_sizes=3, + strides=1, + norm=None, + activation=activation, + ) + + self.trans_decoder = Conv3DBlock( + self.final_dim, + 1, + kernel_sizes=3, + strides=1, + norm=None, + activation=None, + ) + + # rotation, gripper, and collision MLP layers + if self.num_rotation_classes > 0: + self.ss_final = SpatialSoftmax3D( + self.voxel_size, self.voxel_size, self.voxel_size, self.im_channels + ) + + flat_size += self.im_channels * 4 + + self.dense0 = DenseBlock(flat_size, 256, None, activation) + self.dense1 = DenseBlock(256, self.final_dim, None, activation) + + self.rot_grip_collision_ff = DenseBlock( + self.final_dim, + self.num_rotation_classes * 3 + + self.num_grip_classes + + self.num_collision_classes, + None, + None, + ) + + def encode_text(self, x): + with torch.no_grad(): + text_feat, text_emb = self._clip_rn50.encode_text_with_embeddings(x) + + text_feat = text_feat.detach() + text_emb = text_emb.detach() + text_mask = torch.where(x == 0, x, 1) # [1, max_token_len] + return text_feat, text_emb + + def forward( + self, + ins, + proprio, + lang_goal_emb, + lang_token_embs, + prev_layer_voxel_grid, + bounds, + prev_layer_bounds, + mask=None, + ): + # preprocess input + d0 = self.input_preprocess(ins) # [B,10,100,100,100] -> [B,64,100,100,100] + + # aggregated features from 1st softmax and maxpool for MLP decoders + feats = [self.ss0(d0.contiguous()), self.global_maxp(d0).view(ins.shape[0], -1)] + + # patchify input (5x5x5 patches) + ins = self.patchify(d0) # [B,64,100,100,100] -> [B,64,20,20,20] + + b, c, d, h, w, device = *ins.shape, ins.device + axis = [d, h, w] + assert ( + len(axis) == self.input_axis + ), "input must have the same number of axis as input_axis" + + # concat proprio + if self.low_dim_size > 0: + p = self.proprio_preprocess(proprio) # [B,4] -> [B,64] + p = p.unsqueeze(-1).unsqueeze(-1).unsqueeze(-1).repeat(1, 1, d, h, w) + ins = torch.cat([ins, p], dim=1) # [B,128,20,20,20] + + # language ablation + if self.no_language: + lang_goal_emb = torch.zeros_like(lang_goal_emb) + lang_token_embs = torch.zeros_like(lang_token_embs) + + # option 1: tile and concat lang goal to input + if self.lang_fusion_type == "concat": + lang_emb = lang_goal_emb + lang_emb = lang_emb.to(dtype=ins.dtype) + l = self.lang_preprocess(lang_emb) + l = l.unsqueeze(-1).unsqueeze(-1).unsqueeze(-1).repeat(1, 1, d, h, w) + ins = torch.cat([ins, l], dim=1) + + # channel last + ins = rearrange(ins, "b d ... -> b ... d") # [B,20,20,20,128] + + # add pos encoding to grid + if not self.pos_encoding_with_lang: + ins = ins + self.pos_encoding + + ######################## NOTE ############################# + # NOTE: If you add positional encodings ^here the lang embs + # won't have positional encodings. I accidently forgot + # to turn this off for all the experiments in the paper. + # So I guess those models were using language embs + # as a bag of words :( But it doesn't matter much for + # RLBench tasks since we don't test for novel instructions + # at test time anyway. The recommend way is to add + # positional encodings to the final input sequence + # fed into the Perceiver Transformer, as done below + # (and also in the Colab tutorial). + ########################################################### + + # concat to channels of and flatten axis + queries_orig_shape = ins.shape + + # rearrange input to be channel last + ins = rearrange(ins, "b ... d -> b (...) d") # [B,8000,128] + ins_wo_prev_layers = ins + + # option 2: add lang token embs as a sequence + if self.lang_fusion_type == "seq": + l = self.lang_preprocess(lang_token_embs) # [B,77,1024] -> [B,77,128] + ins = torch.cat((l, ins), dim=1) # [B,8077,128] + + # add pos encoding to language + flattened grid (the recommended way) + if self.pos_encoding_with_lang: + ins = ins + self.pos_encoding + + # batchify latents + x = repeat(self.latents, "n d -> b n d", b=b) + + cross_attn, cross_ff = self.cross_attend_blocks + + for it in range(self.iterations): + # encoder cross attention + x = cross_attn(x, context=ins, mask=mask) + x + x = cross_ff(x) + x + + # self-attention layers + for self_attn, self_ff in self.layers: + x = self_attn(x) + x + x = self_ff(x) + x + + # decoder cross attention + latents = self.decoder_cross_attn(ins, context=x) + + # crop out the language part of the output sequence + if self.lang_fusion_type == "seq": + latents = latents[:, l.shape[1] :] + + # reshape back to voxel grid + latents = latents.view( + b, *queries_orig_shape[1:-1], latents.shape[-1] + ) # [B,20,20,20,64] + latents = rearrange(latents, "b ... d -> b d ...") # [B,64,20,20,20] + + # aggregated features from 2nd softmax and maxpool for MLP decoders + feats.extend( + [self.ss1(latents.contiguous()), self.global_maxp(latents).view(b, -1)] + ) + + # upsample + u0 = self.up0(latents) + + # ablations + if self.no_skip_connection: + u = self.final(u0) + elif self.no_perceiver: + u = self.final(d0) + else: + u = self.final(torch.cat([d0, u0], dim=1)) + + # translation decoder + trans = self.trans_decoder(u) + + # rotation, gripper, and collision MLPs + rot_and_grip_out = None + if self.num_rotation_classes > 0: + feats.extend( + [self.ss_final(u.contiguous()), self.global_maxp(u).view(b, -1)] + ) + + dense0 = self.dense0(torch.cat(feats, dim=1)) + dense1 = self.dense1(dense0) # [B,72*3+2+2] + + rot_and_grip_collision_out = self.rot_grip_collision_ff(dense1) + rot_and_grip_out = rot_and_grip_collision_out[ + :, : -self.num_collision_classes + ] + collision_out = rot_and_grip_collision_out[:, -self.num_collision_classes :] + + return trans, rot_and_grip_out, collision_out diff --git a/external/peract_bimanual/agents/peract_bc/qattention_peract_bc_agent.py b/external/peract_bimanual/agents/peract_bc/qattention_peract_bc_agent.py new file mode 100644 index 0000000000000000000000000000000000000000..82559ccad68b30ca0843bf828cdf1168ddca87b1 --- /dev/null +++ b/external/peract_bimanual/agents/peract_bc/qattention_peract_bc_agent.py @@ -0,0 +1,808 @@ +import copy +import logging +import os +from typing import List + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +from torchvision import transforms +from pytorch3d import transforms as torch3d_tf +from yarr.agents.agent import ( + Agent, + ActResult, + ScalarSummary, + HistogramSummary, + ImageSummary, + Summary, +) + +from helpers import utils +from helpers.utils import visualise_voxel, stack_on_channel +from voxel.voxel_grid import VoxelGrid +from voxel.augmentation import apply_se3_augmentation +from einops import rearrange +from helpers.clip.core.clip import build_model, load_clip + +import transformers +from helpers.optim.lamb import Lamb + +from torch.nn.parallel import DistributedDataParallel as DDP + + +class QFunction(nn.Module): + def __init__( + self, + perceiver_encoder: nn.Module, + voxelizer: VoxelGrid, + bounds_offset: float, + rotation_resolution: float, + device, + training, + ): + super(QFunction, self).__init__() + self._rotation_resolution = rotation_resolution + self._voxelizer = voxelizer + self._bounds_offset = bounds_offset + self._qnet = perceiver_encoder.to(device) + + # distributed training + if training: + self._qnet = DDP(self._qnet, device_ids=[device]) + + def _argmax_3d(self, tensor_orig): + b, c, d, h, w = tensor_orig.shape # c will be one + idxs = tensor_orig.view(b, c, -1).argmax(-1) + indices = torch.cat([((idxs // h) // d), (idxs // h) % w, idxs % w], 1) + return indices + + def choose_highest_action(self, q_trans, q_rot_grip, q_collision): + coords = self._argmax_3d(q_trans) + rot_and_grip_indicies = None + ignore_collision = None + if q_rot_grip is not None: + q_rot = torch.stack( + torch.split( + q_rot_grip[:, :-2], int(360 // self._rotation_resolution), dim=1 + ), + dim=1, + ) + rot_and_grip_indicies = torch.cat( + [ + q_rot[:, 0:1].argmax(-1), + q_rot[:, 1:2].argmax(-1), + q_rot[:, 2:3].argmax(-1), + q_rot_grip[:, -2:].argmax(-1, keepdim=True), + ], + -1, + ) + ignore_collision = q_collision[:, -2:].argmax(-1, keepdim=True) + return coords, rot_and_grip_indicies, ignore_collision + + def forward( + self, + rgb_pcd, + proprio, + pcd, + lang_goal_emb, + lang_token_embs, + bounds=None, + prev_bounds=None, + prev_layer_voxel_grid=None, + ): + # rgb_pcd will be list of list (list of [rgb, pcd]) + b = rgb_pcd[0][0].shape[0] + pcd_flat = torch.cat([p.permute(0, 2, 3, 1).reshape(b, -1, 3) for p in pcd], 1) + + # flatten RGBs and Pointclouds + rgb = [rp[0] for rp in rgb_pcd] + feat_size = rgb[0].shape[1] + flat_imag_features = torch.cat( + [p.permute(0, 2, 3, 1).reshape(b, -1, feat_size) for p in rgb], 1 + ) + + # construct voxel grid + voxel_grid = self._voxelizer.coords_to_bounding_voxel_grid( + pcd_flat, coord_features=flat_imag_features, coord_bounds=bounds + ) + + # swap to channels fist + voxel_grid = voxel_grid.permute(0, 4, 1, 2, 3).detach() + + # batch bounds if necessary + if bounds.shape[0] != b: + bounds = bounds.repeat(b, 1) + + # forward pass + q_trans, q_rot_and_grip, q_ignore_collisions = self._qnet( + voxel_grid, + proprio, + lang_goal_emb, + lang_token_embs, + prev_layer_voxel_grid, + bounds, + prev_bounds, + ) + + return q_trans, q_rot_and_grip, q_ignore_collisions, voxel_grid + + +class QAttentionPerActBCAgent(Agent): + def __init__( + self, + layer: int, + coordinate_bounds: list, + perceiver_encoder: nn.Module, + camera_names: list, + batch_size: int, + voxel_size: int, + bounds_offset: float, + voxel_feature_size: int, + image_crop_size: int, + num_rotation_classes: int, + rotation_resolution: float, + lr: float = 0.0001, + lr_scheduler: bool = False, + training_iterations: int = 100000, + num_warmup_steps: int = 20000, + trans_loss_weight: float = 1.0, + rot_loss_weight: float = 1.0, + grip_loss_weight: float = 1.0, + collision_loss_weight: float = 1.0, + include_low_dim_state: bool = False, + image_resolution: list = None, + lambda_weight_l2: float = 0.0, + transform_augmentation: bool = True, + transform_augmentation_xyz: list = [0.0, 0.0, 0.0], + transform_augmentation_rpy: list = [0.0, 0.0, 180.0], + transform_augmentation_rot_resolution: int = 5, + optimizer_type: str = "adam", + num_devices: int = 1, + checkpoint_name_prefix=None, + ): + self._layer = layer + self._coordinate_bounds = coordinate_bounds + self._perceiver_encoder = perceiver_encoder + self._voxel_feature_size = voxel_feature_size + self._bounds_offset = bounds_offset + self._image_crop_size = image_crop_size + self._lr = lr + self._lr_scheduler = lr_scheduler + self._training_iterations = training_iterations + self._num_warmup_steps = num_warmup_steps + self._trans_loss_weight = trans_loss_weight + self._rot_loss_weight = rot_loss_weight + self._grip_loss_weight = grip_loss_weight + self._collision_loss_weight = collision_loss_weight + self._include_low_dim_state = include_low_dim_state + self._image_resolution = image_resolution or [128, 128] + self._voxel_size = voxel_size + self._camera_names = camera_names + self._num_cameras = len(camera_names) + self._batch_size = batch_size + self._lambda_weight_l2 = lambda_weight_l2 + self._transform_augmentation = transform_augmentation + self._transform_augmentation_xyz = torch.from_numpy( + np.array(transform_augmentation_xyz) + ) + self._transform_augmentation_rpy = transform_augmentation_rpy + self._transform_augmentation_rot_resolution = ( + transform_augmentation_rot_resolution + ) + self._optimizer_type = optimizer_type + self._num_devices = num_devices + self._num_rotation_classes = num_rotation_classes + self._rotation_resolution = rotation_resolution + + self._cross_entropy_loss = nn.CrossEntropyLoss(reduction="none") + checkpoint_name_prefix = checkpoint_name_prefix or "QAttentionAgent" + self._name = f"{checkpoint_name_prefix}_layer_{self._layer}" + + def build(self, training: bool, device: torch.device = None): + self._training = training + + if device is None: + device = torch.device("cpu") + + self._device = device + + self._voxelizer = VoxelGrid( + coord_bounds=self._coordinate_bounds, + voxel_size=self._voxel_size, + device=device, + batch_size=self._batch_size if training else 1, + feature_size=self._voxel_feature_size, + max_num_coords=np.prod(self._image_resolution) * self._num_cameras, + ) + + self._q = ( + QFunction( + self._perceiver_encoder, + self._voxelizer, + self._bounds_offset, + self._rotation_resolution, + device, + training, + ) + .to(device) + .train(training) + ) + + grid_for_crop = ( + torch.arange(0, self._image_crop_size, device=device) + .unsqueeze(0) + .repeat(self._image_crop_size, 1) + .unsqueeze(-1) + ) + self._grid_for_crop = torch.cat( + [grid_for_crop.transpose(1, 0), grid_for_crop], dim=2 + ).unsqueeze(0) + + self._coordinate_bounds = torch.tensor( + self._coordinate_bounds, device=device + ).unsqueeze(0) + + if self._training: + # optimizer + if self._optimizer_type == "lamb": + self._optimizer = Lamb( + self._q.parameters(), + lr=self._lr, + weight_decay=self._lambda_weight_l2, + betas=(0.9, 0.999), + adam=False, + ) + elif self._optimizer_type == "adam": + self._optimizer = torch.optim.Adam( + self._q.parameters(), + lr=self._lr, + weight_decay=self._lambda_weight_l2, + ) + else: + raise Exception("Unknown optimizer type") + + # learning rate scheduler + if self._lr_scheduler: + self._scheduler = ( + transformers.get_cosine_with_hard_restarts_schedule_with_warmup( + self._optimizer, + num_warmup_steps=self._num_warmup_steps, + num_training_steps=self._training_iterations, + num_cycles=self._training_iterations // 10000, + ) + ) + + # one-hot zero tensors + self._action_trans_one_hot_zeros = torch.zeros( + ( + self._batch_size, + 1, + self._voxel_size, + self._voxel_size, + self._voxel_size, + ), + dtype=int, + device=device, + ) + self._action_rot_x_one_hot_zeros = torch.zeros( + (self._batch_size, self._num_rotation_classes), dtype=int, device=device + ) + self._action_rot_y_one_hot_zeros = torch.zeros( + (self._batch_size, self._num_rotation_classes), dtype=int, device=device + ) + self._action_rot_z_one_hot_zeros = torch.zeros( + (self._batch_size, self._num_rotation_classes), dtype=int, device=device + ) + self._action_grip_one_hot_zeros = torch.zeros( + (self._batch_size, 2), dtype=int, device=device + ) + self._action_ignore_collisions_one_hot_zeros = torch.zeros( + (self._batch_size, 2), dtype=int, device=device + ) + + # print total params + logging.info( + "# Q Params: %d" + % sum( + p.numel() + for name, p in self._q.named_parameters() + if p.requires_grad and "clip" not in name + ) + ) + else: + for param in self._q.parameters(): + param.requires_grad = False + + # load CLIP for encoding language goals during evaluation + model, _ = load_clip("RN50", jit=False) + self._clip_rn50 = build_model(model.state_dict()) + self._clip_rn50 = self._clip_rn50.float().to(device) + self._clip_rn50.eval() + del model + + self._voxelizer.to(device) + self._q.to(device) + + def _extract_crop(self, pixel_action, observation): + # Pixel action will now be (B, 2) + # observation = stack_on_channel(observation) + h = observation.shape[-1] + top_left_corner = torch.clamp( + pixel_action - self._image_crop_size // 2, 0, h - self._image_crop_size + ) + grid = self._grid_for_crop + top_left_corner.unsqueeze(1) + grid = ((grid / float(h)) * 2.0) - 1.0 # between -1 and 1 + # Used for cropping the images across a batch + # swap fro y x, to x, y + grid = torch.cat((grid[:, :, :, 1:2], grid[:, :, :, 0:1]), dim=-1) + crop = F.grid_sample(observation, grid, mode="nearest", align_corners=True) + return crop + + def _preprocess_inputs(self, replay_sample): + obs = [] + pcds = [] + self._crop_summary = [] + for n in self._camera_names: + rgb = replay_sample["%s_rgb" % n] + pcd = replay_sample["%s_point_cloud" % n] + + obs.append([rgb, pcd]) + pcds.append(pcd) + return obs, pcds + + def _act_preprocess_inputs(self, observation): + obs, pcds = [], [] + for n in self._camera_names: + rgb = observation["%s_rgb" % n] + pcd = observation["%s_point_cloud" % n] + + obs.append([rgb, pcd]) + pcds.append(pcd) + return obs, pcds + + def _get_value_from_voxel_index(self, q, voxel_idx): + b, c, d, h, w = q.shape + q_trans_flat = q.view(b, c, d * h * w) + flat_indicies = ( + voxel_idx[:, 0] * d * h + voxel_idx[:, 1] * h + voxel_idx[:, 2] + )[:, None].int() + highest_idxs = flat_indicies.unsqueeze(-1).repeat(1, c, 1) + chosen_voxel_values = q_trans_flat.gather(2, highest_idxs)[ + ..., 0 + ] # (B, trans + rot + grip) + return chosen_voxel_values + + def _get_value_from_rot_and_grip(self, rot_grip_q, rot_and_grip_idx): + q_rot = torch.stack( + torch.split( + rot_grip_q[:, :-2], int(360 // self._rotation_resolution), dim=1 + ), + dim=1, + ) # B, 3, 72 + q_grip = rot_grip_q[:, -2:] + rot_and_grip_values = torch.cat( + [ + q_rot[:, 0].gather(1, rot_and_grip_idx[:, 0:1]), + q_rot[:, 1].gather(1, rot_and_grip_idx[:, 1:2]), + q_rot[:, 2].gather(1, rot_and_grip_idx[:, 2:3]), + q_grip.gather(1, rot_and_grip_idx[:, 3:4]), + ], + -1, + ) + return rot_and_grip_values + + def _celoss(self, pred, labels): + return self._cross_entropy_loss(pred, labels.argmax(-1)) + + def _softmax_q_trans(self, q): + q_shape = q.shape + return F.softmax(q.reshape(q_shape[0], -1), dim=1).reshape(q_shape) + + def _softmax_q_rot_grip(self, q_rot_grip): + q_rot_x_flat = q_rot_grip[ + :, 0 * self._num_rotation_classes : 1 * self._num_rotation_classes + ] + q_rot_y_flat = q_rot_grip[ + :, 1 * self._num_rotation_classes : 2 * self._num_rotation_classes + ] + q_rot_z_flat = q_rot_grip[ + :, 2 * self._num_rotation_classes : 3 * self._num_rotation_classes + ] + q_grip_flat = q_rot_grip[:, 3 * self._num_rotation_classes :] + + q_rot_x_flat_softmax = F.softmax(q_rot_x_flat, dim=1) + q_rot_y_flat_softmax = F.softmax(q_rot_y_flat, dim=1) + q_rot_z_flat_softmax = F.softmax(q_rot_z_flat, dim=1) + q_grip_flat_softmax = F.softmax(q_grip_flat, dim=1) + + return torch.cat( + [ + q_rot_x_flat_softmax, + q_rot_y_flat_softmax, + q_rot_z_flat_softmax, + q_grip_flat_softmax, + ], + dim=1, + ) + + def _softmax_ignore_collision(self, q_collision): + q_collision_softmax = F.softmax(q_collision, dim=1) + return q_collision_softmax + + def update(self, step: int, replay_sample: dict) -> dict: + action_trans = replay_sample["trans_action_indicies"][ + :, self._layer * 3 : self._layer * 3 + 3 + ].int() + action_rot_grip = replay_sample["rot_grip_action_indicies"].int() + action_gripper_pose = replay_sample["gripper_pose"] + action_ignore_collisions = replay_sample["ignore_collisions"].int() + lang_goal_emb = replay_sample["lang_goal_emb"].float() + lang_token_embs = replay_sample["lang_token_embs"].float() + prev_layer_voxel_grid = replay_sample.get("prev_layer_voxel_grid", None) + prev_layer_bounds = replay_sample.get("prev_layer_bounds", None) + device = self._device + + bounds = self._coordinate_bounds.to(device) + if self._layer > 0: + cp = replay_sample["attention_coordinate_layer_%d" % (self._layer - 1)] + bounds = torch.cat( + [cp - self._bounds_offset, cp + self._bounds_offset], dim=1 + ) + + proprio = None + if self._include_low_dim_state: + proprio = replay_sample["low_dim_state"] + + obs, pcd = self._preprocess_inputs(replay_sample) + + # batch size + bs = pcd[0].shape[0] + + # SE(3) augmentation of point clouds and actions + if self._transform_augmentation: + action_trans, action_rot_grip, pcd = apply_se3_augmentation( + pcd, + action_gripper_pose, + action_trans, + action_rot_grip, + bounds, + self._layer, + self._transform_augmentation_xyz, + self._transform_augmentation_rpy, + self._transform_augmentation_rot_resolution, + self._voxel_size, + self._rotation_resolution, + self._device, + ) + + # forward pass + q_trans, q_rot_grip, q_collision, voxel_grid = self._q( + obs, + proprio, + pcd, + lang_goal_emb, + lang_token_embs, + bounds, + prev_layer_bounds, + prev_layer_voxel_grid, + ) + + # argmax to choose best action + ( + coords, + rot_and_grip_indicies, + ignore_collision_indicies, + ) = self._q.choose_highest_action(q_trans, q_rot_grip, q_collision) + + q_trans_loss, q_rot_loss, q_grip_loss, q_collision_loss = 0.0, 0.0, 0.0, 0.0 + + # translation one-hot + action_trans_one_hot = self._action_trans_one_hot_zeros.clone() + for b in range(bs): + gt_coord = action_trans[b, :].int() + action_trans_one_hot[b, :, gt_coord[0], gt_coord[1], gt_coord[2]] = 1 + + # translation loss + q_trans_flat = q_trans.view(bs, -1) + action_trans_one_hot_flat = action_trans_one_hot.view(bs, -1) + q_trans_loss = self._celoss(q_trans_flat, action_trans_one_hot_flat) + + with_rot_and_grip = rot_and_grip_indicies is not None + if with_rot_and_grip: + # rotation, gripper, and collision one-hots + action_rot_x_one_hot = self._action_rot_x_one_hot_zeros.clone() + action_rot_y_one_hot = self._action_rot_y_one_hot_zeros.clone() + action_rot_z_one_hot = self._action_rot_z_one_hot_zeros.clone() + action_grip_one_hot = self._action_grip_one_hot_zeros.clone() + action_ignore_collisions_one_hot = ( + self._action_ignore_collisions_one_hot_zeros.clone() + ) + + for b in range(bs): + gt_rot_grip = action_rot_grip[b, :].int() + action_rot_x_one_hot[b, gt_rot_grip[0]] = 1 + action_rot_y_one_hot[b, gt_rot_grip[1]] = 1 + action_rot_z_one_hot[b, gt_rot_grip[2]] = 1 + action_grip_one_hot[b, gt_rot_grip[3]] = 1 + + gt_ignore_collisions = action_ignore_collisions[b, :].int() + action_ignore_collisions_one_hot[b, gt_ignore_collisions[0]] = 1 + + # flatten predictions + q_rot_x_flat = q_rot_grip[ + :, 0 * self._num_rotation_classes : 1 * self._num_rotation_classes + ] + q_rot_y_flat = q_rot_grip[ + :, 1 * self._num_rotation_classes : 2 * self._num_rotation_classes + ] + q_rot_z_flat = q_rot_grip[ + :, 2 * self._num_rotation_classes : 3 * self._num_rotation_classes + ] + q_grip_flat = q_rot_grip[:, 3 * self._num_rotation_classes :] + q_ignore_collisions_flat = q_collision + + # rotation loss + q_rot_loss += self._celoss(q_rot_x_flat, action_rot_x_one_hot) + q_rot_loss += self._celoss(q_rot_y_flat, action_rot_y_one_hot) + q_rot_loss += self._celoss(q_rot_z_flat, action_rot_z_one_hot) + + # gripper loss + q_grip_loss += self._celoss(q_grip_flat, action_grip_one_hot) + + # collision loss + q_collision_loss += self._celoss( + q_ignore_collisions_flat, action_ignore_collisions_one_hot + ) + + combined_losses = ( + (q_trans_loss * self._trans_loss_weight) + + (q_rot_loss * self._rot_loss_weight) + + (q_grip_loss * self._grip_loss_weight) + + (q_collision_loss * self._collision_loss_weight) + ) + total_loss = combined_losses.mean() + + self._optimizer.zero_grad() + total_loss.backward() + self._optimizer.step() + + self._summaries = { + "losses/total_loss": total_loss, + "losses/trans_loss": q_trans_loss.mean(), + "losses/rot_loss": q_rot_loss.mean() if with_rot_and_grip else 0.0, + "losses/grip_loss": q_grip_loss.mean() if with_rot_and_grip else 0.0, + "losses/collision_loss": q_collision_loss.mean() + if with_rot_and_grip + else 0.0, + } + + if self._lr_scheduler: + self._scheduler.step() + self._summaries["learning_rate"] = self._scheduler.get_last_lr()[0] + + self._vis_voxel_grid = voxel_grid[0] + self._vis_translation_qvalue = self._softmax_q_trans(q_trans[0]) + self._vis_max_coordinate = coords[0] + self._vis_gt_coordinate = action_trans[0] + + # Note: PerAct doesn't use multi-layer voxel grids like C2FARM + # stack prev_layer_voxel_grid(s) from previous layers into a list + if prev_layer_voxel_grid is None: + prev_layer_voxel_grid = [voxel_grid] + else: + prev_layer_voxel_grid = prev_layer_voxel_grid + [voxel_grid] + + # stack prev_layer_bound(s) from previous layers into a list + if prev_layer_bounds is None: + prev_layer_bounds = [self._coordinate_bounds.repeat(bs, 1)] + else: + prev_layer_bounds = prev_layer_bounds + [bounds] + + return { + "total_loss": total_loss, + "prev_layer_voxel_grid": prev_layer_voxel_grid, + "prev_layer_bounds": prev_layer_bounds, + } + + def act(self, step: int, observation: dict, deterministic=False) -> ActResult: + deterministic = True + bounds = self._coordinate_bounds + prev_layer_voxel_grid = observation.get("prev_layer_voxel_grid", None) + prev_layer_bounds = observation.get("prev_layer_bounds", None) + lang_goal_tokens = observation.get("lang_goal_tokens", None).long() + + # extract CLIP language embs + with torch.no_grad(): + lang_goal_tokens = lang_goal_tokens.to(device=self._device) + ( + lang_goal_emb, + lang_token_embs, + ) = self._clip_rn50.encode_text_with_embeddings(lang_goal_tokens[0]) + + # voxelization resolution + res = (bounds[:, 3:] - bounds[:, :3]) / self._voxel_size + max_rot_index = int(360 // self._rotation_resolution) + proprio = None + + if self._include_low_dim_state: + proprio = observation["low_dim_state"] + proprio = proprio[0].to(self._device) + + obs, pcd = self._act_preprocess_inputs(observation) + + # correct batch size and device + obs = [[o[0][0].to(self._device), o[1][0].to(self._device)] for o in obs] + pcd = [p[0].to(self._device) for p in pcd] + lang_goal_emb = lang_goal_emb.to(self._device) + lang_token_embs = lang_token_embs.to(self._device) + bounds = torch.as_tensor(bounds, device=self._device) + prev_layer_voxel_grid = ( + prev_layer_voxel_grid.to(self._device) + if prev_layer_voxel_grid is not None + else None + ) + prev_layer_bounds = ( + prev_layer_bounds.to(self._device) + if prev_layer_bounds is not None + else None + ) + + # inference + q_trans, q_rot_grip, q_ignore_collisions, vox_grid = self._q( + obs, + proprio, + pcd, + lang_goal_emb, + lang_token_embs, + bounds, + prev_layer_bounds, + prev_layer_voxel_grid, + ) + + # softmax Q predictions + q_trans = self._softmax_q_trans(q_trans) + q_rot_grip = ( + self._softmax_q_rot_grip(q_rot_grip) + if q_rot_grip is not None + else q_rot_grip + ) + q_ignore_collisions = ( + self._softmax_ignore_collision(q_ignore_collisions) + if q_ignore_collisions is not None + else q_ignore_collisions + ) + + # argmax Q predictions + ( + coords, + rot_and_grip_indicies, + ignore_collisions, + ) = self._q.choose_highest_action(q_trans, q_rot_grip, q_ignore_collisions) + + rot_grip_action = rot_and_grip_indicies if q_rot_grip is not None else None + ignore_collisions_action = ( + ignore_collisions.int() if ignore_collisions is not None else None + ) + + coords = coords.int() + attention_coordinate = bounds[:, :3] + res * coords + res / 2 + + # stack prev_layer_voxel_grid(s) into a list + # NOTE: PerAct doesn't used multi-layer voxel grids like C2FARM + if prev_layer_voxel_grid is None: + prev_layer_voxel_grid = [vox_grid] + else: + prev_layer_voxel_grid = prev_layer_voxel_grid + [vox_grid] + + if prev_layer_bounds is None: + prev_layer_bounds = [bounds] + else: + prev_layer_bounds = prev_layer_bounds + [bounds] + + observation_elements = { + "attention_coordinate": attention_coordinate, + "prev_layer_voxel_grid": prev_layer_voxel_grid, + "prev_layer_bounds": prev_layer_bounds, + } + info = { + "voxel_grid_depth%d" % self._layer: vox_grid, + "q_depth%d" % self._layer: q_trans, + "voxel_idx_depth%d" % self._layer: coords, + } + self._act_voxel_grid = vox_grid[0] + self._act_max_coordinate = coords[0] + self._act_qvalues = q_trans[0].detach() + return ActResult( + (coords, rot_grip_action, ignore_collisions_action), + observation_elements=observation_elements, + info=info, + ) + + def update_summaries(self) -> List[Summary]: + summaries = [ + ImageSummary( + "%s/update_qattention" % self._name, + transforms.ToTensor()( + visualise_voxel( + self._vis_voxel_grid.detach().cpu().numpy(), + self._vis_translation_qvalue.detach().cpu().numpy(), + self._vis_max_coordinate.detach().cpu().numpy(), + self._vis_gt_coordinate.detach().cpu().numpy(), + ) + ), + ) + ] + + for n, v in self._summaries.items(): + summaries.append(ScalarSummary("%s/%s" % (self._name, n), v)) + + for name, crop in self._crop_summary: + crops = (torch.cat(torch.split(crop, 3, dim=1), dim=3) + 1.0) / 2.0 + summaries.extend([ImageSummary("%s/crops/%s" % (self._name, name), crops)]) + + for tag, param in self._q.named_parameters(): + # assert not torch.isnan(param.grad.abs() <= 1.0).all() + summaries.append( + HistogramSummary("%s/gradient/%s" % (self._name, tag), param.grad) + ) + summaries.append( + HistogramSummary("%s/weight/%s" % (self._name, tag), param.data) + ) + + return summaries + + def act_summaries(self) -> List[Summary]: + return [ + ImageSummary( + "%s/act_Qattention" % self._name, + transforms.ToTensor()( + visualise_voxel( + self._act_voxel_grid.cpu().numpy(), + self._act_qvalues.cpu().numpy(), + self._act_max_coordinate.cpu().numpy(), + ) + ), + ) + ] + + def load_weights(self, savedir: str): + device = ( + self._device + if not self._training + else torch.device("cuda:%d" % self._device) + ) + weight_file = os.path.join(savedir, "%s.pt" % self._name) + state_dict = torch.load(weight_file, map_location=device) + + # load only keys that are in the current model + merged_state_dict = self._q.state_dict() + for k, v in state_dict.items(): + if not self._training: + k = k.replace("_qnet.module", "_qnet") + if k in merged_state_dict: + merged_state_dict[k] = v + else: + if "_voxelizer" not in k: + logging.warning("key %s not found in checkpoint" % k) + if not self._training: + # reshape voxelizer weights + b = merged_state_dict["_voxelizer._ones_max_coords"].shape[0] + merged_state_dict["_voxelizer._ones_max_coords"] = merged_state_dict[ + "_voxelizer._ones_max_coords" + ][0:1] + flat_shape = merged_state_dict["_voxelizer._flat_output"].shape[0] + merged_state_dict["_voxelizer._flat_output"] = merged_state_dict[ + "_voxelizer._flat_output" + ][0 : flat_shape // b] + merged_state_dict["_voxelizer._tiled_batch_indices"] = merged_state_dict[ + "_voxelizer._tiled_batch_indices" + ][0:1] + merged_state_dict["_voxelizer._index_grid"] = merged_state_dict[ + "_voxelizer._index_grid" + ][0:1] + self._q.load_state_dict(merged_state_dict) + print("loaded weights from %s" % weight_file) + + def save_weights(self, savedir: str): + torch.save(self._q.state_dict(), os.path.join(savedir, "%s.pt" % self._name)) diff --git a/external/peract_bimanual/agents/peract_bc/qattention_stack_agent.py b/external/peract_bimanual/agents/peract_bc/qattention_stack_agent.py new file mode 100644 index 0000000000000000000000000000000000000000..648b636bb9f91c2d301976a13925b27457b04541 --- /dev/null +++ b/external/peract_bimanual/agents/peract_bc/qattention_stack_agent.py @@ -0,0 +1,132 @@ +from typing import List + +import torch +from yarr.agents.agent import Agent, ActResult, Summary + +import numpy as np + +from helpers import utils +from agents.peract_bc.qattention_peract_bc_agent import QAttentionPerActBCAgent + +NAME = "QAttentionStackAgent" + + +class QAttentionStackAgent(Agent): + def __init__( + self, + qattention_agents: List[QAttentionPerActBCAgent], + rotation_resolution: float, + camera_names: List[str], + rotation_prediction_depth: int = 0, + ): + super(QAttentionStackAgent, self).__init__() + self._qattention_agents = qattention_agents + self._rotation_resolution = rotation_resolution + self._camera_names = camera_names + self._rotation_prediction_depth = rotation_prediction_depth + + def build(self, training: bool, device=None) -> None: + self._device = device + if self._device is None: + self._device = torch.device("cpu") + for qa in self._qattention_agents: + qa.build(training, device) + + def update(self, step: int, replay_sample: dict) -> dict: + priorities = 0 + total_losses = 0.0 + for qa in self._qattention_agents: + update_dict = qa.update(step, replay_sample) + replay_sample.update(update_dict) + total_losses += update_dict["total_loss"] + return { + "total_losses": total_losses, + } + + def act(self, step: int, observation: dict, deterministic=False) -> ActResult: + observation_elements = {} + translation_results, rot_grip_results, ignore_collisions_results = [], [], [] + infos = {} + for depth, qagent in enumerate(self._qattention_agents): + act_results = qagent.act(step, observation, deterministic) + attention_coordinate = ( + act_results.observation_elements["attention_coordinate"].cpu().numpy() + ) + observation_elements[ + "attention_coordinate_layer_%d" % depth + ] = attention_coordinate[0] + + translation_idxs, rot_grip_idxs, ignore_collisions_idxs = act_results.action + translation_results.append(translation_idxs) + if rot_grip_idxs is not None: + rot_grip_results.append(rot_grip_idxs) + if ignore_collisions_idxs is not None: + ignore_collisions_results.append(ignore_collisions_idxs) + + observation["attention_coordinate"] = act_results.observation_elements[ + "attention_coordinate" + ] + observation["prev_layer_voxel_grid"] = act_results.observation_elements[ + "prev_layer_voxel_grid" + ] + observation["prev_layer_bounds"] = act_results.observation_elements[ + "prev_layer_bounds" + ] + + for n in self._camera_names: + px, py = utils.point_to_pixel_index( + attention_coordinate[0], + observation["%s_camera_extrinsics" % n][0, 0].cpu().numpy(), + observation["%s_camera_intrinsics" % n][0, 0].cpu().numpy(), + ) + pc_t = torch.tensor( + [[[py, px]]], dtype=torch.float32, device=self._device + ) + observation["%s_pixel_coord" % n] = pc_t + observation_elements["%s_pixel_coord" % n] = [py, px] + + infos.update(act_results.info) + + rgai = torch.cat(rot_grip_results, 1)[0].cpu().numpy() + ignore_collisions = float( + torch.cat(ignore_collisions_results, 1)[0].cpu().numpy() + ) + observation_elements["trans_action_indicies"] = ( + torch.cat(translation_results, 1)[0].cpu().numpy() + ) + observation_elements["rot_grip_action_indicies"] = rgai + continuous_action = np.concatenate( + [ + act_results.observation_elements["attention_coordinate"] + .cpu() + .numpy()[0], + utils.discrete_euler_to_quaternion( + rgai[-4:-1], self._rotation_resolution + ), + rgai[-1:], + [ignore_collisions], + ] + ) + return ActResult( + continuous_action, observation_elements=observation_elements, info=infos + ) + + def update_summaries(self) -> List[Summary]: + summaries = [] + for qa in self._qattention_agents: + summaries.extend(qa.update_summaries()) + return summaries + + def act_summaries(self) -> List[Summary]: + s = [] + for qa in self._qattention_agents: + s.extend(qa.act_summaries()) + return s + + def load_weights(self, savedir: str): + for qa in self._qattention_agents: + qa.load_weights(savedir) + + def save_weights(self, savedir: str): + for qa in self._qattention_agents: + qa.save_weights(savedir) diff --git a/external/peract_bimanual/agents/replay_utils.py b/external/peract_bimanual/agents/replay_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..4c1cf884fb12ce6592959b694f2302691462bce4 --- /dev/null +++ b/external/peract_bimanual/agents/replay_utils.py @@ -0,0 +1,643 @@ +import logging +from typing import List + +import numpy as np +from rlbench.backend.observation import Observation +from rlbench.observation_config import ObservationConfig +import rlbench.utils as rlbench_utils +from rlbench.demo import Demo +from yarr.replay_buffer.replay_buffer import ReplayBuffer + +from helpers import demo_loading_utils, utils +from helpers import observation_utils +from helpers.clip.core.clip import tokenize + + +from yarr.replay_buffer.prioritized_replay_buffer import ObservationElement +from yarr.replay_buffer.replay_buffer import ReplayElement +from yarr.replay_buffer.task_uniform_replay_buffer import TaskUniformReplayBuffer + + +import torch +from torch.multiprocessing import Process, Value, Manager +from helpers.clip.core.clip import build_model, load_clip +from omegaconf import DictConfig + + +REWARD_SCALE = 100.0 +LOW_DIM_SIZE = 4 + + +def create_replay(cfg, replay_path): + if cfg.method.robot_name == "bimanual": + return create_bimanual_replay( + cfg.replay.batch_size, + cfg.replay.timesteps, + cfg.replay.prioritisation, + cfg.replay.task_uniform, + replay_path if cfg.replay.use_disk else None, + cfg.rlbench.cameras, + cfg.method.voxel_sizes, + cfg.rlbench.camera_resolution, + ) + else: + return create_unimanual_replay( + cfg.replay.batch_size, + cfg.replay.timesteps, + cfg.replay.prioritisation, + cfg.replay.task_uniform, + replay_path if cfg.replay.use_disk else None, + cfg.rlbench.cameras, + cfg.method.voxel_sizes, + cfg.rlbench.camera_resolution, + ) + + +def create_bimanual_replay( + batch_size: int, + timesteps: int, + prioritisation: bool, + task_uniform: bool, + save_dir: str, + cameras: list, + voxel_sizes, + image_size=[128, 128], + replay_size=3e5, +): + trans_indicies_size = 3 * len(voxel_sizes) + rot_and_grip_indicies_size = 3 + 1 + gripper_pose_size = 7 + ignore_collisions_size = 1 + max_token_seq_len = 77 + lang_feat_dim = 1024 + lang_emb_dim = 512 + + # low_dim_state + observation_elements = [] + observation_elements.append( + ObservationElement("right_low_dim_state", (LOW_DIM_SIZE,), np.float32) + ) + observation_elements.append( + ObservationElement("left_low_dim_state", (LOW_DIM_SIZE,), np.float32) + ) + + # rgb, depth, point cloud, intrinsics, extrinsics + for cname in cameras: + observation_elements.append( + # color, height, width + ObservationElement( + "%s_rgb" % cname, + ( + 3, + image_size[1], + image_size[0], + ), + np.float32, + ) + ) + observation_elements.append( + ObservationElement( + "%s_point_cloud" % cname, (3, image_size[1], image_size[0]), np.float16 + ) + ) # see pyrep/objects/vision_sensor.py on how pointclouds are extracted from depth frames + observation_elements.append( + ObservationElement( + "%s_camera_extrinsics" % cname, + ( + 4, + 4, + ), + np.float32, + ) + ) + observation_elements.append( + ObservationElement( + "%s_camera_intrinsics" % cname, + ( + 3, + 3, + ), + np.float32, + ) + ) + + # discretized translation, discretized rotation, discrete ignore collision, 6-DoF gripper pose, and pre-trained language embeddings + for robot_name in ["right", "left"]: + observation_elements.extend( + [ + ReplayElement( + f"{robot_name}_trans_action_indicies", + (trans_indicies_size,), + np.int32, + ), + ReplayElement( + f"{robot_name}_rot_grip_action_indicies", + (rot_and_grip_indicies_size,), + np.int32, + ), + ReplayElement( + f"{robot_name}_ignore_collisions", + (ignore_collisions_size,), + np.int32, + ), + ReplayElement( + f"{robot_name}_gripper_pose", (gripper_pose_size,), np.float32 + ), + ] + ) + + observation_elements.extend( + [ + ReplayElement("lang_goal_emb", (lang_feat_dim,), np.float32), + ReplayElement( + "lang_token_embs", + ( + max_token_seq_len, + lang_emb_dim, + ), + np.float32, + ), # extracted from CLIP's language encoder + ReplayElement("task", (), str), + ReplayElement( + "lang_goal", (1,), object + ), # language goal string for debugging and visualization + ] + ) + + extra_replay_elements = [ + ReplayElement("demo", (), bool), + ] + + replay_buffer = TaskUniformReplayBuffer( + save_dir=save_dir, + batch_size=batch_size, + timesteps=timesteps, + replay_capacity=int(replay_size), + action_shape=(8 * 2,), + action_dtype=np.float32, + reward_shape=(), + reward_dtype=np.float32, + update_horizon=1, + observation_elements=observation_elements, + extra_replay_elements=extra_replay_elements, + ) + return replay_buffer + + +def create_unimanual_replay( + batch_size: int, + timesteps: int, + prioritisation: bool, + task_uniform: bool, + save_dir: str, + cameras: list, + voxel_sizes, + image_size=[128, 128], + replay_size=3e5, +): + trans_indicies_size = 3 * len(voxel_sizes) + rot_and_grip_indicies_size = 3 + 1 + gripper_pose_size = 7 + ignore_collisions_size = 1 + max_token_seq_len = 77 + lang_feat_dim = 1024 + lang_emb_dim = 512 + + # low_dim_state + observation_elements = [] + observation_elements.append( + ObservationElement("low_dim_state", (LOW_DIM_SIZE,), np.float32) + ) + + # rgb, depth, point cloud, intrinsics, extrinsics + for cname in cameras: + observation_elements.append( + ObservationElement( + "%s_rgb" % cname, + ( + 3, + *image_size, + ), + np.float32, + ) + ) + observation_elements.append( + ObservationElement("%s_point_cloud" % cname, (3, *image_size), np.float32) + ) # see pyrep/objects/vision_sensor.py on how pointclouds are extracted from depth frames + observation_elements.append( + ObservationElement( + "%s_camera_extrinsics" % cname, + ( + 4, + 4, + ), + np.float32, + ) + ) + observation_elements.append( + ObservationElement( + "%s_camera_intrinsics" % cname, + ( + 3, + 3, + ), + np.float32, + ) + ) + + # discretized translation, discretized rotation, discrete ignore collision, 6-DoF gripper pose, and pre-trained language embeddings + observation_elements.extend( + [ + ReplayElement("trans_action_indicies", (trans_indicies_size,), np.int32), + ReplayElement( + "rot_grip_action_indicies", (rot_and_grip_indicies_size,), np.int32 + ), + ReplayElement("ignore_collisions", (ignore_collisions_size,), np.int32), + ReplayElement("gripper_pose", (gripper_pose_size,), np.float32), + ReplayElement("lang_goal_emb", (lang_feat_dim,), np.float32), + ReplayElement( + "lang_token_embs", + ( + max_token_seq_len, + lang_emb_dim, + ), + np.float32, + ), # extracted from CLIP's language encoder + ReplayElement("task", (), str), + ReplayElement( + "lang_goal", (1,), object + ), # language goal string for debugging and visualization + ] + ) + + extra_replay_elements = [ + ReplayElement("demo", (), bool), + ] + + replay_buffer = TaskUniformReplayBuffer( + save_dir=save_dir, + batch_size=batch_size, + timesteps=timesteps, + replay_capacity=int(replay_size), + action_shape=(8,), + action_dtype=np.float32, + reward_shape=(), + reward_dtype=np.float32, + update_horizon=1, + observation_elements=observation_elements, + extra_replay_elements=extra_replay_elements, + ) + return replay_buffer + + +def _get_action( + obs_tp1: Observation, + obs_tm1: Observation, + rlbench_scene_bounds: List[float], # metric 3D bounds of the scene + voxel_sizes: List[int], + bounds_offset: List[float], + rotation_resolution: int, + crop_augmentation: bool, +): + quat = utils.normalize_quaternion(obs_tp1.gripper_pose[3:]) + if quat[-1] < 0: + quat = -quat + disc_rot = utils.quaternion_to_discrete_euler(quat, rotation_resolution) + disc_rot = utils.correct_rotation_instability(disc_rot, rotation_resolution) + + attention_coordinate = obs_tp1.gripper_pose[:3] + trans_indicies, attention_coordinates = [], [] + bounds = np.array(rlbench_scene_bounds) + ignore_collisions = int(obs_tm1.ignore_collisions) + for depth, vox_size in enumerate( + voxel_sizes + ): # only single voxelization-level is used in PerAct + if depth > 0: + if crop_augmentation: + shift = bounds_offset[depth - 1] * 0.75 + attention_coordinate += np.random.uniform(-shift, shift, size=(3,)) + bounds = np.concatenate( + [ + attention_coordinate - bounds_offset[depth - 1], + attention_coordinate + bounds_offset[depth - 1], + ] + ) + index = utils.point_to_voxel_index(obs_tp1.gripper_pose[:3], vox_size, bounds) + trans_indicies.extend(index.tolist()) + res = (bounds[3:] - bounds[:3]) / vox_size + attention_coordinate = bounds[:3] + res * index + attention_coordinates.append(attention_coordinate) + + rot_and_grip_indicies = disc_rot.tolist() + grip = float(obs_tp1.gripper_open) + rot_and_grip_indicies.extend([int(obs_tp1.gripper_open)]) + return ( + trans_indicies, + rot_and_grip_indicies, + ignore_collisions, + np.concatenate([obs_tp1.gripper_pose, np.array([grip])]), + attention_coordinates, + ) + + +def _add_keypoints_to_replay( + cfg: DictConfig, + task: str, + replay: ReplayBuffer, + inital_obs: Observation, + demo: Demo, + episode_keypoints: List[int], + description: str = "", + clip_model=None, + device="cpu", +): + cameras = cfg.rlbench.cameras + rlbench_scene_bounds = cfg.rlbench.scene_bounds + voxel_sizes = cfg.method.voxel_sizes + bounds_offset = cfg.method.bounds_offset + rotation_resolution = cfg.method.rotation_resolution + crop_augmentation = cfg.method.crop_augmentation + robot_name = cfg.method.robot_name + + prev_action = None + obs = inital_obs + + for k, keypoint in enumerate(episode_keypoints): + obs_tp1 = demo[keypoint] + obs_tm1 = demo[max(0, keypoint - 1)] + + if obs_tp1.is_bimanual and robot_name == "bimanual": + # assert isinstance(obs_tp1, BimanualObservation) + ( + right_trans_indicies, + right_rot_grip_indicies, + right_ignore_collisions, + right_action, + right_attention_coordinates, + ) = _get_action( + obs_tp1.right, + obs_tm1.right, + rlbench_scene_bounds, + voxel_sizes, + bounds_offset, + rotation_resolution, + crop_augmentation, + ) + + ( + left_trans_indicies, + left_rot_grip_indicies, + left_ignore_collisions, + left_action, + left_attention_coordinates, + ) = _get_action( + obs_tp1.left, + obs_tm1.left, + rlbench_scene_bounds, + voxel_sizes, + bounds_offset, + rotation_resolution, + crop_augmentation, + ) + + action = np.append(right_action, left_action) + + right_ignore_collisions = np.array([right_ignore_collisions]) + left_ignore_collisions = np.array([left_ignore_collisions]) + + elif robot_name == "unimanual": + ( + trans_indicies, + rot_grip_indicies, + ignore_collisions, + action, + attention_coordinates, + ) = _get_action( + obs_tp1, + obs_tm1, + rlbench_scene_bounds, + voxel_sizes, + bounds_offset, + rotation_resolution, + crop_augmentation, + ) + gripper_pose = obs_tp1.gripper_pose + elif obs_tp1.is_bimanual and robot_name == "right": + ( + trans_indicies, + rot_grip_indicies, + ignore_collisions, + action, + attention_coordinates, + ) = _get_action( + obs_tp1.right, + obs_tm1.right, + rlbench_scene_bounds, + voxel_sizes, + bounds_offset, + rotation_resolution, + crop_augmentation, + ) + gripper_pose = obs_tp1.right.gripper_pose + elif obs_tp1.is_bimanual and robot_name == "left": + ( + trans_indicies, + rot_grip_indicies, + ignore_collisions, + action, + attention_coordinates, + ) = _get_action( + obs_tp1.left, + obs_tm1.left, + rlbench_scene_bounds, + voxel_sizes, + bounds_offset, + rotation_resolution, + crop_augmentation, + ) + gripper_pose = obs_tp1.left.gripper_pose + else: + logging.error("Invalid robot name %s", cfg.method.robot_name) + raise Exception("Invalid robot name.") + + terminal = k == len(episode_keypoints) - 1 + reward = float(terminal) * REWARD_SCALE if terminal else 0 + + obs_dict = observation_utils.extract_obs( + obs, + t=k, + prev_action=prev_action, + cameras=cameras, + episode_length=cfg.rlbench.episode_length, + robot_name=robot_name, + ) + tokens = tokenize([description]).numpy() + token_tensor = torch.from_numpy(tokens).to(device) + sentence_emb, token_embs = clip_model.encode_text_with_embeddings(token_tensor) + obs_dict["lang_goal_emb"] = sentence_emb[0].float().detach().cpu().numpy() + obs_dict["lang_token_embs"] = token_embs[0].float().detach().cpu().numpy() + + prev_action = np.copy(action) + + others = {"demo": True} + if robot_name == "bimanual": + final_obs = { + "right_trans_action_indicies": right_trans_indicies, + "right_rot_grip_action_indicies": right_rot_grip_indicies, + "right_gripper_pose": obs_tp1.right.gripper_pose, + "left_trans_action_indicies": left_trans_indicies, + "left_rot_grip_action_indicies": left_rot_grip_indicies, + "left_gripper_pose": obs_tp1.left.gripper_pose, + "task": task, + "lang_goal": np.array([description], dtype=object), + } + else: + final_obs = { + "trans_action_indicies": trans_indicies, + "rot_grip_action_indicies": rot_grip_indicies, + "gripper_pose": gripper_pose, + "task": task, + "lang_goal": np.array([description], dtype=object), + } + + others.update(final_obs) + others.update(obs_dict) + + timeout = False + replay.add(action, reward, terminal, timeout, **others) + obs = obs_tp1 + + # final step + obs_dict_tp1 = observation_utils.extract_obs( + obs_tp1, + t=k + 1, + prev_action=prev_action, + cameras=cameras, + episode_length=cfg.rlbench.episode_length, + robot_name=cfg.method.robot_name, + ) + obs_dict_tp1["lang_goal_emb"] = sentence_emb[0].float().detach().cpu().numpy() + obs_dict_tp1["lang_token_embs"] = token_embs[0].float().detach().cpu().numpy() + + obs_dict_tp1.pop("wrist_world_to_cam", None) + obs_dict_tp1.update(final_obs) + replay.add_final(**obs_dict_tp1) + + +def fill_replay( + cfg: DictConfig, + obs_config: ObservationConfig, + rank: int, + replay: ReplayBuffer, + task: str, + clip_model=None, + device="cpu", +): + num_demos = cfg.rlbench.demos + demo_augmentation = cfg.method.demo_augmentation + demo_augmentation_every_n = cfg.method.demo_augmentation_every_n + keypoint_method = cfg.method.keypoint_method + + if clip_model is None: + model, _ = load_clip("RN50", jit=False, device=device) + clip_model = build_model(model.state_dict()) + clip_model.to(device) + del model + + logging.debug("Filling %s replay ..." % task) + for d_idx in range(num_demos): + # load demo from disk + demo = rlbench_utils.get_stored_demos( + amount=1, + image_paths=False, + dataset_root=cfg.rlbench.demo_path, + variation_number=-1, + task_name=task, + obs_config=obs_config, + random_selection=False, + from_episode_number=d_idx, + )[0] + + descs = demo._observations[0].misc["descriptions"] + + # extract keypoints (a.k.a keyframes) + episode_keypoints = demo_loading_utils.keypoint_discovery( + demo, method=keypoint_method + ) + + if rank == 0: + logging.info( + f"Loading Demo({d_idx}) - found {len(episode_keypoints)} keypoints - {task}" + ) + + for i in range(len(demo) - 1): + if not demo_augmentation and i > 0: + break + if i % demo_augmentation_every_n != 0: + continue + + obs = demo[i] + desc = descs[0] + # if our starting point is past one of the keypoints, then remove it + while len(episode_keypoints) > 0 and i >= episode_keypoints[0]: + episode_keypoints = episode_keypoints[1:] + if len(episode_keypoints) == 0: + break + _add_keypoints_to_replay( + cfg, + task, + replay, + obs, + demo, + episode_keypoints, + description=desc, + clip_model=clip_model, + device=device, + ) + logging.debug("Replay %s filled with demos." % task) + + +def fill_multi_task_replay( + cfg: DictConfig, + obs_config: ObservationConfig, + rank: int, + replay: ReplayBuffer, + tasks: List[str], + clip_model=None, +): + tasks = cfg.rlbench.tasks + + manager = Manager() + store = manager.dict() + + # create a MP dict for storing indicies + # TODO(mohit): this shouldn't be initialized here + del replay._task_idxs + task_idxs = manager.dict() + replay._task_idxs = task_idxs + replay._create_storage(store) + replay.add_count = Value("i", 0) + + # fill replay buffer in parallel across tasks + max_parallel_processes = cfg.replay.max_parallel_processes + processes = [] + n = np.arange(len(tasks)) + split_n = utils.split_list(n, max_parallel_processes) + for split in split_n: + for e_idx, task_idx in enumerate(split): + task = tasks[int(task_idx)] + model_device = torch.device( + "cuda:%s" % (e_idx % torch.cuda.device_count()) + if torch.cuda.is_available() + else "cpu" + ) + p = Process( + target=fill_replay, + args=(cfg, obs_config, rank, replay, task, clip_model, model_device), + ) + + p.start() + processes.append(p) + + for p in processes: + p.join() diff --git a/external/peract_bimanual/agents/rvt/__init__.py b/external/peract_bimanual/agents/rvt/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..2532f671ea3b1d4ade07ca26e13c5225f6d805d8 --- /dev/null +++ b/external/peract_bimanual/agents/rvt/__init__.py @@ -0,0 +1 @@ +import agents.rvt.launch_utils diff --git a/external/peract_bimanual/agents/rvt/launch_utils.py b/external/peract_bimanual/agents/rvt/launch_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..33faba23c1953e0a5af61b725107aeb1a527fa51 --- /dev/null +++ b/external/peract_bimanual/agents/rvt/launch_utils.py @@ -0,0 +1,168 @@ +import os +from typing import List +import torch +import numpy as np + +from omegaconf import DictConfig + +from yarr.agents.agent import Agent +from yarr.agents.agent import ActResult +from yarr.agents.agent import Summary +from yarr.agents.agent import ScalarSummary + + +from torch.nn.parallel import DistributedDataParallel as DDP + +from helpers.preprocess_agent import PreprocessAgent + + +from rvt.mvt.mvt import MVT +from rvt.models import rvt_agent +from rvt.utils.peract_utils import ( + CAMERAS, + SCENE_BOUNDS, + IMAGE_SIZE, + DATA_FOLDER, +) + + +import rvt.config as exp_cfg_mod +import rvt.models.rvt_agent as rvt_agent +import rvt.mvt.config as mvt_cfg_mod + + +def create_agent(cfg: DictConfig): + exp_cfg = exp_cfg_mod.get_cfg_defaults() + exp_cfg.bs = cfg.replay.batch_size + exp_cfg.tasks = ",".join(cfg.rlbench.tasks) + + exp_cfg.freeze() + + mvt_cfg = mvt_cfg_mod.get_cfg_defaults() + mvt_cfg.proprio_dim = cfg.method.low_dim_size + mvt_cfg.freeze() + + agent = RVTAgentWrapper( + cfg.framework.checkpoint_name_prefix, cfg.rlbench, mvt_cfg, exp_cfg + ) + + preprocess_agent = PreprocessAgent(pose_agent=agent) + return preprocess_agent + + +class RVTAgentWrapper(Agent): + def __init__(self, checkpoint_name_prefix, rlbench_cfg, mvt_cfg, exp_cfg): + self._checkpoint_filename = f"{checkpoint_name_prefix}.pt" + self.rvt_agent = None + self.rlbench_cfg = rlbench_cfg + self.mvt_cfg = mvt_cfg + self.exp_cfg = exp_cfg + self._summaries = {} + + def build(self, training: bool, device=None) -> None: + import torch + + torch.cuda.set_device(device) + torch.cuda.empty_cache() + + if isinstance(device, int): + device = f"cuda:{device}" + + rvt = MVT( + renderer_device=device, + **self.mvt_cfg, + ) + rvt = rvt.to(device) + + if training: + rvt = DDP(rvt, device_ids=[device]) + + self.rvt_agent = rvt_agent.RVTAgent( + network=rvt, + # image_resolution=self.rlbench_cfg.camera_resolution, + add_lang=self.mvt_cfg.add_lang, + scene_bounds=self.rlbench_cfg.scene_bounds, + cameras=self.rlbench_cfg.cameras, + log_dir="/tmp/eval_run", + **self.exp_cfg.peract, + **self.exp_cfg.rvt, + ) + + self.rvt_agent.build(training, device) + + def update(self, step: int, replay_sample: dict) -> dict: + for k, v in replay_sample.items(): + replay_sample[k] = v.unsqueeze(1) + # RVT is based on the PerAct's Colab version. + replay_sample["lang_goal_embs"] = replay_sample["lang_token_embs"] + replay_sample["tasks"] = self.exp_cfg.tasks.split(",") + + update_dict = self.rvt_agent.update(step, replay_sample) + + for key, val in self.rvt_agent.loss_log.items(): + self._summaries[key] = np.mean(np.array(val)) + + return { + "total_losses": update_dict["total_loss"], + } + + return result + + def act(self, step: int, observation: dict, deterministic: bool) -> ActResult: + return self.rvt_agent.act(step, observation, deterministic) + + def reset(self) -> None: + self.rvt_agent.reset() + + def update_summaries(self) -> List[Summary]: + summaries = [] + for k, v in self._summaries.items(): + summaries.append(ScalarSummary(f"RVT/{k}", v)) + return summaries + + def act_summaries(self) -> List[Summary]: + return [] + + def load_weights(self, savedir: str) -> None: + """ + copied from RVT + """ + device = torch.device("cuda:0") + weight_file = os.path.join(savedir, self._checkpoint_filename) + state_dict = torch.load(weight_file, map_location=device) + + model = self.rvt_agent._network + optimizer = self.rvt_agent._optimizer + lr_sched = self.rvt_agent._lr_sched + + if isinstance(model, DDP): + model = model.module + + model.load_state_dict(state_dict["model_state"]) + optimizer.load_state_dict(state_dict["optimizer_state"]) + lr_sched.load_state_dict(state_dict["lr_sched_state"]) + + return self.rvt_agent.load_clip() + + def save_weights(self, savedir: str) -> None: + os.makedirs(savedir, exist_ok=True) + + weight_file = os.path.join(savedir, self._checkpoint_filename) + + model = self.rvt_agent._network + optimizer = self.rvt_agent._optimizer + lr_sched = self.rvt_agent._lr_sched + + if isinstance(model, DDP): + model = model.module + + model_state = model.state_dict() + + torch.save( + { + "model_state": model_state, + "optimizer_state": optimizer.state_dict(), + "lr_sched_state": lr_sched.state_dict(), + }, + weight_file, + ) diff --git a/external/peract_bimanual/conf/config.yaml b/external/peract_bimanual/conf/config.yaml new file mode 100644 index 0000000000000000000000000000000000000000..ec8fc598bd7a41bd23ceb9e785c5d5730958d077 --- /dev/null +++ b/external/peract_bimanual/conf/config.yaml @@ -0,0 +1,52 @@ +ddp: + master_addr: "localhost" + master_port: "0" + num_devices: 1 + +rlbench: + task_name: "multi" + tasks: [open_drawer,slide_block_to_color_target] + demos: 100 + demo_path: /my/demo/path + episode_length: 25 + cameras: ["over_shoulder_left", "over_shoulder_right", "overhead", "wrist_right", "wrist_left", "front"] + camera_resolution: [128, 128] + scene_bounds: [-0.3, -0.5, 0.6, 0.7, 0.5, 1.6] + include_lang_goal_in_obs: True + +replay: + batch_size: 8 + timesteps: 1 + prioritisation: False + task_uniform: True # uniform sampling of tasks for multi-task buffers + use_disk: True + path: '/tmp/arm/replay' # only used when use_disk is True. + max_parallel_processes: 32 + +framework: + log_freq: 100 + save_freq: 100 + train_envs: 1 + replay_ratio: ${replay.batch_size} + transitions_before_train: 200 + tensorboard_logging: True + csv_logging: True + training_iterations: 40000 + gpu: 0 + env_gpu: 0 + logdir: '/tmp/arm_test/' + logging_level: 20 # https://docs.python.org/3/library/logging.html#levels + seeds: 1 + start_seed: 0 + load_existing_weights: True + num_weights_to_keep: 60 # older checkpoints will be deleted chronologically + num_workers: 0 + record_every_n: 5 + checkpoint_name_prefix: "checkpoint" + +defaults: + - method: PERACT_BC + +hydra: + run: + dir: ${framework.logdir}/${rlbench.task_name}/${method.name} diff --git a/external/peract_bimanual/conf/eval.yaml b/external/peract_bimanual/conf/eval.yaml new file mode 100644 index 0000000000000000000000000000000000000000..9dc3ffb66d5b3153651a837a5e93348e046d2490 --- /dev/null +++ b/external/peract_bimanual/conf/eval.yaml @@ -0,0 +1,39 @@ +defaults: + - method: PERACT_BC + + +rlbench: + task_name: "multi" + tasks: [open_drawer,slide_block_to_color_target] + demo_path: /my/demo/path + episode_length: 25 + cameras: ["over_shoulder_left", "over_shoulder_right", "overhead", "wrist_right", "wrist_left", "front"] + camera_resolution: [128, 128] + scene_bounds: [-0.3, -0.5, 0.6, 0.7, 0.5, 1.6] + include_lang_goal_in_obs: True + time_in_state: True + headless: True + gripper_mode: 'Discrete' + arm_action_mode: 'EndEffectorPoseViaPlanning' + action_mode: 'MoveArmThenGripper' + +framework: + tensorboard_logging: True + csv_logging: True + gpu: 0 + logdir: '/tmp/arm_test/' + start_seed: 0 + record_every_n: 5 + + eval_envs: 1 + eval_from_eps_number: 0 + eval_episodes: 5 + eval_type: 'last' # or 'best', 'missing', or 'last' + eval_save_metrics: True + +cinematic_recorder: + enabled: False + camera_resolution: [1280, 720] + fps: 30 + rotate_speed: 0.005 + save_path: '/tmp/videos/' diff --git a/external/peract_bimanual/conf/hydra/job_logging/custom.yaml b/external/peract_bimanual/conf/hydra/job_logging/custom.yaml new file mode 100644 index 0000000000000000000000000000000000000000..742d1e311f48645a15cded235e727c915b43e690 --- /dev/null +++ b/external/peract_bimanual/conf/hydra/job_logging/custom.yaml @@ -0,0 +1,12 @@ +version: 1 +formatters: + simple: + format: '[%(levelname)s] - %(message)s' +handlers: + rich_console: + class: rich.logging.RichHandler +root: + handlers: [rich_console] + + +disable_existing_loggers: false diff --git a/external/peract_bimanual/conf/method/ACT_BC_LANG.yaml b/external/peract_bimanual/conf/method/ACT_BC_LANG.yaml new file mode 100644 index 0000000000000000000000000000000000000000..4917f082a301bf11c8dd2c06529b389c3c744a93 --- /dev/null +++ b/external/peract_bimanual/conf/method/ACT_BC_LANG.yaml @@ -0,0 +1,51 @@ +# @package _group_ + +name: 'ACT_BC_LANG' + +# Agent +robot_name: 'bimanual' +agent_type: 'bimanual' + + +train_demo_path: "/home/markus/rlbench_data_v2_128/train/" + +activation: lrelu +lr: 1e-4 +weight_decay: 0.000001 +grad_clip: 0.1 +demo_augmentation: True +demo_augmentation_every_n: 10 + +prev_action_horizon: 1 +next_action_horizon: 10 + +# hyperparameters +lr_backbone: 1e-5 +backbone: resnet18 +dilation: False +position_embedding: sine +kl_weight: 100 +chunk_size: ${method.next_action_horizon} + +# transformer +input_dim: 16 # 7 revolute joints + 1 gripper joints +enc_layers: 4 +dec_layers: 7 +dim_feedforward: 3200 +hidden_dim: 512 +dropout: 0.1 +nheads: 8 +num_queries: ${method.next_action_horizon} +pre_norm: False + +# unused +masks: False + +# legacy +camera_names: ${rlbench.cameras} + +# ..todo:: also set the following + ++rlbench.episode_length: 400 ++rlbench.arm_action_mode: JointPosition ++rlbench.action_mode: JointPositionActionMode diff --git a/external/peract_bimanual/conf/method/ARM.yaml b/external/peract_bimanual/conf/method/ARM.yaml new file mode 100644 index 0000000000000000000000000000000000000000..58720f8e5eef553a7dd415121cf9d3d0a67f54d7 --- /dev/null +++ b/external/peract_bimanual/conf/method/ARM.yaml @@ -0,0 +1,24 @@ +# @package _group_ + +name: 'ARM' +activation: lrelu +q_conf: True +alpha: 0.05 +alpha_lr: 0.0001 +alpha_auto_tune: False +next_best_pose_critic_lr: 0.0025 +next_best_pose_actor_lr: 0.001 +next_best_pose_critic_weight_decay: 0.00001 +next_best_pose_actor_weight_decay: 0.00001 +crop_shape: [16, 16] +next_best_pose_tau: 0.005 +next_best_pose_critic_grad_clip: 5 +next_best_pose_actor_grad_clip: 5 +qattention_grad_clip: 5 +qattention_tau: 0.005 +qattention_lr: 0.0005 +qattention_weight_decay: 0.00001 +qattention_lambda_qreg: 0.0000001 + +demo_augmentation: True +demo_augmentation_every_n: 10 diff --git a/external/peract_bimanual/conf/method/BC_LANG.yaml b/external/peract_bimanual/conf/method/BC_LANG.yaml new file mode 100644 index 0000000000000000000000000000000000000000..9dd27e67861b2b6724f05dadaaedfe5c7ef3ad06 --- /dev/null +++ b/external/peract_bimanual/conf/method/BC_LANG.yaml @@ -0,0 +1,9 @@ +# @package _group_ + +name: 'BC_LANG' +activation: lrelu +lr: 0.0005 +weight_decay: 0.000001 +grad_clip: 0.1 +demo_augmentation: True +demo_augmentation_every_n: 10 diff --git a/external/peract_bimanual/conf/method/BIMANUAL_PERACT.yaml b/external/peract_bimanual/conf/method/BIMANUAL_PERACT.yaml new file mode 100644 index 0000000000000000000000000000000000000000..0a66a21a327cdc981329eb1f991d9080d020e177 --- /dev/null +++ b/external/peract_bimanual/conf/method/BIMANUAL_PERACT.yaml @@ -0,0 +1,70 @@ +# @package _group_ + +name: 'BIMANUAL_PERACT' + +# Agent +robot_name: 'bimanual' +agent_type: 'bimanual' + + +# Voxelization +image_crop_size: 64 +bounds_offset: [0.15] +voxel_sizes: [100] +include_prev_layer: False + +# Perceiver +num_latents: 2048 +latent_dim: 512 +transformer_depth: 6 +transformer_iterations: 1 +cross_heads: 1 +cross_dim_head: 64 +latent_heads: 8 +latent_dim_head: 64 +pos_encoding_with_lang: True +conv_downsample: True +lang_fusion_type: 'seq' # or 'concat' +voxel_patch_size: 5 +voxel_patch_stride: 5 +final_dim: 64 +low_dim_size: 8 + + +# Training +input_dropout: 0.1 +attn_dropout: 0.1 +decoder_dropout: 0.0 + +lr: 0.0005 +lr_scheduler: False +num_warmup_steps: 3000 +optimizer: 'lamb' # or 'adam' + +lambda_weight_l2: 0.000001 +trans_loss_weight: 1.0 +rot_loss_weight: 1.0 +grip_loss_weight: 1.0 +collision_loss_weight: 1.0 +rotation_resolution: 5 + +# Network +activation: lrelu +norm: None + +# Augmentation +crop_augmentation: True +transform_augmentation: + apply_se3: True + aug_xyz: [0.125, 0.125, 0.125] + aug_rpy: [0.0, 0.0, 45.0] + aug_rot_resolution: ${method.rotation_resolution} + +demo_augmentation: True +demo_augmentation_every_n: 10 + +# Ablations +no_skip_connection: False +no_perceiver: False +no_language: False +keypoint_method: 'heuristic' diff --git a/external/peract_bimanual/conf/method/C2FARM_LINGUNET_BC.yaml b/external/peract_bimanual/conf/method/C2FARM_LINGUNET_BC.yaml new file mode 100644 index 0000000000000000000000000000000000000000..4019dc480182163c5fdf9017d9d9034c227172eb --- /dev/null +++ b/external/peract_bimanual/conf/method/C2FARM_LINGUNET_BC.yaml @@ -0,0 +1,40 @@ +# @package _group_ + +name: 'C2FARM_LINGUNET_BC' + +# Voxelization +image_crop_size: 64 +bounds_offset: [0.15] +voxel_sizes: [32, 32] +include_prev_layer: False + +# Training +lr: 0.0005 +lr_scheduler: False +num_warmup_steps: 10000 + +lambda_weight_l2: 0.000001 +trans_loss_weight: 1.0 +rot_loss_weight: 1.0 +grip_loss_weight: 1.0 +collision_loss_weight: 1.0 +rotation_resolution: 5 + +# Network +activation: lrelu +norm: None + +# Augmentation +crop_augmentation: True +transform_augmentation: + apply_se3: True + aug_xyz: [0.125, 0.125, 0.125] + aug_rpy: [0.0, 0.0, 45.0] + aug_rot_resolution: ${method.rotation_resolution} + +demo_augmentation: True +demo_augmentation_every_n: 10 +exploration_strategy: gaussian + +# Ablations +keypoint_method: 'heuristic' \ No newline at end of file diff --git a/external/peract_bimanual/conf/method/PERACT_BC.yaml b/external/peract_bimanual/conf/method/PERACT_BC.yaml new file mode 100644 index 0000000000000000000000000000000000000000..08d3e1bd6c34eb8be89e04a41b5136eac42e530e --- /dev/null +++ b/external/peract_bimanual/conf/method/PERACT_BC.yaml @@ -0,0 +1,68 @@ +# @package _group_ + +name: 'PERACT_BC' + +# Agent +agent_type: 'leader_follower' +robot_name: 'bimanual' + +# Voxelization +image_crop_size: 64 +bounds_offset: [0.15] +voxel_sizes: [100] +include_prev_layer: False + +# Perceiver +num_latents: 2048 +latent_dim: 512 +transformer_depth: 6 +transformer_iterations: 1 +cross_heads: 1 +cross_dim_head: 64 +latent_heads: 8 +latent_dim_head: 64 +pos_encoding_with_lang: True +conv_downsample: True +lang_fusion_type: 'seq' # or 'concat' +voxel_patch_size: 5 +voxel_patch_stride: 5 +final_dim: 64 +low_dim_size: 4 + +# Training +input_dropout: 0.1 +attn_dropout: 0.1 +decoder_dropout: 0.0 + +lr: 0.0005 +lr_scheduler: False +num_warmup_steps: 3000 +optimizer: 'lamb' # or 'adam' + +lambda_weight_l2: 0.000001 +trans_loss_weight: 1.0 +rot_loss_weight: 1.0 +grip_loss_weight: 1.0 +collision_loss_weight: 1.0 +rotation_resolution: 5 + +# Network +activation: lrelu +norm: None + +# Augmentation +crop_augmentation: True +transform_augmentation: + apply_se3: True + aug_xyz: [0.125, 0.125, 0.125] + aug_rpy: [0.0, 0.0, 45.0] + aug_rot_resolution: ${method.rotation_resolution} + +demo_augmentation: True +demo_augmentation_every_n: 10 + +# Ablations +no_skip_connection: False +no_perceiver: False +no_language: False +keypoint_method: 'heuristic' diff --git a/external/peract_bimanual/conf/method/RVT.yaml b/external/peract_bimanual/conf/method/RVT.yaml new file mode 100644 index 0000000000000000000000000000000000000000..6b226db86314c034ef326ca2ce0d00c1804306f0 --- /dev/null +++ b/external/peract_bimanual/conf/method/RVT.yaml @@ -0,0 +1,69 @@ +# @package _group_ + +name: 'RVT' + +# Agent +agent_type: 'leader_follower' +robot_name: 'bimanual' + +# Voxelization +image_crop_size: 64 +bounds_offset: [0.15] +voxel_sizes: [100] +include_prev_layer: False + +low_dim_size: 4 + +# Perceiver +num_latents: 2048 +latent_dim: 512 +transformer_depth: 6 +transformer_iterations: 1 +cross_heads: 1 +cross_dim_head: 64 +latent_heads: 8 +latent_dim_head: 64 +pos_encoding_with_lang: True +conv_downsample: True +lang_fusion_type: 'seq' # or 'concat' +voxel_patch_size: 5 +voxel_patch_stride: 5 +final_dim: 64 + +# Training +input_dropout: 0.1 +attn_dropout: 0.1 +decoder_dropout: 0.0 + +lr: 0.0005 +lr_scheduler: False +num_warmup_steps: 3000 +optimizer: 'lamb' # or 'adam' + +lambda_weight_l2: 0.000001 +trans_loss_weight: 1.0 +rot_loss_weight: 1.0 +grip_loss_weight: 1.0 +collision_loss_weight: 1.0 +rotation_resolution: 5 + +# Network +activation: lrelu +norm: None + +# Augmentation +crop_augmentation: True +transform_augmentation: + apply_se3: True + aug_xyz: [0.125, 0.125, 0.125] + aug_rpy: [0.0, 0.0, 45.0] + aug_rot_resolution: ${method.rotation_resolution} + +demo_augmentation: True +demo_augmentation_every_n: 10 + +# Ablations +no_skip_connection: False +no_perceiver: False +no_language: False +keypoint_method: 'heuristic' diff --git a/external/peract_bimanual/conf/method/VIT_BC_LANG.yaml b/external/peract_bimanual/conf/method/VIT_BC_LANG.yaml new file mode 100644 index 0000000000000000000000000000000000000000..f7949ada14c854682ad0cfa7b77305fc1ff8a2b7 --- /dev/null +++ b/external/peract_bimanual/conf/method/VIT_BC_LANG.yaml @@ -0,0 +1,9 @@ +# @package _group_ + +name: 'VIT_BC_LANG' +activation: lrelu +lr: 0.0005 +weight_decay: 0.000001 +grad_clip: 0.1 +demo_augmentation: True +demo_augmentation_every_n: 10 diff --git a/external/peract_bimanual/eval.py b/external/peract_bimanual/eval.py new file mode 100644 index 0000000000000000000000000000000000000000..1d939cd1089e8c8c1159480bb3bec75cb11f6a7d --- /dev/null +++ b/external/peract_bimanual/eval.py @@ -0,0 +1,290 @@ +import gc +import logging +import os +import sys + +import peract_config + +import hydra +import numpy as np +import torch +import pandas as pd +from omegaconf import DictConfig, OmegaConf, ListConfig +from rlbench.action_modes.action_mode import BimanualMoveArmThenGripper +from rlbench.action_modes.action_mode import BimanualJointPositionActionMode +from rlbench.action_modes.arm_action_modes import BimanualEndEffectorPoseViaPlanning +from rlbench.action_modes.arm_action_modes import BimanualJointPosition, JointPosition +from rlbench.action_modes.gripper_action_modes import BimanualDiscrete +from rlbench.action_modes.action_mode import MoveArmThenGripper +from rlbench.action_modes.arm_action_modes import EndEffectorPoseViaPlanning +from rlbench.action_modes.gripper_action_modes import Discrete + +from rlbench.backend import task as rlbench_task +from rlbench.backend.utils import task_file_to_task_class +from yarr.runners.independent_env_runner import IndependentEnvRunner +from yarr.utils.stat_accumulator import SimpleAccumulator + +from helpers import utils +from helpers import observation_utils + +from yarr.utils.rollout_generator import RolloutGenerator +import torch.multiprocessing as mp + +from agents import agent_factory + + +def eval_seed( + train_cfg, eval_cfg, logdir, env_device, multi_task, seed, env_config +) -> None: + tasks = eval_cfg.rlbench.tasks + rg = RolloutGenerator() + + train_cfg.method.robot_name = eval_cfg.method.robot_name + + agent = agent_factory.create_agent(train_cfg) + stat_accum = SimpleAccumulator(eval_video_fps=30) + + cwd = os.getcwd() + weightsdir = os.path.join(logdir, "weights") + + env_runner = IndependentEnvRunner( + train_env=None, + agent=agent, + train_replay_buffer=None, + num_train_envs=0, + num_eval_envs=eval_cfg.framework.eval_envs, + rollout_episodes=99999, + eval_episodes=eval_cfg.framework.eval_episodes, + training_iterations=train_cfg.framework.training_iterations, + eval_from_eps_number=eval_cfg.framework.eval_from_eps_number, + episode_length=eval_cfg.rlbench.episode_length, + stat_accumulator=stat_accum, + weightsdir=weightsdir, + logdir=logdir, + env_device=env_device, + rollout_generator=rg, + num_eval_runs=len(tasks), + multi_task=multi_task, + ) + + env_runner._on_thread_start = peract_config.config_logging + + manager = mp.Manager() + save_load_lock = manager.Lock() + writer_lock = manager.Lock() + + # evaluate all checkpoints (0, 1000, ...) which don't have results, i.e. validation phase + if eval_cfg.framework.eval_type == "missing": + weight_folders = os.listdir(weightsdir) + weight_folders = sorted(map(int, weight_folders)) + + env_data_csv_file = os.path.join(logdir, "eval_data.csv") + if os.path.exists(env_data_csv_file): + env_dict = pd.read_csv(env_data_csv_file).to_dict() + evaluated_weights = sorted(map(int, list(env_dict["step"].values()))) + weight_folders = [w for w in weight_folders if w not in evaluated_weights] + + print("Missing weights: ", weight_folders) + + # pick the best checkpoint from validation and evaluate, i.e. test phase + elif eval_cfg.framework.eval_type == "best": + env_data_csv_file = os.path.join(logdir, "eval_data.csv") + if os.path.exists(env_data_csv_file): + env_dict = pd.read_csv(env_data_csv_file).to_dict() + existing_weights = list( + map(int, sorted(os.listdir(os.path.join(logdir, "weights")))) + ) + task_weights = {} + for task in tasks: + weights = list(env_dict["step"].values()) + + if len(tasks) > 1: + task_score = list(env_dict["eval_envs/return/%s" % task].values()) + else: + task_score = list(env_dict["eval_envs/return"].values()) + + avail_weights, avail_task_scores = [], [] + for step_idx, step in enumerate(weights): + if step in existing_weights: + avail_weights.append(step) + avail_task_scores.append(task_score[step_idx]) + + assert len(avail_weights) == len(avail_task_scores) + best_weight = avail_weights[ + np.argwhere(avail_task_scores == np.amax(avail_task_scores)) + .flatten() + .tolist()[-1] + ] + task_weights[task] = best_weight + + weight_folders = [task_weights] + print("Best weights:", weight_folders) + else: + raise Exception("No existing eval_data.csv file found in %s" % logdir) + + # evaluate only the last checkpoint + elif eval_cfg.framework.eval_type == "last": + weight_folders = os.listdir(weightsdir) + weight_folders = sorted(map(int, weight_folders)) + weight_folders = [weight_folders[-1]] + print("Last weight:", weight_folders) + + elif eval_cfg.framework.eval_type == "all": + weight_folders = os.listdir(weightsdir) + weight_folders = sorted(map(int, weight_folders)) + + # evaluate a specific checkpoint + elif type(eval_cfg.framework.eval_type) == int: + weight_folders = [int(eval_cfg.framework.eval_type)] + print("Weight:", weight_folders) + + else: + raise Exception("Unknown eval type") + + if len(weight_folders) == 0: + logging.info( + "No weights to evaluate. Results are already available in eval_data.csv" + ) + sys.exit(0) + + # evaluate several checkpoints in parallel + # NOTE: in multi-task settings, each task is evaluated serially, which makes everything slow! + split_n = utils.split_list(weight_folders, eval_cfg.framework.eval_envs) + for split in split_n: + processes = [] + for e_idx, weight in enumerate(split): + p = mp.Process( + target=env_runner.start, + args=( + weight, + save_load_lock, + writer_lock, + env_config, + e_idx % torch.cuda.device_count(), + eval_cfg.framework.eval_save_metrics, + eval_cfg.cinematic_recorder, + ), + ) + p.start() + processes.append(p) + for p in processes: + p.join() + + del env_runner + del agent + gc.collect() + torch.cuda.empty_cache() + + +@hydra.main(config_name="eval", config_path="conf") +def main(eval_cfg: DictConfig) -> None: + logging.info("\n" + OmegaConf.to_yaml(eval_cfg)) + + start_seed = eval_cfg.framework.start_seed + logdir = os.path.join( + eval_cfg.framework.logdir, + eval_cfg.rlbench.task_name, + eval_cfg.method.name, + "seed%d" % start_seed, + ) + + train_config_path = os.path.join(logdir, "config.yaml") + + if os.path.exists(train_config_path): + with open(train_config_path, "r") as f: + train_cfg = OmegaConf.load(f) + else: + raise Exception(f"Missing seed{start_seed}/config.yaml. Logdir is {logdir}") + + # sanity checks + assert train_cfg.method.name == eval_cfg.method.name + assert train_cfg.method.agent_type == eval_cfg.method.agent_type + for task in eval_cfg.rlbench.tasks: + assert task in train_cfg.rlbench.tasks + + env_device = utils.get_device(eval_cfg.framework.gpu) + logging.info("Using env device %s." % str(env_device)) + + gripper_mode = eval(eval_cfg.rlbench.gripper_mode)() + arm_action_mode = eval(eval_cfg.rlbench.arm_action_mode)() + action_mode = eval(eval_cfg.rlbench.action_mode)(arm_action_mode, gripper_mode) + + is_bimanual = eval_cfg.method.robot_name == "bimanual" + + if is_bimanual: + # TODO: automate instantiation with eval + task_path = rlbench_task.BIMANUAL_TASKS_PATH + else: + task_path = rlbench_task.TASKS_PATH + + task_files = [ + t.replace(".py", "") + for t in os.listdir(task_path) + if t != "__init__.py" and t.endswith(".py") + ] + eval_cfg.rlbench.cameras = ( + eval_cfg.rlbench.cameras + if isinstance(eval_cfg.rlbench.cameras, ListConfig) + else [eval_cfg.rlbench.cameras] + ) + obs_config = observation_utils.create_obs_config( + eval_cfg.rlbench.cameras, + eval_cfg.rlbench.camera_resolution, + eval_cfg.method.name, + eval_cfg.method.robot_name, + ) + + if eval_cfg.cinematic_recorder.enabled: + obs_config.record_gripper_closing = True + + multi_task = len(eval_cfg.rlbench.tasks) > 1 + + tasks = eval_cfg.rlbench.tasks + task_classes = [] + for task in tasks: + if task not in task_files: + raise ValueError("Task %s not recognised!." % task) + task_classes.append(task_file_to_task_class(task, is_bimanual)) + + # single-task or multi-task + if multi_task: + env_config = ( + task_classes, + obs_config, + action_mode, + eval_cfg.rlbench.demo_path, + eval_cfg.rlbench.episode_length, + eval_cfg.rlbench.headless, + eval_cfg.framework.eval_episodes, + train_cfg.rlbench.include_lang_goal_in_obs, + eval_cfg.rlbench.time_in_state, + eval_cfg.framework.record_every_n, + ) + else: + env_config = ( + task_classes[0], + obs_config, + action_mode, + eval_cfg.rlbench.demo_path, + eval_cfg.rlbench.episode_length, + eval_cfg.rlbench.headless, + train_cfg.rlbench.include_lang_goal_in_obs, + eval_cfg.rlbench.time_in_state, + eval_cfg.framework.record_every_n, + ) + + logging.info("Evaluating seed %d." % start_seed) + eval_seed( + train_cfg, + eval_cfg, + logdir, + env_device, + multi_task, + start_seed, + env_config, + ) + + +if __name__ == "__main__": + peract_config.on_init() + main() diff --git a/external/peract_bimanual/helpers/__init__.py b/external/peract_bimanual/helpers/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/external/peract_bimanual/helpers/clip/core/__init__.py b/external/peract_bimanual/helpers/clip/core/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/external/peract_bimanual/helpers/clip/core/attention.py b/external/peract_bimanual/helpers/clip/core/attention.py new file mode 100644 index 0000000000000000000000000000000000000000..c6cb6f5072cbb04274fa560d0dadad1a540b3177 --- /dev/null +++ b/external/peract_bimanual/helpers/clip/core/attention.py @@ -0,0 +1,84 @@ +"""Attention module.""" + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F + +import cliport.models as models +from cliport.utils import utils + + +class Attention(nn.Module): + """Attention (a.k.a Pick) module.""" + + def __init__(self, stream_fcn, in_shape, n_rotations, preprocess, cfg, device): + super().__init__() + self.stream_fcn = stream_fcn + self.n_rotations = n_rotations + self.preprocess = preprocess + self.cfg = cfg + self.device = device + self.batchnorm = self.cfg["train"]["batchnorm"] + + self.padding = np.zeros((3, 2), dtype=int) + max_dim = np.max(in_shape[:2]) + pad = (max_dim - np.array(in_shape[:2])) / 2 + self.padding[:2] = pad.reshape(2, 1) + + in_shape = np.array(in_shape) + in_shape += np.sum(self.padding, axis=1) + in_shape = tuple(in_shape) + self.in_shape = in_shape + + self.rotator = utils.ImageRotator(self.n_rotations) + + self._build_nets() + + def _build_nets(self): + stream_one_fcn, _ = self.stream_fcn + self.attn_stream = models.names[stream_one_fcn]( + self.in_shape, 1, self.cfg, self.device + ) + print(f"Attn FCN: {stream_one_fcn}") + + def attend(self, x): + return self.attn_stream(x) + + def forward(self, inp_img, softmax=True): + """Forward pass.""" + in_data = np.pad(inp_img, self.padding, mode="constant") + in_shape = (1,) + in_data.shape + in_data = in_data.reshape(in_shape) + in_tens = torch.from_numpy(in_data).to( + dtype=torch.float, device=self.device + ) # [B W H 6] + + # Rotation pivot. + pv = np.array(in_data.shape[1:3]) // 2 + + # Rotate input. + in_tens = in_tens.permute(0, 3, 1, 2) # [B 6 W H] + in_tens = in_tens.repeat(self.n_rotations, 1, 1, 1) + in_tens = self.rotator(in_tens, pivot=pv) + + # Forward pass. + logits = [] + for x in in_tens: + lgts = self.attend(x) + logits.append(lgts) + logits = torch.cat(logits, dim=0) + + # Rotate back output. + logits = self.rotator(logits, reverse=True, pivot=pv) + logits = torch.cat(logits, dim=0) + c0 = self.padding[:2, 0] + c1 = c0 + inp_img.shape[:2] + logits = logits[:, :, c0[0] : c1[0], c0[1] : c1[1]] + + logits = logits.permute(1, 2, 3, 0) # [B W H 1] + output = logits.reshape(1, np.prod(logits.shape)) + if softmax: + output = F.softmax(output, dim=-1) + output = output.reshape(logits.shape[1:]) + return output diff --git a/external/peract_bimanual/helpers/clip/core/attention_image_goal.py b/external/peract_bimanual/helpers/clip/core/attention_image_goal.py new file mode 100644 index 0000000000000000000000000000000000000000..2fad38afec66f0589845a249281596924f66c882 --- /dev/null +++ b/external/peract_bimanual/helpers/clip/core/attention_image_goal.py @@ -0,0 +1,59 @@ +"""Attention module.""" + +import numpy as np +import torch +import torch.nn.functional as F + + +from cliport.models.core.attention import Attention + + +class AttentionImageGoal(Attention): + """Attention (a.k.a Pick) with image-goals module.""" + + def __init__(self, stream_fcn, in_shape, n_rotations, preprocess, cfg, device): + super().__init__(stream_fcn, in_shape, n_rotations, preprocess, cfg, device) + + def forward(self, inp_img, goal_img, softmax=True): + """Forward pass.""" + # Input image. + in_data = np.pad(inp_img, self.padding, mode="constant") + in_shape = (1,) + in_data.shape + in_data = in_data.reshape(in_shape) + in_tens = torch.from_numpy(in_data).to(dtype=torch.float, device=self.device) + + goal_tensor = np.pad(goal_img, self.padding, mode="constant") + goal_shape = (1,) + goal_tensor.shape + goal_tensor = goal_tensor.reshape(goal_shape) + goal_tensor = torch.from_numpy(goal_tensor.copy()).to( + dtype=torch.float, device=self.device + ) + in_tens = in_tens * goal_tensor + + # Rotation pivot. + pv = np.array(in_data.shape[1:3]) // 2 + + # Rotate input. + in_tens = in_tens.permute(0, 3, 1, 2) + in_tens = in_tens.repeat(self.n_rotations, 1, 1, 1) + in_tens = self.rotator(in_tens, pivot=pv) + + # Forward pass. + logits = [] + for x in in_tens: + logits.append(self.attend(x)) + logits = torch.cat(logits, dim=0) + + # Rotate back output. + logits = self.rotator(logits, reverse=True, pivot=pv) + logits = torch.cat(logits, dim=0) + c0 = self.padding[:2, 0] + c1 = c0 + inp_img.shape[:2] + logits = logits[:, :, c0[0] : c1[0], c0[1] : c1[1]] + + logits = logits.permute(1, 2, 3, 0) # D H W C + output = logits.reshape(1, np.prod(logits.shape)) + if softmax: + output = F.softmax(output, dim=-1) + output = output.reshape(logits.shape[1:]) + return output diff --git a/external/peract_bimanual/helpers/clip/core/clip.py b/external/peract_bimanual/helpers/clip/core/clip.py new file mode 100644 index 0000000000000000000000000000000000000000..a61fdd2a9f0a8e2bab4d8f11a2120133b65c230b --- /dev/null +++ b/external/peract_bimanual/helpers/clip/core/clip.py @@ -0,0 +1,812 @@ +########################################### +#### Authors: OpenAI +#### Credit: https://github.com/openai/CLIP +#### https://github.com/openai/CLIP/blob/main/LICENSE + +from collections import OrderedDict +from typing import Tuple, Union + +import torch +import torch.nn.functional as F +from torch import nn + +import hashlib +import os +import urllib +import warnings +from typing import Union, List + +import torch +from PIL import Image +from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize +from tqdm import tqdm + +from .simple_tokenizer import SimpleTokenizer as _Tokenizer + +import ssl + +ssl._create_default_https_context = ssl._create_unverified_context + + +__all__ = ["available_models", "load", "tokenize"] +_tokenizer = _Tokenizer() + +_MODELS = { + "RN50": "https://openaipublic.azureedge.net/clip/models/afeb0e10f9e5a86da6080e35cf09123aca3b358a0c3e3b6c78a7b63bc04b6762/RN50.pt", + "RN101": "https://openaipublic.azureedge.net/clip/models/8fa8567bab74a42d41c5915025a8e4538c3bdbe8804a470a72f30b0d94fab599/RN101.pt", + "RN50x4": "https://openaipublic.azureedge.net/clip/models/7e526bd135e493cef0776de27d5f42653e6b4c8bf9e0f653bb11773263205fdd/RN50x4.pt", + "RN50x16": "https://openaipublic.azureedge.net/clip/models/52378b407f34354e150460fe41077663dd5b39c54cd0bfd2b27167a4a06ec9aa/RN50x16.pt", + "RN50x64": "https://openaipublic.azureedge.net/clip/models/be1cfb55d75a9666199fb2206c106743da0f6468c9d327f3e0d0a543a9919d9c/RN50x64.pt", + "ViT-B/32": "https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt", + "ViT-B/16": "https://openaipublic.azureedge.net/clip/models/5806e77cd80f8b59890b7e101eabd078d9fb84e6937f9e85e4ecb61988df416f/ViT-B-16.pt", + "ViT-L/14": "https://openaipublic.azureedge.net/clip/models/b8cca3fd41ae0c99ba7e8951adf17d267cdb84cd88be6f7c2e0eca1737a03836/ViT-L-14.pt", +} + + +class Bottleneck(nn.Module): + expansion = 4 + + def __init__(self, inplanes, planes, stride=1): + super().__init__() + + # all conv layers have stride 1. an avgpool is performed after the second convolution when stride > 1 + self.conv1 = nn.Conv2d(inplanes, planes, 1, bias=False) + self.bn1 = nn.BatchNorm2d(planes) + + self.conv2 = nn.Conv2d(planes, planes, 3, padding=1, bias=False) + self.bn2 = nn.BatchNorm2d(planes) + + self.avgpool = nn.AvgPool2d(stride) if stride > 1 else nn.Identity() + + self.conv3 = nn.Conv2d(planes, planes * self.expansion, 1, bias=False) + self.bn3 = nn.BatchNorm2d(planes * self.expansion) + + self.relu = nn.ReLU(inplace=True) + self.downsample = None + self.stride = stride + + if stride > 1 or inplanes != planes * Bottleneck.expansion: + # downsampling layer is prepended with an avgpool, and the subsequent convolution has stride 1 + self.downsample = nn.Sequential( + OrderedDict( + [ + ("-1", nn.AvgPool2d(stride)), + ( + "0", + nn.Conv2d( + inplanes, + planes * self.expansion, + 1, + stride=1, + bias=False, + ), + ), + ("1", nn.BatchNorm2d(planes * self.expansion)), + ] + ) + ) + + def forward(self, x: torch.Tensor): + identity = x + + out = self.relu(self.bn1(self.conv1(x))) + out = self.relu(self.bn2(self.conv2(out))) + out = self.avgpool(out) + out = self.bn3(self.conv3(out)) + + if self.downsample is not None: + identity = self.downsample(x) + + out += identity + out = self.relu(out) + return out + + +class AttentionPool2d(nn.Module): + def __init__( + self, spacial_dim: int, embed_dim: int, num_heads: int, output_dim: int = None + ): + super().__init__() + self.positional_embedding = nn.Parameter( + torch.randn(spacial_dim**2 + 1, embed_dim) / embed_dim**0.5 + ) + self.k_proj = nn.Linear(embed_dim, embed_dim) + self.q_proj = nn.Linear(embed_dim, embed_dim) + self.v_proj = nn.Linear(embed_dim, embed_dim) + self.c_proj = nn.Linear(embed_dim, output_dim or embed_dim) + self.num_heads = num_heads + + def forward(self, x): + x = x.reshape(x.shape[0], x.shape[1], x.shape[2] * x.shape[3]).permute( + 2, 0, 1 + ) # NCHW -> (HW)NC + x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0) # (HW+1)NC + x = x + self.positional_embedding[:, None, :].to(x.dtype) # (HW+1)NC + x, _ = F.multi_head_attention_forward( + query=x, + key=x, + value=x, + embed_dim_to_check=x.shape[-1], + num_heads=self.num_heads, + q_proj_weight=self.q_proj.weight, + k_proj_weight=self.k_proj.weight, + v_proj_weight=self.v_proj.weight, + in_proj_weight=None, + in_proj_bias=torch.cat( + [self.q_proj.bias, self.k_proj.bias, self.v_proj.bias] + ), + bias_k=None, + bias_v=None, + add_zero_attn=False, + dropout_p=0, + out_proj_weight=self.c_proj.weight, + out_proj_bias=self.c_proj.bias, + use_separate_proj_weight=True, + training=self.training, + need_weights=False, + ) + + return x[0] + + +class ModifiedResNet(nn.Module): + """ + A ResNet class that is similar to torchvision's but contains the following changes: + - There are now 3 "stem" convolutions as opposed to 1, with an average pool instead of a max pool. + - Performs anti-aliasing strided convolutions, where an avgpool is prepended to convolutions with stride > 1 + - The final pooling layer is a QKV attention instead of an average pool + """ + + def __init__(self, layers, output_dim, heads, input_resolution=224, width=64): + super().__init__() + self.output_dim = output_dim + self.input_resolution = input_resolution + + # the 3-layer stem + self.conv1 = nn.Conv2d( + 3, width // 2, kernel_size=3, stride=2, padding=1, bias=False + ) + self.bn1 = nn.BatchNorm2d(width // 2) + self.conv2 = nn.Conv2d( + width // 2, width // 2, kernel_size=3, padding=1, bias=False + ) + self.bn2 = nn.BatchNorm2d(width // 2) + self.conv3 = nn.Conv2d(width // 2, width, kernel_size=3, padding=1, bias=False) + self.bn3 = nn.BatchNorm2d(width) + self.avgpool = nn.AvgPool2d(2) + self.relu = nn.ReLU(inplace=True) + + # residual layers + self._inplanes = width # this is a *mutable* variable used during construction + self.layer1 = self._make_layer(width, layers[0]) + self.layer2 = self._make_layer(width * 2, layers[1], stride=2) + self.layer3 = self._make_layer(width * 4, layers[2], stride=2) + self.layer4 = self._make_layer(width * 8, layers[3], stride=2) + + embed_dim = width * 32 # the ResNet feature dimension + self.attnpool = AttentionPool2d( + input_resolution // 32, embed_dim, heads, output_dim + ) + + def _make_layer(self, planes, blocks, stride=1): + layers = [Bottleneck(self._inplanes, planes, stride)] + + self._inplanes = planes * Bottleneck.expansion + for _ in range(1, blocks): + layers.append(Bottleneck(self._inplanes, planes)) + + return nn.Sequential(*layers) + + def forward(self, x): + x = self.prepool(x) + x = self.attnpool(x) + return x + + def prepool(self, x): + def stem(x): + for conv, bn in [ + (self.conv1, self.bn1), + (self.conv2, self.bn2), + (self.conv3, self.bn3), + ]: + x = self.relu(bn(conv(x))) + x = self.avgpool(x) + return x + + x = x.type(self.conv1.weight.dtype) + x = stem(x) + x = self.layer1(x) + x = self.layer2(x) + x = self.layer3(x) + x = self.layer4(x) + return x + + def prepool_im(self, x): + """Run until prepool and save intermediate features""" + im = [] + + def stem(x): + for conv, bn in [ + (self.conv1, self.bn1), + (self.conv2, self.bn2), + (self.conv3, self.bn3), + ]: + x = self.relu(bn(conv(x))) + im.append(x) + x = self.avgpool(x) + im.append(x) + return x + + x = x.type(self.conv1.weight.dtype) + x = stem(x) + + for layer in [self.layer1, self.layer2, self.layer3, self.layer4]: + x = layer(x) + im.append(x) + + return x, im + + +class LayerNorm(nn.LayerNorm): + """Subclass torch's LayerNorm to handle fp16.""" + + def forward(self, x: torch.Tensor): + orig_type = x.dtype + ret = super().forward(x.type(torch.float32)) + return ret.type(orig_type) + + +class QuickGELU(nn.Module): + def forward(self, x: torch.Tensor): + return x * torch.sigmoid(1.702 * x) + + +class ResidualAttentionBlock(nn.Module): + def __init__(self, d_model: int, n_head: int, attn_mask: torch.Tensor = None): + super().__init__() + + self.attn = nn.MultiheadAttention(d_model, n_head) + self.ln_1 = LayerNorm(d_model) + self.mlp = nn.Sequential( + OrderedDict( + [ + ("c_fc", nn.Linear(d_model, d_model * 4)), + ("gelu", QuickGELU()), + ("c_proj", nn.Linear(d_model * 4, d_model)), + ] + ) + ) + self.ln_2 = LayerNorm(d_model) + self.attn_mask = attn_mask + + def attention(self, x: torch.Tensor): + self.attn_mask = ( + self.attn_mask.to(dtype=x.dtype, device=x.device) + if self.attn_mask is not None + else None + ) + return self.attn(x, x, x, need_weights=False, attn_mask=self.attn_mask)[0] + + def forward(self, x: torch.Tensor): + x = x + self.attention(self.ln_1(x)) + x = x + self.mlp(self.ln_2(x)) + return x + + +class Transformer(nn.Module): + def __init__( + self, width: int, layers: int, heads: int, attn_mask: torch.Tensor = None + ): + super().__init__() + self.width = width + self.layers = layers + self.resblocks = nn.Sequential( + *[ResidualAttentionBlock(width, heads, attn_mask) for _ in range(layers)] + ) + + def forward(self, x: torch.Tensor): + return self.resblocks(x) + + +class VisualTransformer(nn.Module): + def __init__( + self, + input_resolution: int, + patch_size: int, + width: int, + layers: int, + heads: int, + output_dim: int, + ): + super().__init__() + self.input_resolution = input_resolution + self.output_dim = output_dim + self.conv1 = nn.Conv2d( + in_channels=3, + out_channels=width, + kernel_size=patch_size, + stride=patch_size, + bias=False, + ) + + scale = width**-0.5 + self.class_embedding = nn.Parameter(scale * torch.randn(width)) + self.positional_embedding = nn.Parameter( + scale * torch.randn((input_resolution // patch_size) ** 2 + 1, width) + ) + self.ln_pre = LayerNorm(width) + + self.transformer = Transformer(width, layers, heads) + + self.ln_post = LayerNorm(width) + self.proj = nn.Parameter(scale * torch.randn(width, output_dim)) + + def forward(self, x: torch.Tensor): + x = self.conv1(x) # shape = [*, width, grid, grid] + x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2] + x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width] + x = torch.cat( + [ + self.class_embedding.to(x.dtype) + + torch.zeros( + x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device + ), + x, + ], + dim=1, + ) # shape = [*, grid ** 2 + 1, width] + x = x + self.positional_embedding.to(x.dtype) + x = self.ln_pre(x) + + x = x.permute(1, 0, 2) # NLD -> LND + x = self.transformer(x) + x = x.permute(1, 0, 2) # LND -> NLD + + x = self.ln_post(x[:, 0, :]) + + if self.proj is not None: + x = x @ self.proj + + return x + + def forward_spatial(self, x: torch.Tensor): + x = self.conv1(x) # shape = [*, width, grid, grid] + x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2] + x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width] + x = torch.cat( + [ + self.class_embedding.to(x.dtype) + + torch.zeros( + x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device + ), + x, + ], + dim=1, + ) # shape = [*, grid ** 2 + 1, width] + x = x + self.positional_embedding.to(x.dtype) + x = self.ln_pre(x) + + x = x.permute(1, 0, 2) # NLD -> LND + x = self.transformer(x) + x = x.permute(1, 0, 2) # LND -> NLD + + x = self.ln_post(x)[:, 1:] + return x + + +class CLIP(nn.Module): + def __init__( + self, + embed_dim: int, + # vision + image_resolution: int, + vision_layers: Union[Tuple[int, int, int, int], int], + vision_width: int, + vision_patch_size: int, + # text + context_length: int, + vocab_size: int, + transformer_width: int, + transformer_heads: int, + transformer_layers: int, + ): + super().__init__() + + self.context_length = context_length + + if isinstance(vision_layers, (tuple, list)): + vision_heads = vision_width * 32 // 64 + self.visual = ModifiedResNet( + layers=vision_layers, + output_dim=embed_dim, + heads=vision_heads, + input_resolution=image_resolution, + width=vision_width, + ) + else: + vision_heads = vision_width // 64 + self.visual = VisualTransformer( + input_resolution=image_resolution, + patch_size=vision_patch_size, + width=vision_width, + layers=vision_layers, + heads=vision_heads, + output_dim=embed_dim, + ) + + self.transformer = Transformer( + width=transformer_width, + layers=transformer_layers, + heads=transformer_heads, + attn_mask=self.build_attention_mask(), + ) + + self.vocab_size = vocab_size + self.token_embedding = nn.Embedding(vocab_size, transformer_width) + self.positional_embedding = nn.Parameter( + torch.empty(self.context_length, transformer_width) + ) + self.ln_final = LayerNorm(transformer_width) + + self.text_projection = nn.Parameter(torch.empty(transformer_width, embed_dim)) + self.logit_scale = nn.Parameter(torch.ones([])) + + self.initialize_parameters() + + def initialize_parameters(self): + nn.init.normal_(self.token_embedding.weight, std=0.02) + nn.init.normal_(self.positional_embedding, std=0.01) + + if isinstance(self.visual, ModifiedResNet): + if self.visual.attnpool is not None: + std = self.visual.attnpool.c_proj.in_features**-0.5 + nn.init.normal_(self.visual.attnpool.q_proj.weight, std=std) + nn.init.normal_(self.visual.attnpool.k_proj.weight, std=std) + nn.init.normal_(self.visual.attnpool.v_proj.weight, std=std) + nn.init.normal_(self.visual.attnpool.c_proj.weight, std=std) + + for resnet_block in [ + self.visual.layer1, + self.visual.layer2, + self.visual.layer3, + self.visual.layer4, + ]: + for name, param in resnet_block.named_parameters(): + if name.endswith("bn3.weight"): + nn.init.zeros_(param) + + proj_std = (self.transformer.width**-0.5) * ( + (2 * self.transformer.layers) ** -0.5 + ) + attn_std = self.transformer.width**-0.5 + fc_std = (2 * self.transformer.width) ** -0.5 + for block in self.transformer.resblocks: + nn.init.normal_(block.attn.in_proj_weight, std=attn_std) + nn.init.normal_(block.attn.out_proj.weight, std=proj_std) + nn.init.normal_(block.mlp.c_fc.weight, std=fc_std) + nn.init.normal_(block.mlp.c_proj.weight, std=proj_std) + + if self.text_projection is not None: + nn.init.normal_(self.text_projection, std=self.transformer.width**-0.5) + + def build_attention_mask(self): + # lazily create causal attention mask, with full attention between the vision tokens + # pytorch uses additive attention mask; fill with -inf + mask = torch.empty(self.context_length, self.context_length) + mask.fill_(float("-inf")) + mask.triu_(1) # zero out the lower diagonal + return mask + + @property + def dtype(self): + return self.visual.conv1.weight.dtype + + def encode_image(self, image): + return self.visual(image.type(self.dtype)) + + def encode_text(self, text): + x = self.token_embedding(text).type(self.dtype) # [batch_size, n_ctx, d_model] + + x = x + self.positional_embedding.type(self.dtype) + x = x.permute(1, 0, 2) # NLD -> LND + x = self.transformer(x) + x = x.permute(1, 0, 2) # LND -> NLD + x = self.ln_final(x).type(self.dtype) + + # x.shape = [batch_size, n_ctx, transformer.width] + # take features from the eot embedding (eot_token is the highest number in each sequence) + x = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ self.text_projection + + return x + + def encode_text_with_embeddings(self, text): + x = self.token_embedding(text).type(self.dtype) # [batch_size, n_ctx, d_model] + + x = x + self.positional_embedding.type(self.dtype) + x = x.permute(1, 0, 2) # NLD -> LND + x = self.transformer(x) + x = x.permute(1, 0, 2) # LND -> NLD + x = self.ln_final(x).type(self.dtype) + + emb = x.clone() + # x.shape = [batch_size, n_ctx, transformer.width] + # take features from the eot embedding (eot_token is the highest number in each sequence) + x = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ self.text_projection + + return x, emb + + def forward(self, image, text): + image_features = self.encode_image(image) + text_features = self.encode_text(text) + + # normalized features + image_features = image_features / image_features.norm(dim=-1, keepdim=True) + text_features = text_features / text_features.norm(dim=-1, keepdim=True) + + # cosine similarity as logits + logit_scale = self.logit_scale.exp() + logits_per_image = logit_scale * image_features @ text_features.t() + logits_per_text = logit_scale * text_features @ image_features.t() + + # shape = [global_batch_size, global_batch_size] + return logits_per_image, logits_per_text + + +def convert_weights(model: nn.Module): + """Convert applicable model parameters to fp16""" + + def _convert_weights_to_fp16(l): + if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Linear)): + l.weight.data = l.weight.data.half() + if l.bias is not None: + l.bias.data = l.bias.data.half() + + if isinstance(l, nn.MultiheadAttention): + for attr in [ + *[f"{s}_proj_weight" for s in ["in", "q", "k", "v"]], + "in_proj_bias", + "bias_k", + "bias_v", + ]: + tensor = getattr(l, attr) + if tensor is not None: + tensor.data = tensor.data.half() + + for name in ["text_projection", "proj"]: + if hasattr(l, name): + attr = getattr(l, name) + if attr is not None: + attr.data = attr.data.half() + + model.apply(_convert_weights_to_fp16) + + +def build_model(state_dict: dict): + vit = "visual.proj" in state_dict + + if vit: + vision_width = state_dict["visual.conv1.weight"].shape[0] + vision_layers = len( + [ + k + for k in state_dict.keys() + if k.startswith("visual.") and k.endswith(".attn.in_proj_weight") + ] + ) + vision_patch_size = state_dict["visual.conv1.weight"].shape[-1] + grid_size = round( + (state_dict["visual.positional_embedding"].shape[0] - 1) ** 0.5 + ) + image_resolution = vision_patch_size * grid_size + else: + counts: list = [ + len( + set( + k.split(".")[2] + for k in state_dict + if k.startswith(f"visual.layer{b}") + ) + ) + for b in [1, 2, 3, 4] + ] + vision_layers = tuple(counts) + vision_width = state_dict["visual.layer1.0.conv1.weight"].shape[0] + output_width = round( + (state_dict["visual.attnpool.positional_embedding"].shape[0] - 1) ** 0.5 + ) + vision_patch_size = None + assert ( + output_width**2 + 1 + == state_dict["visual.attnpool.positional_embedding"].shape[0] + ) + image_resolution = output_width * 32 + + embed_dim = state_dict["text_projection"].shape[1] + context_length = state_dict["positional_embedding"].shape[0] + vocab_size = state_dict["token_embedding.weight"].shape[0] + transformer_width = state_dict["ln_final.weight"].shape[0] + transformer_heads = transformer_width // 64 + transformer_layers = len( + set( + k.split(".")[2] + for k in state_dict + if k.startswith(f"transformer.resblocks") + ) + ) + + model = CLIP( + embed_dim, + image_resolution, + vision_layers, + vision_width, + vision_patch_size, + context_length, + vocab_size, + transformer_width, + transformer_heads, + transformer_layers, + ) + + for key in ["input_resolution", "context_length", "vocab_size"]: + if key in state_dict: + del state_dict[key] + + convert_weights(model) + model.load_state_dict(state_dict) + return model.eval() + + +def _download(url: str, root: str = os.path.expanduser("~/.cache/clip")): + os.makedirs(root, exist_ok=True) + filename = os.path.basename(url) + + expected_sha256 = url.split("/")[-2] + download_target = os.path.join(root, filename) + + if os.path.exists(download_target) and not os.path.isfile(download_target): + raise RuntimeError(f"{download_target} exists and is not a regular file") + + if os.path.isfile(download_target): + if ( + hashlib.sha256(open(download_target, "rb").read()).hexdigest() + == expected_sha256 + ): + return download_target + else: + warnings.warn( + f"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file" + ) + + with urllib.request.urlopen(url) as source, open(download_target, "wb") as output: + with tqdm(total=int(source.info().get("Content-Length")), ncols=80) as loop: + while True: + buffer = source.read(8192) + if not buffer: + break + + output.write(buffer) + loop.update(len(buffer)) + + if ( + hashlib.sha256(open(download_target, "rb").read()).hexdigest() + != expected_sha256 + ): + raise RuntimeError( + f"Model has been downloaded but the SHA256 checksum does not not match" + ) + + return download_target + + +def available_models(): + return list(_MODELS.keys()) + + +def load_clip( + name: str, + device: Union[str, torch.device] = "cuda" if torch.cuda.is_available() else "cpu", + jit=True, +): + if name not in _MODELS: + raise RuntimeError( + f"Model {name} not found; available models = {available_models()}" + ) + + model_path = _download(_MODELS[name]) + model = torch.jit.load(model_path, map_location=device if jit else "cpu").eval() + n_px = model.input_resolution.item() + + transform = Compose( + [ + Resize(n_px, interpolation=Image.BICUBIC), + CenterCrop(n_px), + # lambda image: image.convert("RGB"), + # ToTensor(), + Normalize( + (0.48145466, 0.4578275, 0.40821073), + (0.26862954, 0.26130258, 0.27577711), + ), + ] + ) + + if not jit: + model = build_model(model.state_dict()).to(device) + if str(device) == "cpu": + model.float() + return model, transform + + # patch the device names + device_holder = torch.jit.trace( + lambda: torch.ones([]).to(torch.device(device)), example_inputs=[] + ) + device_node = [ + n + for n in device_holder.graph.findAllNodes("prim::Constant") + if "Device" in repr(n) + ][-1] + + def patch_device(module): + graphs = [module.graph] if hasattr(module, "graph") else [] + if hasattr(module, "forward1"): + graphs.append(module.forward1.graph) + + for graph in graphs: + for node in graph.findAllNodes("prim::Constant"): + if "value" in node.attributeNames() and str(node["value"]).startswith( + "cuda" + ): + node.copyAttributes(device_node) + + model.apply(patch_device) + patch_device(model.encode_image) + patch_device(model.encode_text) + + # patch dtype to float32 on CPU + if str(device) == "cpu": + float_holder = torch.jit.trace( + lambda: torch.ones([]).float(), example_inputs=[] + ) + float_input = list(float_holder.graph.findNode("aten::to").inputs())[1] + float_node = float_input.node() + + def patch_float(module): + graphs = [module.graph] if hasattr(module, "graph") else [] + if hasattr(module, "forward1"): + graphs.append(module.forward1.graph) + + for graph in graphs: + for node in graph.findAllNodes("aten::to"): + inputs = list(node.inputs()) + for i in [ + 1, + 2, + ]: # dtype can be the second or third argument to aten::to() + if inputs[i].node()["value"] == 5: + inputs[i].node().copyAttributes(float_node) + + model.apply(patch_float) + patch_float(model.encode_image) + patch_float(model.encode_text) + + model.float() + + return model, transform + + +def tokenize(texts: Union[str, List[str]], context_length: int = 77): + if isinstance(texts, str): + texts = [texts] + + sot_token = _tokenizer.encoder["<|startoftext|>"] + eot_token = _tokenizer.encoder["<|endoftext|>"] + all_tokens = [[sot_token] + _tokenizer.encode(text) + [eot_token] for text in texts] + result = torch.zeros(len(all_tokens), context_length, dtype=torch.long) + + for i, tokens in enumerate(all_tokens): + if len(tokens) > context_length: + raise RuntimeError( + f"Input {texts[i]} is too long for context length {context_length}" + ) + result[i, : len(tokens)] = torch.tensor(tokens) + + return result diff --git a/external/peract_bimanual/helpers/clip/core/fusion.py b/external/peract_bimanual/helpers/clip/core/fusion.py new file mode 100644 index 0000000000000000000000000000000000000000..7884a956cfe9782024c560ca1859f9dfd309cf61 --- /dev/null +++ b/external/peract_bimanual/helpers/clip/core/fusion.py @@ -0,0 +1,370 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + +import numpy as np + + +class DotAttn(nn.Module): + """Dot-Attention""" + + def forward(self, inp, h): + score = self.softmax(inp, h) + return score.expand_as(inp).mul(inp).sum(1), score + + def softmax(self, inp, h): + raw_score = inp.bmm(h.unsqueeze(2)) + score = F.softmax(raw_score, dim=1) + return score + + +class ScaledDotAttn(nn.Module): + """Scaled Dot-Attention""" + + def forward(self, inp, h): + score = self.softmax(inp, h) + return score.expand_as(inp).mul(inp).sum(1), score + + def softmax(self, inp, h): + raw_score = inp.bmm(h.unsqueeze(2)) / np.sqrt(h.shape[-1]) + score = F.softmax(raw_score, dim=1) + return score + + +class Fusion(nn.Module): + """Base Fusion Class""" + + def __init__(self, input_dim=3): + super().__init__() + self.input_dim = input_dim + + def tile_x2(self, x1, x2, x2_proj=None): + if x2_proj: + x2 = x2_proj(x2) + + x2 = x2.unsqueeze(-1).unsqueeze(-1) + x2 = x2.repeat(1, 1, x1.shape[-2], x1.shape[-1]) + return x2 + + def forward(self, x1, x2, x2_mask=None, x2_proj=None): + raise NotImplementedError() + + +class FusionAdd(Fusion): + """x1 + x2""" + + def __init__(self, input_dim=3): + super(FusionAdd, self).__init__(input_dim=input_dim) + + def forward(self, x1, x2, x2_mask=None, x2_proj=None): + if x1.shape != x2.shape and len(x1.shape) != len(x2.shape): + x2 = self.tile_x2(x1, x2, x2_proj) + return x1 + x2 + + +class FusionMult(Fusion): + """x1 * x2""" + + def __init__(self, input_dim=3): + super(FusionMult, self).__init__(input_dim=input_dim) + + def forward(self, x1, x2, x2_mask=None, x2_proj=None): + if x1.shape != x2.shape and len(x1.shape) != len(x2.shape): + x2 = self.tile_x2(x1, x2, x2_proj) + return x1 * x2 + + +class FusionMax(Fusion): + """max(x1, x2)""" + + def __init__(self, input_dim=3): + super(FusionMax, self).__init__(input_dim=input_dim) + + def forward(self, x1, x2, x2_mask=None, x2_proj=None): + if x1.shape != x2.shape and len(x1.shape) != len(x2.shape): + x2 = self.tile_x2(x1, x2, x2_proj) + return torch.max(x1, x2) + + +class FusionConcat(Fusion): + """[x1; x2]""" + + def __init__(self, input_dim=3): + super(FusionConcat, self).__init__(input_dim=input_dim) + + def forward(self, x1, x2, x2_mask=None, x2_proj=None): + if x1.shape != x2.shape and len(x1.shape) != len(x2.shape): + x2 = self.tile_x2(x1, x2, x2_proj) + return torch.cat([x1, x2], dim=1) + + +class FusionConv(Fusion): + """1x1 convs after [x1; x2]""" + + def __init__(self, input_dim=3): + super(FusionConv, self).__init__(input_dim=input_dim) + self.conv = nn.Sequential( + nn.ReLU(True), + nn.Conv2d(input_dim * 2, input_dim, kernel_size=1, bias=False), + ) + + def forward(self, x1, x2, x2_mask=None, x2_proj=None): + if x1.shape != x2.shape and len(x1.shape) != len(x2.shape): + x2 = self.tile_x2(x1, x2, x2_proj) + x = torch.cat([x1, x2], dim=1) # [B, 2C, H, W] + x = self.conv(x) # [B, C, H, W] + return x + + +class FusionConvLat(Fusion): + """1x1 convs after [x1; x2] for lateral fusion""" + + def __init__(self, input_dim=3, output_dim=3): + super(FusionConvLat, self).__init__(input_dim=input_dim) + self.conv = nn.Sequential( + nn.ReLU(True), nn.Conv2d(input_dim, output_dim, kernel_size=1, bias=False) + ) + + def forward(self, x1, x2, x2_mask=None, x2_proj=None): + if x1.shape != x2.shape and len(x1.shape) != len(x2.shape): + x2 = self.tile_x2(x1, x2, x2_proj) + x = torch.cat([x1, x2], dim=1) # [B, input_dim, H, W] + x = self.conv(x) # [B, output_dim, H, W] + return x + + +## ------------- NOTE ---------------- +## The following are various fusion types I experimented with. +## Most of them didn't work well ¯\_(ツ)_/¯ +## But it doesn't mean there isn't a better way of +## doing lateral and multi-modal (language+vision) fusion. + + +class FusionFiLM(Fusion): + """FiLM (Perez et. al, https://arxiv.org/abs/1709.07871). + Note: This is not used inside a Residual block before ReLU. + I had a version this in UpBlock with FiLM, which didn't seem to work at all. + """ + + def __init__(self, input_dim=3, output_dim=3): + super(FusionFiLM, self).__init__(input_dim=input_dim) + + def forward(self, x1, x2, gamma, beta): + g = self.tile_x2(x1, x2, gamma) + b = self.tile_x2(x1, x2, beta) + return x1 * g + b + + +class FusionDeepConv(Fusion): + """Multi-Layer 1x1 convs after [x1; x2]""" + + def __init__(self, input_dim=3): + super(FusionDeepConv, self).__init__(input_dim=input_dim) + self.conv = nn.Sequential( + nn.ReLU(True), + nn.Conv2d(input_dim * 2, input_dim, kernel_size=1, bias=False), + nn.ReLU(True), + nn.Conv2d(input_dim, input_dim, kernel_size=1, bias=False), + nn.ReLU(True), + nn.Conv2d(input_dim, input_dim, kernel_size=1, bias=False), + ) + + def forward(self, x1, x2, x2_mask=None, x2_proj=None): + if x1.shape != x2.shape and len(x1.shape) != len(x2.shape): + x2 = self.tile_x2(x1, x2, x2_proj) + x = torch.cat([x1, x2], dim=1) # [B, 2C, H, W] + x = self.conv(x) # [B, C, H, W] + return x + + +class FusionMultWord(nn.Module): + """Product with weighted-sum of words""" + + def __init__(self, input_dim=3): + super().__init__() + self.input_dim = input_dim + + def forward(self, x1, x2, x2_mask=None, x2_proj=None): + B, D, H, W = x1.shape + x2_len = int(x2_mask.count_nonzero()) + + weighted_x1 = torch.zeros_like(x1) + for t in range(x2_len): + x2_t = x2_proj(x2[:, t]) if x2_proj else x2[:, t] + x2_t = x2_t.unsqueeze(-1).unsqueeze(-1).repeat(B, 1, H, W) + weighted_x1 += x1 * x2_t + weighted_x1 /= x2_len + return weighted_x1 + + +class FusionWordAttention(nn.Module): + """Word Attention""" + + def __init__(self, input_dim=3): + super().__init__() + self.input_dim = input_dim + self.dot_attn = DotAttn() + + def forward(self, x1, x2, x2_mask=None, x2_proj=None): + B, D, H, W = x1.shape + x1_flat = x1.reshape(B, D, H * W) + x2_len = int(x2_mask.count_nonzero()) + + # TODO: batch this unrolling? + weight_sum_x1_flat = torch.zeros_like(x1_flat) + for t in range(x2_len): + x2_t = x2_proj(x2[:, t]) if x2_proj else x2[:, t] + x2_t = x2_t.repeat(B, 1) + + _, attn_x1 = self.dot_attn(x1_flat.transpose(1, 2), x2_t) + weight_sum_x1_flat += x1_flat * attn_x1.transpose(1, 2) + + weight_sum_x1_flat /= x2_len + x2 = weight_sum_x1_flat.reshape(B, D, H, W) + return x2 + + +class FusionSentenceAttention(nn.Module): + """Sentence Attention""" + + def __init__(self, input_dim=3): + super().__init__() + self.input_dim = input_dim + self.dot_attn = ScaledDotAttn() + + def forward(self, x1, x2, x2_mask=None, x2_proj=None): + B, D, H, W = x1.shape + x1_flat = x1.reshape(B, D, H * W) + + x2_t = x2_proj(x2) if x2_proj else x2 + x2_t = x2_t.repeat(B, 1) + + _, attn_x1 = self.dot_attn(x1_flat.transpose(1, 2), x2_t) + weight_sum_x1_flat = x1_flat * attn_x1.transpose(1, 2) + + x2 = weight_sum_x1_flat.reshape(B, D, H, W) + return x2 + + +class CrossModalAttention2d(nn.Module): + """Cross-Modal Attention. Adapted from: https://github.com/openai/CLIP/blob/main/clip/model.py#L56""" + + def __init__( + self, + spacial_dim=7, + embed_dim=1024, + num_heads=32, + output_dim=1024, + lang_dim=512, + lang_max_tokens=77, + ): + super().__init__() + self.embed_dim = embed_dim + self.lang_dim = lang_dim + self.lang_max_tokens = lang_max_tokens + self.num_heads = num_heads + self.lang_proj = nn.Linear(self.lang_dim, embed_dim) + self.vision_positional_embedding = nn.Parameter( + torch.randn(spacial_dim**2, embed_dim) / embed_dim**0.5 + ) + self.lang_positional_embedding = nn.Parameter( + torch.randn(lang_max_tokens, embed_dim) / embed_dim**0.5 + ) + self.k_proj = nn.Linear(embed_dim, embed_dim) + self.q_proj = nn.Linear(embed_dim, embed_dim) + self.v_proj = nn.Linear(embed_dim, embed_dim) + self.c_proj = nn.Linear(embed_dim, output_dim or embed_dim) + + def forward(self, x, l, l_mask): + # reshape vision features + x_shape = x.shape + x = x.reshape(x.shape[0], x.shape[1], x.shape[2] * x.shape[3]).permute( + 2, 0, 1 + ) # NCHW -> (HW)NC + x = x + self.vision_positional_embedding[: x.shape[0], None, :].to( + x.dtype + ) # (HW)NC + + # project language + l = l.permute(1, 0, 2) + l_shape = l.shape + l = l.reshape(-1, self.lang_dim) + l = self.lang_proj(l) + l = l.reshape(l_shape[0], l_shape[1], self.embed_dim) + l = l + self.lang_positional_embedding[:, None, :].to(l.dtype) + + # hard language mask + l_len = int(l_mask.count_nonzero()) + l = l[:l_len] + l = l.repeat(1, x.shape[1], 1) + + x, _ = F.multi_head_attention_forward( + query=x, + key=l, + value=l, + embed_dim_to_check=x.shape[-1], + num_heads=self.num_heads, + q_proj_weight=self.q_proj.weight, + k_proj_weight=self.k_proj.weight, + v_proj_weight=self.v_proj.weight, + in_proj_weight=None, + in_proj_bias=torch.cat( + [self.q_proj.bias, self.k_proj.bias, self.v_proj.bias] + ), + bias_k=None, + bias_v=None, + add_zero_attn=False, + dropout_p=0, + out_proj_weight=self.c_proj.weight, + out_proj_bias=self.c_proj.bias, + use_separate_proj_weight=True, + training=self.training, + need_weights=False, + ) + + x = x.permute(1, 2, 0) + x = x.reshape(x_shape) + return x + + +class FusionMultiHeadedWordAttention(nn.Module): + """Multi-Headed Word Attention that uses Cross Modal Attention at different scales""" + + def __init__(self, input_dim=3): + super().__init__() + self.input_dim = input_dim + self.attn1 = CrossModalAttention2d( + spacial_dim=7, embed_dim=1024, output_dim=1024 + ) + self.attn2 = CrossModalAttention2d( + spacial_dim=14, embed_dim=512, output_dim=512 + ) + self.attn3 = CrossModalAttention2d( + spacial_dim=28, embed_dim=256, output_dim=256 + ) + + self.multi_headed_attns = { + 1024: self.attn1, + 512: self.attn2, + 256: self.attn3, + } + + def forward(self, x1, x2, x2_mask=None, x2_proj=None): + emb_dim = x1.shape[1] + x = self.multi_headed_attns[emb_dim](x1, x2, x2_mask) + return x + + +names = { + "add": FusionAdd, + "mult": FusionMult, + "mult_word": FusionMultWord, + "film": FusionFiLM, + "max": FusionMax, + "concat": FusionConcat, + "conv": FusionConv, + "deep_conv": FusionDeepConv, + "word_attn": FusionWordAttention, + "sent_attn": FusionSentenceAttention, + "multi_headed_word_attn": FusionMultiHeadedWordAttention, +} diff --git a/external/peract_bimanual/helpers/clip/core/resnet.py b/external/peract_bimanual/helpers/clip/core/resnet.py new file mode 100644 index 0000000000000000000000000000000000000000..a8d3954f19eb72b3da0638001963c5409a22a9b1 --- /dev/null +++ b/external/peract_bimanual/helpers/clip/core/resnet.py @@ -0,0 +1,169 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + + +class IdentityBlock(nn.Module): + def __init__( + self, in_planes, filters, kernel_size, stride=1, final_relu=True, batchnorm=True + ): + super(IdentityBlock, self).__init__() + self.final_relu = final_relu + self.batchnorm = batchnorm + + filters1, filters2, filters3 = filters + self.conv1 = nn.Conv2d(in_planes, filters1, kernel_size=1, bias=False) + self.bn1 = nn.BatchNorm2d(filters1) if self.batchnorm else nn.Identity() + self.conv2 = nn.Conv2d( + filters1, + filters2, + kernel_size=kernel_size, + dilation=1, + stride=stride, + padding=1, + bias=False, + ) + self.bn2 = nn.BatchNorm2d(filters2) if self.batchnorm else nn.Identity() + self.conv3 = nn.Conv2d(filters2, filters3, kernel_size=1, bias=False) + self.bn3 = nn.BatchNorm2d(filters3) if self.batchnorm else nn.Identity() + + def forward(self, x): + out = F.relu(self.bn1(self.conv1(x))) + out = F.relu(self.bn2(self.conv2(out))) + out = self.bn3(self.conv3(out)) + out += x + if self.final_relu: + out = F.relu(out) + return out + + +class ConvBlock(nn.Module): + def __init__( + self, in_planes, filters, kernel_size, stride=1, final_relu=True, batchnorm=True + ): + super(ConvBlock, self).__init__() + self.final_relu = final_relu + self.batchnorm = batchnorm + + filters1, filters2, filters3 = filters + self.conv1 = nn.Conv2d(in_planes, filters1, kernel_size=1, bias=False) + self.bn1 = nn.BatchNorm2d(filters1) if self.batchnorm else nn.Identity() + self.conv2 = nn.Conv2d( + filters1, + filters2, + kernel_size=kernel_size, + dilation=1, + stride=stride, + padding=1, + bias=False, + ) + self.bn2 = nn.BatchNorm2d(filters2) if self.batchnorm else nn.Identity() + self.conv3 = nn.Conv2d(filters2, filters3, kernel_size=1, bias=False) + self.bn3 = nn.BatchNorm2d(filters3) if self.batchnorm else nn.Identity() + + self.shortcut = nn.Sequential( + nn.Conv2d(in_planes, filters3, kernel_size=1, stride=stride, bias=False), + nn.BatchNorm2d(filters3) if self.batchnorm else nn.Identity(), + ) + + def forward(self, x): + out = F.relu(self.bn1(self.conv1(x))) + out = F.relu(self.bn2(self.conv2(out))) + out = self.bn3(self.conv3(out)) + out += self.shortcut(x) + if self.final_relu: + out = F.relu(out) + return out + + +class ResNet43_8s(nn.Module): + def __init__(self, input_shape, output_dim, cfg, device, preprocess): + super(ResNet43_8s, self).__init__() + self.input_shape = input_shape + self.input_dim = input_shape[-1] + self.output_dim = output_dim + self.cfg = cfg + self.device = device + self.batchnorm = self.cfg["train"]["batchnorm"] + self.preprocess = preprocess + + self.layers = self._make_layers() + + def _make_layers(self): + layers = nn.Sequential( + # conv1 + nn.Conv2d(self.input_dim, 64, stride=1, kernel_size=3, padding=1), + nn.BatchNorm2d(64) if self.batchnorm else nn.Identity(), + nn.ReLU(True), + # fcn + ConvBlock( + 64, [64, 64, 64], kernel_size=3, stride=1, batchnorm=self.batchnorm + ), + IdentityBlock( + 64, [64, 64, 64], kernel_size=3, stride=1, batchnorm=self.batchnorm + ), + ConvBlock( + 64, [128, 128, 128], kernel_size=3, stride=2, batchnorm=self.batchnorm + ), + IdentityBlock( + 128, [128, 128, 128], kernel_size=3, stride=1, batchnorm=self.batchnorm + ), + ConvBlock( + 128, [256, 256, 256], kernel_size=3, stride=2, batchnorm=self.batchnorm + ), + IdentityBlock( + 256, [256, 256, 256], kernel_size=3, stride=1, batchnorm=self.batchnorm + ), + ConvBlock( + 256, [512, 512, 512], kernel_size=3, stride=2, batchnorm=self.batchnorm + ), + IdentityBlock( + 512, [512, 512, 512], kernel_size=3, stride=1, batchnorm=self.batchnorm + ), + # head + ConvBlock( + 512, [256, 256, 256], kernel_size=3, stride=1, batchnorm=self.batchnorm + ), + IdentityBlock( + 256, [256, 256, 256], kernel_size=3, stride=1, batchnorm=self.batchnorm + ), + nn.UpsamplingBilinear2d(scale_factor=2), + ConvBlock( + 256, [128, 128, 128], kernel_size=3, stride=1, batchnorm=self.batchnorm + ), + IdentityBlock( + 128, [128, 128, 128], kernel_size=3, stride=1, batchnorm=self.batchnorm + ), + nn.UpsamplingBilinear2d(scale_factor=2), + ConvBlock( + 128, [64, 64, 64], kernel_size=3, stride=1, batchnorm=self.batchnorm + ), + IdentityBlock( + 64, [64, 64, 64], kernel_size=3, stride=1, batchnorm=self.batchnorm + ), + nn.UpsamplingBilinear2d(scale_factor=2), + # conv2 + ConvBlock( + 64, + [16, 16, self.output_dim], + kernel_size=3, + stride=1, + final_relu=False, + batchnorm=self.batchnorm, + ), + IdentityBlock( + self.output_dim, + [16, 16, self.output_dim], + kernel_size=3, + stride=1, + final_relu=False, + batchnorm=self.batchnorm, + ), + ) + return layers + + def forward(self, x): + x = self.preprocess(x, dist="transporter") + + out = self.layers(x) + return out diff --git a/external/peract_bimanual/helpers/clip/core/simple_tokenizer.py b/external/peract_bimanual/helpers/clip/core/simple_tokenizer.py new file mode 100644 index 0000000000000000000000000000000000000000..7487823a9f4e8e11899a8c041af7729149de8644 --- /dev/null +++ b/external/peract_bimanual/helpers/clip/core/simple_tokenizer.py @@ -0,0 +1,150 @@ +import gzip +import html +import os +from functools import lru_cache + +import ftfy +import regex as re + + +@lru_cache() +def default_bpe(): + return os.path.join( + os.path.dirname(os.path.abspath(__file__)), "bpe_simple_vocab_16e6.txt.gz" + ) + + +@lru_cache() +def bytes_to_unicode(): + """ + Returns list of utf-8 byte and a corresponding list of unicode strings. + The reversible bpe codes work on unicode strings. + This means you need a large # of unicode characters in your vocab if you want to avoid UNKs. + When you're at something like a 10B token dataset you end up needing around 5K for decent coverage. + This is a signficant percentage of your normal, say, 32K bpe vocab. + To avoid that, we want lookup tables between utf-8 bytes and unicode strings. + And avoids mapping to whitespace/control characters the bpe code barfs on. + """ + bs = ( + list(range(ord("!"), ord("~") + 1)) + + list(range(ord("¡"), ord("¬") + 1)) + + list(range(ord("®"), ord("ÿ") + 1)) + ) + cs = bs[:] + n = 0 + for b in range(2**8): + if b not in bs: + bs.append(b) + cs.append(2**8 + n) + n += 1 + cs = [chr(n) for n in cs] + return dict(zip(bs, cs)) + + +def get_pairs(word): + """Return set of symbol pairs in a word. + Word is represented as tuple of symbols (symbols being variable-length strings). + """ + pairs = set() + prev_char = word[0] + for char in word[1:]: + pairs.add((prev_char, char)) + prev_char = char + return pairs + + +def basic_clean(text): + text = ftfy.fix_text(text) + text = html.unescape(html.unescape(text)) + return text.strip() + + +def whitespace_clean(text): + text = re.sub(r"\s+", " ", text) + text = text.strip() + return text + + +class SimpleTokenizer(object): + def __init__(self, bpe_path: str = default_bpe()): + self.byte_encoder = bytes_to_unicode() + self.byte_decoder = {v: k for k, v in self.byte_encoder.items()} + merges = gzip.open(bpe_path).read().decode("utf-8").split("\n") + merges = merges[1 : 49152 - 256 - 2 + 1] + merges = [tuple(merge.split()) for merge in merges] + vocab = list(bytes_to_unicode().values()) + vocab = vocab + [v + "" for v in vocab] + for merge in merges: + vocab.append("".join(merge)) + vocab.extend(["<|startoftext|>", "<|endoftext|>"]) + self.encoder = dict(zip(vocab, range(len(vocab)))) + self.decoder = {v: k for k, v in self.encoder.items()} + self.bpe_ranks = dict(zip(merges, range(len(merges)))) + self.cache = { + "<|startoftext|>": "<|startoftext|>", + "<|endoftext|>": "<|endoftext|>", + } + self.pat = re.compile( + r"""<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""", + re.IGNORECASE, + ) + + def bpe(self, token): + if token in self.cache: + return self.cache[token] + word = tuple(token[:-1]) + (token[-1] + "",) + pairs = get_pairs(word) + + if not pairs: + return token + "" + + while True: + bigram = min(pairs, key=lambda pair: self.bpe_ranks.get(pair, float("inf"))) + if bigram not in self.bpe_ranks: + break + first, second = bigram + new_word = [] + i = 0 + while i < len(word): + try: + j = word.index(first, i) + new_word.extend(word[i:j]) + i = j + except: + new_word.extend(word[i:]) + break + + if word[i] == first and i < len(word) - 1 and word[i + 1] == second: + new_word.append(first + second) + i += 2 + else: + new_word.append(word[i]) + i += 1 + new_word = tuple(new_word) + word = new_word + if len(word) == 1: + break + else: + pairs = get_pairs(word) + word = " ".join(word) + self.cache[token] = word + return word + + def encode(self, text): + bpe_tokens = [] + text = whitespace_clean(basic_clean(text)).lower() + for token in re.findall(self.pat, text): + token = "".join(self.byte_encoder[b] for b in token.encode("utf-8")) + bpe_tokens.extend( + self.encoder[bpe_token] for bpe_token in self.bpe(token).split(" ") + ) + return bpe_tokens + + def decode(self, tokens): + text = "".join([self.decoder[token] for token in tokens]) + text = ( + bytearray([self.byte_decoder[c] for c in text]) + .decode("utf-8", errors="replace") + .replace("", " ") + ) + return text diff --git a/external/peract_bimanual/helpers/clip/core/transport.py b/external/peract_bimanual/helpers/clip/core/transport.py new file mode 100644 index 0000000000000000000000000000000000000000..e6de2193e5daee9957c6f207a90697aa84ba6a62 --- /dev/null +++ b/external/peract_bimanual/helpers/clip/core/transport.py @@ -0,0 +1,110 @@ +import numpy as np +import cliport.models as models +from cliport.utils import utils + +import torch +import torch.nn as nn +import torch.nn.functional as F + + +class Transport(nn.Module): + def __init__( + self, stream_fcn, in_shape, n_rotations, crop_size, preprocess, cfg, device + ): + """Transport (a.k.a Place) module.""" + super().__init__() + + self.iters = 0 + self.stream_fcn = stream_fcn + self.n_rotations = n_rotations + self.crop_size = crop_size # crop size must be N*16 (e.g. 96) + self.preprocess = preprocess + self.cfg = cfg + self.device = device + self.batchnorm = self.cfg["train"]["batchnorm"] + + self.pad_size = int(self.crop_size / 2) + self.padding = np.zeros((3, 2), dtype=int) + self.padding[:2, :] = self.pad_size + + in_shape = np.array(in_shape) + in_shape = tuple(in_shape) + self.in_shape = in_shape + + # Crop before network (default from Transporters CoRL 2020). + self.kernel_shape = (self.crop_size, self.crop_size, self.in_shape[2]) + + if not hasattr(self, "output_dim"): + self.output_dim = 3 + if not hasattr(self, "kernel_dim"): + self.kernel_dim = 3 + + self.rotator = utils.ImageRotator(self.n_rotations) + + self._build_nets() + + def _build_nets(self): + stream_one_fcn, _ = self.stream_fcn + model = models.names[stream_one_fcn] + self.key_resnet = model(self.in_shape, self.output_dim, self.cfg, self.device) + self.query_resnet = model( + self.kernel_shape, self.kernel_dim, self.cfg, self.device + ) + print(f"Transport FCN: {stream_one_fcn}") + + def correlate(self, in0, in1, softmax): + """Correlate two input tensors.""" + output = F.conv2d(in0, in1, padding=(self.pad_size, self.pad_size)) + output = F.interpolate( + output, size=(in0.shape[-2], in0.shape[-1]), mode="bilinear" + ) + output = output[ + :, :, self.pad_size : -self.pad_size, self.pad_size : -self.pad_size + ] + if softmax: + output_shape = output.shape + output = output.reshape((1, np.prod(output.shape))) + output = F.softmax(output, dim=-1) + output = output.reshape(output_shape[1:]) + return output + + def transport(self, in_tensor, crop): + logits = self.key_resnet(in_tensor) + kernel = self.query_resnet(crop) + return logits, kernel + + def forward(self, inp_img, p, softmax=True): + """Forward pass.""" + img_unprocessed = np.pad(inp_img, self.padding, mode="constant") + input_data = img_unprocessed + in_shape = (1,) + input_data.shape + input_data = input_data.reshape(in_shape) # [B W H D] + in_tensor = torch.from_numpy(input_data).to( + dtype=torch.float, device=self.device + ) + + # Rotation pivot. + pv = np.array([p[0], p[1]]) + self.pad_size + + # Crop before network (default from Transporters CoRL 2020). + hcrop = self.pad_size + in_tensor = in_tensor.permute(0, 3, 1, 2) # [B D W H] + + crop = in_tensor.repeat(self.n_rotations, 1, 1, 1) + crop = self.rotator(crop, pivot=pv) + crop = torch.cat(crop, dim=0) + crop = crop[:, :, pv[0] - hcrop : pv[0] + hcrop, pv[1] - hcrop : pv[1] + hcrop] + + logits, kernel = self.transport(in_tensor, crop) + + # TODO(Mohit): Crop after network. Broken for now. + # in_tensor = in_tensor.permute(0, 3, 1, 2) + # logits, crop = self.transport(in_tensor) + # crop = crop.repeat(self.n_rotations, 1, 1, 1) + # crop = self.rotator(crop, pivot=pv) + # crop = torch.cat(crop, dim=0) + + # kernel = crop[:, :, pv[0]-hcrop:pv[0]+hcrop, pv[1]-hcrop:pv[1]+hcrop] + # kernel = crop[:, :, p[0]:(p[0] + self.crop_size), p[1]:(p[1] + self.crop_size)] + + return self.correlate(logits, kernel, softmax) diff --git a/external/peract_bimanual/helpers/clip/core/transport_image_goal.py b/external/peract_bimanual/helpers/clip/core/transport_image_goal.py new file mode 100644 index 0000000000000000000000000000000000000000..d1c15f568c0c7e9e6da73db413ad356b69d3e5d0 --- /dev/null +++ b/external/peract_bimanual/helpers/clip/core/transport_image_goal.py @@ -0,0 +1,144 @@ +import numpy as np +import cliport.models as models +from cliport.utils import utils + +import torch +import torch.nn as nn +import torch.nn.functional as F + + +class TransportImageGoal(nn.Module): + """Transport module.""" + + def __init__( + self, stream_fcn, in_shape, n_rotations, crop_size, preprocess, cfg, device + ): + """Transport module for placing. + Args: + in_shape: shape of input image. + n_rotations: number of rotations of convolving kernel. + crop_size: crop size around pick argmax used as convolving kernel. + preprocess: function to preprocess input images. + """ + super().__init__() + + self.iters = 0 + self.stream_fcn = stream_fcn + self.n_rotations = n_rotations + self.crop_size = crop_size # crop size must be N*16 (e.g. 96) + self.preprocess = preprocess + self.cfg = cfg + self.device = device + self.batchnorm = self.cfg["train"]["batchnorm"] + + self.pad_size = int(self.crop_size / 2) + self.padding = np.zeros((3, 2), dtype=int) + self.padding[:2, :] = self.pad_size + + in_shape = np.array(in_shape) + in_shape = tuple(in_shape) + self.in_shape = in_shape + + # Crop before network (default for Transporters CoRL 2020). + self.kernel_shape = (self.crop_size, self.crop_size, self.in_shape[2]) + + if not hasattr(self, "output_dim"): + self.output_dim = 3 + if not hasattr(self, "kernel_dim"): + self.kernel_dim = 3 + + self.rotator = utils.ImageRotator(self.n_rotations) + + self._build_nets() + + def _build_nets(self): + stream_one_fcn, _ = self.stream_fcn + model = models.names[stream_one_fcn] + self.key_resnet = model(self.in_shape, self.output_dim, self.cfg, self.device) + self.query_resnet = model(self.in_shape, self.kernel_dim, self.cfg, self.device) + self.goal_resnet = model(self.in_shape, self.output_dim, self.cfg, self.device) + print(f"Transport FCN: {stream_one_fcn}") + + def correlate(self, in0, in1, softmax): + """Correlate two input tensors.""" + output = F.conv2d(in0, in1, padding=(self.pad_size, self.pad_size)) + output = F.interpolate( + output, size=(in0.shape[-2], in0.shape[-1]), mode="bilinear" + ) + output = output[ + :, :, self.pad_size : -self.pad_size, self.pad_size : -self.pad_size + ] + if softmax: + output_shape = output.shape + output = output.reshape((1, np.prod(output.shape))) + output = F.softmax(output, dim=-1) + output = output.reshape(output_shape[1:]) + return output + + def forward(self, inp_img, goal_img, p, softmax=True): + """Forward pass.""" + + # Input image. + img_unprocessed = np.pad(inp_img, self.padding, mode="constant") + input_data = img_unprocessed + in_shape = (1,) + input_data.shape + input_data = input_data.reshape(in_shape) + in_tensor = torch.from_numpy(input_data.copy()).to( + dtype=torch.float, device=self.device + ) + in_tensor = in_tensor.permute(0, 3, 1, 2) + + # Goal image. + goal_tensor = np.pad(goal_img, self.padding, mode="constant") + goal_shape = (1,) + goal_tensor.shape + goal_tensor = goal_tensor.reshape(goal_shape) + goal_tensor = torch.from_numpy(goal_tensor.copy()).to( + dtype=torch.float, device=self.device + ) + goal_tensor = goal_tensor.permute(0, 3, 1, 2) + + # Rotation pivot. + pv = np.array([p[0], p[1]]) + self.pad_size + hcrop = self.pad_size + + # Cropped input features. + in_crop = in_tensor.repeat(self.n_rotations, 1, 1, 1) + in_crop = self.rotator(in_crop, pivot=pv) + in_crop = torch.cat(in_crop, dim=0) + in_crop = in_crop[ + :, :, pv[0] - hcrop : pv[0] + hcrop, pv[1] - hcrop : pv[1] + hcrop + ] + + # Cropped goal features. + goal_crop = goal_tensor.repeat(self.n_rotations, 1, 1, 1) + goal_crop = self.rotator(goal_crop, pivot=pv) + goal_crop = torch.cat(goal_crop, dim=0) + goal_crop = goal_crop[ + :, :, pv[0] - hcrop : pv[0] + hcrop, pv[1] - hcrop : pv[1] + hcrop + ] + + in_logits = self.key_resnet(in_tensor) + goal_logits = self.goal_resnet(goal_tensor) + kernel_crop = self.query_resnet(in_crop) + goal_crop = self.goal_resnet(goal_crop) + + # Fuse Goal and Transport features + goal_x_in_logits = ( + goal_logits + in_logits + ) # Mohit: why doesn't multiply work? :( + goal_x_kernel = goal_crop + kernel_crop + + # TODO(Mohit): Crop after network. Broken for now + # in_logits = self.key_resnet(in_tensor) + # kernel_nocrop_logits = self.query_resnet(in_tensor) + # goal_logits = self.goal_resnet(goal_tensor) + + # goal_x_in_logits = in_logits + # goal_x_kernel_logits = goal_logits * kernel_nocrop_logits + + # goal_crop = goal_x_kernel_logits.repeat(self.n_rotations, 1, 1, 1) + # goal_crop = self.rotator(goal_crop, pivot=pv) + # goal_crop = torch.cat(goal_crop, dim=0) + # goal_crop = goal_crop[:, :, pv[0]-hcrop:pv[0]+hcrop, pv[1]-hcrop:pv[1]+hcrop] + + return self.correlate(goal_x_in_logits, goal_x_kernel, softmax) diff --git a/external/peract_bimanual/helpers/clip/core/unet.py b/external/peract_bimanual/helpers/clip/core/unet.py new file mode 100644 index 0000000000000000000000000000000000000000..dbb17086b9c652f5b7f3be64058ed5fbbc56f3cb --- /dev/null +++ b/external/peract_bimanual/helpers/clip/core/unet.py @@ -0,0 +1,77 @@ +# Credit: https://github.com/milesial/Pytorch-UNet/ + +import torch +import torch.nn as nn +import torch.nn.functional as F + + +class DoubleConv(nn.Module): + """(convolution => [BN] => ReLU) * 2""" + + def __init__(self, in_channels, out_channels, mid_channels=None): + super().__init__() + if not mid_channels: + mid_channels = out_channels + self.double_conv = nn.Sequential( + nn.Conv2d(in_channels, mid_channels, kernel_size=3, padding=1), + # nn.BatchNorm2d(mid_channels), # (Mohit): argh... forgot to remove this batchnorm + nn.ReLU(inplace=True), + nn.Conv2d(mid_channels, out_channels, kernel_size=3, padding=1), + # nn.BatchNorm2d(out_channels), # (Mohit): argh... forgot to remove this batchnorm + nn.ReLU(inplace=True), + ) + + def forward(self, x): + return self.double_conv(x) + + +class Down(nn.Module): + """Downscaling with maxpool then double conv""" + + def __init__(self, in_channels, out_channels): + super().__init__() + self.maxpool_conv = nn.Sequential( + nn.MaxPool2d(2), DoubleConv(in_channels, out_channels) + ) + + def forward(self, x): + return self.maxpool_conv(x) + + +class Up(nn.Module): + """Upscaling then double conv""" + + def __init__(self, in_channels, out_channels, bilinear=True): + super().__init__() + + # if bilinear, use the normal convolutions to reduce the number of channels + if bilinear: + self.up = nn.Upsample(scale_factor=2, mode="bilinear", align_corners=True) + self.conv = DoubleConv(in_channels, out_channels, in_channels // 2) + else: + self.up = nn.ConvTranspose2d( + in_channels, in_channels // 2, kernel_size=2, stride=2 + ) + self.conv = DoubleConv(in_channels, out_channels) + + def forward(self, x1, x2): + x1 = self.up(x1) + # input is CHW + diffY = x2.size()[2] - x1.size()[2] + diffX = x2.size()[3] - x1.size()[3] + + x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2, diffY // 2, diffY - diffY // 2]) + # if you have padding issues, see + # https://github.com/HaiyongJiang/U-Net-Pytorch-Unstructured-Buggy/commit/0e854509c2cea854e247a9c615f175f76fbb2e3a + # https://github.com/xiaopeng-liao/Pytorch-UNet/commit/8ebac70e633bac59fc22bb5195e513d5832fb3bd + x = torch.cat([x2, x1], dim=1) + return self.conv(x) + + +class OutConv(nn.Module): + def __init__(self, in_channels, out_channels): + super(OutConv, self).__init__() + self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1) + + def forward(self, x): + return self.conv(x) diff --git a/external/peract_bimanual/helpers/custom_rlbench_env.py b/external/peract_bimanual/helpers/custom_rlbench_env.py new file mode 100644 index 0000000000000000000000000000000000000000..977d71897822dbb5412d142b5c7d73345b8374e6 --- /dev/null +++ b/external/peract_bimanual/helpers/custom_rlbench_env.py @@ -0,0 +1,576 @@ +from typing import Type, List + +import numpy as np +from rlbench import ObservationConfig, ActionMode +from rlbench.backend.exceptions import InvalidActionError +from rlbench.backend.observation import ( + BimanualObservation, + Observation, + UnimanualObservation, +) +from rlbench.backend.task import Task +from yarr.agents.agent import ActResult, VideoSummary, TextSummary +from yarr.envs.rlbench_env import RLBenchEnv, MultiTaskRLBenchEnv +from yarr.utils.observation_type import ObservationElement +from yarr.utils.transition import Transition +from yarr.utils.process_str import change_case + +from pyrep.const import RenderMode +from pyrep.errors import IKError, ConfigurationPathError +from pyrep.objects import VisionSensor, Dummy + +import logging + + +class CustomRLBenchEnv(RLBenchEnv): + def __init__( + self, + task_class: Type[Task], + observation_config: ObservationConfig, + action_mode: ActionMode, + episode_length: int, + dataset_root: str = "", + channels_last: bool = False, + reward_scale=100.0, + headless: bool = True, + time_in_state: bool = False, + include_lang_goal_in_obs: bool = False, + record_every_n: int = 20, + ): + super(CustomRLBenchEnv, self).__init__( + task_class, + observation_config, + action_mode, + dataset_root, + channels_last, + headless=headless, + include_lang_goal_in_obs=include_lang_goal_in_obs, + ) + self._reward_scale = reward_scale + self._episode_index = 0 + self._record_current_episode = False + self._record_cam = None + self._previous_obs, self._previous_obs_dict = None, None + self._recorded_images = [] + self._episode_length = episode_length + self._time_in_state = time_in_state + self._record_every_n = record_every_n + self._i = 0 + self._error_type_counts = { + "IKError": 0, + "ConfigurationPathError": 0, + "InvalidActionError": 0, + } + self._last_exception = None + + @property + def observation_elements(self) -> List[ObservationElement]: + obs_elems = super(CustomRLBenchEnv, self).observation_elements + for oe in obs_elems: + if "low_dim_state" in oe.name: + oe.shape = ( + oe.shape[0] - 7 * 3 + int(self._time_in_state), + ) # remove pose and joint velocities as they will not be included + self.low_dim_state_len = oe.shape[0] + + return obs_elems + + def extract_obs(self, obs: Observation, t=None, prev_action=None): + if obs.is_bimanual: + return self.extract_obs_bimanual(obs, t, prev_action) + else: + return self.extract_obs_unimanual(obs, t, prev_action) + + def extract_obs_bimanual(self, obs: BimanualObservation, t=None, prev_action=None): + obs.right.joint_velocities = None + right_grip_mat = obs.right.gripper_matrix + right_grip_pose = obs.right.gripper_pose + right_joint_pos = obs.right.joint_positions + obs.right.gripper_pose = None + obs.right.gripper_matrix = None + obs.right.joint_positions = None + + obs.left.joint_velocities = None + left_grip_mat = obs.left.gripper_matrix + left_grip_pose = obs.left.gripper_pose + left_joint_pos = obs.left.joint_positions + obs.left.gripper_pose = None + obs.left.gripper_matrix = None + obs.left.joint_positions = None + + if obs.right.gripper_joint_positions is not None: + obs.right.gripper_joint_positions = np.clip( + obs.right.gripper_joint_positions, 0.0, 0.04 + ) + obs.left.gripper_joint_positions = np.clip( + obs.left.gripper_joint_positions, 0.0, 0.04 + ) + + obs_dict = super(CustomRLBenchEnv, self).extract_obs(obs) + + if self._time_in_state: + time = ( + 1.0 - ((self._i if t is None else t) / float(self._episode_length - 1)) + ) * 2.0 - 1.0 + + if "low_dim_state" in obs_dict: + obs_dict["low_dim_state"] = np.concatenate( + [obs_dict["low_dim_state"], [time]] + ).astype(np.float32) + else: + obs_dict["right_low_dim_state"] = np.concatenate( + [obs_dict["right_low_dim_state"], [time]] + ).astype(np.float32) + obs_dict["left_low_dim_state"] = np.concatenate( + [obs_dict["left_low_dim_state"], [time]] + ).astype(np.float32) + + obs.right.gripper_matrix = right_grip_mat + obs.right.joint_positions = right_joint_pos + obs.right.gripper_pose = right_grip_pose + obs.left.gripper_matrix = left_grip_mat + obs.left.joint_positions = left_joint_pos + obs.left.gripper_pose = left_grip_pose + + obs_dict["left_joint_positions"] = obs.left.joint_positions + obs_dict["left_gripper_joint_positions"] = obs.left.gripper_joint_positions + obs_dict["right_joint_positions"] = obs.right.joint_positions + obs_dict["right_gripper_joint_positions"] = obs.right.gripper_joint_positions + + return obs_dict + + def extract_obs_unimanual( + self, obs: UnimanualObservation, t=None, prev_action=None + ): + obs.joint_velocities = None + grip_mat = obs.gripper_matrix + grip_pose = obs.gripper_pose + joint_pos = obs.joint_positions + obs.gripper_pose = None + # obs.gripper_pose = None + obs.gripper_matrix = None + obs.joint_positions = None + if obs.gripper_joint_positions is not None: + obs.gripper_joint_positions = np.clip( + obs.gripper_joint_positions, 0.0, 0.04 + ) + + obs_dict = super(CustomRLBenchEnv, self).extract_obs(obs) + + if self._time_in_state: + time = ( + 1.0 - ((self._i if t is None else t) / float(self._episode_length - 1)) + ) * 2.0 - 1.0 + obs_dict["low_dim_state"] = np.concatenate( + [obs_dict["low_dim_state"], [time]] + ).astype(np.float32) + + obs.gripper_matrix = grip_mat + # obs.gripper_pose = grip_pose + obs.joint_positions = joint_pos + obs.gripper_pose = grip_pose + # obs_dict['gripper_pose'] = grip_pose + + obs_dict["joint_positions"] = obs.joint_positions + obs_dict["gripper_joint_positions"] = obs.gripper_joint_positions + + return obs_dict + + def launch(self): + super(CustomRLBenchEnv, self).launch() + self._task._scene.register_step_callback(self._my_callback) + if self.eval: + cam_placeholder = Dummy("cam_cinematic_placeholder") + cam_base = Dummy("cam_cinematic_base") + cam_base.rotate([0, 0, np.pi * 0.75]) + self._record_cam = VisionSensor.create([320, 180]) + self._record_cam.set_explicit_handling(True) + self._record_cam.set_pose(cam_placeholder.get_pose()) + self._record_cam.set_render_mode(RenderMode.OPENGL) + + def reset(self) -> dict: + self._i = 0 + self._previous_obs_dict = super(CustomRLBenchEnv, self).reset() + self._record_current_episode = ( + self.eval and self._episode_index % self._record_every_n == 0 + ) + self._episode_index += 1 + self._recorded_images.clear() + return self._previous_obs_dict + + def register_callback(self, func): + self._task._scene.register_step_callback(func) + + def _my_callback(self): + if self._record_current_episode: + self._record_cam.handle_explicitly() + cap = (self._record_cam.capture_rgb() * 255).astype(np.uint8) + self._recorded_images.append(cap) + + def _append_final_frame(self, success: bool): + self._record_cam.handle_explicitly() + img = (self._record_cam.capture_rgb() * 255).astype(np.uint8) + self._recorded_images.append(img) + final_frames = np.zeros((10,) + img.shape[:2] + (3,), dtype=np.uint8) + # Green/red for success/failure + final_frames[:, :, :, 1 if success else 0] = 255 + self._recorded_images.extend(list(final_frames)) + + def step(self, act_result: ActResult) -> Transition: + action = act_result.action + success = False + obs = self._previous_obs_dict # in case action fails. + + try: + obs, reward, terminal = self._task.step(action) + if reward >= 1: + success = True + reward *= self._reward_scale + else: + reward = 0.0 + obs = self.extract_obs(obs) + self._previous_obs_dict = obs + except (IKError, ConfigurationPathError, InvalidActionError) as e: + terminal = True + reward = 0.0 + + if isinstance(e, IKError): + self._error_type_counts["IKError"] += 1 + elif isinstance(e, ConfigurationPathError): + self._error_type_counts["ConfigurationPathError"] += 1 + elif isinstance(e, InvalidActionError): + self._error_type_counts["InvalidActionError"] += 1 + + self._last_exception = e + + summaries = [] + self._i += 1 + if ( + terminal or self._i == self._episode_length + ) and self._record_current_episode: + self._append_final_frame(success) + vid = np.array(self._recorded_images).transpose((0, 3, 1, 2)) + summaries.append( + VideoSummary( + "episode_rollout_" + ("success" if success else "fail"), vid, fps=30 + ) + ) + + # error summary + error_str = ( + f"Errors - IK : {self._error_type_counts['IKError']}, " + f"ConfigPath : {self._error_type_counts['ConfigurationPathError']}, " + f"InvalidAction : {self._error_type_counts['InvalidActionError']}" + ) + if not success and self._last_exception is not None: + error_str += f"\n Last Exception: {self._last_exception}" + self._last_exception = None + + summaries.append( + TextSummary("errors", f"Success: {success} | " + error_str) + ) + return Transition(obs, reward, terminal, summaries=summaries) + + def reset_to_demo(self, i): + self._i = 0 + # super(CustomRLBenchEnv, self).reset() + + self._task.set_variation(-1) + (d,) = self._task.get_demos( + 1, live_demos=False, random_selection=False, from_episode_number=i + ) + + self._task.set_variation(d.variation_number) + _, obs = self._task.reset_to_demo(d) + self._lang_goal = self._task.get_task_descriptions()[0] + + self._previous_obs_dict = self.extract_obs(obs) + self._record_current_episode = ( + self.eval and self._episode_index % self._record_every_n == 0 + ) + self._episode_index += 1 + self._recorded_images.clear() + + return self._previous_obs_dict + + +class CustomMultiTaskRLBenchEnv(MultiTaskRLBenchEnv): + def __init__( + self, + task_classes: List[Type[Task]], + observation_config: ObservationConfig, + action_mode: ActionMode, + episode_length: int, + dataset_root: str = "", + channels_last: bool = False, + reward_scale=100.0, + headless: bool = True, + swap_task_every: int = 1, + time_in_state: bool = False, + include_lang_goal_in_obs: bool = False, + record_every_n: int = 20, + ): + super(CustomMultiTaskRLBenchEnv, self).__init__( + task_classes, + observation_config, + action_mode, + dataset_root, + channels_last, + headless=headless, + swap_task_every=swap_task_every, + include_lang_goal_in_obs=include_lang_goal_in_obs, + ) + self._reward_scale = reward_scale + self._episode_index = 0 + self._record_current_episode = False + self._record_cam = None + self._previous_obs, self._previous_obs_dict = None, None + self._recorded_images = [] + self._episode_length = episode_length + self._time_in_state = time_in_state + self._record_every_n = record_every_n + self._i = 0 + self._error_type_counts = { + "IKError": 0, + "ConfigurationPathError": 0, + "InvalidActionError": 0, + } + self._last_exception = None + + @property + def observation_elements(self) -> List[ObservationElement]: + obs_elems = super(CustomMultiTaskRLBenchEnv, self).observation_elements + for oe in obs_elems: + if "low_dim_state" in oe.name: + # ..todo:: since we have the low_dimensional state separate for both robots this will also work + oe.shape = ( + oe.shape[0] - 7 * 3 + int(self._time_in_state), + ) # remove pose and joint velocities as they will not be included + self.low_dim_state_len = oe.shape[0] + return obs_elems + + def extract_obs(self, obs: Observation, t=None, prev_action=None): + if obs.is_bimanual: + return self.extract_obs_bimanual(obs, t, prev_action) + else: + return self.extract_obs_unimanual(obs, t, prev_action) + + def extract_obs_bimanual(self, obs: BimanualObservation, t=None, prev_action=None): + obs.right.joint_velocities = None + right_grip_mat = obs.right.gripper_matrix + right_grip_pose = obs.right.gripper_pose + right_joint_pos = obs.right.joint_positions + obs.right.gripper_pose = None + obs.right.gripper_matrix = None + obs.right.joint_positions = None + + obs.left.joint_velocities = None + left_grip_mat = obs.left.gripper_matrix + left_grip_pose = obs.left.gripper_pose + left_joint_pos = obs.left.joint_positions + obs.left.gripper_pose = None + obs.left.gripper_matrix = None + obs.left.joint_positions = None + + if obs.right.gripper_joint_positions is not None: + obs.right.gripper_joint_positions = np.clip( + obs.right.gripper_joint_positions, 0.0, 0.04 + ) + obs.left.gripper_joint_positions = np.clip( + obs.left.gripper_joint_positions, 0.0, 0.04 + ) + + obs_dict = super(CustomMultiTaskRLBenchEnv, self).extract_obs(obs) + + if self._time_in_state: + time = ( + 1.0 - ((self._i if t is None else t) / float(self._episode_length - 1)) + ) * 2.0 - 1.0 + obs_dict["right_low_dim_state"] = np.concatenate( + [obs_dict["right_low_dim_state"], [time]] + ).astype(np.float32) + obs_dict["left_low_dim_state"] = np.concatenate( + [obs_dict["left_low_dim_state"], [time]] + ).astype(np.float32) + + obs.right.gripper_matrix = right_grip_mat + obs.right.joint_positions = right_joint_pos + obs.right.gripper_pose = right_grip_pose + obs.left.gripper_matrix = left_grip_mat + obs.left.joint_positions = left_joint_pos + obs.left.gripper_pose = left_grip_pose + + obs_dict["left_joint_positions"] = obs.left.joint_positions + obs_dict["left_gripper_joint_positions"] = obs.left.gripper_joint_positions + obs_dict["right_joint_positions"] = obs.right.joint_positions + obs_dict["right_gripper_joint_positions"] = obs.right.gripper_joint_positions + + return obs_dict + + def extract_obs_unimanual(self, obs: Observation, t=None, prev_action=None): + obs.joint_velocities = None + grip_mat = obs.gripper_matrix + grip_pose = obs.gripper_pose + joint_pos = obs.joint_positions + obs.gripper_pose = None + # obs.gripper_pose = None + obs.gripper_matrix = None + obs.wrist_camera_matrix = None + obs.joint_positions = None + if obs.gripper_joint_positions is not None: + obs.gripper_joint_positions = np.clip( + obs.gripper_joint_positions, 0.0, 0.04 + ) + + obs_dict = super(CustomMultiTaskRLBenchEnv, self).extract_obs(obs) + + if self._time_in_state: + time = ( + 1.0 - ((self._i if t is None else t) / float(self._episode_length - 1)) + ) * 2.0 - 1.0 + obs_dict["low_dim_state"] = np.concatenate( + [obs_dict["low_dim_state"], [time]] + ).astype(np.float32) + + obs.gripper_matrix = grip_mat + # obs.gripper_pose = grip_pose + obs.joint_positions = joint_pos + obs.gripper_pose = grip_pose + # obs_dict['gripper_pose'] = grip_pose + + obs_dict["joint_positions"] = obs.joint_positions + obs_dict["gripper_joint_positions"] = obs.gripper_joint_positions + + return obs_dict + + def launch(self): + super(CustomMultiTaskRLBenchEnv, self).launch() + self._task._scene.register_step_callback(self._my_callback) + if self.eval: + cam_placeholder = Dummy("cam_cinematic_placeholder") + cam_base = Dummy("cam_cinematic_base") + cam_base.rotate([0, 0, np.pi * 0.75]) + self._record_cam = VisionSensor.create([320, 180]) + self._record_cam.set_explicit_handling(True) + self._record_cam.set_pose(cam_placeholder.get_pose()) + self._record_cam.set_render_mode(RenderMode.OPENGL) + + def reset(self) -> dict: + self._i = 0 + self._previous_obs_dict = super(CustomMultiTaskRLBenchEnv, self).reset() + self._record_current_episode = ( + self.eval and self._episode_index % self._record_every_n == 0 + ) + self._episode_index += 1 + self._recorded_images.clear() + return self._previous_obs_dict + + def register_callback(self, func): + self._task._scene.register_step_callback(func) + + def _my_callback(self): + if self._record_current_episode: + self._record_cam.handle_explicitly() + cap = (self._record_cam.capture_rgb() * 255).astype(np.uint8) + self._recorded_images.append(cap) + + def _append_final_frame(self, success: bool): + self._record_cam.handle_explicitly() + img = (self._record_cam.capture_rgb() * 255).astype(np.uint8) + self._recorded_images.append(img) + final_frames = np.zeros((10,) + img.shape[:2] + (3,), dtype=np.uint8) + # Green/red for success/failure + final_frames[:, :, :, 1 if success else 0] = 255 + self._recorded_images.extend(list(final_frames)) + + def step(self, act_result: ActResult) -> Transition: + action = act_result.action + success = False + obs = self._previous_obs_dict # in case action fails. + + try: + obs, reward, terminal = self._task.step(action) + if reward >= 1: + success = True + reward *= self._reward_scale + else: + reward = 0.0 + obs = self.extract_obs(obs) + self._previous_obs_dict = obs + except (IKError, ConfigurationPathError, InvalidActionError) as e: + terminal = True + reward = 0.0 + + if isinstance(e, IKError): + self._error_type_counts["IKError"] += 1 + elif isinstance(e, ConfigurationPathError): + self._error_type_counts["ConfigurationPathError"] += 1 + elif isinstance(e, InvalidActionError): + self._error_type_counts["InvalidActionError"] += 1 + + self._last_exception = e + + summaries = [] + self._i += 1 + if ( + terminal or self._i == self._episode_length + ) and self._record_current_episode: + self._append_final_frame(success) + vid = np.array(self._recorded_images).transpose((0, 3, 1, 2)) + task_name = change_case(self._task._task.__class__.__name__) + summaries.append( + VideoSummary( + "episode_rollout_" + + ("success" if success else "fail") + + f"/{task_name}", + vid, + fps=30, + ) + ) + + # error summary + error_str = ( + f"Errors - IK : {self._error_type_counts['IKError']}, " + f"ConfigPath : {self._error_type_counts['ConfigurationPathError']}, " + f"InvalidAction : {self._error_type_counts['InvalidActionError']}" + ) + if not success and self._last_exception is not None: + error_str += f"\n Last Exception: {self._last_exception}" + self._last_exception = None + + summaries.append( + TextSummary("errors", f"Success: {success} | " + error_str) + ) + return Transition(obs, reward, terminal, summaries=summaries) + + def reset_to_demo(self, i, variation_number=-1): + if self._episodes_this_task == self._swap_task_every: + self._set_new_task() + self._episodes_this_task = 0 + self._episodes_this_task += 1 + + self._i = 0 + # super(CustomMultiTaskRLBenchEnv, self).reset() + + # if variation_number == -1: + # self._task.sample_variation() + # else: + # self._task.set_variation(variation_number) + + self._task.set_variation(-1) + d = self._task.get_demos( + 1, live_demos=False, random_selection=False, from_episode_number=i + )[0] + + self._task.set_variation(d.variation_number) + _, obs = self._task.reset_to_demo(d) + self._lang_goal = self._task.get_task_descriptions()[0] + + self._previous_obs_dict = self.extract_obs(obs) + self._record_current_episode = ( + self.eval and self._episode_index % self._record_every_n == 0 + ) + self._episode_index += 1 + self._recorded_images.clear() + + return self._previous_obs_dict diff --git a/external/peract_bimanual/helpers/demo_loading_utils.py b/external/peract_bimanual/helpers/demo_loading_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..4784108c4efd7ac8035bbdb94ec2ba4513adb432 --- /dev/null +++ b/external/peract_bimanual/helpers/demo_loading_utils.py @@ -0,0 +1,130 @@ +import logging +from typing import List + +import numpy as np +from rlbench.demo import Demo +import omegaconf + + +def _is_stopped(demo, i, obs, delta=0.1): + next_is_not_final = i == (len(demo) - 2) + gripper_state_no_change = i < (len(demo) - 2) and ( + obs.gripper_open == demo[i + 1].gripper_open + and obs.gripper_open == demo[i - 1].gripper_open + and demo[i - 2].gripper_open == demo[i - 1].gripper_open + ) + small_delta = np.allclose(obs.joint_velocities, 0, atol=delta) + return small_delta and (not next_is_not_final) and gripper_state_no_change + + +def _is_stopped_right(demo, i, obs, delta=0.1): + next_is_not_final = i == (len(demo) - 2) + gripper_state_no_change = i < (len(demo) - 2) and ( + obs.gripper_open == demo[i + 1].right.gripper_open + and obs.gripper_open == demo[i - 1].right.gripper_open + and demo[i - 2].right.gripper_open == demo[i - 1].right.gripper_open + ) + small_delta = np.allclose(obs.joint_velocities, 0, atol=delta) + return small_delta and (not next_is_not_final) and gripper_state_no_change + + +def _is_stopped_left(demo, i, obs, delta=0.1): + next_is_not_final = i == (len(demo) - 2) + gripper_state_no_change = i < (len(demo) - 2) and ( + obs.gripper_open == demo[i + 1].left.gripper_open + and obs.gripper_open == demo[i - 1].left.gripper_open + and demo[i - 2].left.gripper_open == demo[i - 1].left.gripper_open + ) + small_delta = np.allclose(obs.joint_velocities, 0, atol=delta) + return small_delta and (not next_is_not_final) and gripper_state_no_change + + +def _keypoint_discovery_bimanual(demo: Demo, stopping_delta=0.1) -> List[int]: + episode_keypoints = [] + right_prev_gripper_open = demo[0].right.gripper_open + left_prev_gripper_open = demo[0].left.gripper_open + stopped_buffer = 0 + for i, obs in enumerate(demo._observations): + right_stopped = _is_stopped_right(demo, i, obs.right, stopping_delta) + left_stopped = _is_stopped_left(demo, i, obs.left, stopping_delta) + stopped = (stopped_buffer <= 0) and right_stopped and left_stopped + stopped_buffer = 4 if stopped else stopped_buffer - 1 + # if change in gripper, or end of episode. + last = i == (len(demo) - 1) + right_state_changed = obs.right.gripper_open != right_prev_gripper_open + left_state_changed = obs.left.gripper_open != left_prev_gripper_open + state_changed = right_state_changed or left_state_changed + if i != 0 and (state_changed or last or stopped): + episode_keypoints.append(i) + + right_prev_gripper_open = obs.right.gripper_open + left_prev_gripper_open = obs.left.gripper_open + if ( + len(episode_keypoints) > 1 + and (episode_keypoints[-1] - 1) == episode_keypoints[-2] + ): + episode_keypoints.pop(-2) + print("Found %d keypoints." % len(episode_keypoints), episode_keypoints) + return episode_keypoints + + +def _keypoint_discovery_unimanual(demo: Demo, stopping_delta=0.1) -> List[int]: + episode_keypoints = [] + prev_gripper_open = demo[0].gripper_open + stopped_buffer = 0 + for i, obs in enumerate(demo): + stopped = _is_stopped(demo, i, obs, stopping_delta) + stopped = (stopped_buffer <= 0) and stopped + stopped_buffer = 4 if stopped else stopped_buffer - 1 + # if change in gripper, or end of episode. + last = i == (len(demo) - 1) + if i != 0 and (obs.gripper_open != prev_gripper_open or last or stopped): + episode_keypoints.append(i) + prev_gripper_open = obs.gripper_open + if ( + len(episode_keypoints) > 1 + and (episode_keypoints[-1] - 1) == episode_keypoints[-2] + ): + episode_keypoints.pop(-2) + print("Found %d keypoints." % len(episode_keypoints), episode_keypoints) + return episode_keypoints + + +def _keypoint_discovery_heuristic(demo: Demo, stopping_delta=0.1) -> List[int]: + if demo[0].is_bimanual: + return _keypoint_discovery_bimanual(demo, stopping_delta) + else: + return _keypoint_discovery_unimanual(demo, stopping_delta) + + +def keypoint_discovery(demo: Demo, stopping_delta=0.1, method="heuristic") -> List[int]: + episode_keypoints = [] + if method == "heuristic": + return _keypoint_discovery_heuristic(demo, stopping_delta) + + elif method == "random": + # Randomly select keypoints. + episode_keypoints = np.random.choice(range(len(demo)), size=20, replace=False) + episode_keypoints.sort() + return episode_keypoints + + elif method == "fixed_interval": + # Fixed interval. + episode_keypoints = [] + segment_length = len(demo) // 20 + for i in range(0, len(demo), segment_length): + episode_keypoints.append(i) + return episode_keypoints + elif isinstance(method, omegaconf.listconfig.ListConfig): + return list(method) + else: + raise NotImplementedError + + +# find minimum difference between any two elements in list +def find_minimum_difference(lst): + minimum = lst[-1] + for i in range(1, len(lst)): + if lst[i] - lst[i - 1] < minimum: + minimum = lst[i] - lst[i - 1] + return minimum diff --git a/external/peract_bimanual/helpers/network_utils.py b/external/peract_bimanual/helpers/network_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..c7325c3e2e337846290fa25f95d03cff7093a57c --- /dev/null +++ b/external/peract_bimanual/helpers/network_utils.py @@ -0,0 +1,985 @@ +import copy +from typing import List, Union + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F + +from einops import rearrange, repeat +from einops.layers.torch import Rearrange + +LRELU_SLOPE = 0.02 + + +def act_layer(act): + if act == "relu": + return nn.ReLU() + elif act == "lrelu": + return nn.LeakyReLU(LRELU_SLOPE) + elif act == "elu": + return nn.ELU() + elif act == "tanh": + return nn.Tanh() + elif act == "prelu": + return nn.PReLU() + else: + raise ValueError("%s not recognized." % act) + + +def norm_layer2d(norm, channels): + if norm == "batch": + return nn.BatchNorm2d(channels) + elif norm == "instance": + return nn.InstanceNorm2d(channels, affine=True) + elif norm == "layer": + return nn.GroupNorm(1, channels, affine=True) + elif norm == "group": + return nn.GroupNorm(4, channels, affine=True) + else: + raise ValueError("%s not recognized." % norm) + + +def norm_layer1d(norm, num_channels): + if norm == "batch": + return nn.BatchNorm1d(num_channels) + elif norm == "instance": + return nn.InstanceNorm1d(num_channels, affine=True) + elif norm == "layer": + return nn.LayerNorm(num_channels) + else: + raise ValueError("%s not recognized." % norm) + + +class FiLMBlock(nn.Module): + def __init__(self): + super(FiLMBlock, self).__init__() + + def forward(self, x, gamma, beta): + beta = beta.view(x.size(0), x.size(1), 1, 1) + gamma = gamma.view(x.size(0), x.size(1), 1, 1) + + x = gamma * x + beta + + return x + + +class Conv2DBlock(nn.Module): + def __init__( + self, + in_channels, + out_channels, + kernel_sizes, + strides, + norm=None, + activation=None, + padding_mode="replicate", + ): + super(Conv2DBlock, self).__init__() + padding = ( + kernel_sizes // 2 + if isinstance(kernel_sizes, int) + else (kernel_sizes[0] // 2, kernel_sizes[1] // 2) + ) + self.conv2d = nn.Conv2d( + in_channels, + out_channels, + kernel_sizes, + strides, + padding=padding, + padding_mode=padding_mode, + ) + + if activation is None: + nn.init.xavier_uniform_( + self.conv2d.weight, gain=nn.init.calculate_gain("linear") + ) + nn.init.zeros_(self.conv2d.bias) + elif activation == "tanh": + nn.init.xavier_uniform_( + self.conv2d.weight, gain=nn.init.calculate_gain("tanh") + ) + nn.init.zeros_(self.conv2d.bias) + elif activation == "lrelu": + nn.init.kaiming_uniform_( + self.conv2d.weight, a=LRELU_SLOPE, nonlinearity="leaky_relu" + ) + nn.init.zeros_(self.conv2d.bias) + elif activation == "relu": + nn.init.kaiming_uniform_(self.conv2d.weight, nonlinearity="relu") + nn.init.zeros_(self.conv2d.bias) + else: + raise ValueError() + + self.activation = None + self.norm = None + if norm is not None: + self.norm = norm_layer2d(norm, out_channels) + if activation is not None: + self.activation = act_layer(activation) + + def forward(self, x): + x = self.conv2d(x) + x = self.norm(x) if self.norm is not None else x + x = self.activation(x) if self.activation is not None else x + return x + + +class Conv2DFiLMBlock(Conv2DBlock): + def __init__( + self, + in_channels, + out_channels, + kernel_sizes, + strides, + norm=None, + activation=None, + padding_mode="replicate", + ): + super(Conv2DFiLMBlock, self).__init__( + in_channels, + out_channels, + kernel_sizes, + strides, + norm, + activation, + padding_mode, + ) + + self.film = FiLMBlock() + + def forward(self, x, gamma, beta): + x = self.conv2d(x) + x = self.norm(x) if self.norm is not None else x + x = self.film(x, gamma, beta) + x = self.activation(x) if self.activation is not None else x + return x + + +class Conv3DBlock(nn.Module): + def __init__( + self, + in_channels, + out_channels, + kernel_sizes: Union[int, list] = 3, + strides=1, + norm=None, + activation=None, + padding_mode="replicate", + padding=None, + ): + super(Conv3DBlock, self).__init__() + padding = kernel_sizes // 2 if padding is None else padding + self.conv3d = nn.Conv3d( + in_channels, + out_channels, + kernel_sizes, + strides, + padding=padding, + padding_mode=padding_mode, + ) + + if activation is None: + nn.init.xavier_uniform_( + self.conv3d.weight, gain=nn.init.calculate_gain("linear") + ) + nn.init.zeros_(self.conv3d.bias) + elif activation == "tanh": + nn.init.xavier_uniform_( + self.conv3d.weight, gain=nn.init.calculate_gain("tanh") + ) + nn.init.zeros_(self.conv3d.bias) + elif activation == "lrelu": + nn.init.kaiming_uniform_( + self.conv3d.weight, a=LRELU_SLOPE, nonlinearity="leaky_relu" + ) + nn.init.zeros_(self.conv3d.bias) + elif activation == "relu": + nn.init.kaiming_uniform_(self.conv3d.weight, nonlinearity="relu") + nn.init.zeros_(self.conv3d.bias) + else: + raise ValueError() + + self.activation = None + self.norm = None + if norm is not None: + raise NotImplementedError("Norm not implemented.") + if activation is not None: + self.activation = act_layer(activation) + self.out_channels = out_channels + + def forward(self, x): + x = self.conv3d(x) + x = self.norm(x) if self.norm is not None else x + x = self.activation(x) if self.activation is not None else x + return x + + +class ConvTranspose3DBlock(nn.Module): + def __init__( + self, + in_channels, + out_channels, + kernel_sizes: Union[int, list], + strides, + norm=None, + activation=None, + padding_mode="zeros", + padding=None, + ): + super(ConvTranspose3DBlock, self).__init__() + padding = kernel_sizes // 2 if padding is None else padding + self.conv3d = nn.ConvTranspose3d( + in_channels, + out_channels, + kernel_sizes, + strides, + padding=padding, + padding_mode=padding_mode, + ) + + if activation is None: + nn.init.xavier_uniform_( + self.conv3d.weight, gain=nn.init.calculate_gain("linear") + ) + nn.init.zeros_(self.conv3d.bias) + elif activation == "tanh": + nn.init.xavier_uniform_( + self.conv3d.weight, gain=nn.init.calculate_gain("tanh") + ) + nn.init.zeros_(self.conv3d.bias) + elif activation == "lrelu": + nn.init.kaiming_uniform_( + self.conv3d.weight, a=LRELU_SLOPE, nonlinearity="leaky_relu" + ) + nn.init.zeros_(self.conv3d.bias) + elif activation == "relu": + nn.init.kaiming_uniform_(self.conv3d.weight, nonlinearity="relu") + nn.init.zeros_(self.conv3d.bias) + else: + raise ValueError() + + self.activation = None + self.norm = None + if norm is not None: + self.norm = norm_layer3d(norm, out_channels) + if activation is not None: + self.activation = act_layer(activation) + + def forward(self, x): + x = self.conv3d(x) + x = self.norm(x) if self.norm is not None else x + x = self.activation(x) if self.activation is not None else x + return x + + +class Conv2DUpsampleBlock(nn.Module): + def __init__( + self, + in_channels, + out_channels, + kernel_sizes, + strides, + norm=None, + activation=None, + ): + super(Conv2DUpsampleBlock, self).__init__() + layer = [ + Conv2DBlock(in_channels, out_channels, kernel_sizes, 1, norm, activation) + ] + if strides > 1: + layer.append( + nn.Upsample(scale_factor=strides, mode="bilinear", align_corners=False) + ) + convt_block = Conv2DBlock( + out_channels, out_channels, kernel_sizes, 1, norm, activation + ) + layer.append(convt_block) + self.conv_up = nn.Sequential(*layer) + + def forward(self, x): + return self.conv_up(x) + + +class Conv3DUpsampleBlock(nn.Module): + def __init__( + self, + in_channels, + out_channels, + strides, + kernel_sizes=3, + norm=None, + activation=None, + ): + super(Conv3DUpsampleBlock, self).__init__() + layer = [ + Conv3DBlock(in_channels, out_channels, kernel_sizes, 1, norm, activation) + ] + if strides > 1: + layer.append( + nn.Upsample(scale_factor=strides, mode="trilinear", align_corners=False) + ) + convt_block = Conv3DBlock( + out_channels, out_channels, kernel_sizes, 1, norm, activation + ) + layer.append(convt_block) + self.conv_up = nn.Sequential(*layer) + + def forward(self, x): + return self.conv_up(x) + + +class DenseBlock(nn.Module): + def __init__(self, in_features, out_features, norm=None, activation=None): + super(DenseBlock, self).__init__() + self.linear = nn.Linear(in_features, out_features) + + if activation is None: + nn.init.xavier_uniform_( + self.linear.weight, gain=nn.init.calculate_gain("linear") + ) + nn.init.zeros_(self.linear.bias) + elif activation == "tanh": + nn.init.xavier_uniform_( + self.linear.weight, gain=nn.init.calculate_gain("tanh") + ) + nn.init.zeros_(self.linear.bias) + elif activation == "lrelu": + nn.init.kaiming_uniform_( + self.linear.weight, a=LRELU_SLOPE, nonlinearity="leaky_relu" + ) + nn.init.zeros_(self.linear.bias) + elif activation == "relu": + nn.init.kaiming_uniform_(self.linear.weight, nonlinearity="relu") + nn.init.zeros_(self.linear.bias) + else: + raise ValueError() + + self.activation = None + self.norm = None + if norm is not None: + self.norm = norm_layer1d(norm, out_features) + if activation is not None: + self.activation = act_layer(activation) + + def forward(self, x): + x = self.linear(x) + x = self.norm(x) if self.norm is not None else x + x = self.activation(x) if self.activation is not None else x + return x + + +class SiameseNet(nn.Module): + def __init__( + self, + input_channels: List[int], + filters: List[int], + kernel_sizes: List[int], + strides: List[int], + norm: str = None, + activation: str = "relu", + ): + super(SiameseNet, self).__init__() + self._input_channels = input_channels + self._filters = filters + self._kernel_sizes = kernel_sizes + self._strides = strides + self._norm = norm + self._activation = activation + self.output_channels = filters[-1] # * len(input_channels) + + def build(self): + self._siamese_blocks = nn.ModuleList() + for i, ch in enumerate(self._input_channels): + blocks = [] + for i, (filt, ksize, stride) in enumerate( + zip(self._filters, self._kernel_sizes, self._strides) + ): + conv_block = Conv2DBlock( + ch, filt, ksize, stride, self._norm, self._activation + ) + blocks.append(conv_block) + self._siamese_blocks.append(nn.Sequential(*blocks)) + self._fuse = Conv2DBlock( + self._filters[-1] * len(self._siamese_blocks), + self._filters[-1], + 1, + 1, + self._norm, + self._activation, + ) + + def forward(self, x): + if len(x) != len(self._siamese_blocks): + raise ValueError( + "Expected a list of tensors of size %d." % len(self._siamese_blocks) + ) + self.streams = [stream(y) for y, stream in zip(x, self._siamese_blocks)] + y = self._fuse(torch.cat(self.streams, 1)) + return y + + +class CNNAndFcsNet(nn.Module): + def __init__( + self, + siamese_net: SiameseNet, + low_dim_state_len: int, + input_resolution: List[int], + filters: List[int], + kernel_sizes: List[int], + strides: List[int], + norm: str = None, + fc_layers: List[int] = None, + activation: str = "relu", + ): + super(CNNAndFcsNet, self).__init__() + self._siamese_net = copy.deepcopy(siamese_net) + self._input_channels = self._siamese_net.output_channels + low_dim_state_len + self._filters = filters + self._kernel_sizes = kernel_sizes + self._strides = strides + self._norm = norm + self._activation = activation + self._fc_layers = [] if fc_layers is None else fc_layers + self._input_resolution = input_resolution + + def build(self): + self._siamese_net.build() + layers = [] + channels = self._input_channels + for i, (filt, ksize, stride) in enumerate( + list(zip(self._filters, self._kernel_sizes, self._strides))[:-1] + ): + layers.append( + Conv2DBlock(channels, filt, ksize, stride, self._norm, self._activation) + ) + channels = filt + layers.append( + Conv2DBlock( + channels, self._filters[-1], self._kernel_sizes[-1], self._strides[-1] + ) + ) + self._cnn = nn.Sequential(*layers) + self._maxp = nn.AdaptiveMaxPool2d(1) + + channels = self._filters[-1] + dense_layers = [] + for n in self._fc_layers[:-1]: + dense_layers.append(DenseBlock(channels, n, activation=self._activation)) + channels = n + dense_layers.append(DenseBlock(channels, self._fc_layers[-1])) + self._fcs = nn.Sequential(*dense_layers) + + def forward(self, observations, low_dim_ins): + x = self._siamese_net(observations) + _, _, h, w = x.shape + low_dim_latents = low_dim_ins.unsqueeze(-1).unsqueeze(-1).repeat(1, 1, h, w) + combined = torch.cat([x, low_dim_latents], dim=1) + x = self._cnn(combined) + x = self._maxp(x).squeeze(-1).squeeze(-1) + return self._fcs(x) + + +class CNNLangAndFcsNet(nn.Module): + def __init__( + self, + siamese_net: SiameseNet, + low_dim_state_len: int, + input_resolution: List[int], + filters: List[int], + kernel_sizes: List[int], + strides: List[int], + norm: str = None, + fc_layers: List[int] = None, + activation: str = "relu", + ): + super(CNNLangAndFcsNet, self).__init__() + self._siamese_net = copy.deepcopy(siamese_net) + self._input_channels = self._siamese_net.output_channels + low_dim_state_len + self._filters = filters + self._kernel_sizes = kernel_sizes + self._strides = strides + self._norm = norm + self._activation = activation + self._fc_layers = [] if fc_layers is None else fc_layers + self._input_resolution = input_resolution + + self._lang_feat_dim = 1024 + + def build(self): + self._siamese_net.build() + layers = [] + channels = self._input_channels + + self.conv1 = Conv2DFiLMBlock( + channels, self._filters[0], self._kernel_sizes[0], self._strides[0] + ) + self.gamma1 = nn.Linear(self._lang_feat_dim, self._filters[0]) + self.beta1 = nn.Linear(self._lang_feat_dim, self._filters[0]) + + self.conv2 = Conv2DFiLMBlock( + self._filters[0], self._filters[1], self._kernel_sizes[1], self._strides[1] + ) + self.gamma2 = nn.Linear(self._lang_feat_dim, self._filters[1]) + self.beta2 = nn.Linear(self._lang_feat_dim, self._filters[1]) + + self.conv3 = Conv2DFiLMBlock( + self._filters[1], self._filters[2], self._kernel_sizes[2], self._strides[2] + ) + self.gamma3 = nn.Linear(self._lang_feat_dim, self._filters[2]) + self.beta3 = nn.Linear(self._lang_feat_dim, self._filters[2]) + + self._maxp = nn.AdaptiveMaxPool2d(1) + + channels = self._filters[-1] + dense_layers = [] + for n in self._fc_layers[:-1]: + dense_layers.append(DenseBlock(channels, n, activation=self._activation)) + channels = n + dense_layers.append(DenseBlock(channels, self._fc_layers[-1])) + self._fcs = nn.Sequential(*dense_layers) + + def forward(self, observations, low_dim_ins, lang_goal_emb): + x = self._siamese_net(observations) + _, _, h, w = x.shape + low_dim_latents = low_dim_ins.unsqueeze(-1).unsqueeze(-1).repeat(1, 1, h, w) + combined = torch.cat([x, low_dim_latents], dim=1) + + g1 = self.gamma1(lang_goal_emb) + b1 = self.beta1(lang_goal_emb) + x = self.conv1(combined, g1, b1) + + g2 = self.gamma2(lang_goal_emb) + b2 = self.beta2(lang_goal_emb) + x = self.conv2(x, g2, b2) + + g3 = self.gamma3(lang_goal_emb) + b3 = self.beta3(lang_goal_emb) + x = self.conv3(x, g3, b3) + + x = self._maxp(x).squeeze(-1).squeeze(-1) + return self._fcs(x) + + +# helpers + + +def pair(t): + return t if isinstance(t, tuple) else (t, t) + + +# classes + + +class PreNorm(nn.Module): + def __init__(self, dim, fn): + super().__init__() + self.norm = nn.LayerNorm(dim) + self.fn = fn + + def forward(self, x, **kwargs): + return self.fn(self.norm(x), **kwargs) + + +class FeedForward(nn.Module): + def __init__(self, dim, hidden_dim, dropout=0.0): + super().__init__() + self.net = nn.Sequential( + nn.Linear(dim, hidden_dim), + nn.GELU(), + nn.Dropout(dropout), + nn.Linear(hidden_dim, dim), + nn.Dropout(dropout), + ) + + def forward(self, x): + return self.net(x) + + +class Attention(nn.Module): + def __init__(self, dim, heads=8, dim_head=64, dropout=0.0): + super().__init__() + inner_dim = dim_head * heads + project_out = not (heads == 1 and dim_head == dim) + + self.heads = heads + self.scale = dim_head**-0.5 + + self.attend = nn.Softmax(dim=-1) + self.dropout = nn.Dropout(dropout) + + self.to_qkv = nn.Linear(dim, inner_dim * 3, bias=False) + + self.to_out = ( + nn.Sequential(nn.Linear(inner_dim, dim), nn.Dropout(dropout)) + if project_out + else nn.Identity() + ) + + def forward(self, x): + qkv = self.to_qkv(x).chunk(3, dim=-1) + q, k, v = map(lambda t: rearrange(t, "b n (h d) -> b h n d", h=self.heads), qkv) + + dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale + + attn = self.attend(dots) + attn = self.dropout(attn) + + out = torch.matmul(attn, v) + out = rearrange(out, "b h n d -> b n (h d)") + return self.to_out(out) + + +class Transformer(nn.Module): + def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout=0.0): + super().__init__() + self.layers = nn.ModuleList([]) + for _ in range(depth): + self.layers.append( + nn.ModuleList( + [ + PreNorm( + dim, + Attention( + dim, heads=heads, dim_head=dim_head, dropout=dropout + ), + ), + PreNorm(dim, FeedForward(dim, mlp_dim, dropout=dropout)), + ] + ) + ) + + def forward(self, x): + for attn, ff in self.layers: + x = attn(x) + x + x = ff(x) + x + return x + + +# ViT IO implementation adpated for baseline +# Source: https://github.com/lucidrains/vit-pytorch +# License: https://github.com/lucidrains/vit-pytorch/blob/main/LICENSE + + +class ViT(nn.Module): + def __init__( + self, + *, + image_size, + patch_size, + num_classes, + dim, + depth, + heads, + mlp_dim, + pool="cls", + channels=3, + dim_head=64, + dropout=0.0, + emb_dropout=0.0 + ): + super().__init__() + image_height, image_width = pair(image_size) + patch_height, patch_width = pair(patch_size) + + assert ( + image_height % patch_height == 0 and image_width % patch_width == 0 + ), "Image dimensions must be divisible by the patch size." + + self.num_patches_x = image_height // patch_height + self.num_patches_y = image_width // patch_width + self.num_patches = self.num_patches_x * self.num_patches_y + patch_dim = channels * patch_height * patch_width + assert pool in { + "cls", + "mean", + }, "pool type must be either cls (cls token) or mean (mean pooling)" + + self.to_patch_embedding = nn.Sequential( + Rearrange( + "b c (h p1) (w p2) -> b (h w) (p1 p2 c)", + p1=patch_height, + p2=patch_width, + ), + nn.Linear(patch_dim, dim), + ) + + self.pos_embedding = nn.Parameter(torch.randn(1, self.num_patches + 1, dim)) + self.cls_token = nn.Parameter(torch.randn(1, 1, dim)) + self.dropout = nn.Dropout(emb_dropout) + + self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, dropout) + + def forward(self, img): + x = self.to_patch_embedding(img) + b, n, _ = x.shape + + cls_tokens = repeat(self.cls_token, "1 1 d -> b 1 d", b=b) + x = torch.cat((cls_tokens, x), dim=1) + x += self.pos_embedding[:, : (n + 1)] + x = self.dropout(x) + + x = self.transformer(x) + x = x[:, 1:].reshape(b, -1, self.num_patches_x, self.num_patches_y) + + return x + + +class ViTLangAndFcsNet(nn.Module): + def __init__( + self, + vit: ViT, + low_dim_state_len: int, + input_resolution: List[int], + filters: List[int], + kernel_sizes: List[int], + strides: List[int], + norm: str = None, + fc_layers: List[int] = None, + activation: str = "relu", + ): + super(ViTLangAndFcsNet, self).__init__() + self._vit = copy.deepcopy(vit) + self._input_channels = 64 + low_dim_state_len + self._filters = filters + self._kernel_sizes = kernel_sizes + self._strides = strides + self._norm = norm + self._activation = activation + self._fc_layers = [] if fc_layers is None else fc_layers + self._input_resolution = input_resolution + + self._lang_feat_dim = 1024 + + def build(self): + layers = [] + channels = self._input_channels + + self.conv1 = Conv2DFiLMBlock( + channels, self._filters[0], self._kernel_sizes[0], self._strides[0] + ) + self.gamma1 = nn.Linear(self._lang_feat_dim, self._filters[0]) + self.beta1 = nn.Linear(self._lang_feat_dim, self._filters[0]) + + self.conv2 = Conv2DFiLMBlock( + self._filters[0], self._filters[1], self._kernel_sizes[1], self._strides[1] + ) + self.gamma2 = nn.Linear(self._lang_feat_dim, self._filters[1]) + self.beta2 = nn.Linear(self._lang_feat_dim, self._filters[1]) + + self.conv3 = Conv2DFiLMBlock( + self._filters[1], self._filters[2], self._kernel_sizes[2], self._strides[2] + ) + self.gamma3 = nn.Linear(self._lang_feat_dim, self._filters[2]) + self.beta3 = nn.Linear(self._lang_feat_dim, self._filters[2]) + + self._maxp = nn.AdaptiveMaxPool2d(1) + + channels = self._filters[-1] + dense_layers = [] + for n in self._fc_layers[:-1]: + dense_layers.append(DenseBlock(channels, n, activation=self._activation)) + channels = n + dense_layers.append(DenseBlock(channels, self._fc_layers[-1])) + self._fcs = nn.Sequential(*dense_layers) + + def forward(self, observations, low_dim_ins, lang_goal_emb): + rgb_depth = torch.cat([*observations], dim=1) + x = self._vit(rgb_depth) + _, _, h, w = x.shape + low_dim_latents = low_dim_ins.unsqueeze(-1).unsqueeze(-1).repeat(1, 1, h, w) + combined = torch.cat([x, low_dim_latents], dim=1) + + g1 = self.gamma1(lang_goal_emb) + b1 = self.beta1(lang_goal_emb) + x = self.conv1(combined, g1, b1) + + g2 = self.gamma2(lang_goal_emb) + b2 = self.beta2(lang_goal_emb) + x = self.conv2(x, g2, b2) + + g3 = self.gamma3(lang_goal_emb) + b3 = self.beta3(lang_goal_emb) + x = self.conv3(x, g3, b3) + + x = self._maxp(x).squeeze(-1).squeeze(-1) + return self._fcs(x) + + +class Conv3DInceptionBlockUpsampleBlock(nn.Module): + def __init__( + self, + in_channels, + out_channels, + scale_factor, + norm=None, + activation=None, + residual=False, + ): + super(Conv3DInceptionBlockUpsampleBlock, self).__init__() + layer = [] + + convt_block = Conv3DInceptionBlock(in_channels, out_channels, norm, activation) + layer.append(convt_block) + + if scale_factor > 1: + layer.append( + nn.Upsample( + scale_factor=scale_factor, mode="trilinear", align_corners=False + ) + ) + + convt_block = Conv3DInceptionBlock(out_channels, out_channels, norm, activation) + layer.append(convt_block) + + self.conv_up = nn.Sequential(*layer) + + def forward(self, x): + return self.conv_up(x) + + +class Conv3DInceptionBlock(nn.Module): + def __init__( + self, in_channels, out_channels, norm=None, activation=None, residual=False + ): + super(Conv3DInceptionBlock, self).__init__() + self._residual = residual + cs = out_channels // 4 + assert out_channels % 4 == 0 + latent = 32 + self._1x1conv = Conv3DBlock( + in_channels, + cs * 2, + kernel_sizes=1, + strides=1, + norm=norm, + activation=activation, + ) + + self._1x1conv_a = Conv3DBlock( + in_channels, + latent, + kernel_sizes=1, + strides=1, + norm=norm, + activation=activation, + ) + self._3x3conv = Conv3DBlock( + latent, cs, kernel_sizes=3, strides=1, norm=norm, activation=activation + ) + + self._1x1conv_b = Conv3DBlock( + in_channels, + latent, + kernel_sizes=1, + strides=1, + norm=norm, + activation=activation, + ) + self._5x5_via_3x3conv_a = Conv3DBlock( + latent, latent, kernel_sizes=3, strides=1, norm=norm, activation=activation + ) + self._5x5_via_3x3conv_b = Conv3DBlock( + latent, cs, kernel_sizes=3, strides=1, norm=norm, activation=activation + ) + self.out_channels = out_channels + (in_channels if residual else 0) + + def forward(self, x): + yy = [] + if self._residual: + yy = [x] + return torch.cat( + yy + + [ + self._1x1conv(x), + self._3x3conv(self._1x1conv_a(x)), + self._5x5_via_3x3conv_b(self._5x5_via_3x3conv_a(self._1x1conv_b(x))), + ], + 1, + ) + + +class ConvTransposeUp3DBlock(nn.Module): + def __init__( + self, + in_channels, + out_channels, + strides=2, + padding=0, + norm=None, + activation=None, + residual=False, + ): + super(ConvTransposeUp3DBlock, self).__init__() + self._residual = residual + + self._1x1conv = Conv3DBlock( + in_channels, + out_channels, + kernel_sizes=1, + strides=1, + norm=norm, + activation=activation, + ) + self._3x3conv = ConvTranspose3DBlock( + out_channels, + out_channels, + kernel_sizes=2, + strides=strides, + norm=norm, + activation=activation, + padding=padding, + ) + self._1x1conv_a = Conv3DBlock( + out_channels, + out_channels, + kernel_sizes=1, + strides=1, + norm=norm, + ) + self.out_channels = out_channels + + def forward(self, x): + x = self._1x1conv(x) + x = self._3x3conv(x) + x = self._1x1conv_a(x) + return x + + +class SpatialSoftmax3D(torch.nn.Module): + def __init__(self, depth, height, width, channel): + super(SpatialSoftmax3D, self).__init__() + self.depth = depth + self.height = height + self.width = width + self.channel = channel + self.temperature = 0.01 + pos_x, pos_y, pos_z = np.meshgrid( + np.linspace(-1.0, 1.0, self.depth), + np.linspace(-1.0, 1.0, self.height), + np.linspace(-1.0, 1.0, self.width), + ) + pos_x = torch.from_numpy( + pos_x.reshape(self.depth * self.height * self.width) + ).float() + pos_y = torch.from_numpy( + pos_y.reshape(self.depth * self.height * self.width) + ).float() + pos_z = torch.from_numpy( + pos_z.reshape(self.depth * self.height * self.width) + ).float() + self.register_buffer("pos_x", pos_x) + self.register_buffer("pos_y", pos_y) + self.register_buffer("pos_z", pos_z) + + def forward(self, feature): + feature = feature.view( + -1, self.height * self.width * self.depth + ) # (B, c*d*h*w) + softmax_attention = F.softmax(feature / self.temperature, dim=-1) + expected_x = torch.sum(self.pos_x * softmax_attention, dim=1, keepdim=True) + expected_y = torch.sum(self.pos_y * softmax_attention, dim=1, keepdim=True) + expected_z = torch.sum(self.pos_z * softmax_attention, dim=1, keepdim=True) + expected_xy = torch.cat([expected_x, expected_y, expected_z], 1) + feature_keypoints = expected_xy.view(-1, self.channel * 3) + return feature_keypoints diff --git a/external/peract_bimanual/helpers/observation_utils.py b/external/peract_bimanual/helpers/observation_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..ba01fda8d388706bad71aefd14363e142d825db7 --- /dev/null +++ b/external/peract_bimanual/helpers/observation_utils.py @@ -0,0 +1,254 @@ +import numpy as np +from rlbench.backend.observation import Observation + +from rlbench.backend.observation import BimanualObservation +from rlbench import CameraConfig, ObservationConfig +from pyrep.const import RenderMode +from typing import List + +REMOVE_KEYS = [ + "joint_velocities", + "joint_positions", + "joint_forces", + "gripper_open", + "gripper_pose", + "gripper_joint_positions", + "gripper_touch_forces", + "task_low_dim_state", + "misc", +] + + +def extract_obs( + obs: Observation, + cameras, + t: int = 0, + prev_action=None, + channels_last: bool = False, + episode_length: int = 10, + robot_name: str = "", +): + if obs.is_bimanual: + return extract_obs_bimanual( + obs, cameras, t, prev_action, channels_last, episode_length, robot_name + ) + else: + return extract_obs_unimanual( + obs, cameras, t, prev_action, channels_last, episode_length + ) + + +def extract_obs_unimanual( + obs: Observation, + cameras, + t: int = 0, + prev_action=None, + channels_last: bool = False, + episode_length: int = 10, +): + obs.joint_velocities = None + grip_mat = obs.gripper_matrix + grip_pose = obs.gripper_pose + joint_pos = obs.joint_positions + obs.gripper_pose = None + obs.gripper_matrix = None + obs.joint_positions = None + if obs.gripper_joint_positions is not None: + obs.gripper_joint_positions = np.clip(obs.gripper_joint_positions, 0.0, 0.04) + + obs_dict = vars(obs) + obs_dict = {k: v for k, v in obs_dict.items() if v is not None} + robot_state = obs.get_low_dim_data() + # remove low-level proprioception variables that are not needed + obs_dict = {k: v for k, v in obs_dict.items() if k not in REMOVE_KEYS} + + if not channels_last: + # swap channels from last dim to 1st dim + obs_dict = { + k: np.transpose(v, [2, 0, 1]) if v.ndim == 3 else np.expand_dims(v, 0) + for k, v in obs.perception_data.items() + if type(v) == np.ndarray or type(v) == list + } + else: + # add extra dim to depth data + obs_dict = { + k: v if v.ndim == 3 else np.expand_dims(v, -1) + for k, v in obs.perception_data.items() + } + obs_dict["low_dim_state"] = np.array(robot_state, dtype=np.float32) + + # binary variable indicating if collisions are allowed or not while planning paths to reach poses + obs_dict["ignore_collisions"] = np.array([obs.ignore_collisions], dtype=np.float32) + for k, v in [(k, v) for k, v in obs_dict.items() if "point_cloud" in k]: + obs_dict[k] = v.astype(np.float32) + + for camera_name in cameras: + obs_dict["%s_camera_extrinsics" % camera_name] = obs.misc[ + "%s_camera_extrinsics" % camera_name + ] + obs_dict["%s_camera_intrinsics" % camera_name] = obs.misc[ + "%s_camera_intrinsics" % camera_name + ] + + # add timestep to low_dim_state + time = (1.0 - (t / float(episode_length - 1))) * 2.0 - 1.0 + obs_dict["low_dim_state"] = np.concatenate( + [obs_dict["low_dim_state"], [time]] + ).astype(np.float32) + + obs.gripper_matrix = grip_mat + obs.joint_positions = joint_pos + obs.gripper_pose = grip_pose + + return obs_dict + + +def extract_obs_bimanual( + obs: Observation, + cameras, + t: int = 0, + prev_action=None, + channels_last: bool = False, + episode_length: int = 10, + robot_name: str = "", +): + obs.right.joint_velocities = None + right_grip_mat = obs.right.gripper_matrix + right_grip_pose = obs.right.gripper_pose + right_joint_pos = obs.right.joint_positions + obs.right.gripper_pose = None + obs.right.gripper_matrix = None + obs.right.joint_positions = None + + obs.left.joint_velocities = None + left_grip_mat = obs.left.gripper_matrix + left_grip_pose = obs.left.gripper_pose + left_joint_pos = obs.left.joint_positions + obs.left.gripper_pose = None + obs.left.gripper_matrix = None + obs.left.joint_positions = None + + if obs.right.gripper_joint_positions is not None: + obs.right.gripper_joint_positions = np.clip( + obs.right.gripper_joint_positions, 0.0, 0.04 + ) + obs.left.gripper_joint_positions = np.clip( + obs.left.gripper_joint_positions, 0.0, 0.04 + ) + + # fixme:: + obs_dict = vars(obs) + obs_dict = {k: v for k, v in obs_dict.items() if v is not None} + + right_robot_state = obs.get_low_dim_data(obs.right) + left_robot_state = obs.get_low_dim_data(obs.left) + + # remove low-level proprioception variables that are not needed + obs_dict = {k: v for k, v in obs_dict.items() if k not in REMOVE_KEYS} + + if not channels_last: + # swap channels from last dim to 1st dim + obs_dict = { + k: np.transpose(v, [2, 0, 1]) if v.ndim == 3 else np.expand_dims(v, 0) + for k, v in obs.perception_data.items() + if type(v) == np.ndarray or type(v) == list + } + else: + # add extra dim to depth data + obs_dict = { + k: v if v.ndim == 3 else np.expand_dims(v, -1) + for k, v in obs.perception_data.items() + } + + if robot_name == "right": + obs_dict["low_dim_state"] = right_robot_state.astype(np.float32) + # binary variable indicating if collisions are allowed or not while planning paths to reach poses + obs_dict["ignore_collisions"] = np.array( + [obs.right.ignore_collisions], dtype=np.float32 + ) + elif robot_name == "left": + obs_dict["low_dim_state"] = left_robot_state.astype(np.float32) + obs_dict["ignore_collisions"] = np.array( + [obs.left.ignore_collisions], dtype=np.float32 + ) + elif robot_name == "bimanual": + obs_dict["right_low_dim_state"] = right_robot_state.astype(np.float32) + obs_dict["left_low_dim_state"] = left_robot_state.astype(np.float32) + obs_dict["right_ignore_collisions"] = np.array( + [obs.right.ignore_collisions], dtype=np.float32 + ) + obs_dict["left_ignore_collisions"] = np.array( + [obs.left.ignore_collisions], dtype=np.float32 + ) + + for k, v in [(k, v) for k, v in obs_dict.items() if "point_cloud" in k]: + # ..TODO:: switch to np.float16 + obs_dict[k] = v.astype(np.float32) + + for camera_name in cameras: + obs_dict["%s_camera_extrinsics" % camera_name] = obs.misc[ + "%s_camera_extrinsics" % camera_name + ] + obs_dict["%s_camera_intrinsics" % camera_name] = obs.misc[ + "%s_camera_intrinsics" % camera_name + ] + + # add timestep to low_dim_state + time = (1.0 - (t / float(episode_length - 1))) * 2.0 - 1.0 + + if "low_dim_state" in obs_dict: + obs_dict["low_dim_state"] = np.concatenate( + [obs_dict["low_dim_state"], [time]] + ).astype(np.float32) + else: + obs_dict["right_low_dim_state"] = np.concatenate( + [obs_dict["right_low_dim_state"], [time]] + ).astype(np.float32) + obs_dict["left_low_dim_state"] = np.concatenate( + [obs_dict["left_low_dim_state"], [time]] + ).astype(np.float32) + + obs.right.gripper_matrix = right_grip_mat + obs.right.joint_positions = right_joint_pos + obs.right.gripper_pose = right_grip_pose + obs.left.gripper_matrix = left_grip_mat + obs.left.joint_positions = left_joint_pos + obs.left.gripper_pose = left_grip_pose + + return obs_dict + + +def create_obs_config( + camera_names: List[str], + camera_resolution: List[int], + method_name: str, + robot_name: str = "bimanual", +): + unused_cams = CameraConfig() + unused_cams.set_all(False) + used_cams = CameraConfig( + rgb=True, + point_cloud=True, + mask=False, + depth=False, + image_size=camera_resolution, + render_mode=RenderMode.OPENGL, + ) + + camera_configs = {camera_name: used_cams for camera_name in camera_names} + + # Some of these obs are only used for keypoint detection. + obs_config = ObservationConfig( + camera_configs=camera_configs, + joint_forces=False, + joint_positions=True, + joint_velocities=True, + task_low_dim_state=False, + gripper_touch_forces=False, + gripper_pose=True, + gripper_open=True, + gripper_matrix=True, + gripper_joint_positions=True, + robot_name=robot_name, + ) + return obs_config diff --git a/external/peract_bimanual/helpers/optim/__init__.py b/external/peract_bimanual/helpers/optim/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/external/peract_bimanual/helpers/optim/lamb.py b/external/peract_bimanual/helpers/optim/lamb.py new file mode 100644 index 0000000000000000000000000000000000000000..30cf6fc4a98c7b2fd5478a36fbb569e4fc505952 --- /dev/null +++ b/external/peract_bimanual/helpers/optim/lamb.py @@ -0,0 +1,129 @@ +"""Lamb optimizer.""" + +# LAMB optimizer used as is. +# Source: https://github.com/cybertronai/pytorch-lamb +# License: https://github.com/cybertronai/pytorch-lamb/blob/master/LICENSE + +import collections +import math + +import torch +from torch.optim import Optimizer + + +# def log_lamb_rs(optimizer: Optimizer, event_writer: SummaryWriter, token_count: int): +# """Log a histogram of trust ratio scalars in across layers.""" +# results = collections.defaultdict(list) +# for group in optimizer.param_groups: +# for p in group['params']: +# state = optimizer.state[p] +# for i in ('weight_norm', 'adam_norm', 'trust_ratio'): +# if i in state: +# results[i].append(state[i]) +# +# for k, v in results.items(): +# event_writer.add_histogram(f'lamb/{k}', torch.tensor(v), token_count) + + +class Lamb(Optimizer): + r"""Implements Lamb algorithm. + It has been proposed in `Large Batch Optimization for Deep Learning: Training BERT in 76 minutes`_. + Arguments: + params (iterable): iterable of parameters to optimize or dicts defining + parameter groups + lr (float, optional): learning rate (default: 1e-3) + betas (Tuple[float, float], optional): coefficients used for computing + running averages of gradient and its square (default: (0.9, 0.999)) + eps (float, optional): term added to the denominator to improve + numerical stability (default: 1e-8) + weight_decay (float, optional): weight decay (L2 penalty) (default: 0) + adam (bool, optional): always use trust ratio = 1, which turns this into + Adam. Useful for comparison purposes. + .. _Large Batch Optimization for Deep Learning: Training BERT in 76 minutes: + https://arxiv.org/abs/1904.00962 + """ + + def __init__( + self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-6, weight_decay=0, adam=False + ): + if not 0.0 <= lr: + raise ValueError("Invalid learning rate: {}".format(lr)) + if not 0.0 <= eps: + raise ValueError("Invalid epsilon value: {}".format(eps)) + if not 0.0 <= betas[0] < 1.0: + raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0])) + if not 0.0 <= betas[1] < 1.0: + raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1])) + defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay) + self.adam = adam + super(Lamb, self).__init__(params, defaults) + + def step(self, closure=None): + """Performs a single optimization step. + Arguments: + closure (callable, optional): A closure that reevaluates the model + and returns the loss. + """ + loss = None + if closure is not None: + loss = closure() + + for group in self.param_groups: + for p in group["params"]: + if p.grad is None: + continue + grad = p.grad.data + if grad.is_sparse: + raise RuntimeError( + "Lamb does not support sparse gradients, consider SparseAdam instad." + ) + + state = self.state[p] + + # State initialization + if len(state) == 0: + state["step"] = 0 + # Exponential moving average of gradient values + state["exp_avg"] = torch.zeros_like(p.data) + # Exponential moving average of squared gradient values + state["exp_avg_sq"] = torch.zeros_like(p.data) + + exp_avg, exp_avg_sq = state["exp_avg"], state["exp_avg_sq"] + beta1, beta2 = group["betas"] + + state["step"] += 1 + + # Decay the first and second moment running average coefficient + # m_t + exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1) + # v_t + exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2) + + # Paper v3 does not use debiasing. + # bias_correction1 = 1 - beta1 ** state['step'] + # bias_correction2 = 1 - beta2 ** state['step'] + # Apply bias to lr to avoid broadcast. + step_size = group[ + "lr" + ] # * math.sqrt(bias_correction2) / bias_correction1 + + weight_norm = p.data.pow(2).sum().sqrt().clamp(0, 10) + + adam_step = exp_avg / exp_avg_sq.sqrt().add(group["eps"]) + if group["weight_decay"] != 0: + adam_step.add_(p.data, alpha=group["weight_decay"]) + + adam_norm = adam_step.pow(2).sum().sqrt() + if weight_norm == 0 or adam_norm == 0: + trust_ratio = 1 + else: + trust_ratio = weight_norm / adam_norm + state["weight_norm"] = weight_norm + state["adam_norm"] = adam_norm + state["trust_ratio"] = trust_ratio + if self.adam: + trust_ratio = 1 + + p.data.add_(adam_step, alpha=-step_size * trust_ratio) + + return loss diff --git a/external/peract_bimanual/helpers/preprocess_agent.py b/external/peract_bimanual/helpers/preprocess_agent.py new file mode 100644 index 0000000000000000000000000000000000000000..031f198d5454556e0991553c7bb587c58903825a --- /dev/null +++ b/external/peract_bimanual/helpers/preprocess_agent.py @@ -0,0 +1,136 @@ +from typing import List + +import torch +import torchvision.transforms as transforms + +from yarr.agents.agent import ( + Agent, + Summary, + ActResult, + ScalarSummary, + HistogramSummary, + ImageSummary, +) + + +class PreprocessAgent(Agent): + def __init__( + self, pose_agent: Agent, norm_rgb: bool = True, norm_type: str = "zero_mean" + ): + self._pose_agent = pose_agent + self._norm_rgb = norm_rgb + self._norm_type = norm_type + + def build(self, training: bool, device: torch.device = None): + self._pose_agent.build(training, device) + + def _norm_rgb_(self, x): + if self._norm_type == "zero_mean": + return (x.float() / 255.0) * 2.0 - 1.0 + elif self._norm_type == "imagenet": + # normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], + # std=[0.229, 0.224, 0.225]) + # return normalize(x) + return x.float() / 255.0 + else: + raise NotImplementedError + + def update(self, step: int, replay_sample: dict) -> dict: + # Samples are (B, N, ...) where N is number of buffers/tasks. This is a single task setup, so 0 index. + replay_sample = { + k: v[:, 0] if len(v.shape) > 2 and v.shape[1] == 1 else v + for k, v in replay_sample.items() + } + for k, v in replay_sample.items(): + if self._norm_rgb and "rgb" in k: + replay_sample[k] = self._norm_rgb_(v) + else: + replay_sample[k] = v.float() + self._replay_sample = replay_sample + return self._pose_agent.update(step, replay_sample) + + def act(self, step: int, observation: dict, deterministic=False) -> ActResult: + # observation = {k: torch.tensor(v) for k, v in observation.items()} + for k, v in observation.items(): + if self._norm_rgb and "rgb" in k: + observation[k] = self._norm_rgb_(v) + else: + observation[k] = v.float() + act_res = self._pose_agent.act(step, observation, deterministic) + act_res.replay_elements.update({"demo": False}) + return act_res + + def update_summaries(self) -> List[Summary]: + prefix = "inputs" + demo_f = self._replay_sample["demo"].float() + demo_proportion = demo_f.mean() + tile = lambda x: torch.squeeze(torch.cat(x.split(1, dim=1), dim=-1), dim=1) + sums = [ + ScalarSummary("%s/demo_proportion" % prefix, demo_proportion), + ScalarSummary( + "%s/timeouts" % prefix, self._replay_sample["timeout"].float().mean() + ), + ] + + for robot_prefix in ["", "right_", "left_"]: + if not f"{robot_prefix}low_dim_state" in self._replay_sample.keys(): + continue + + sums.extend( + [ + HistogramSummary( + f"{prefix}/{robot_prefix}low_dim_state", + self._replay_sample[f"{robot_prefix}low_dim_state"], + ), + HistogramSummary( + f"{prefix}/{robot_prefix}low_dim_state_tp1", + self._replay_sample[f"{robot_prefix}low_dim_state_tp1"], + ), + ScalarSummary( + f"{prefix}/{robot_prefix}low_dim_state_mean", + self._replay_sample[f"{robot_prefix}low_dim_state"].mean(), + ), + ScalarSummary( + f"{prefix}/{robot_prefix}low_dim_state_min", + self._replay_sample[f"{robot_prefix}low_dim_state"].min(), + ), + ScalarSummary( + f"{prefix}/{robot_prefix}low_dim_state_max", + self._replay_sample[f"{robot_prefix}low_dim_state"].max(), + ), + ] + ) + + for k, v in self._replay_sample.items(): + if "rgb" in k or "point_cloud" in k: + if "rgb" in k: + # Convert back to 0 - 1 + v = (v + 1.0) / 2.0 + sums.append( + ImageSummary( + "%s/%s" % (prefix, k), tile(v) if len(v.shape) > 4 else v + ) + ) + + if "sampling_probabilities" in self._replay_sample: + sums.extend( + [ + HistogramSummary( + "replay/priority", self._replay_sample["sampling_probabilities"] + ), + ] + ) + sums.extend(self._pose_agent.update_summaries()) + return sums + + def act_summaries(self) -> List[Summary]: + return self._pose_agent.act_summaries() + + def load_weights(self, savedir: str): + self._pose_agent.load_weights(savedir) + + def save_weights(self, savedir: str): + self._pose_agent.save_weights(savedir) + + def reset(self) -> None: + self._pose_agent.reset() diff --git a/external/peract_bimanual/helpers/utils.py b/external/peract_bimanual/helpers/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..1c786f7766eff42287bbe73bf1bfe34a57c92b62 --- /dev/null +++ b/external/peract_bimanual/helpers/utils.py @@ -0,0 +1,346 @@ +import numpy as np +import pyrender +import torch +import trimesh +from pyrender.trackball import Trackball +from rlbench.backend.const import DEPTH_SCALE +from scipy.spatial.transform import Rotation +from rlbench.backend.observation import Observation +from rlbench import CameraConfig, ObservationConfig +from pyrep.const import RenderMode +from typing import List + + +SCALE_FACTOR = DEPTH_SCALE +DEFAULT_SCENE_SCALE = 2.0 + + +def loss_weights(replay_sample, beta=1.0): + loss_weights = 1.0 + if "sampling_probabilities" in replay_sample: + probs = replay_sample["sampling_probabilities"] + loss_weights = 1.0 / torch.sqrt(probs + 1e-10) + loss_weights = (loss_weights / torch.max(loss_weights)) ** beta + return loss_weights + + +def soft_updates(net, target_net, tau): + for param, target_param in zip(net.parameters(), target_net.parameters()): + target_param.data.copy_(tau * param.data + (1 - tau) * target_param.data) + + +def stack_on_channel(x): + # expect (B, T, C, ...) + return torch.cat(torch.split(x, 1, dim=1), dim=2).squeeze(1) + + +def normalize_quaternion(quat): + return np.array(quat) / np.linalg.norm(quat, axis=-1, keepdims=True) + + +def correct_rotation_instability(disc, resolution): + # q1 = discrete_euler_to_quaternion(disc, resolution) + # q2 = discrete_euler_to_quaternion(quaternion_to_discrete_euler(q1, resolution), resolution) + # + # d2 = quaternion_to_discrete_euler(q2, resolution) + # + # # choose the smallest change + # if np.any(disc != d2): + # if np.sum(disc) < np.sum(d2): + # return disc + # else: + # return d2 + return disc + + +def check_gimbal_lock(pred_rot_and_grip, gt_rot_and_grip, resolution): + pred_rot_and_grip_np = pred_rot_and_grip.detach().cpu().numpy() + gt_rot_and_grip_np = gt_rot_and_grip.detach().cpu().numpy() + + pred_rot = discrete_euler_to_quaternion(pred_rot_and_grip_np[:, :3], resolution) + gt_rot = discrete_euler_to_quaternion(gt_rot_and_grip_np[:, :3], resolution) + gimbal_lock_matches = [ + np.all(np.abs(pred_rot[i] - gt_rot[i]) < 1e-10) + and np.any(pred_rot_and_grip_np[i, :3] != gt_rot_and_grip_np[i, :3]) + for i in range(pred_rot.shape[0]) + ] + return 0 + + +def quaternion_to_discrete_euler(quaternion, resolution): + euler = Rotation.from_quat(quaternion).as_euler("xyz", degrees=True) + 180 + assert np.min(euler) >= 0 and np.max(euler) <= 360 + disc = np.around((euler / resolution)).astype(int) + disc[disc == int(360 / resolution)] = 0 + return disc + + +def discrete_euler_to_quaternion(discrete_euler, resolution): + euluer = (discrete_euler * resolution) - 180 + return Rotation.from_euler("xyz", euluer, degrees=True).as_quat() + + +def point_to_voxel_index( + point: np.ndarray, voxel_size: np.ndarray, coord_bounds: np.ndarray +): + bb_mins = np.array(coord_bounds[0:3]) + bb_maxs = np.array(coord_bounds[3:]) + dims_m_one = np.array([voxel_size] * 3) - 1 + bb_ranges = bb_maxs - bb_mins + res = bb_ranges / (np.array([voxel_size] * 3) + 1e-12) + voxel_indicy = np.minimum( + np.floor((point - bb_mins) / (res + 1e-12)).astype(np.int32), dims_m_one + ) + return voxel_indicy + + +def voxel_index_to_point( + voxel_index: torch.Tensor, voxel_size: int, coord_bounds: np.ndarray +): + res = (coord_bounds[:, 3:] - coord_bounds[:, :3]) / voxel_size + points = (voxel_index * res) + coord_bounds[:, :3] + return points + + +def point_to_pixel_index( + point: np.ndarray, extrinsics: np.ndarray, intrinsics: np.ndarray +): + point = np.array([point[0], point[1], point[2], 1]) + world_to_cam = np.linalg.inv(extrinsics) + point_in_cam_frame = world_to_cam.dot(point) + px, py, pz = point_in_cam_frame[:3] + px = 2 * intrinsics[0, 2] - int(-intrinsics[0, 0] * (px / pz) + intrinsics[0, 2]) + py = 2 * intrinsics[1, 2] - int(-intrinsics[1, 1] * (py / pz) + intrinsics[1, 2]) + return px, py + + +def _compute_initial_camera_pose(scene): + # Adapted from: + # https://github.com/mmatl/pyrender/blob/master/pyrender/viewer.py#L1032 + centroid = scene.centroid + scale = scene.scale + if scale == 0.0: + scale = DEFAULT_SCENE_SCALE + s2 = 1.0 / np.sqrt(2.0) + cp = np.eye(4) + cp[:3, :3] = np.array([[0.0, -s2, s2], [1.0, 0.0, 0.0], [0.0, s2, s2]]) + hfov = np.pi / 6.0 + dist = scale / (2.0 * np.tan(hfov)) + cp[:3, 3] = dist * np.array([1.0, 0.0, 1.0]) + centroid + return cp + + +def _from_trimesh_scene(trimesh_scene, bg_color=None, ambient_light=None): + # convert trimesh geometries to pyrender geometries + geometries = { + name: pyrender.Mesh.from_trimesh(geom, smooth=False) + for name, geom in trimesh_scene.geometry.items() + } + # create the pyrender scene object + scene_pr = pyrender.Scene(bg_color=bg_color, ambient_light=ambient_light) + # add every node with geometry to the pyrender scene + for node in trimesh_scene.graph.nodes_geometry: + pose, geom_name = trimesh_scene.graph[node] + scene_pr.add(geometries[geom_name], pose=pose) + return scene_pr + + +def _create_bounding_box(scene, voxel_size, res): + l = voxel_size * res + T = np.eye(4) + w = 0.01 + for trans in [[0, 0, l / 2], [0, l, l / 2], [l, l, l / 2], [l, 0, l / 2]]: + T[:3, 3] = np.array(trans) - voxel_size / 2 + scene.add_geometry( + trimesh.creation.box([w, w, l], T, face_colors=[0, 0, 0, 255]) + ) + for trans in [[l / 2, 0, 0], [l / 2, 0, l], [l / 2, l, 0], [l / 2, l, l]]: + T[:3, 3] = np.array(trans) - voxel_size / 2 + scene.add_geometry( + trimesh.creation.box([l, w, w], T, face_colors=[0, 0, 0, 255]) + ) + for trans in [[0, l / 2, 0], [0, l / 2, l], [l, l / 2, 0], [l, l / 2, l]]: + T[:3, 3] = np.array(trans) - voxel_size / 2 + scene.add_geometry( + trimesh.creation.box([w, l, w], T, face_colors=[0, 0, 0, 255]) + ) + + +def create_voxel_scene( + voxel_grid: np.ndarray, + q_attention: np.ndarray = None, + highlight_coordinate: np.ndarray = None, + highlight_gt_coordinate: np.ndarray = None, + highlight_alpha: float = 1.0, + voxel_size: float = 0.1, + show_bb: bool = False, + alpha: float = 0.5, +): + _, d, h, w = voxel_grid.shape + v = voxel_grid.transpose((1, 2, 3, 0)) + occupancy = v[:, :, :, -1] != 0 + alpha = np.expand_dims(np.full_like(occupancy, alpha, dtype=np.float32), -1) + rgb = np.concatenate([(v[:, :, :, 3:6] + 1) / 2.0, alpha], axis=-1) + + if q_attention is not None: + q = np.max(q_attention, 0) + q = q / np.max(q) + show_q = q > 0.75 + occupancy = (show_q + occupancy).astype(bool) + q = np.expand_dims(q - 0.5, -1) # Max q can be is 0.9 + q_rgb = np.concatenate( + [q, np.zeros_like(q), np.zeros_like(q), np.clip(q, 0, 1)], axis=-1 + ) + rgb = np.where(np.expand_dims(show_q, -1), q_rgb, rgb) + + if highlight_coordinate is not None: + x, y, z = highlight_coordinate + occupancy[x, y, z] = True + rgb[x, y, z] = [1.0, 0.0, 0.0, highlight_alpha] + + if highlight_gt_coordinate is not None: + x, y, z = highlight_gt_coordinate + occupancy[x, y, z] = True + rgb[x, y, z] = [0.0, 0.0, 1.0, highlight_alpha] + + transform = trimesh.transformations.scale_and_translate( + scale=voxel_size, translate=(0.0, 0.0, 0.0) + ) + trimesh_voxel_grid = trimesh.voxel.VoxelGrid( + encoding=occupancy, transform=transform + ) + geometry = trimesh_voxel_grid.as_boxes(colors=rgb) + scene = trimesh.Scene() + scene.add_geometry(geometry) + if show_bb: + assert d == h == w + _create_bounding_box(scene, voxel_size, d) + return scene + + +def visualise_voxel( + voxel_grid: np.ndarray, + q_attention: np.ndarray = None, + highlight_coordinate: np.ndarray = None, + highlight_gt_coordinate: np.ndarray = None, + highlight_alpha: float = 1.0, + rotation_amount: float = 0.0, + show: bool = False, + voxel_size: float = 0.1, + offscreen_renderer: pyrender.OffscreenRenderer = None, + show_bb: bool = False, + alpha: float = 0.5, +): + scene = create_voxel_scene( + voxel_grid, + q_attention, + highlight_coordinate, + highlight_gt_coordinate, + highlight_alpha, + voxel_size, + show_bb, + alpha, + ) + if show: + scene.show() + else: + r = offscreen_renderer or pyrender.OffscreenRenderer( + viewport_width=640, viewport_height=480, point_size=1.0 + ) + s = _from_trimesh_scene( + scene, ambient_light=[0.8, 0.8, 0.8], bg_color=[1.0, 1.0, 1.0] + ) + cam = pyrender.PerspectiveCamera( + yfov=np.pi / 4.0, aspectRatio=r.viewport_width / r.viewport_height + ) + p = _compute_initial_camera_pose(s) + t = Trackball(p, (r.viewport_width, r.viewport_height), s.scale, s.centroid) + t.rotate(rotation_amount, np.array([0.0, 0.0, 1.0])) + s.add(cam, pose=t.pose) + color, depth = r.render(s) + return color.copy() + + +def preprocess(img, dist="transporter"): + """Pre-process input (subtract mean, divide by std).""" + + transporter_color_mean = [0.18877631, 0.18877631, 0.18877631] + transporter_color_std = [0.07276466, 0.07276466, 0.07276466] + transporter_depth_mean = 0.00509261 + transporter_depth_std = 0.00903967 + + franka_color_mean = [0.622291933, 0.628313992, 0.623031488] + franka_color_std = [0.168154213, 0.17626014, 0.184527364] + franka_depth_mean = 0.872146842 + franka_depth_std = 0.195743116 + + clip_color_mean = [0.48145466, 0.4578275, 0.40821073] + clip_color_std = [0.26862954, 0.26130258, 0.27577711] + + # choose distribution + if dist == "clip": + color_mean = clip_color_mean + color_std = clip_color_std + elif dist == "franka": + color_mean = franka_color_mean + color_std = franka_color_std + else: + color_mean = transporter_color_mean + color_std = transporter_color_std + + if dist == "franka": + depth_mean = franka_depth_mean + depth_std = franka_depth_std + else: + depth_mean = transporter_depth_mean + depth_std = transporter_depth_std + + # convert to pytorch tensor (if required) + if type(img) == torch.Tensor: + + def cast_shape(stat, img): + tensor = torch.from_numpy(np.array(stat)).to( + device=img.device, dtype=img.dtype + ) + tensor = tensor.unsqueeze(0).unsqueeze(-1).unsqueeze(-1) + tensor = tensor.repeat(img.shape[0], 1, img.shape[-2], img.shape[-1]) + return tensor + + color_mean = cast_shape(color_mean, img) + color_std = cast_shape(color_std, img) + depth_mean = cast_shape(depth_mean, img) + depth_std = cast_shape(depth_std, img) + + # normalize + img = img.clone() + img[:, :3, :, :] = (img[:, :3, :, :] / 255 - color_mean) / color_std + img[:, 3:, :, :] = (img[:, 3:, :, :] - depth_mean) / depth_std + else: + # normalize + img[:, :, :3] = (img[:, :, :3] / 255 - color_mean) / color_std + img[:, :, 3:] = (img[:, :, 3:] - depth_mean) / depth_std + return img + + +def rand_dist(size, min=-1.0, max=1.0): + return (max - min) * torch.rand(size) + min + + +def rand_discrete(size, min=0, max=1): + if min == max: + return torch.zeros(size) + return torch.randint(min, max + 1, size) + + +def split_list(lst, n): + for i in range(0, len(lst), n): + yield lst[i : i + n] + + +def get_device(gpu): + if gpu is not None and gpu >= 0 and torch.cuda.is_available(): + device = torch.device("cuda:%d" % gpu) + torch.backends.cudnn.enabled = torch.backends.cudnn.benchmark = True + else: + device = torch.device("cpu") + return device diff --git a/external/peract_bimanual/scripts/install_conda.sh b/external/peract_bimanual/scripts/install_conda.sh new file mode 100644 index 0000000000000000000000000000000000000000..db474f2f7839248ef593261d0aa65899620b6f32 --- /dev/null +++ b/external/peract_bimanual/scripts/install_conda.sh @@ -0,0 +1,22 @@ +#!/bin/bash -exu + +# install conda + +sudo apt install curl + + +TEMP_DIR=$(mktemp --tmpdir -d miniconda_XXXXXXXXXX) +cd $TEMP_DIR + +curl -L -O https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh +chmod +x Miniconda3-latest-Linux-x86_64.sh +./Miniconda3-latest-Linux-x86_64.sh + + +SHELL_NAME=`basename $SHELL` +eval "$($HOME/miniconda3/bin/conda shell.${SHELL_NAME} hook)" + +conda init ${SHELL_NAME} +conda install mamba -c conda-forge +conda config --set auto_activate_base false + diff --git a/external/peract_bimanual/scripts/install_dependencies.sh b/external/peract_bimanual/scripts/install_dependencies.sh new file mode 100644 index 0000000000000000000000000000000000000000..964dceb5a874f46485277708c69c46867742bae8 --- /dev/null +++ b/external/peract_bimanual/scripts/install_dependencies.sh @@ -0,0 +1,87 @@ +#!/bin/bash + + +# edit this line if you want to install the dependencies to another directory + +WORKSPACE_DIR=${HOME}/code +ENVIRONMENT_NAME=rlbench + +basedir=$(dirname $0) +basedir=$(readlink -f $basedir) + + +if ! [ -x "$(command -v curl)" ]; then + echo "Unable to find curl. installing." + sudo apt install curl +fi + +if ! [ -x "$(command -v git)" ]; then + echo "Unable to find git. installing." + sudo apt install git +fi + +if ! [ -x "$(command -v conda)" ]; then + echo "Unable to find conda" + exit 1 +fi + +conda create -n ${ENVIRONMENT_NAME} python=3.8 +mamba install -n ${ENVIRONMENT_NAME} pytorch torchvision torchaudio pytorch-cuda=12.1 -c pytorch -c nvidia + + +export COPPELIASIM_ROOT=${WORKSPACE_DIR}/coppelia_sim +mkdir -p $COPPELIASIM_ROOT + +TEMP_DIR=$(mktemp --tmpdir -d coppelia_XXXXXXXXXX) +cd $TEMP_DIR + +curl -L -O https://www.coppeliarobotics.com/files/V4_1_0/CoppeliaSim_Edu_V4_1_0_Ubuntu20_04.tar.xz +tar -xvf CoppeliaSim_Edu_V4_1_0_Ubuntu20_04.tar.xz -C $COPPELIASIM_ROOT --strip-components 1 +rm -rf CoppeliaSim_Edu_V4_1_0_Ubuntu20_04.tar.xz + +CONDA_PREFIX=$(conda info --envs | grep -e "^${ENVIRONMENT_NAME}\ " | awk '{print $2}') +mkdir -p ${CONDA_PREFIX}/etc/conda/activate.d/ +cat > ${CONDA_PREFIX}/etc/conda/activate.d/coppelia_sim.sh < 100: + raise Exception("Failing to perturb action and keep it within bounds.") + + # sample translation perturbation with specified range + trans_range = (bounds[:, 3:] - bounds[:, :3]) * trans_aug_range.to( + device=device + ) + trans_shift = trans_range * utils.rand_dist((bs, 3)).to(device=device) + trans_shift_4x4 = identity_4x4.detach().clone() + trans_shift_4x4[:, 0:3, 3] = trans_shift + + # sample rotation perturbation at specified resolution and range + roll_aug_steps = int(rot_aug_range[0] // rot_aug_resolution) + pitch_aug_steps = int(rot_aug_range[1] // rot_aug_resolution) + yaw_aug_steps = int(rot_aug_range[2] // rot_aug_resolution) + + roll = utils.rand_discrete( + (bs, 1), min=-roll_aug_steps, max=roll_aug_steps + ) * np.deg2rad(rot_aug_resolution) + pitch = utils.rand_discrete( + (bs, 1), min=-pitch_aug_steps, max=pitch_aug_steps + ) * np.deg2rad(rot_aug_resolution) + yaw = utils.rand_discrete( + (bs, 1), min=-yaw_aug_steps, max=yaw_aug_steps + ) * np.deg2rad(rot_aug_resolution) + rot_shift_3x3 = torch3d_tf.euler_angles_to_matrix( + torch.cat((roll, pitch, yaw), dim=1), "XYZ" + ) + rot_shift_4x4 = identity_4x4.detach().clone() + rot_shift_4x4[:, :3, :3] = rot_shift_3x3 + + # rotate then translate the 4x4 keyframe action + right_perturbed_action_gripper_4x4 = torch.bmm( + right_action_gripper_4x4, rot_shift_4x4 + ) + right_perturbed_action_gripper_4x4[:, 0:3, 3] += trans_shift + + # convert transformation matrix to translation + quaternion + right_perturbed_action_trans = ( + right_perturbed_action_gripper_4x4[:, 0:3, 3].cpu().numpy() + ) + right_perturbed_action_quat_wxyz = torch3d_tf.matrix_to_quaternion( + right_perturbed_action_gripper_4x4[:, :3, :3] + ) + right_perturbed_action_quat_xyzw = ( + torch.cat( + [ + right_perturbed_action_quat_wxyz[:, 1:], + right_perturbed_action_quat_wxyz[:, 0].unsqueeze(1), + ], + dim=1, + ) + .cpu() + .numpy() + ) + + # rotate then translate the 4x4 keyframe action + left_perturbed_action_gripper_4x4 = torch.bmm( + left_action_gripper_4x4, rot_shift_4x4 + ) + left_perturbed_action_gripper_4x4[:, 0:3, 3] += trans_shift + + # convert transformation matrix to translation + quaternion + left_perturbed_action_trans = ( + left_perturbed_action_gripper_4x4[:, 0:3, 3].cpu().numpy() + ) + left_perturbed_action_quat_wxyz = torch3d_tf.matrix_to_quaternion( + left_perturbed_action_gripper_4x4[:, :3, :3] + ) + left_perturbed_action_quat_xyzw = ( + torch.cat( + [ + left_perturbed_action_quat_wxyz[:, 1:], + left_perturbed_action_quat_wxyz[:, 0].unsqueeze(1), + ], + dim=1, + ) + .cpu() + .numpy() + ) + + # discretize perturbed translation and rotation + # TODO(mohit): do this in torch without any numpy. + right_trans_indicies, right_rot_grip_indicies = [], [] + left_trans_indicies, left_rot_grip_indicies = [], [] + for b in range(bs): + bounds_idx = b if layer > 0 else 0 + bounds_np = bounds[bounds_idx].cpu().numpy() + + right_trans_idx = utils.point_to_voxel_index( + right_perturbed_action_trans[b], voxel_size, bounds_np + ) + right_trans_indicies.append(right_trans_idx.tolist()) + + right_quat = right_perturbed_action_quat_xyzw[b] + right_quat = utils.normalize_quaternion(right_perturbed_action_quat_xyzw[b]) + if right_quat[-1] < 0: + right_quat = -right_quat + right_disc_rot = utils.quaternion_to_discrete_euler( + right_quat, rot_resolution + ) + right_rot_grip_indicies.append( + right_disc_rot.tolist() + + [int(right_action_rot_grip[b, 3].cpu().numpy())] + ) + + left_trans_idx = utils.point_to_voxel_index( + left_perturbed_action_trans[b], voxel_size, bounds_np + ) + left_trans_indicies.append(left_trans_idx.tolist()) + + left_quat = left_perturbed_action_quat_xyzw[b] + left_quat = utils.normalize_quaternion(left_perturbed_action_quat_xyzw[b]) + if left_quat[-1] < 0: + left_quat = -left_quat + left_disc_rot = utils.quaternion_to_discrete_euler( + left_quat, rot_resolution + ) + left_rot_grip_indicies.append( + left_disc_rot.tolist() + [int(left_action_rot_grip[b, 3].cpu().numpy())] + ) + + # if the perturbed action is out of bounds, + # the discretized perturb_trans should have invalid indicies + right_perturbed_trans = torch.from_numpy(np.array(right_trans_indicies)).to( + device=device + ) + right_perturbed_rot_grip = torch.from_numpy( + np.array(right_rot_grip_indicies) + ).to(device=device) + + left_perturbed_trans = torch.from_numpy(np.array(left_trans_indicies)).to( + device=device + ) + left_perturbed_rot_grip = torch.from_numpy(np.array(left_rot_grip_indicies)).to( + device=device + ) + + right_action_trans = right_perturbed_trans + right_action_rot_grip = right_perturbed_rot_grip + + left_action_trans = left_perturbed_trans + left_action_rot_grip = left_perturbed_rot_grip + + # apply perturbation to pointclouds + # pcd = bimanual_perturb_se3(pcd, trans_shift_4x4, rot_shift_4x4, right_action_gripper_4x4, left_action_gripper_4x4, bounds) + + pcd = perturb_se3( + pcd, trans_shift_4x4, rot_shift_4x4, right_action_gripper_4x4, bounds + ) + + return ( + right_action_trans, + right_action_rot_grip, + left_action_trans, + left_action_rot_grip, + pcd, + ) + + +def apply_se3_augmentation( + pcd, + action_gripper_pose, + action_trans, + action_rot_grip, + bounds, + layer, + trans_aug_range, + rot_aug_range, + rot_aug_resolution, + voxel_size, + rot_resolution, + device, +): + """Apply SE3 augmentation to a point clouds and actions. + :param pcd: list of point clouds [[bs, 3, H, W], ...] for N cameras + :param action_gripper_pose: 6-DoF pose of keyframe action [bs, 7] + :param action_trans: discretized translation action [bs, 3] + :param action_rot_grip: discretized rotation and gripper action [bs, 4] + :param bounds: metric scene bounds of voxel grid [bs, 6] + :param layer: voxelization layer (always 1 for PerAct) + :param trans_aug_range: range of translation augmentation [x_range, y_range, z_range] + :param rot_aug_range: range of rotation augmentation [x_range, y_range, z_range] + :param rot_aug_resolution: degree increments for discretized augmentation rotations + :param voxel_size: voxelization resoltion + :param rot_resolution: degree increments for discretized rotations + :param device: torch device + :return: perturbed action_trans, action_rot_grip, pcd + """ + + # batch size + bs = pcd[0].shape[0] + + # identity matrix + identity_4x4 = torch.eye(4).unsqueeze(0).repeat(bs, 1, 1).to(device=device) + + # 4x4 matrix of keyframe action gripper pose + action_gripper_trans = action_gripper_pose[:, :3] + action_gripper_quat_wxyz = torch.cat( + (action_gripper_pose[:, 6].unsqueeze(1), action_gripper_pose[:, 3:6]), dim=1 + ) + action_gripper_rot = torch3d_tf.quaternion_to_matrix(action_gripper_quat_wxyz) + action_gripper_4x4 = identity_4x4.detach().clone() + action_gripper_4x4[:, :3, :3] = action_gripper_rot + action_gripper_4x4[:, 0:3, 3] = action_gripper_trans + + perturbed_trans = torch.full_like(action_trans, -1.0) + perturbed_rot_grip = torch.full_like(action_rot_grip, -1.0) + + # perturb the action, check if it is within bounds, if not, try another perturbation + perturb_attempts = 0 + while torch.any(perturbed_trans < 0): + # might take some repeated attempts to find a perturbation that doesn't go out of bounds + perturb_attempts += 1 + if perturb_attempts > 100: + raise Exception("Failing to perturb action and keep it within bounds.") + + # sample translation perturbation with specified range + trans_range = (bounds[:, 3:] - bounds[:, :3]) * trans_aug_range.to( + device=device + ) + trans_shift = trans_range * utils.rand_dist((bs, 3)).to(device=device) + trans_shift_4x4 = identity_4x4.detach().clone() + trans_shift_4x4[:, 0:3, 3] = trans_shift + + # sample rotation perturbation at specified resolution and range + roll_aug_steps = int(rot_aug_range[0] // rot_aug_resolution) + pitch_aug_steps = int(rot_aug_range[1] // rot_aug_resolution) + yaw_aug_steps = int(rot_aug_range[2] // rot_aug_resolution) + + roll = utils.rand_discrete( + (bs, 1), min=-roll_aug_steps, max=roll_aug_steps + ) * np.deg2rad(rot_aug_resolution) + pitch = utils.rand_discrete( + (bs, 1), min=-pitch_aug_steps, max=pitch_aug_steps + ) * np.deg2rad(rot_aug_resolution) + yaw = utils.rand_discrete( + (bs, 1), min=-yaw_aug_steps, max=yaw_aug_steps + ) * np.deg2rad(rot_aug_resolution) + rot_shift_3x3 = torch3d_tf.euler_angles_to_matrix( + torch.cat((roll, pitch, yaw), dim=1), "XYZ" + ) + rot_shift_4x4 = identity_4x4.detach().clone() + rot_shift_4x4[:, :3, :3] = rot_shift_3x3 + + # rotate then translate the 4x4 keyframe action + perturbed_action_gripper_4x4 = torch.bmm(action_gripper_4x4, rot_shift_4x4) + perturbed_action_gripper_4x4[:, 0:3, 3] += trans_shift + + # convert transformation matrix to translation + quaternion + perturbed_action_trans = perturbed_action_gripper_4x4[:, 0:3, 3].cpu().numpy() + perturbed_action_quat_wxyz = torch3d_tf.matrix_to_quaternion( + perturbed_action_gripper_4x4[:, :3, :3] + ) + perturbed_action_quat_xyzw = ( + torch.cat( + [ + perturbed_action_quat_wxyz[:, 1:], + perturbed_action_quat_wxyz[:, 0].unsqueeze(1), + ], + dim=1, + ) + .cpu() + .numpy() + ) + + # discretize perturbed translation and rotation + # TODO(mohit): do this in torch without any numpy. + trans_indicies, rot_grip_indicies = [], [] + for b in range(bs): + bounds_idx = b if layer > 0 else 0 + bounds_np = bounds[bounds_idx].cpu().numpy() + + trans_idx = utils.point_to_voxel_index( + perturbed_action_trans[b], voxel_size, bounds_np + ) + trans_indicies.append(trans_idx.tolist()) + + quat = perturbed_action_quat_xyzw[b] + quat = utils.normalize_quaternion(perturbed_action_quat_xyzw[b]) + if quat[-1] < 0: + quat = -quat + disc_rot = utils.quaternion_to_discrete_euler(quat, rot_resolution) + rot_grip_indicies.append( + disc_rot.tolist() + [int(action_rot_grip[b, 3].cpu().numpy())] + ) + + # if the perturbed action is out of bounds, + # the discretized perturb_trans should have invalid indicies + perturbed_trans = torch.from_numpy(np.array(trans_indicies)).to(device=device) + perturbed_rot_grip = torch.from_numpy(np.array(rot_grip_indicies)).to( + device=device + ) + + action_trans = perturbed_trans + action_rot_grip = perturbed_rot_grip + + # apply perturbation to pointclouds + pcd = perturb_se3(pcd, trans_shift_4x4, rot_shift_4x4, action_gripper_4x4, bounds) + + return action_trans, action_rot_grip, pcd