Add files using upload-large-folder tool
Browse filesThis view is limited to 50 files because it contains too many changes. See raw diff
- external/peract_bimanual/.gitignore +160 -0
- external/peract_bimanual/ARM_LICENSE +196 -0
- external/peract_bimanual/Dockerfile +68 -0
- external/peract_bimanual/INSTALLATION.md +87 -0
- external/peract_bimanual/agents/__init__.py +0 -0
- external/peract_bimanual/agents/act_bc_lang/__init__.py +1 -0
- external/peract_bimanual/agents/act_bc_lang/act_bc_lang_agent.py +381 -0
- external/peract_bimanual/agents/act_bc_lang/act_policy.py +135 -0
- external/peract_bimanual/agents/act_bc_lang/detr/__init__.py +0 -0
- external/peract_bimanual/agents/act_bc_lang/detr/build.py +41 -0
- external/peract_bimanual/agents/act_bc_lang/detr/util/__init__.py +1 -0
- external/peract_bimanual/agents/act_bc_lang/launch_utils.py +456 -0
- external/peract_bimanual/agents/agent_factory.py +111 -0
- external/peract_bimanual/agents/arm/launch_utils.py +441 -0
- external/peract_bimanual/agents/arm/next_best_pose_agent.py +526 -0
- external/peract_bimanual/agents/arm/qattention_agent.py +247 -0
- external/peract_bimanual/agents/baselines/__init__.py +0 -0
- external/peract_bimanual/agents/baselines/bc_lang/__init__.py +1 -0
- external/peract_bimanual/agents/baselines/bc_lang/bc_lang_agent.py +148 -0
- external/peract_bimanual/agents/baselines/bc_lang/launch_utils.py +368 -0
- external/peract_bimanual/agents/baselines/vit_bc_lang/__init__.py +1 -0
- external/peract_bimanual/agents/baselines/vit_bc_lang/launch_utils.py +372 -0
- external/peract_bimanual/agents/baselines/vit_bc_lang/vit_bc_lang_agent.py +148 -0
- external/peract_bimanual/agents/bimanual_peract/__init__.py +1 -0
- external/peract_bimanual/agents/bimanual_peract/launch_utils.py +93 -0
- external/peract_bimanual/agents/bimanual_peract/perceiver_lang_io.py +549 -0
- external/peract_bimanual/agents/bimanual_peract/qattention_peract_bc_agent.py +1063 -0
- external/peract_bimanual/agents/bimanual_peract/qattention_stack_agent.py +202 -0
- external/peract_bimanual/agents/c2farm_lingunet_bc/__init__.py +1 -0
- external/peract_bimanual/agents/c2farm_lingunet_bc/launch_utils.py +519 -0
- external/peract_bimanual/agents/c2farm_lingunet_bc/networks.py +301 -0
- external/peract_bimanual/agents/c2farm_lingunet_bc/qattention_lingunet_bc_agent.py +790 -0
- external/peract_bimanual/agents/c2farm_lingunet_bc/qattention_stack_agent.py +136 -0
- external/peract_bimanual/agents/peract_bc/__init__.py +1 -0
- external/peract_bimanual/agents/peract_bc/launch_utils.py +94 -0
- external/peract_bimanual/agents/peract_bc/perceiver_lang_io.py +426 -0
- external/peract_bimanual/agents/peract_bc/qattention_peract_bc_agent.py +808 -0
- external/peract_bimanual/agents/peract_bc/qattention_stack_agent.py +132 -0
- external/peract_bimanual/agents/replay_utils.py +643 -0
- external/peract_bimanual/agents/rvt/__init__.py +1 -0
- external/peract_bimanual/agents/rvt/launch_utils.py +168 -0
- external/peract_bimanual/conf/config.yaml +52 -0
- external/peract_bimanual/conf/eval.yaml +39 -0
- external/peract_bimanual/conf/hydra/job_logging/custom.yaml +12 -0
- external/peract_bimanual/conf/method/ACT_BC_LANG.yaml +51 -0
- external/peract_bimanual/conf/method/ARM.yaml +24 -0
- external/peract_bimanual/conf/method/BC_LANG.yaml +9 -0
- external/peract_bimanual/conf/method/BIMANUAL_PERACT.yaml +70 -0
- external/peract_bimanual/conf/method/C2FARM_LINGUNET_BC.yaml +40 -0
- external/peract_bimanual/conf/method/PERACT_BC.yaml +68 -0
external/peract_bimanual/.gitignore
ADDED
|
@@ -0,0 +1,160 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Byte-compiled / optimized / DLL files
|
| 2 |
+
__pycache__/
|
| 3 |
+
*.py[cod]
|
| 4 |
+
*$py.class
|
| 5 |
+
|
| 6 |
+
# C extensions
|
| 7 |
+
*.so
|
| 8 |
+
|
| 9 |
+
# Distribution / packaging
|
| 10 |
+
.Python
|
| 11 |
+
build/
|
| 12 |
+
develop-eggs/
|
| 13 |
+
dist/
|
| 14 |
+
downloads/
|
| 15 |
+
eggs/
|
| 16 |
+
.eggs/
|
| 17 |
+
lib/
|
| 18 |
+
lib64/
|
| 19 |
+
parts/
|
| 20 |
+
sdist/
|
| 21 |
+
var/
|
| 22 |
+
wheels/
|
| 23 |
+
share/python-wheels/
|
| 24 |
+
*.egg-info/
|
| 25 |
+
.installed.cfg
|
| 26 |
+
*.egg
|
| 27 |
+
MANIFEST
|
| 28 |
+
|
| 29 |
+
# PyInstaller
|
| 30 |
+
# Usually these files are written by a python script from a template
|
| 31 |
+
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
| 32 |
+
*.manifest
|
| 33 |
+
*.spec
|
| 34 |
+
|
| 35 |
+
# Installer logs
|
| 36 |
+
pip-log.txt
|
| 37 |
+
pip-delete-this-directory.txt
|
| 38 |
+
|
| 39 |
+
# Unit test / coverage reports
|
| 40 |
+
htmlcov/
|
| 41 |
+
.tox/
|
| 42 |
+
.nox/
|
| 43 |
+
.coverage
|
| 44 |
+
.coverage.*
|
| 45 |
+
.cache
|
| 46 |
+
nosetests.xml
|
| 47 |
+
coverage.xml
|
| 48 |
+
*.cover
|
| 49 |
+
*.py,cover
|
| 50 |
+
.hypothesis/
|
| 51 |
+
.pytest_cache/
|
| 52 |
+
cover/
|
| 53 |
+
|
| 54 |
+
# Translations
|
| 55 |
+
*.mo
|
| 56 |
+
*.pot
|
| 57 |
+
|
| 58 |
+
# Django stuff:
|
| 59 |
+
*.log
|
| 60 |
+
local_settings.py
|
| 61 |
+
db.sqlite3
|
| 62 |
+
db.sqlite3-journal
|
| 63 |
+
|
| 64 |
+
# Flask stuff:
|
| 65 |
+
instance/
|
| 66 |
+
.webassets-cache
|
| 67 |
+
|
| 68 |
+
# Scrapy stuff:
|
| 69 |
+
.scrapy
|
| 70 |
+
|
| 71 |
+
# Sphinx documentation
|
| 72 |
+
docs/_build/
|
| 73 |
+
|
| 74 |
+
# PyBuilder
|
| 75 |
+
.pybuilder/
|
| 76 |
+
target/
|
| 77 |
+
|
| 78 |
+
# Jupyter Notebook
|
| 79 |
+
.ipynb_checkpoints
|
| 80 |
+
|
| 81 |
+
# IPython
|
| 82 |
+
profile_default/
|
| 83 |
+
ipython_config.py
|
| 84 |
+
|
| 85 |
+
# pyenv
|
| 86 |
+
# For a library or package, you might want to ignore these files since the code is
|
| 87 |
+
# intended to run in multiple environments; otherwise, check them in:
|
| 88 |
+
# .python-version
|
| 89 |
+
|
| 90 |
+
# pipenv
|
| 91 |
+
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
| 92 |
+
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
| 93 |
+
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
| 94 |
+
# install all needed dependencies.
|
| 95 |
+
#Pipfile.lock
|
| 96 |
+
|
| 97 |
+
# poetry
|
| 98 |
+
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
|
| 99 |
+
# This is especially recommended for binary packages to ensure reproducibility, and is more
|
| 100 |
+
# commonly ignored for libraries.
|
| 101 |
+
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
|
| 102 |
+
#poetry.lock
|
| 103 |
+
|
| 104 |
+
# pdm
|
| 105 |
+
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
|
| 106 |
+
#pdm.lock
|
| 107 |
+
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
|
| 108 |
+
# in version control.
|
| 109 |
+
# https://pdm.fming.dev/#use-with-ide
|
| 110 |
+
.pdm.toml
|
| 111 |
+
|
| 112 |
+
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
|
| 113 |
+
__pypackages__/
|
| 114 |
+
|
| 115 |
+
# Celery stuff
|
| 116 |
+
celerybeat-schedule
|
| 117 |
+
celerybeat.pid
|
| 118 |
+
|
| 119 |
+
# SageMath parsed files
|
| 120 |
+
*.sage.py
|
| 121 |
+
|
| 122 |
+
# Environments
|
| 123 |
+
.env
|
| 124 |
+
.venv
|
| 125 |
+
env/
|
| 126 |
+
venv/
|
| 127 |
+
ENV/
|
| 128 |
+
env.bak/
|
| 129 |
+
venv.bak/
|
| 130 |
+
|
| 131 |
+
# Spyder project settings
|
| 132 |
+
.spyderproject
|
| 133 |
+
.spyproject
|
| 134 |
+
|
| 135 |
+
# Rope project settings
|
| 136 |
+
.ropeproject
|
| 137 |
+
|
| 138 |
+
# mkdocs documentation
|
| 139 |
+
/site
|
| 140 |
+
|
| 141 |
+
# mypy
|
| 142 |
+
.mypy_cache/
|
| 143 |
+
.dmypy.json
|
| 144 |
+
dmypy.json
|
| 145 |
+
|
| 146 |
+
# Pyre type checker
|
| 147 |
+
.pyre/
|
| 148 |
+
|
| 149 |
+
# pytype static type analyzer
|
| 150 |
+
.pytype/
|
| 151 |
+
|
| 152 |
+
# Cython debug symbols
|
| 153 |
+
cython_debug/
|
| 154 |
+
|
| 155 |
+
# PyCharm
|
| 156 |
+
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
|
| 157 |
+
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
|
| 158 |
+
# and can be added to the global gitignore or merged into this file. For a more nuclear
|
| 159 |
+
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
|
| 160 |
+
#.idea/
|
external/peract_bimanual/ARM_LICENSE
ADDED
|
@@ -0,0 +1,196 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
Q-attention: Enabling Efficient Learning for Vision-based Robotic Manipulation
|
| 2 |
+
|
| 3 |
+
LICENCE AGREEMENT
|
| 4 |
+
|
| 5 |
+
WE (Imperial College of Science, Technology and Medicine, (“Imperial College London”))
|
| 6 |
+
ARE WILLING TO LICENSE THIS SOFTWARE TO YOU (a licensee “You”) ONLY ON THE
|
| 7 |
+
CONDITION THAT YOU ACCEPT ALL OF THE TERMS CONTAINED IN THE
|
| 8 |
+
FOLLOWING AGREEMENT. PLEASE READ THE AGREEMENT CAREFULLY BEFORE
|
| 9 |
+
DOWNLOADING THE SOFTWARE. BY EXERCISING THE OPTION TO DOWNLOAD
|
| 10 |
+
THE SOFTWARE YOU AGREE TO BE BOUND BY THE TERMS OF THE AGREEMENT.
|
| 11 |
+
SOFTWARE LICENCE AGREEMENT (EXCLUDING BSD COMPONENTS)
|
| 12 |
+
|
| 13 |
+
1.This Agreement pertains to a worldwide, non-exclusive, temporary, fully paid-up, royalty
|
| 14 |
+
free, non-transferable, non-sub- licensable licence (the “Licence”) to use the Q-attention
|
| 15 |
+
source code, including any modification, part or derivative (the “Software”).
|
| 16 |
+
Ownership and Licence. Your rights to use and download the Software onto your computer,
|
| 17 |
+
and all other copies that You are authorised to make, are specified in this Agreement.
|
| 18 |
+
However, we (or our licensors) retain all rights, including but not limited to all copyright and
|
| 19 |
+
other intellectual property rights anywhere in the world, in the Software not expressly
|
| 20 |
+
granted to You in this Agreement.
|
| 21 |
+
|
| 22 |
+
2. Permitted use of the Licence:
|
| 23 |
+
|
| 24 |
+
(a) You may download and install the Software onto one computer or server for use in
|
| 25 |
+
accordance with Clause 2(b) of this Agreement provided that You ensure that the Software is
|
| 26 |
+
not accessible by other users unless they have themselves accepted the terms of this licence
|
| 27 |
+
agreement.
|
| 28 |
+
|
| 29 |
+
(b) You may use the Software solely for non-commercial, internal or academic research
|
| 30 |
+
purposes and only in accordance with the terms of this Agreement. You may not use the
|
| 31 |
+
Software for commercial purposes, including but not limited to (1) integration of all or part of
|
| 32 |
+
the source code or the Software into a product for sale or licence by or on behalf of You to
|
| 33 |
+
third parties or (2) use of the Software or any derivative of it for research to develop software
|
| 34 |
+
products for sale or licence to a third party or (3) use of the Software or any derivative of it
|
| 35 |
+
for research to develop non-software products for sale or licence to a third party, or (4) use of
|
| 36 |
+
the Software to provide any service to an external organisation for which payment is
|
| 37 |
+
received.
|
| 38 |
+
|
| 39 |
+
Should You wish to use the Software for commercial purposes, You shall
|
| 40 |
+
email researchcontracts.engineering@imperial.ac.uk .
|
| 41 |
+
|
| 42 |
+
(c) Right to Copy. You may copy the Software for back-up and archival purposes, provided
|
| 43 |
+
that each copy is kept in your possession and provided You reproduce our copyright notice
|
| 44 |
+
(set out in Schedule 1) on each copy.
|
| 45 |
+
|
| 46 |
+
(d) Transfer and sub-licensing. You may not rent, lend, or lease the Software and You may
|
| 47 |
+
not transmit, transfer or sub-license this licence to use the Software or any of your rights or
|
| 48 |
+
obligations under this Agreement to another party.
|
| 49 |
+
|
| 50 |
+
(e) Identity of Licensee. The licence granted herein is personal to You. You shall not permit
|
| 51 |
+
any third party to access, modify or otherwise use the Software nor shall You access modify
|
| 52 |
+
or otherwise use the Software on behalf of any third party. If You wish to obtain a licence for
|
| 53 |
+
mutiple users or a site licence for the Software please contact us
|
| 54 |
+
at researchcontracts.engineering@imperial.ac.uk .
|
| 55 |
+
|
| 56 |
+
(f) Publications and presentations. You may make public, results or data obtained from,
|
| 57 |
+
dependent on or arising from research carried out using the Software, provided that any such
|
| 58 |
+
presentation or publication identifies the Software as the source of the results or the data,
|
| 59 |
+
including the Copyright Notice given in each element of the Software, and stating that the
|
| 60 |
+
Software has been made available for use by You under licence from Imperial College London
|
| 61 |
+
and You provide a copy of any such publication to Imperial College London.
|
| 62 |
+
|
| 63 |
+
3. Prohibited Uses. You may not, without written permission from us
|
| 64 |
+
at researchcontracts.engineering@imperial.ac.uk :
|
| 65 |
+
|
| 66 |
+
(a) Use, copy, modify, merge, or transfer copies of the Software or any documentation
|
| 67 |
+
provided by us which relates to the Software except as provided in this Agreement;
|
| 68 |
+
|
| 69 |
+
(b) Use any back-up or archival copies of the Software (or allow anyone else to use such
|
| 70 |
+
copies) for any purpose other than to replace the original copy in the event it is destroyed or
|
| 71 |
+
becomes defective; or
|
| 72 |
+
|
| 73 |
+
(c) Disassemble, decompile or "unlock", reverse translate, or in any manner decode the
|
| 74 |
+
Software for any reason.
|
| 75 |
+
|
| 76 |
+
4. Warranty Disclaimer
|
| 77 |
+
|
| 78 |
+
(a) Disclaimer. The Software has been developed for research purposes only. You
|
| 79 |
+
acknowledge that we are providing the Software to You under this licence agreement free of
|
| 80 |
+
charge and on condition that the disclaimer set out below shall apply. We do not represent or
|
| 81 |
+
warrant that the Software as to: (i) the quality, accuracy or reliability of the Software; (ii) the
|
| 82 |
+
suitability of the Software for any particular use or for use under any specific conditions; and
|
| 83 |
+
(iii) whether use of the Software will infringe third-party rights.
|
| 84 |
+
You acknowledge that You have reviewed and evaluated the Software to determine that it
|
| 85 |
+
meets your needs and that You assume all responsibility and liability for determining the
|
| 86 |
+
suitability of the Software as fit for your particular purposes and requirements. Subject to
|
| 87 |
+
Clause 4(b), we exclude and expressly disclaim all express and implied representations,
|
| 88 |
+
warranties, conditions and terms not stated herein (including the implied conditions or
|
| 89 |
+
warranties of satisfactory quality, merchantable quality, merchantability and fitness for
|
| 90 |
+
purpose).
|
| 91 |
+
|
| 92 |
+
(b) Savings. Some jurisdictions may imply warranties, conditions or terms or impose
|
| 93 |
+
obligations upon us which cannot, in whole or in part, be excluded, restricted or modified or
|
| 94 |
+
otherwise do not allow the exclusion of implied warranties, conditions or terms, in which
|
| 95 |
+
case the above warranty disclaimer and exclusion will only apply to You to the extent
|
| 96 |
+
permitted in the relevant jurisdiction and does not in any event exclude any implied
|
| 97 |
+
warranties, conditions or terms which may not under applicable law be excluded.
|
| 98 |
+
|
| 99 |
+
(c) Imperial College London disclaims all responsibility for the use which is made of the
|
| 100 |
+
Software and any liability for the outcomes arising from using the Software.
|
| 101 |
+
|
| 102 |
+
5. Limitation of Liability
|
| 103 |
+
|
| 104 |
+
(a) You acknowledge that we are providing the Software to You under this licence agreement
|
| 105 |
+
free of charge and on condition that the limitation of liability set out below shall apply.
|
| 106 |
+
Accordingly, subject to Clause 5(b), we exclude all liability whether in contract, tort,
|
| 107 |
+
negligence or otherwise, in respect of the Software and/or any related documentation
|
| 108 |
+
provided to You by us including, but not limited to, liability for loss or corruption of data,
|
| 109 |
+
loss of contracts, loss of income, loss of profits, loss of cover and any consequential or indirect
|
| 110 |
+
loss or damage of any kind arising out of or in connection with this licence agreement,
|
| 111 |
+
however caused. This exclusion shall apply even if we have been advised of the possibility of
|
| 112 |
+
such loss or damage.
|
| 113 |
+
|
| 114 |
+
(b) You agree to indemnify Imperial College London and hold it harmless from and against
|
| 115 |
+
any and all claims, damages and liabilities asserted by third parties (including claims for
|
| 116 |
+
negligence) which arise directly or indirectly from the use of the Software or any derivative
|
| 117 |
+
of it or the sale of any products based on the Software. You undertake to make no liability
|
| 118 |
+
claim against any employee, student, agent or appointee of Imperial College London, in
|
| 119 |
+
connection with this Licence or the Software.
|
| 120 |
+
|
| 121 |
+
(c) Nothing in this Agreement shall have the effect of excluding or limiting our statutory
|
| 122 |
+
liability.
|
| 123 |
+
|
| 124 |
+
(d) Some jurisdictions do not allow these limitations or exclusions either wholly or in part,
|
| 125 |
+
and, to that extent, they may not apply to you. Nothing in this licence agreement will affect
|
| 126 |
+
your statutory rights or other relevant statutory provisions which cannot be excluded,
|
| 127 |
+
restricted or modified, and its terms and conditions must be read and construed subject to any
|
| 128 |
+
such statutory rights and/or provisions.
|
| 129 |
+
|
| 130 |
+
6. Confidentiality. You agree not to disclose any confidential information provided to You by
|
| 131 |
+
us pursuant to this Agreement to any third party without our prior written consent. The
|
| 132 |
+
obligations in this Clause 6 shall survive the termination of this Agreement for any reason.
|
| 133 |
+
|
| 134 |
+
7. Termination.
|
| 135 |
+
|
| 136 |
+
(a) We may terminate this licence agreement and your right to use the Software at any time
|
| 137 |
+
with immediate effect upon written notice to You.
|
| 138 |
+
|
| 139 |
+
(b) This licence agreement and your right to use the Software automatically terminate if You:
|
| 140 |
+
(i) fail to comply with any provisions of this Agreement; or
|
| 141 |
+
(ii) destroy the copies of the Software in your possession, or voluntarily return the Software
|
| 142 |
+
to us.
|
| 143 |
+
|
| 144 |
+
(c) Upon termination You will destroy all copies of the Software.
|
| 145 |
+
|
| 146 |
+
(d) Otherwise, the restrictions on your rights to use the Software will expire 10 (ten) years
|
| 147 |
+
after first use of the Software under this licence agreement.
|
| 148 |
+
|
| 149 |
+
8. Miscellaneous Provisions.
|
| 150 |
+
|
| 151 |
+
(a) This Agreement will be governed by and construed in accordance with the substantive
|
| 152 |
+
laws of England and Wales whose courts shall have exclusive jurisdiction over all disputes
|
| 153 |
+
which may arise between us.
|
| 154 |
+
|
| 155 |
+
(b) This is the entire agreement between us relating to the Software, and supersedes any prior
|
| 156 |
+
purchase order, communications, advertising or representations concerning the Software.
|
| 157 |
+
|
| 158 |
+
(c) No change or modification of this Agreement will be valid unless it is in writing, and is
|
| 159 |
+
signed by us.
|
| 160 |
+
|
| 161 |
+
(d) The unenforceability or invalidity of any part of this Agreement will not affect the
|
| 162 |
+
enforceability or validity of the remaining parts.
|
| 163 |
+
|
| 164 |
+
BSD Elements of the Software
|
| 165 |
+
|
| 166 |
+
For BSD elements of the Software, the following terms shall apply:
|
| 167 |
+
|
| 168 |
+
Copyright as indicated in the header of the individual element of the Software.
|
| 169 |
+
|
| 170 |
+
All rights reserved.
|
| 171 |
+
|
| 172 |
+
Redistribution and use in source and binary forms, with or without modification, are
|
| 173 |
+
permitted provided that the following conditions are met:
|
| 174 |
+
|
| 175 |
+
1. Redistributions of source code must retain the above copyright notice, this list of
|
| 176 |
+
conditions and the following disclaimer.
|
| 177 |
+
|
| 178 |
+
2. Redistributions in binary form must reproduce the above copyright notice, this list of
|
| 179 |
+
conditions and the following disclaimer in the documentation and/or other materials
|
| 180 |
+
provided with the distribution.
|
| 181 |
+
|
| 182 |
+
3. Neither the name of the copyright holder nor the names of its contributors may be used to
|
| 183 |
+
endorse or promote products derived from this software without specific prior written
|
| 184 |
+
permission.
|
| 185 |
+
|
| 186 |
+
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
|
| 187 |
+
"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
|
| 188 |
+
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
|
| 189 |
+
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
|
| 190 |
+
HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
|
| 191 |
+
SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
|
| 192 |
+
LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
|
| 193 |
+
DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON
|
| 194 |
+
ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
|
| 195 |
+
(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
| 196 |
+
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
external/peract_bimanual/Dockerfile
ADDED
|
@@ -0,0 +1,68 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Use the NVIDIA base image for CUDA
|
| 2 |
+
FROM nvcr.io/nvidia/cuda:12.3.2-cudnn9-devel-ubuntu20.04
|
| 3 |
+
|
| 4 |
+
# Set environment variables
|
| 5 |
+
ENV COPPELIASIM_ROOT=${HOME}/code/coppelia_sim
|
| 6 |
+
ENV LD_LIBRARY_PATH=$LD_LIBRARY_PATH:$COPPELIASIM_ROOT
|
| 7 |
+
ENV QT_QPA_PLATFORM_PLUGIN_PATH=$COPPELIASIM_ROOT
|
| 8 |
+
ENV DEBIAN_FRONTEND=noninteractive
|
| 9 |
+
ENV TZ=America/Los_Angeles
|
| 10 |
+
ENV CONDA_ALWAYS_YES=true
|
| 11 |
+
ENV FORCE_CUDA=1
|
| 12 |
+
ENV TORCH_CUDA_ARCH_LIST="5.0;5.2;5.3;6.0;6.1;6.2;7.0;7.2;7.5;8.0;8.6;8.7;8.9;9.0+PTX"
|
| 13 |
+
|
| 14 |
+
# Create necessary directories
|
| 15 |
+
RUN mkdir -p ${HOME}/code
|
| 16 |
+
|
| 17 |
+
# Install dependencies and essential tools
|
| 18 |
+
RUN apt-get update && apt-get install -y \
|
| 19 |
+
tzdata sudo curl git vim htop tar bzip2 pigz rsync less mlocate \
|
| 20 |
+
build-essential gdb ca-certificates stress sysstat itop \
|
| 21 |
+
xauth xvfb mesa-utils mesa-utils-extra x11-apps \
|
| 22 |
+
xorg xserver-xorg-core libxv1 x11-xserver-utils libxcb-randr0-dev \
|
| 23 |
+
libxrender-dev libxkbcommon-dev libxkbcommon-x11-0 libavcodec-dev \
|
| 24 |
+
libavformat-dev libswscale-dev '^libxcb.*-dev' libx11-xcb-dev \
|
| 25 |
+
libglu1-mesa-dev libxrender-dev libxi-dev libxkbcommon-dev \
|
| 26 |
+
libxkbcommon-x11-dev libegl1-mesa libarchive-dev libarchive13 \
|
| 27 |
+
&& rm -rf /var/lib/apt/lists/*
|
| 28 |
+
|
| 29 |
+
# Install VirtualGL
|
| 30 |
+
RUN TEMP_DIR=$(mktemp -d -p /) && cd $TEMP_DIR && \
|
| 31 |
+
curl -L -o virtualgl.deb https://sourceforge.net/projects/virtualgl/files/3.1/virtualgl_3.1_amd64.deb/download && \
|
| 32 |
+
dpkg -i virtualgl.deb && \
|
| 33 |
+
/opt/VirtualGL/bin/vglserver_config +glx +egl +s +f +t && \
|
| 34 |
+
rm -rf $TEMP_DIR
|
| 35 |
+
|
| 36 |
+
RUN mkdir ${HOME}/.ssh && chmod -R 700 ${HOME}/.ssh
|
| 37 |
+
|
| 38 |
+
RUN ssh-keyscan github.com >> ${HOME}/.ssh/known_hosts
|
| 39 |
+
|
| 40 |
+
RUN curl -L -O https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh
|
| 41 |
+
RUN bash Miniconda3-latest-Linux-x86_64.sh -b -p /opt/conda
|
| 42 |
+
RUN export PATH=/opt/conda/bin:${PATH}
|
| 43 |
+
|
| 44 |
+
# Install code and dependencies
|
| 45 |
+
|
| 46 |
+
WORKDIR ${HOME}/code
|
| 47 |
+
|
| 48 |
+
RUN eval "$(/opt/conda/bin/conda shell.bash hook)" && conda init bash
|
| 49 |
+
RUN eval "$(/opt/conda/bin/conda shell.bash hook)" && conda install mamba -c conda-forge
|
| 50 |
+
#RUN conda config --set auto_activate_base false
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
RUN git clone https://github.com/markusgrotz/peract_bimanual.git ${HOME}/code/peract_bimanual
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
RUN eval "$(/opt/conda/bin/conda shell.bash hook)" && ${HOME}/code/peract_bimanual/scripts/install_dependencies.sh
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
# Activate the environment by default
|
| 60 |
+
RUN echo ". /opt/conda/etc/profile.d/conda.sh" >> ~/.bashrc && \
|
| 61 |
+
echo "conda activate rlbench" >> ~/.bashrc
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
WORKDIR /root/code/peract_bimanual
|
| 65 |
+
|
| 66 |
+
# Default command
|
| 67 |
+
CMD ["/bin/bash"]
|
| 68 |
+
|
external/peract_bimanual/INSTALLATION.md
ADDED
|
@@ -0,0 +1,87 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# INSTALLATION
|
| 2 |
+
|
| 3 |
+
To install the dependencies execute the `scripts/install_dependencies.sh`
|
| 4 |
+
|
| 5 |
+
```bash
|
| 6 |
+
scripts/install_conda.sh # Skip this step if you already have conda installed.
|
| 7 |
+
scripts/install_dependencies.sh
|
| 8 |
+
```
|
| 9 |
+
|
| 10 |
+
Please see the [README](README.md) for a quick start instruction.
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
Alternatively, you can follow the detailed instructions to setup the software from scratch
|
| 14 |
+
|
| 15 |
+
#### 2. PyRep and Coppelia Simulator
|
| 16 |
+
|
| 17 |
+
Follow instructions from my [PyRep fork](https://github.com/markusgrotz/PyRep); reproduced here for convenience:
|
| 18 |
+
|
| 19 |
+
PyRep requires version **4.1** of CoppeliaSim. Download:
|
| 20 |
+
- [Ubuntu 20.04](https://www.coppeliarobotics.com/files/V4_1_0/CoppeliaSim_Edu_V4_1_0_Ubuntu20_04.tar.xz)
|
| 21 |
+
|
| 22 |
+
Once you have downloaded CoppeliaSim, you can pull PyRep from git:
|
| 23 |
+
|
| 24 |
+
```bash
|
| 25 |
+
cd <install_dir>
|
| 26 |
+
git clone https://github.com/markusgrotz/PyRep.git
|
| 27 |
+
cd PyRep
|
| 28 |
+
```
|
| 29 |
+
|
| 30 |
+
Add the following to your *~/.bashrc* file: (__NOTE__: the 'EDIT ME' in the first line)
|
| 31 |
+
|
| 32 |
+
```bash
|
| 33 |
+
export COPPELIASIM_ROOT=<EDIT ME>/PATH/TO/COPPELIASIM/INSTALL/DIR
|
| 34 |
+
export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:$COPPELIASIM_ROOT
|
| 35 |
+
export QT_QPA_PLATFORM_PLUGIN_PATH=$COPPELIASIM_ROOT
|
| 36 |
+
```
|
| 37 |
+
|
| 38 |
+
Remember to source your bashrc (`source ~/.bashrc`) or
|
| 39 |
+
zshrc (`source ~/.zshrc`) after this.
|
| 40 |
+
|
| 41 |
+
**Warning**: CoppeliaSim might cause conflicts with ROS workspaces.
|
| 42 |
+
|
| 43 |
+
Finally install the python library:
|
| 44 |
+
|
| 45 |
+
```bash
|
| 46 |
+
pip install -e .
|
| 47 |
+
```
|
| 48 |
+
|
| 49 |
+
You should be good to go!
|
| 50 |
+
You could try running one of the examples in the *examples/* folder.
|
| 51 |
+
|
| 52 |
+
#### 3. RLBench
|
| 53 |
+
|
| 54 |
+
PerAct uses my [RLBench fork](https://github.com/markusgrotz/RLBench/tree/peract).
|
| 55 |
+
|
| 56 |
+
```bash
|
| 57 |
+
cd <install_dir>
|
| 58 |
+
git clone https://github.com/markusgrotz/RLBench.git
|
| 59 |
+
|
| 60 |
+
cd RLBench
|
| 61 |
+
pip install -e .
|
| 62 |
+
```
|
| 63 |
+
|
| 64 |
+
For [running in headless mode](https://github.com/MohitShridhar/RLBench/tree/peract#running-headless), tasks setups, and other issues, please refer to the [official repo](https://github.com/stepjam/RLBench).
|
| 65 |
+
|
| 66 |
+
#### 4. YARR
|
| 67 |
+
|
| 68 |
+
PerAct uses my [YARR fork](https://github.com/markusgrotz/YARR/).
|
| 69 |
+
|
| 70 |
+
```bash
|
| 71 |
+
cd <install_dir>
|
| 72 |
+
git clone https://github.com/markusgrotz/YARR.git
|
| 73 |
+
|
| 74 |
+
cd YARR
|
| 75 |
+
pip install -e .
|
| 76 |
+
```
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
|
| 80 |
+
#### RVT baseline
|
| 81 |
+
|
| 82 |
+
pip install git+https://github.com/NVlabs/RVT.git
|
| 83 |
+
pip install -e .
|
| 84 |
+
|
| 85 |
+
|
| 86 |
+
|
| 87 |
+
|
external/peract_bimanual/agents/__init__.py
ADDED
|
File without changes
|
external/peract_bimanual/agents/act_bc_lang/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
import agents.act_bc_lang.launch_utils
|
external/peract_bimanual/agents/act_bc_lang/act_bc_lang_agent.py
ADDED
|
@@ -0,0 +1,381 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import copy
|
| 2 |
+
import logging
|
| 3 |
+
from functools import lru_cache
|
| 4 |
+
import pickle
|
| 5 |
+
import os
|
| 6 |
+
from typing import List
|
| 7 |
+
import numpy as np
|
| 8 |
+
|
| 9 |
+
import torch
|
| 10 |
+
import torch.nn as nn
|
| 11 |
+
import torch.nn.functional as F
|
| 12 |
+
from yarr.agents.agent import Agent, Summary, ActResult, ScalarSummary, HistogramSummary
|
| 13 |
+
|
| 14 |
+
from helpers import utils
|
| 15 |
+
from helpers.utils import stack_on_channel
|
| 16 |
+
|
| 17 |
+
from helpers.clip.core.clip import build_model, load_clip
|
| 18 |
+
|
| 19 |
+
NAME = "ActBCLangAgent"
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
class ActBCLangAgent(Agent):
|
| 23 |
+
def __init__(
|
| 24 |
+
self,
|
| 25 |
+
actor_network: nn.Module,
|
| 26 |
+
camera_names: List[str],
|
| 27 |
+
lr: float = 0.01,
|
| 28 |
+
weight_decay: float = 1e-5,
|
| 29 |
+
grad_clip: float = 20.0,
|
| 30 |
+
episode_length: int = 400,
|
| 31 |
+
train_demo_path=None,
|
| 32 |
+
task_name=None,
|
| 33 |
+
):
|
| 34 |
+
self._camera_names = camera_names
|
| 35 |
+
self._actor = actor_network
|
| 36 |
+
self._lr = lr
|
| 37 |
+
self._weight_decay = weight_decay
|
| 38 |
+
self._grad_clip = grad_clip
|
| 39 |
+
self._episode_length = episode_length
|
| 40 |
+
self.train_demo_path = train_demo_path
|
| 41 |
+
self.task_name = task_name
|
| 42 |
+
|
| 43 |
+
def build(self, training: bool, device: torch.device = None):
|
| 44 |
+
if device is None:
|
| 45 |
+
device = torch.device("cpu")
|
| 46 |
+
self._actor = self._actor.to(device).train(training)
|
| 47 |
+
self._actor_optimizer = self._actor.configure_optimizers()
|
| 48 |
+
|
| 49 |
+
self._device = device
|
| 50 |
+
|
| 51 |
+
def reset(self):
|
| 52 |
+
super(ActBCLangAgent, self).reset()
|
| 53 |
+
|
| 54 |
+
self._timestep = 0
|
| 55 |
+
# .. input_dim = input_dim * 2 for bimanual
|
| 56 |
+
self._all_time_actions = torch.zeros(
|
| 57 |
+
[
|
| 58 |
+
self._episode_length,
|
| 59 |
+
self._episode_length + self._actor.model.num_queries,
|
| 60 |
+
self._actor.model.input_dim,
|
| 61 |
+
]
|
| 62 |
+
).to(self._device)
|
| 63 |
+
self._all_actions = None
|
| 64 |
+
|
| 65 |
+
def _grad_step(self, loss, opt, model_params=None, clip=None):
|
| 66 |
+
opt.zero_grad()
|
| 67 |
+
loss.backward()
|
| 68 |
+
if clip is not None and model_params is not None:
|
| 69 |
+
nn.utils.clip_grad_value_(model_params, clip)
|
| 70 |
+
opt.step()
|
| 71 |
+
|
| 72 |
+
@lru_cache()
|
| 73 |
+
def train_stats(self):
|
| 74 |
+
right_joint_positions = []
|
| 75 |
+
left_joint_positions = []
|
| 76 |
+
|
| 77 |
+
right_gripper_positions = []
|
| 78 |
+
left_gripper_positions = []
|
| 79 |
+
|
| 80 |
+
episodes_dir = (
|
| 81 |
+
f"{self.train_demo_path}/{self.task_name}/all_variations/episodes/"
|
| 82 |
+
)
|
| 83 |
+
|
| 84 |
+
for episode in os.listdir(episodes_dir):
|
| 85 |
+
with open(
|
| 86 |
+
os.path.join(episodes_dir, episode, "low_dim_obs.pkl"), "br"
|
| 87 |
+
) as f:
|
| 88 |
+
d = pickle.load(f)
|
| 89 |
+
|
| 90 |
+
for o in d:
|
| 91 |
+
right_joint_positions.append(o.right.joint_positions)
|
| 92 |
+
left_joint_positions.append(o.left.joint_positions)
|
| 93 |
+
|
| 94 |
+
right_gripper_positions.append([o.right.gripper_joint_positions[0]])
|
| 95 |
+
left_gripper_positions.append([o.left.gripper_joint_positions[0]])
|
| 96 |
+
|
| 97 |
+
right_joint_positions = np.asarray(right_joint_positions, dtype=np.float32)
|
| 98 |
+
left_joint_positions = np.asarray(left_joint_positions, dtype=np.float32)
|
| 99 |
+
|
| 100 |
+
right_gripper_positions = np.asarray(right_gripper_positions, dtype=np.float32)
|
| 101 |
+
left_gripper_positions = np.asarray(left_gripper_positions, dtype=np.float32)
|
| 102 |
+
|
| 103 |
+
stats = {
|
| 104 |
+
"right_joints_mean": right_joint_positions.mean(axis=0),
|
| 105 |
+
"right_joints_std": right_joint_positions.std(axis=0),
|
| 106 |
+
"left_joints_mean": left_joint_positions.mean(axis=0),
|
| 107 |
+
"left_joints_std": left_joint_positions.std(axis=0),
|
| 108 |
+
"right_gripper_mean": right_gripper_positions.mean(axis=0),
|
| 109 |
+
"right_gripper_std": right_gripper_positions.std(axis=0),
|
| 110 |
+
"left_gripper_mean": left_gripper_positions.mean(axis=0),
|
| 111 |
+
"left_gripper_std": left_gripper_positions.std(axis=0),
|
| 112 |
+
}
|
| 113 |
+
|
| 114 |
+
return {k: torch.from_numpy(v).to(self._device) for k, v in stats.items()}
|
| 115 |
+
|
| 116 |
+
def normalize_z(self, data, mean, std):
|
| 117 |
+
return (data - mean) / std
|
| 118 |
+
|
| 119 |
+
def unnormalize_z(self, data, mean, std):
|
| 120 |
+
return data * std + mean
|
| 121 |
+
|
| 122 |
+
def preprocess_qpos(self, observation: dict):
|
| 123 |
+
stats = self.train_stats()
|
| 124 |
+
|
| 125 |
+
right_qrev = self.normalize_z(
|
| 126 |
+
observation["right_joint_positions"][:, 0],
|
| 127 |
+
stats["right_joints_mean"],
|
| 128 |
+
stats["right_joints_std"],
|
| 129 |
+
)
|
| 130 |
+
right_qgripper = self.normalize_z(
|
| 131 |
+
observation["right_gripper_joint_positions"][:, 0],
|
| 132 |
+
stats["right_gripper_mean"],
|
| 133 |
+
stats["right_gripper_std"],
|
| 134 |
+
)
|
| 135 |
+
left_qrev = self.normalize_z(
|
| 136 |
+
observation["left_joint_positions"][:, 0],
|
| 137 |
+
stats["left_joints_mean"],
|
| 138 |
+
stats["left_joints_std"],
|
| 139 |
+
)
|
| 140 |
+
left_qgripper = self.normalize_z(
|
| 141 |
+
observation["left_gripper_joint_positions"][:, 0],
|
| 142 |
+
stats["left_gripper_mean"],
|
| 143 |
+
stats["left_gripper_std"],
|
| 144 |
+
)
|
| 145 |
+
qpos = torch.cat(
|
| 146 |
+
[
|
| 147 |
+
right_qrev,
|
| 148 |
+
right_qgripper[:, 0].unsqueeze(-1),
|
| 149 |
+
left_qrev,
|
| 150 |
+
left_qgripper[:, 0].unsqueeze(-1),
|
| 151 |
+
],
|
| 152 |
+
dim=-1,
|
| 153 |
+
)
|
| 154 |
+
|
| 155 |
+
return qpos
|
| 156 |
+
|
| 157 |
+
def preprocess_action(self, replay_sample: dict):
|
| 158 |
+
stats = self.train_stats()
|
| 159 |
+
|
| 160 |
+
right_qrev = self.normalize_z(
|
| 161 |
+
replay_sample["right_prev_joint_positions"][:, 0],
|
| 162 |
+
stats["right_joints_mean"],
|
| 163 |
+
stats["right_joints_std"],
|
| 164 |
+
)
|
| 165 |
+
right_qgripper = self.normalize_z(
|
| 166 |
+
replay_sample["right_prev_gripper_joint_positions"][:, 0],
|
| 167 |
+
stats["right_gripper_mean"],
|
| 168 |
+
stats["right_gripper_std"],
|
| 169 |
+
)
|
| 170 |
+
left_qrev = self.normalize_z(
|
| 171 |
+
replay_sample["left_prev_joint_positions"][:, 0],
|
| 172 |
+
stats["left_joints_mean"],
|
| 173 |
+
stats["left_joints_std"],
|
| 174 |
+
)
|
| 175 |
+
left_qgripper = self.normalize_z(
|
| 176 |
+
replay_sample["left_prev_gripper_joint_positions"][:, 0],
|
| 177 |
+
stats["left_gripper_mean"],
|
| 178 |
+
stats["left_gripper_std"],
|
| 179 |
+
)
|
| 180 |
+
qpos = torch.cat(
|
| 181 |
+
[
|
| 182 |
+
right_qrev,
|
| 183 |
+
right_qgripper[:, 0].unsqueeze(-1),
|
| 184 |
+
left_qrev,
|
| 185 |
+
left_qgripper[:, 0].unsqueeze(-1),
|
| 186 |
+
],
|
| 187 |
+
dim=-1,
|
| 188 |
+
)
|
| 189 |
+
|
| 190 |
+
right_action_rev = self.normalize_z(
|
| 191 |
+
replay_sample["right_next_joint_positions"],
|
| 192 |
+
stats["right_joints_mean"],
|
| 193 |
+
stats["right_joints_std"],
|
| 194 |
+
)
|
| 195 |
+
right_action_gripper = self.normalize_z(
|
| 196 |
+
replay_sample["right_next_gripper_joint_positions"],
|
| 197 |
+
stats["right_gripper_mean"],
|
| 198 |
+
stats["right_gripper_std"],
|
| 199 |
+
)
|
| 200 |
+
left_action_rev = self.normalize_z(
|
| 201 |
+
replay_sample["left_next_joint_positions"],
|
| 202 |
+
stats["left_joints_mean"],
|
| 203 |
+
stats["left_joints_std"],
|
| 204 |
+
)
|
| 205 |
+
left_action_gripper = self.normalize_z(
|
| 206 |
+
replay_sample["left_next_gripper_joint_positions"],
|
| 207 |
+
stats["left_gripper_mean"],
|
| 208 |
+
stats["left_gripper_std"],
|
| 209 |
+
)
|
| 210 |
+
action_seq = torch.cat(
|
| 211 |
+
[
|
| 212 |
+
right_action_rev,
|
| 213 |
+
right_action_gripper[:, :, 0].unsqueeze(-1),
|
| 214 |
+
left_action_rev,
|
| 215 |
+
left_action_gripper[:, :, 0].unsqueeze(-1),
|
| 216 |
+
],
|
| 217 |
+
dim=-1,
|
| 218 |
+
)
|
| 219 |
+
|
| 220 |
+
return qpos, action_seq
|
| 221 |
+
|
| 222 |
+
def preprocess_images(self, replay_sample: dict):
|
| 223 |
+
stacked_rgb = []
|
| 224 |
+
stacked_point_cloud = []
|
| 225 |
+
|
| 226 |
+
for camera in self._camera_names:
|
| 227 |
+
rgb = replay_sample["%s_rgb" % camera]
|
| 228 |
+
rgb = rgb if rgb.dim() == 4 else rgb[:, 0]
|
| 229 |
+
stacked_rgb.append(rgb)
|
| 230 |
+
|
| 231 |
+
point_cloud = replay_sample["%s_point_cloud" % camera]
|
| 232 |
+
point_cloud = point_cloud if point_cloud.dim() == 4 else point_cloud[:, 0]
|
| 233 |
+
stacked_point_cloud.append(point_cloud)
|
| 234 |
+
|
| 235 |
+
stacked_rgb = torch.stack(stacked_rgb, dim=1)
|
| 236 |
+
stacked_point_cloud = torch.stack(stacked_point_cloud, dim=1)
|
| 237 |
+
|
| 238 |
+
return stacked_rgb, stacked_point_cloud
|
| 239 |
+
|
| 240 |
+
def update(self, step: int, replay_sample: dict) -> dict:
|
| 241 |
+
lang_goal_emb = replay_sample["lang_goal_emb"] # TODO use language
|
| 242 |
+
robot_state = replay_sample["low_dim_state"]
|
| 243 |
+
|
| 244 |
+
# preprocess input
|
| 245 |
+
qpos, action_seq = self.preprocess_action(replay_sample)
|
| 246 |
+
stacked_rgb, stacked_point_cloud = self.preprocess_images(replay_sample)
|
| 247 |
+
is_pad = replay_sample["is_pad"].bool()
|
| 248 |
+
|
| 249 |
+
# forward pass
|
| 250 |
+
loss_dict = self._actor(qpos, stacked_rgb, action_seq, is_pad)
|
| 251 |
+
|
| 252 |
+
# gradient step
|
| 253 |
+
loss = loss_dict["total_losses"]
|
| 254 |
+
loss.backward()
|
| 255 |
+
self._actor_optimizer.step()
|
| 256 |
+
self._actor_optimizer.zero_grad()
|
| 257 |
+
|
| 258 |
+
self._summaries = {
|
| 259 |
+
"loss": loss_dict["total_losses"],
|
| 260 |
+
"l1": loss_dict["l1"],
|
| 261 |
+
"right_l1": loss_dict["right_l1"],
|
| 262 |
+
"left_l1": loss_dict["left_l1"],
|
| 263 |
+
"kl": loss_dict["kl"],
|
| 264 |
+
}
|
| 265 |
+
|
| 266 |
+
return loss_dict
|
| 267 |
+
|
| 268 |
+
def _normalize_quat(self, x):
|
| 269 |
+
return x / x.square().sum(dim=1).sqrt().unsqueeze(-1)
|
| 270 |
+
|
| 271 |
+
def _normalize_revolute_joints(self, x):
|
| 272 |
+
# normalize joint angles
|
| 273 |
+
# input ranges from -pi to pi
|
| 274 |
+
# out ranges from 0 to 1
|
| 275 |
+
return (x + np.pi) / (2 * np.pi)
|
| 276 |
+
|
| 277 |
+
def _unnormalize_revolute_joints(self, x):
|
| 278 |
+
# map input with range 0 to 1 to -pi to pi
|
| 279 |
+
x = (x - 0.5) * 2.0 * np.pi
|
| 280 |
+
x = torch.clamp(x, -np.pi, np.pi)
|
| 281 |
+
return x
|
| 282 |
+
|
| 283 |
+
def _normalize_gripper_joints(self, x):
|
| 284 |
+
gripper_min = 0
|
| 285 |
+
gripper_max = 0.04
|
| 286 |
+
# normalize gripper joint angles between 0 and 1, the input ranges from 0 to 0.04
|
| 287 |
+
return (x - gripper_min) / (gripper_max - gripper_min)
|
| 288 |
+
|
| 289 |
+
def _unnormalize_gripper_joints(self, x):
|
| 290 |
+
gripper_min = 0
|
| 291 |
+
gripper_max = 0.04
|
| 292 |
+
|
| 293 |
+
x = x * (gripper_max - gripper_min) + gripper_min
|
| 294 |
+
x = torch.clamp(x, gripper_min, gripper_max)
|
| 295 |
+
return torch.unsqueeze(x, dim=0)
|
| 296 |
+
|
| 297 |
+
def act(self, step: int, observation: dict, deterministic=False) -> ActResult:
|
| 298 |
+
# lang_goal_tokens = observation.get('lang_goal_tokens', None).long()
|
| 299 |
+
# with torch.no_grad():
|
| 300 |
+
# lang_goal_tokens = lang_goal_tokens.to(device=self._device)
|
| 301 |
+
# lang_goal_emb, _ = self._clip_rn50.encode_text_with_embeddings(lang_goal_tokens[0])
|
| 302 |
+
# lang_goal_emb = lang_goal_emb.to(device=self._device)
|
| 303 |
+
|
| 304 |
+
action_horizon = self._actor.model.num_queries
|
| 305 |
+
query_freq = 1
|
| 306 |
+
|
| 307 |
+
stats = self.train_stats()
|
| 308 |
+
|
| 309 |
+
if self._timestep % query_freq == 0:
|
| 310 |
+
with torch.no_grad():
|
| 311 |
+
# preprocess input
|
| 312 |
+
qpos = self.preprocess_qpos(observation)
|
| 313 |
+
stacked_rgb, stacked_point_cloud = self.preprocess_images(observation)
|
| 314 |
+
|
| 315 |
+
# forward pass
|
| 316 |
+
self._all_actions = self._actor(
|
| 317 |
+
qpos, stacked_rgb, actions=None, is_pad=None
|
| 318 |
+
)
|
| 319 |
+
|
| 320 |
+
# temporal aggregation
|
| 321 |
+
t = self._timestep
|
| 322 |
+
|
| 323 |
+
self._all_time_actions[[t], t : t + action_horizon] = self._all_actions
|
| 324 |
+
actions_for_curr_step = self._all_time_actions[:, t]
|
| 325 |
+
actions_populated = torch.all(actions_for_curr_step != 0, axis=1)
|
| 326 |
+
actions_for_curr_step = actions_for_curr_step[actions_populated]
|
| 327 |
+
k = 0.01
|
| 328 |
+
exp_weights = np.exp(-k * np.arange(len(actions_for_curr_step)))
|
| 329 |
+
exp_weights = exp_weights / exp_weights.sum()
|
| 330 |
+
exp_weights = torch.from_numpy(exp_weights).to(self._device).unsqueeze(dim=1)
|
| 331 |
+
raw_action = (actions_for_curr_step * exp_weights).sum(dim=0, keepdim=True)
|
| 332 |
+
raw_action = raw_action[0]
|
| 333 |
+
|
| 334 |
+
right_a_rev = self.unnormalize_z(
|
| 335 |
+
raw_action[0:7], stats["right_joints_mean"], stats["right_joints_std"]
|
| 336 |
+
)
|
| 337 |
+
right_a_gripper = self.unnormalize_z(
|
| 338 |
+
raw_action[7], stats["right_gripper_mean"], stats["right_gripper_std"]
|
| 339 |
+
)
|
| 340 |
+
|
| 341 |
+
left_a_rev = self.unnormalize_z(
|
| 342 |
+
raw_action[8:15], stats["left_joints_mean"], stats["left_joints_std"]
|
| 343 |
+
)
|
| 344 |
+
left_a_gripper = self.unnormalize_z(
|
| 345 |
+
raw_action[15], stats["left_gripper_mean"], stats["left_gripper_std"]
|
| 346 |
+
)
|
| 347 |
+
|
| 348 |
+
raw_action = torch.cat(
|
| 349 |
+
[right_a_rev, right_a_gripper, left_a_rev, left_a_gripper], dim=-1
|
| 350 |
+
)
|
| 351 |
+
|
| 352 |
+
self._timestep += 1
|
| 353 |
+
|
| 354 |
+
return ActResult(raw_action.detach().cpu().numpy())
|
| 355 |
+
|
| 356 |
+
def update_summaries(self) -> List[Summary]:
|
| 357 |
+
summaries = []
|
| 358 |
+
for n, v in self._summaries.items():
|
| 359 |
+
summaries.append(ScalarSummary("%s/%s" % (NAME, n), v))
|
| 360 |
+
|
| 361 |
+
# for tag, param in self._actor.named_parameters():
|
| 362 |
+
# summaries.append(
|
| 363 |
+
#
|
| 364 |
+
# summaries.append(
|
| 365 |
+
# HistogramSummary('%s/weight/%s' % (NAME, tag), param.data))
|
| 366 |
+
|
| 367 |
+
return summaries
|
| 368 |
+
|
| 369 |
+
def act_summaries(self) -> List[Summary]:
|
| 370 |
+
return []
|
| 371 |
+
|
| 372 |
+
def load_weights(self, savedir: str):
|
| 373 |
+
self._actor.load_state_dict(
|
| 374 |
+
torch.load(
|
| 375 |
+
os.path.join(savedir, "bc_actor.pt"), map_location=torch.device("cpu")
|
| 376 |
+
)
|
| 377 |
+
)
|
| 378 |
+
print("Loaded weights from %s" % savedir)
|
| 379 |
+
|
| 380 |
+
def save_weights(self, savedir: str):
|
| 381 |
+
torch.save(self._actor.state_dict(), os.path.join(savedir, "bc_actor.pt"))
|
external/peract_bimanual/agents/act_bc_lang/act_policy.py
ADDED
|
@@ -0,0 +1,135 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
from torch.nn import functional as F
|
| 4 |
+
import torchvision.transforms as transforms
|
| 5 |
+
|
| 6 |
+
from agents.act_bc_lang.detr.build import (
|
| 7 |
+
build_ACT_model_and_optimizer,
|
| 8 |
+
build_CNNMLP_model_and_optimizer,
|
| 9 |
+
)
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
class ACTPolicy(nn.Module):
|
| 13 |
+
def __init__(self, args):
|
| 14 |
+
super().__init__()
|
| 15 |
+
model, optimizer = build_ACT_model_and_optimizer(args)
|
| 16 |
+
self.model = model # CVAE decoder
|
| 17 |
+
self.optimizer = optimizer
|
| 18 |
+
self.kl_weight = args.kl_weight
|
| 19 |
+
print(f"KL Weight {self.kl_weight}")
|
| 20 |
+
|
| 21 |
+
def forward(self, qpos, image, actions=None, is_pad=None):
|
| 22 |
+
env_state = None
|
| 23 |
+
|
| 24 |
+
if actions is not None: # training time
|
| 25 |
+
actions = actions[:, : self.model.num_queries]
|
| 26 |
+
is_pad = is_pad[:, : self.model.num_queries]
|
| 27 |
+
|
| 28 |
+
a_hat, is_pad_hat, (mu, logvar) = self.model(
|
| 29 |
+
qpos, image, env_state, actions, is_pad
|
| 30 |
+
)
|
| 31 |
+
total_kld, dim_wise_kld, mean_kld = kl_divergence(mu, logvar)
|
| 32 |
+
loss_dict = dict()
|
| 33 |
+
|
| 34 |
+
right_actions_joints, right_a_hat_joints = (
|
| 35 |
+
actions[:, :, 0:8],
|
| 36 |
+
a_hat[:, :, 0:8],
|
| 37 |
+
)
|
| 38 |
+
right_actions_gripper, right_a_hat_gripper = (
|
| 39 |
+
actions[:, :, 7],
|
| 40 |
+
a_hat[:, :, 7],
|
| 41 |
+
)
|
| 42 |
+
left_actions_joints, left_a_hat_joints = (
|
| 43 |
+
actions[:, :, 8:16],
|
| 44 |
+
a_hat[:, :, 8:16],
|
| 45 |
+
)
|
| 46 |
+
left_actions_gripper, left_a_hat_gripper = (
|
| 47 |
+
actions[:, :, 15],
|
| 48 |
+
a_hat[:, :, 15],
|
| 49 |
+
)
|
| 50 |
+
|
| 51 |
+
# use L1 loss for joints
|
| 52 |
+
right_l1_loss = F.l1_loss(
|
| 53 |
+
right_a_hat_joints, right_actions_joints, reduction="none"
|
| 54 |
+
)
|
| 55 |
+
right_l1 = (right_l1_loss * ~is_pad.unsqueeze(-1)).mean()
|
| 56 |
+
|
| 57 |
+
left_l1_loss = F.l1_loss(
|
| 58 |
+
left_a_hat_joints, left_actions_joints, reduction="none"
|
| 59 |
+
)
|
| 60 |
+
left_l1 = (left_l1_loss * ~is_pad.unsqueeze(-1)).mean()
|
| 61 |
+
|
| 62 |
+
l1 = right_l1 + left_l1
|
| 63 |
+
|
| 64 |
+
right_gripper_l1_loss = F.l1_loss(
|
| 65 |
+
right_a_hat_gripper, right_actions_gripper, reduction="none"
|
| 66 |
+
)
|
| 67 |
+
right_gripper_l1_loss = (right_gripper_l1_loss * ~is_pad).mean()
|
| 68 |
+
|
| 69 |
+
left_gripper_l1_loss = F.l1_loss(
|
| 70 |
+
left_a_hat_gripper, left_actions_gripper, reduction="none"
|
| 71 |
+
)
|
| 72 |
+
left_gripper_l1_loss = (left_gripper_l1_loss * ~is_pad).mean()
|
| 73 |
+
|
| 74 |
+
gripper_l1 = right_gripper_l1_loss + left_gripper_l1_loss
|
| 75 |
+
loss_dict["right_l1"] = right_l1
|
| 76 |
+
loss_dict["left_l1"] = left_l1
|
| 77 |
+
|
| 78 |
+
loss_dict["l1"] = l1
|
| 79 |
+
loss_dict["gripper_l1"] = gripper_l1
|
| 80 |
+
|
| 81 |
+
loss_dict["kl"] = total_kld[0]
|
| 82 |
+
loss_dict["total_losses"] = (
|
| 83 |
+
loss_dict["l1"] + loss_dict["kl"] * self.kl_weight
|
| 84 |
+
)
|
| 85 |
+
return loss_dict
|
| 86 |
+
else: # inference time
|
| 87 |
+
a_hat, _, (_, _) = self.model(
|
| 88 |
+
qpos, image, env_state
|
| 89 |
+
) # no action, sample from prior
|
| 90 |
+
return a_hat
|
| 91 |
+
|
| 92 |
+
def configure_optimizers(self):
|
| 93 |
+
return self.optimizer
|
| 94 |
+
|
| 95 |
+
|
| 96 |
+
class CNNMLPPolicy(nn.Module):
|
| 97 |
+
def __init__(self, args):
|
| 98 |
+
super().__init__()
|
| 99 |
+
model, optimizer = build_CNNMLP_model_and_optimizer(args)
|
| 100 |
+
self.model = model # decoder
|
| 101 |
+
self.optimizer = optimizer
|
| 102 |
+
|
| 103 |
+
def forward(self, qpos, image, actions=None, is_pad=None):
|
| 104 |
+
env_state = None # TODO
|
| 105 |
+
|
| 106 |
+
if actions is not None: # training time
|
| 107 |
+
actions = actions[:, 0]
|
| 108 |
+
a_hat = self.model(qpos, image, env_state, actions)
|
| 109 |
+
mse = F.mse_loss(actions, a_hat)
|
| 110 |
+
loss_dict = dict()
|
| 111 |
+
loss_dict["mse"] = mse
|
| 112 |
+
loss_dict["loss"] = loss_dict["mse"]
|
| 113 |
+
return loss_dict
|
| 114 |
+
else: # inference time
|
| 115 |
+
a_hat = self.model(qpos, image, env_state) # no action, sample from prior
|
| 116 |
+
return a_hat
|
| 117 |
+
|
| 118 |
+
def configure_optimizers(self):
|
| 119 |
+
return self.optimizer
|
| 120 |
+
|
| 121 |
+
|
| 122 |
+
def kl_divergence(mu, logvar):
|
| 123 |
+
batch_size = mu.size(0)
|
| 124 |
+
assert batch_size != 0
|
| 125 |
+
if mu.data.ndimension() == 4:
|
| 126 |
+
mu = mu.view(mu.size(0), mu.size(1))
|
| 127 |
+
if logvar.data.ndimension() == 4:
|
| 128 |
+
logvar = logvar.view(logvar.size(0), logvar.size(1))
|
| 129 |
+
|
| 130 |
+
klds = -0.5 * (1 + logvar - mu.pow(2) - logvar.exp())
|
| 131 |
+
total_kld = klds.sum(1).mean(0, True)
|
| 132 |
+
dimension_wise_kld = klds.mean(0)
|
| 133 |
+
mean_kld = klds.mean(1).mean(0, True)
|
| 134 |
+
|
| 135 |
+
return total_kld, dimension_wise_kld, mean_kld
|
external/peract_bimanual/agents/act_bc_lang/detr/__init__.py
ADDED
|
File without changes
|
external/peract_bimanual/agents/act_bc_lang/detr/build.py
ADDED
|
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
|
| 2 |
+
import argparse
|
| 3 |
+
from pathlib import Path
|
| 4 |
+
|
| 5 |
+
import numpy as np
|
| 6 |
+
import torch
|
| 7 |
+
from .models import build_ACT_model, build_CNNMLP_model
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
def build_ACT_model_and_optimizer(args):
|
| 12 |
+
model = build_ACT_model(args)
|
| 13 |
+
|
| 14 |
+
param_dicts = [
|
| 15 |
+
{"params": [p for n, p in model.named_parameters() if "backbone" not in n and p.requires_grad]},
|
| 16 |
+
{
|
| 17 |
+
"params": [p for n, p in model.named_parameters() if "backbone" in n and p.requires_grad],
|
| 18 |
+
"lr": args.lr_backbone,
|
| 19 |
+
},
|
| 20 |
+
]
|
| 21 |
+
optimizer = torch.optim.AdamW(param_dicts, lr=args.lr,
|
| 22 |
+
weight_decay=args.weight_decay)
|
| 23 |
+
|
| 24 |
+
return model, optimizer
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
def build_CNNMLP_model_and_optimizer(args):
|
| 28 |
+
model = build_CNNMLP_model(args)
|
| 29 |
+
|
| 30 |
+
param_dicts = [
|
| 31 |
+
{"params": [p for n, p in model.named_parameters() if "backbone" not in n and p.requires_grad]},
|
| 32 |
+
{
|
| 33 |
+
"params": [p for n, p in model.named_parameters() if "backbone" in n and p.requires_grad],
|
| 34 |
+
"lr": args.lr_backbone,
|
| 35 |
+
},
|
| 36 |
+
]
|
| 37 |
+
optimizer = torch.optim.AdamW(param_dicts, lr=args.lr,
|
| 38 |
+
weight_decay=args.weight_decay)
|
| 39 |
+
|
| 40 |
+
return model, optimizer
|
| 41 |
+
|
external/peract_bimanual/agents/act_bc_lang/detr/util/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
|
external/peract_bimanual/agents/act_bc_lang/launch_utils.py
ADDED
|
@@ -0,0 +1,456 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Adapted from ARM
|
| 2 |
+
# Source: https://github.com/stepjam/ARM
|
| 3 |
+
# License: https://github.com/stepjam/ARM/LICENSE
|
| 4 |
+
|
| 5 |
+
import logging
|
| 6 |
+
from typing import List
|
| 7 |
+
|
| 8 |
+
import numpy as np
|
| 9 |
+
from omegaconf import DictConfig
|
| 10 |
+
from rlbench.backend.observation import Observation
|
| 11 |
+
from rlbench.observation_config import ObservationConfig
|
| 12 |
+
import rlbench.utils as rlbench_utils
|
| 13 |
+
from rlbench.demo import Demo
|
| 14 |
+
from yarr.replay_buffer.prioritized_replay_buffer import (
|
| 15 |
+
PrioritizedReplayBuffer,
|
| 16 |
+
ObservationElement,
|
| 17 |
+
)
|
| 18 |
+
from yarr.replay_buffer.replay_buffer import ReplayElement, ReplayBuffer
|
| 19 |
+
from yarr.replay_buffer.uniform_replay_buffer import UniformReplayBuffer
|
| 20 |
+
from yarr.replay_buffer.task_uniform_replay_buffer import TaskUniformReplayBuffer
|
| 21 |
+
|
| 22 |
+
from helpers import utils
|
| 23 |
+
from helpers import observation_utils
|
| 24 |
+
from agents.act_bc_lang.act_bc_lang_agent import ActBCLangAgent
|
| 25 |
+
from helpers.custom_rlbench_env import CustomRLBenchEnv
|
| 26 |
+
from helpers.preprocess_agent import PreprocessAgent
|
| 27 |
+
from agents.act_bc_lang.act_policy import ACTPolicy, CNNMLPPolicy
|
| 28 |
+
|
| 29 |
+
import torch
|
| 30 |
+
from torch.multiprocessing import Process, Value, Manager
|
| 31 |
+
from helpers.clip.core.clip import build_model, load_clip, tokenize
|
| 32 |
+
|
| 33 |
+
LOW_DIM_SIZE = 8
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
def create_replay(
|
| 37 |
+
batch_size: int,
|
| 38 |
+
timesteps: int,
|
| 39 |
+
prioritisation: bool,
|
| 40 |
+
task_uniform: bool,
|
| 41 |
+
save_dir: str,
|
| 42 |
+
cameras: list,
|
| 43 |
+
image_size=[128, 128],
|
| 44 |
+
replay_size=3e5,
|
| 45 |
+
prev_action_horizon: int = 1,
|
| 46 |
+
next_action_horizon: int = 1,
|
| 47 |
+
):
|
| 48 |
+
lang_feat_dim = 1024
|
| 49 |
+
|
| 50 |
+
# low_dim_state
|
| 51 |
+
observation_elements = []
|
| 52 |
+
observation_elements.append(
|
| 53 |
+
ObservationElement("low_dim_state", (LOW_DIM_SIZE,), np.float32)
|
| 54 |
+
)
|
| 55 |
+
|
| 56 |
+
# action sequences
|
| 57 |
+
action_seq_sizes = {
|
| 58 |
+
"right_prev_joint_positions": 7,
|
| 59 |
+
"right_prev_gripper_joint_positions": 2,
|
| 60 |
+
"right_prev_gripper_poses": 7,
|
| 61 |
+
"right_next_joint_positions": 7,
|
| 62 |
+
"right_next_gripper_joint_positions": 2,
|
| 63 |
+
"right_next_gripper_poses": 7,
|
| 64 |
+
"left_prev_joint_positions": 7,
|
| 65 |
+
"left_prev_gripper_joint_positions": 2,
|
| 66 |
+
"left_prev_gripper_poses": 7,
|
| 67 |
+
"left_next_joint_positions": 7,
|
| 68 |
+
"left_next_gripper_joint_positions": 2,
|
| 69 |
+
"left_next_gripper_poses": 7,
|
| 70 |
+
}
|
| 71 |
+
|
| 72 |
+
for seq_name, seq_size in action_seq_sizes.items():
|
| 73 |
+
horizon = prev_action_horizon if "prev" in seq_name else next_action_horizon
|
| 74 |
+
observation_elements.append(
|
| 75 |
+
ObservationElement(
|
| 76 |
+
seq_name,
|
| 77 |
+
(
|
| 78 |
+
horizon,
|
| 79 |
+
seq_size,
|
| 80 |
+
),
|
| 81 |
+
np.float32,
|
| 82 |
+
)
|
| 83 |
+
)
|
| 84 |
+
|
| 85 |
+
# action is_pad
|
| 86 |
+
observation_elements.append(
|
| 87 |
+
ObservationElement("is_pad", (next_action_horizon,), np.int32)
|
| 88 |
+
)
|
| 89 |
+
|
| 90 |
+
# rgb, depth, point cloud, intrinsics, extrinsics
|
| 91 |
+
for cname in cameras:
|
| 92 |
+
observation_elements.append(
|
| 93 |
+
ObservationElement(
|
| 94 |
+
"%s_rgb" % cname,
|
| 95 |
+
(
|
| 96 |
+
3,
|
| 97 |
+
*image_size,
|
| 98 |
+
),
|
| 99 |
+
np.float32,
|
| 100 |
+
)
|
| 101 |
+
)
|
| 102 |
+
observation_elements.append(
|
| 103 |
+
ObservationElement("%s_point_cloud" % cname, (3, *image_size), np.float32)
|
| 104 |
+
) # see pyrep/objects/vision_sensor.py on how pointclouds are extracted from depth frames
|
| 105 |
+
observation_elements.append(
|
| 106 |
+
ObservationElement(
|
| 107 |
+
"%s_camera_extrinsics" % cname,
|
| 108 |
+
(
|
| 109 |
+
4,
|
| 110 |
+
4,
|
| 111 |
+
),
|
| 112 |
+
np.float32,
|
| 113 |
+
)
|
| 114 |
+
)
|
| 115 |
+
observation_elements.append(
|
| 116 |
+
ObservationElement(
|
| 117 |
+
"%s_camera_intrinsics" % cname,
|
| 118 |
+
(
|
| 119 |
+
3,
|
| 120 |
+
3,
|
| 121 |
+
),
|
| 122 |
+
np.float32,
|
| 123 |
+
)
|
| 124 |
+
)
|
| 125 |
+
|
| 126 |
+
observation_elements.extend(
|
| 127 |
+
[
|
| 128 |
+
ReplayElement("lang_goal_emb", (lang_feat_dim,), np.float32),
|
| 129 |
+
ReplayElement("task", (), str),
|
| 130 |
+
ReplayElement(
|
| 131 |
+
"lang_goal", (1,), object
|
| 132 |
+
), # language goal string for debugging and visualization
|
| 133 |
+
]
|
| 134 |
+
)
|
| 135 |
+
|
| 136 |
+
extra_replay_elements = [
|
| 137 |
+
ReplayElement("demo", (), bool),
|
| 138 |
+
]
|
| 139 |
+
|
| 140 |
+
replay_buffer = TaskUniformReplayBuffer(
|
| 141 |
+
save_dir=save_dir,
|
| 142 |
+
batch_size=batch_size,
|
| 143 |
+
timesteps=timesteps,
|
| 144 |
+
replay_capacity=int(replay_size),
|
| 145 |
+
action_shape=(8 * 2,),
|
| 146 |
+
action_dtype=np.float32,
|
| 147 |
+
reward_shape=(),
|
| 148 |
+
reward_dtype=np.float32,
|
| 149 |
+
update_horizon=1,
|
| 150 |
+
observation_elements=observation_elements,
|
| 151 |
+
extra_replay_elements=extra_replay_elements,
|
| 152 |
+
)
|
| 153 |
+
return replay_buffer
|
| 154 |
+
|
| 155 |
+
|
| 156 |
+
def _get_action(obs_tp1: Observation):
|
| 157 |
+
quat = utils.normalize_quaternion(obs_tp1.gripper_pose[3:])
|
| 158 |
+
if quat[-1] < 0:
|
| 159 |
+
quat = -quat
|
| 160 |
+
return np.concatenate(
|
| 161 |
+
[obs_tp1.gripper_pose[:3], quat, [float(obs_tp1.gripper_open)]]
|
| 162 |
+
)
|
| 163 |
+
|
| 164 |
+
|
| 165 |
+
def _get_action_seq(
|
| 166 |
+
demo: Demo,
|
| 167 |
+
timestep: int,
|
| 168 |
+
prev_action_horizon: int,
|
| 169 |
+
next_action_horizon: int,
|
| 170 |
+
robot_name: str,
|
| 171 |
+
):
|
| 172 |
+
action_seq = {
|
| 173 |
+
"right_prev_joint_positions": [],
|
| 174 |
+
"right_prev_gripper_joint_positions": [],
|
| 175 |
+
"right_prev_gripper_poses": [],
|
| 176 |
+
"left_prev_joint_positions": [],
|
| 177 |
+
"left_prev_gripper_joint_positions": [],
|
| 178 |
+
"left_prev_gripper_poses": [],
|
| 179 |
+
"right_next_joint_positions": [],
|
| 180 |
+
"right_next_gripper_joint_positions": [],
|
| 181 |
+
"right_next_gripper_poses": [],
|
| 182 |
+
"left_next_joint_positions": [],
|
| 183 |
+
"left_next_gripper_joint_positions": [],
|
| 184 |
+
"left_next_gripper_poses": [],
|
| 185 |
+
"is_pad": [],
|
| 186 |
+
}
|
| 187 |
+
|
| 188 |
+
for prev_t in list(reversed(range(prev_action_horizon))):
|
| 189 |
+
t = timestep - prev_t
|
| 190 |
+
t = max(0, t)
|
| 191 |
+
obs = demo[t]
|
| 192 |
+
|
| 193 |
+
action_seq["right_prev_joint_positions"].append(obs.right.joint_positions)
|
| 194 |
+
action_seq["right_prev_gripper_joint_positions"].append(
|
| 195 |
+
obs.right.gripper_joint_positions
|
| 196 |
+
)
|
| 197 |
+
action_seq["right_prev_gripper_poses"].append(obs.right.gripper_pose)
|
| 198 |
+
action_seq["left_prev_joint_positions"].append(obs.left.joint_positions)
|
| 199 |
+
action_seq["left_prev_gripper_joint_positions"].append(
|
| 200 |
+
obs.left.gripper_joint_positions
|
| 201 |
+
)
|
| 202 |
+
action_seq["left_prev_gripper_poses"].append(obs.left.gripper_pose)
|
| 203 |
+
|
| 204 |
+
action_seq["is_pad"] = np.zeros(next_action_horizon)
|
| 205 |
+
for idx, next_t in enumerate(range(0, next_action_horizon)):
|
| 206 |
+
t = timestep + next_t
|
| 207 |
+
t = min(t, len(demo) - 1)
|
| 208 |
+
obs = demo[t]
|
| 209 |
+
|
| 210 |
+
if timestep + next_t > len(demo) - 1:
|
| 211 |
+
action_seq["is_pad"][idx] = 1
|
| 212 |
+
|
| 213 |
+
action_seq["right_next_joint_positions"].append(obs.right.joint_positions)
|
| 214 |
+
action_seq["right_next_gripper_joint_positions"].append(
|
| 215 |
+
obs.right.gripper_joint_positions
|
| 216 |
+
)
|
| 217 |
+
action_seq["right_next_gripper_poses"].append(obs.right.gripper_pose)
|
| 218 |
+
action_seq["left_next_joint_positions"].append(obs.left.joint_positions)
|
| 219 |
+
action_seq["left_next_gripper_joint_positions"].append(
|
| 220 |
+
obs.left.gripper_joint_positions
|
| 221 |
+
)
|
| 222 |
+
action_seq["left_next_gripper_poses"].append(obs.left.gripper_pose)
|
| 223 |
+
|
| 224 |
+
# convert to numpy arrays
|
| 225 |
+
return {k: np.array(v) for k, v in action_seq.items()}
|
| 226 |
+
|
| 227 |
+
|
| 228 |
+
def _add_keypoints_to_replay(
|
| 229 |
+
step: int,
|
| 230 |
+
cfg: DictConfig,
|
| 231 |
+
task: str,
|
| 232 |
+
replay: ReplayBuffer,
|
| 233 |
+
inital_obs: Observation,
|
| 234 |
+
demo: Demo,
|
| 235 |
+
description: str = "",
|
| 236 |
+
clip_model=None,
|
| 237 |
+
device="cpu",
|
| 238 |
+
):
|
| 239 |
+
cameras = cfg.rlbench.cameras
|
| 240 |
+
robot_name = cfg.method.robot_name
|
| 241 |
+
|
| 242 |
+
prev_action = None
|
| 243 |
+
obs = inital_obs
|
| 244 |
+
all_actions = []
|
| 245 |
+
k = step
|
| 246 |
+
k_tp1 = min(k + 1, len(demo) - 1)
|
| 247 |
+
obs_tp1 = demo[k_tp1]
|
| 248 |
+
|
| 249 |
+
if obs_tp1.is_bimanual and robot_name == "bimanual":
|
| 250 |
+
right_action = _get_action(obs_tp1.right)
|
| 251 |
+
left_action = _get_action(obs_tp1.left)
|
| 252 |
+
action = np.append(right_action, left_action)
|
| 253 |
+
elif robot_name == "unimanual":
|
| 254 |
+
action = _get_action(obs_tp1)
|
| 255 |
+
elif obs_tp1.is_bimanual and robot_name == "right":
|
| 256 |
+
action = _get_action(obs_tp1.right)
|
| 257 |
+
elif obs_tp1.is_bimanual and robot_name == "left":
|
| 258 |
+
action = _get_action(obs_tp1.left)
|
| 259 |
+
else:
|
| 260 |
+
logging.error("Invalid robot name %s", cfg.method.robot_name)
|
| 261 |
+
raise Exception("Invalid robot name.")
|
| 262 |
+
|
| 263 |
+
all_actions.append(action)
|
| 264 |
+
|
| 265 |
+
terminal = k == len(demo) - 1
|
| 266 |
+
reward = float(terminal) if terminal else 0
|
| 267 |
+
|
| 268 |
+
obs_dict = observation_utils.extract_obs(
|
| 269 |
+
obs,
|
| 270 |
+
t=k,
|
| 271 |
+
prev_action=prev_action,
|
| 272 |
+
cameras=cameras,
|
| 273 |
+
episode_length=cfg.rlbench.episode_length,
|
| 274 |
+
robot_name=robot_name,
|
| 275 |
+
)
|
| 276 |
+
|
| 277 |
+
if obs_tp1.is_bimanual and robot_name == "bimanual":
|
| 278 |
+
obs_dict["low_dim_state"] = np.concatenate(
|
| 279 |
+
[obs_dict["right_low_dim_state"], obs_dict["left_low_dim_state"]]
|
| 280 |
+
)
|
| 281 |
+
del obs_dict["right_low_dim_state"]
|
| 282 |
+
del obs_dict["left_low_dim_state"]
|
| 283 |
+
del obs_dict["right_ignore_collisions"]
|
| 284 |
+
del obs_dict["left_ignore_collisions"]
|
| 285 |
+
else:
|
| 286 |
+
del obs_dict["ignore_collisions"]
|
| 287 |
+
|
| 288 |
+
tokens = tokenize([description]).numpy()
|
| 289 |
+
token_tensor = torch.from_numpy(tokens).to(device)
|
| 290 |
+
lang_feats, lang_embs = clip_model.encode_text_with_embeddings(token_tensor)
|
| 291 |
+
obs_dict["lang_goal_emb"] = lang_feats[0].float().detach().cpu().numpy()
|
| 292 |
+
|
| 293 |
+
final_obs = {
|
| 294 |
+
"task": task,
|
| 295 |
+
"lang_goal": np.array([description], dtype=object),
|
| 296 |
+
}
|
| 297 |
+
|
| 298 |
+
action_seq = _get_action_seq(
|
| 299 |
+
demo,
|
| 300 |
+
step,
|
| 301 |
+
cfg.method.prev_action_horizon,
|
| 302 |
+
cfg.method.next_action_horizon,
|
| 303 |
+
robot_name,
|
| 304 |
+
)
|
| 305 |
+
obs_dict.update(action_seq)
|
| 306 |
+
|
| 307 |
+
prev_action = np.copy(action)
|
| 308 |
+
others = {"demo": True}
|
| 309 |
+
others.update(final_obs)
|
| 310 |
+
others.update(obs_dict)
|
| 311 |
+
timeout = False
|
| 312 |
+
replay.add(action, reward, terminal, timeout, **others)
|
| 313 |
+
|
| 314 |
+
return all_actions
|
| 315 |
+
|
| 316 |
+
|
| 317 |
+
def fill_replay(
|
| 318 |
+
cfg: DictConfig,
|
| 319 |
+
obs_config: ObservationConfig,
|
| 320 |
+
rank: int,
|
| 321 |
+
replay: ReplayBuffer,
|
| 322 |
+
task: str,
|
| 323 |
+
num_demos: int,
|
| 324 |
+
demo_augmentation: bool,
|
| 325 |
+
demo_augmentation_every_n: int,
|
| 326 |
+
cameras: List[str],
|
| 327 |
+
clip_model=None,
|
| 328 |
+
device="cpu",
|
| 329 |
+
):
|
| 330 |
+
if clip_model is None:
|
| 331 |
+
model, _ = load_clip("RN50", jit=False, device=device)
|
| 332 |
+
clip_model = build_model(model.state_dict())
|
| 333 |
+
clip_model.to(device)
|
| 334 |
+
del model
|
| 335 |
+
|
| 336 |
+
logging.debug("Filling %s replay ..." % task)
|
| 337 |
+
all_actions = []
|
| 338 |
+
for d_idx in range(num_demos):
|
| 339 |
+
# load demo from disk
|
| 340 |
+
demo = rlbench_utils.get_stored_demos(
|
| 341 |
+
amount=1,
|
| 342 |
+
image_paths=False,
|
| 343 |
+
dataset_root=cfg.rlbench.demo_path,
|
| 344 |
+
variation_number=-1,
|
| 345 |
+
task_name=task,
|
| 346 |
+
obs_config=obs_config,
|
| 347 |
+
random_selection=False,
|
| 348 |
+
from_episode_number=d_idx,
|
| 349 |
+
)[0]
|
| 350 |
+
|
| 351 |
+
descs = demo._observations[0].misc["descriptions"]
|
| 352 |
+
|
| 353 |
+
if rank == 0:
|
| 354 |
+
logging.info(f"Loading Demo({d_idx})")
|
| 355 |
+
|
| 356 |
+
for i in range(len(demo) - 1):
|
| 357 |
+
obs = demo[i]
|
| 358 |
+
desc = descs[0]
|
| 359 |
+
|
| 360 |
+
# stopped = np.allclose(obs.joint_velocities, 0, atol=0.1)
|
| 361 |
+
# if stopped:
|
| 362 |
+
# continue
|
| 363 |
+
|
| 364 |
+
all_actions.extend(
|
| 365 |
+
_add_keypoints_to_replay(
|
| 366 |
+
i,
|
| 367 |
+
cfg,
|
| 368 |
+
task,
|
| 369 |
+
replay,
|
| 370 |
+
obs,
|
| 371 |
+
demo,
|
| 372 |
+
description=desc,
|
| 373 |
+
clip_model=clip_model,
|
| 374 |
+
device=device,
|
| 375 |
+
)
|
| 376 |
+
)
|
| 377 |
+
logging.debug("Replay filled with demos.")
|
| 378 |
+
return all_actions
|
| 379 |
+
|
| 380 |
+
|
| 381 |
+
def fill_multi_task_replay(
|
| 382 |
+
cfg: DictConfig,
|
| 383 |
+
obs_config: ObservationConfig,
|
| 384 |
+
rank: int,
|
| 385 |
+
replay: ReplayBuffer,
|
| 386 |
+
tasks: List[str],
|
| 387 |
+
num_demos: int,
|
| 388 |
+
demo_augmentation: bool,
|
| 389 |
+
demo_augmentation_every_n: int,
|
| 390 |
+
cameras: List[str],
|
| 391 |
+
clip_model=None,
|
| 392 |
+
):
|
| 393 |
+
manager = Manager()
|
| 394 |
+
store = manager.dict()
|
| 395 |
+
|
| 396 |
+
# create a MP dict for storing indicies
|
| 397 |
+
# TODO(mohit): this shouldn't be initialized here
|
| 398 |
+
del replay._task_idxs
|
| 399 |
+
task_idxs = manager.dict()
|
| 400 |
+
replay._task_idxs = task_idxs
|
| 401 |
+
replay._create_storage(store)
|
| 402 |
+
replay.add_count = Value("i", 0)
|
| 403 |
+
|
| 404 |
+
# fill replay buffer in parallel across tasks
|
| 405 |
+
max_parallel_processes = cfg.replay.max_parallel_processes
|
| 406 |
+
processes = []
|
| 407 |
+
n = np.arange(len(tasks))
|
| 408 |
+
split_n = utils.split_list(n, max_parallel_processes)
|
| 409 |
+
for split in split_n:
|
| 410 |
+
for e_idx, task_idx in enumerate(split):
|
| 411 |
+
task = tasks[int(task_idx)]
|
| 412 |
+
model_device = torch.device(
|
| 413 |
+
"cuda:%s" % (e_idx % torch.cuda.device_count())
|
| 414 |
+
if torch.cuda.is_available()
|
| 415 |
+
else "cpu"
|
| 416 |
+
)
|
| 417 |
+
p = Process(
|
| 418 |
+
target=fill_replay,
|
| 419 |
+
args=(
|
| 420 |
+
cfg,
|
| 421 |
+
obs_config,
|
| 422 |
+
rank,
|
| 423 |
+
replay,
|
| 424 |
+
task,
|
| 425 |
+
num_demos,
|
| 426 |
+
demo_augmentation,
|
| 427 |
+
demo_augmentation_every_n,
|
| 428 |
+
cameras,
|
| 429 |
+
clip_model,
|
| 430 |
+
model_device,
|
| 431 |
+
),
|
| 432 |
+
)
|
| 433 |
+
p.start()
|
| 434 |
+
processes.append(p)
|
| 435 |
+
|
| 436 |
+
for p in processes:
|
| 437 |
+
p.join()
|
| 438 |
+
|
| 439 |
+
logging.debug("Replay filled with multi demos.")
|
| 440 |
+
|
| 441 |
+
|
| 442 |
+
def create_agent(cfg: DictConfig):
|
| 443 |
+
actor_net = ACTPolicy(cfg.method)
|
| 444 |
+
|
| 445 |
+
bc_agent = ActBCLangAgent(
|
| 446 |
+
actor_network=actor_net,
|
| 447 |
+
camera_names=cfg.rlbench.cameras,
|
| 448 |
+
lr=cfg.method.lr,
|
| 449 |
+
weight_decay=cfg.method.weight_decay,
|
| 450 |
+
grad_clip=cfg.method.grad_clip,
|
| 451 |
+
episode_length=cfg.rlbench.episode_length,
|
| 452 |
+
train_demo_path=cfg.method.train_demo_path,
|
| 453 |
+
task_name=cfg.rlbench.tasks[0],
|
| 454 |
+
)
|
| 455 |
+
|
| 456 |
+
return PreprocessAgent(pose_agent=bc_agent, norm_type="imagenet")
|
external/peract_bimanual/agents/agent_factory.py
ADDED
|
@@ -0,0 +1,111 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import logging
|
| 3 |
+
|
| 4 |
+
from omegaconf import DictConfig
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
from yarr.agents.agent import BimanualAgent
|
| 8 |
+
from yarr.agents.agent import LeaderFollowerAgent
|
| 9 |
+
from yarr.agents.agent import Agent
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
supported_agents = {
|
| 13 |
+
"leader_follower": ("PERACT_BC", "RVT"),
|
| 14 |
+
"independent": ("PERACT_BC", "RVT"),
|
| 15 |
+
"bimanual": ("BIMANUAL_PERACT", "ACT_BC_LANG"),
|
| 16 |
+
"unimanual": (),
|
| 17 |
+
}
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
def create_agent(cfg: DictConfig) -> Agent:
|
| 21 |
+
method_name = cfg.method.name
|
| 22 |
+
agent_type = cfg.method.agent_type
|
| 23 |
+
|
| 24 |
+
logging.info("Using method %s with type %s", method_name, agent_type)
|
| 25 |
+
|
| 26 |
+
assert method_name in supported_agents[agent_type]
|
| 27 |
+
|
| 28 |
+
agent_fn = agent_fn_by_name(method_name)
|
| 29 |
+
|
| 30 |
+
if agent_type == "leader_follower":
|
| 31 |
+
checkpoint_name_prefix = cfg.framework.checkpoint_name_prefix
|
| 32 |
+
cfg.method.robot_name = "right"
|
| 33 |
+
cfg.framework.checkpoint_name_prefix = (
|
| 34 |
+
f"{checkpoint_name_prefix}_{method_name.lower()}_leader"
|
| 35 |
+
)
|
| 36 |
+
leader_agent = agent_fn(cfg)
|
| 37 |
+
|
| 38 |
+
cfg.method.robot_name = "left"
|
| 39 |
+
cfg.framework.checkpoint_name_prefix = (
|
| 40 |
+
f"{checkpoint_name_prefix}_{method_name.lower()}_follower"
|
| 41 |
+
)
|
| 42 |
+
cfg.method.low_dim_size = (
|
| 43 |
+
cfg.method.low_dim_size + 8
|
| 44 |
+
) # also add the action size
|
| 45 |
+
follower_agent = agent_fn(cfg)
|
| 46 |
+
|
| 47 |
+
cfg.method.robot_name = "bimanual"
|
| 48 |
+
|
| 49 |
+
return LeaderFollowerAgent(leader_agent, follower_agent)
|
| 50 |
+
|
| 51 |
+
elif agent_type == "independent":
|
| 52 |
+
checkpoint_name_prefix = cfg.framework.checkpoint_name_prefix
|
| 53 |
+
cfg.method.robot_name = "right"
|
| 54 |
+
cfg.framework.checkpoint_name_prefix = (
|
| 55 |
+
f"{checkpoint_name_prefix}_{method_name.lower()}_right"
|
| 56 |
+
)
|
| 57 |
+
right_agent = agent_fn(cfg)
|
| 58 |
+
|
| 59 |
+
cfg.method.robot_name = "left"
|
| 60 |
+
cfg.framework.checkpoint_name_prefix = (
|
| 61 |
+
f"{checkpoint_name_prefix}_{method_name.lower()}_left"
|
| 62 |
+
)
|
| 63 |
+
left_agent = agent_fn(cfg)
|
| 64 |
+
|
| 65 |
+
cfg.method.robot_name = "bimanual"
|
| 66 |
+
|
| 67 |
+
return BimanualAgent(right_agent, left_agent)
|
| 68 |
+
elif agent_type == "bimanual" or agent_type == "unimanual":
|
| 69 |
+
return agent_fn(cfg)
|
| 70 |
+
else:
|
| 71 |
+
raise Exception("invalid agent type")
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
def agent_fn_by_name(method_name: str) -> Agent:
|
| 75 |
+
if method_name == "ARM":
|
| 76 |
+
from agents import arm
|
| 77 |
+
|
| 78 |
+
raise NotImplementedError("ARM not yet supported for eval.py")
|
| 79 |
+
elif method_name == "BC_LANG":
|
| 80 |
+
from agents.baselines import bc_lang
|
| 81 |
+
|
| 82 |
+
return bc_lang.launch_utils.create_agent
|
| 83 |
+
elif method_name == "VIT_BC_LANG":
|
| 84 |
+
from agents.baselines import vit_bc_lang
|
| 85 |
+
|
| 86 |
+
return vit_bc_lang.launch_utils.create_agent
|
| 87 |
+
elif method_name == "C2FARM_LINGUNET_BC":
|
| 88 |
+
from agents import c2farm_lingunet_bc
|
| 89 |
+
|
| 90 |
+
return c2farm_lingunet_bc.launch_utils.create_agent
|
| 91 |
+
elif method_name.startswith("PERACT_BC"):
|
| 92 |
+
from agents import peract_bc
|
| 93 |
+
|
| 94 |
+
return peract_bc.launch_utils.create_agent
|
| 95 |
+
elif method_name.startswith("BIMANUAL_PERACT"):
|
| 96 |
+
from agents import bimanual_peract
|
| 97 |
+
|
| 98 |
+
return bimanual_peract.launch_utils.create_agent
|
| 99 |
+
elif method_name.startswith("RVT"):
|
| 100 |
+
from agents import rvt
|
| 101 |
+
|
| 102 |
+
return rvt.launch_utils.create_agent
|
| 103 |
+
elif method_name.startswith("ACT_BC_LANG"):
|
| 104 |
+
from agents import act_bc_lang
|
| 105 |
+
|
| 106 |
+
return act_bc_lang.launch_utils.create_agent
|
| 107 |
+
elif method_name == "PERACT_RL":
|
| 108 |
+
raise NotImplementedError("PERACT_RL not yet supported for eval.py")
|
| 109 |
+
|
| 110 |
+
else:
|
| 111 |
+
raise ValueError("Method %s does not exists." % method_name)
|
external/peract_bimanual/agents/arm/launch_utils.py
ADDED
|
@@ -0,0 +1,441 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import copy
|
| 2 |
+
import logging
|
| 3 |
+
from typing import List
|
| 4 |
+
|
| 5 |
+
import numpy as np
|
| 6 |
+
import torch
|
| 7 |
+
import torch.nn as nn
|
| 8 |
+
from rlbench.backend.observation import Observation
|
| 9 |
+
from rlbench.demo import Demo
|
| 10 |
+
from yarr.replay_buffer.prioritized_replay_buffer import (
|
| 11 |
+
PrioritizedReplayBuffer,
|
| 12 |
+
ObservationElement,
|
| 13 |
+
)
|
| 14 |
+
from yarr.replay_buffer.replay_buffer import ReplayElement, ReplayBuffer
|
| 15 |
+
from yarr.replay_buffer.uniform_replay_buffer import UniformReplayBuffer
|
| 16 |
+
|
| 17 |
+
from helpers import demo_loading_utils, utils
|
| 18 |
+
from helpers.custom_rlbench_env import CustomRLBenchEnv
|
| 19 |
+
from helpers.network_utils import (
|
| 20 |
+
SiameseNet,
|
| 21 |
+
DenseBlock,
|
| 22 |
+
Conv2DBlock,
|
| 23 |
+
Conv2DUpsampleBlock,
|
| 24 |
+
)
|
| 25 |
+
from helpers.preprocess_agent import PreprocessAgent
|
| 26 |
+
from agents.arm.next_best_pose_agent import NextBestPoseAgent
|
| 27 |
+
from agents.arm.qattention_agent import QAttentionAgent
|
| 28 |
+
|
| 29 |
+
REWARD_SCALE = 100.0
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
def create_replay(
|
| 33 |
+
batch_size: int,
|
| 34 |
+
timesteps: int,
|
| 35 |
+
prioritisation: bool,
|
| 36 |
+
save_dir: str,
|
| 37 |
+
cameras: list,
|
| 38 |
+
env: CustomRLBenchEnv,
|
| 39 |
+
):
|
| 40 |
+
observation_elements = env.observation_elements
|
| 41 |
+
for cname in cameras:
|
| 42 |
+
observation_elements.extend(
|
| 43 |
+
[
|
| 44 |
+
ObservationElement("%s_pixel_coord" % cname, (2,), np.int32),
|
| 45 |
+
]
|
| 46 |
+
)
|
| 47 |
+
|
| 48 |
+
replay_class = UniformReplayBuffer
|
| 49 |
+
if prioritisation:
|
| 50 |
+
replay_class = PrioritizedReplayBuffer
|
| 51 |
+
replay_buffer = replay_class(
|
| 52 |
+
save_dir=save_dir,
|
| 53 |
+
batch_size=batch_size,
|
| 54 |
+
timesteps=timesteps,
|
| 55 |
+
replay_capacity=int(1e5),
|
| 56 |
+
action_shape=(8,),
|
| 57 |
+
action_dtype=np.float32,
|
| 58 |
+
reward_shape=(),
|
| 59 |
+
reward_dtype=np.float32,
|
| 60 |
+
update_horizon=1,
|
| 61 |
+
observation_elements=observation_elements,
|
| 62 |
+
extra_replay_elements=[ReplayElement("demo", (), np.bool)],
|
| 63 |
+
)
|
| 64 |
+
return replay_buffer
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
def _point_to_pixel_index(
|
| 68 |
+
point: np.ndarray, extrinsics: np.ndarray, intrinsics: np.ndarray
|
| 69 |
+
):
|
| 70 |
+
point = np.array([point[0], point[1], point[2], 1])
|
| 71 |
+
world_to_cam = np.linalg.inv(extrinsics)
|
| 72 |
+
point_in_cam_frame = world_to_cam.dot(point)
|
| 73 |
+
px, py, pz = point_in_cam_frame[:3]
|
| 74 |
+
px = 2 * intrinsics[0, 2] - int(-intrinsics[0, 0] * (px / pz) + intrinsics[0, 2])
|
| 75 |
+
py = 2 * intrinsics[1, 2] - int(-intrinsics[1, 1] * (py / pz) + intrinsics[1, 2])
|
| 76 |
+
return px, py
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
def _get_action(obs_tp1: Observation):
|
| 80 |
+
quat = utils.normalize_quaternion(obs_tp1.gripper_pose[3:])
|
| 81 |
+
if quat[-1] < 0:
|
| 82 |
+
quat = -quat
|
| 83 |
+
return np.concatenate(
|
| 84 |
+
[obs_tp1.gripper_pose[:3], quat, [float(obs_tp1.gripper_open)]]
|
| 85 |
+
)
|
| 86 |
+
|
| 87 |
+
|
| 88 |
+
def _add_keypoints_to_replay(
|
| 89 |
+
replay: ReplayBuffer,
|
| 90 |
+
inital_obs: Observation,
|
| 91 |
+
demo: Demo,
|
| 92 |
+
env: CustomRLBenchEnv,
|
| 93 |
+
episode_keypoints: List[int],
|
| 94 |
+
cameras: List[str],
|
| 95 |
+
):
|
| 96 |
+
prev_action = None
|
| 97 |
+
obs = inital_obs
|
| 98 |
+
all_actions = []
|
| 99 |
+
for k, keypoint in enumerate(episode_keypoints):
|
| 100 |
+
obs_tp1 = demo[keypoint]
|
| 101 |
+
action = _get_action(obs_tp1)
|
| 102 |
+
all_actions.append(action)
|
| 103 |
+
terminal = k == len(episode_keypoints) - 1
|
| 104 |
+
reward = float(terminal) * REWARD_SCALE if terminal else 0
|
| 105 |
+
obs_dict = env.extract_obs(obs, t=k, prev_action=prev_action)
|
| 106 |
+
prev_action = np.copy(action)
|
| 107 |
+
others = {"demo": True}
|
| 108 |
+
final_obs = {}
|
| 109 |
+
for name in cameras:
|
| 110 |
+
px, py = _point_to_pixel_index(
|
| 111 |
+
obs_tp1.gripper_pose[:3],
|
| 112 |
+
obs_tp1.misc["%s_camera_extrinsics" % name],
|
| 113 |
+
obs_tp1.misc["%s_camera_intrinsics" % name],
|
| 114 |
+
)
|
| 115 |
+
final_obs["%s_pixel_coord" % name] = [py, px]
|
| 116 |
+
others.update(final_obs)
|
| 117 |
+
others.update(obs_dict)
|
| 118 |
+
timeout = False
|
| 119 |
+
replay.add(action, reward, terminal, timeout, **others)
|
| 120 |
+
obs = obs_tp1 # Set the next obs
|
| 121 |
+
# Final step
|
| 122 |
+
obs_dict_tp1 = env.extract_obs(obs_tp1, t=k + 1, prev_action=prev_action)
|
| 123 |
+
obs_dict_tp1.update(final_obs)
|
| 124 |
+
replay.add_final(**obs_dict_tp1)
|
| 125 |
+
return all_actions
|
| 126 |
+
|
| 127 |
+
|
| 128 |
+
def fill_replay(
|
| 129 |
+
replay: ReplayBuffer,
|
| 130 |
+
task: str,
|
| 131 |
+
env: CustomRLBenchEnv,
|
| 132 |
+
num_demos: int,
|
| 133 |
+
demo_augmentation: bool,
|
| 134 |
+
demo_augmentation_every_n: int,
|
| 135 |
+
cameras: List[str],
|
| 136 |
+
):
|
| 137 |
+
logging.info("Filling replay with demos...")
|
| 138 |
+
all_actions = []
|
| 139 |
+
for d_idx in range(num_demos):
|
| 140 |
+
demo = env.env.get_demos(
|
| 141 |
+
task,
|
| 142 |
+
1,
|
| 143 |
+
variation_number=0,
|
| 144 |
+
random_selection=False,
|
| 145 |
+
from_episode_number=d_idx,
|
| 146 |
+
)[0]
|
| 147 |
+
episode_keypoints = demo_loading_utils.keypoint_discovery(demo)
|
| 148 |
+
|
| 149 |
+
for i in range(len(demo) - 1):
|
| 150 |
+
if not demo_augmentation and i > 0:
|
| 151 |
+
break
|
| 152 |
+
if i % demo_augmentation_every_n != 0:
|
| 153 |
+
continue
|
| 154 |
+
obs = demo[i]
|
| 155 |
+
# If our starting point is past one of the keypoints, then remove it
|
| 156 |
+
while len(episode_keypoints) > 0 and i >= episode_keypoints[0]:
|
| 157 |
+
episode_keypoints = episode_keypoints[1:]
|
| 158 |
+
if len(episode_keypoints) == 0:
|
| 159 |
+
break
|
| 160 |
+
all_actions.extend(
|
| 161 |
+
_add_keypoints_to_replay(
|
| 162 |
+
replay, obs, demo, env, episode_keypoints, cameras
|
| 163 |
+
)
|
| 164 |
+
)
|
| 165 |
+
logging.info("Replay filled with demos.")
|
| 166 |
+
return all_actions
|
| 167 |
+
|
| 168 |
+
|
| 169 |
+
class SharedNet(nn.Module):
|
| 170 |
+
def __init__(self, activation: str, norm: str = None):
|
| 171 |
+
super(SharedNet, self).__init__()
|
| 172 |
+
self._activation = activation
|
| 173 |
+
self._norm = norm
|
| 174 |
+
|
| 175 |
+
def build(self):
|
| 176 |
+
self._rgb_pre = nn.Sequential(
|
| 177 |
+
Conv2DBlock(3, 32, 3, 1, activation=self._activation, norm=self._norm),
|
| 178 |
+
)
|
| 179 |
+
self._pcd_pre = nn.Sequential(
|
| 180 |
+
Conv2DBlock(3, 32, 3, 1, activation=self._activation, norm=self._norm),
|
| 181 |
+
)
|
| 182 |
+
|
| 183 |
+
def forward(self, observations):
|
| 184 |
+
x_rgb, x_pcd = self._rgb_pre(observations[0]), self._pcd_pre(observations[1])
|
| 185 |
+
x = torch.cat([x_rgb, x_pcd], dim=1)
|
| 186 |
+
return x
|
| 187 |
+
|
| 188 |
+
|
| 189 |
+
class ActorNet(nn.Module):
|
| 190 |
+
def __init__(self, activation: str, low_dim_size: int, norm: str = None):
|
| 191 |
+
super(ActorNet, self).__init__()
|
| 192 |
+
self._activation = activation
|
| 193 |
+
self._low_dim_size = low_dim_size
|
| 194 |
+
self._norm = norm
|
| 195 |
+
|
| 196 |
+
def build(self):
|
| 197 |
+
self._convs = nn.Sequential(
|
| 198 |
+
Conv2DBlock(
|
| 199 |
+
64 + self._low_dim_size,
|
| 200 |
+
64,
|
| 201 |
+
1,
|
| 202 |
+
1,
|
| 203 |
+
activation=self._activation,
|
| 204 |
+
norm=self._norm,
|
| 205 |
+
),
|
| 206 |
+
Conv2DBlock(64, 64, 3, 1, activation=self._activation, norm=self._norm),
|
| 207 |
+
)
|
| 208 |
+
self._fcs = nn.Sequential(
|
| 209 |
+
DenseBlock(64, 64, activation=self._activation),
|
| 210 |
+
DenseBlock(64, 64, activation=self._activation),
|
| 211 |
+
DenseBlock(64, 8 * 2),
|
| 212 |
+
)
|
| 213 |
+
self._maxp = nn.AdaptiveMaxPool2d(1)
|
| 214 |
+
|
| 215 |
+
def forward(self, observation_feats, low_dim_ins):
|
| 216 |
+
low_dim_feats = low_dim_ins
|
| 217 |
+
_, _, h, w = observation_feats.shape
|
| 218 |
+
low_dim_feats = low_dim_feats.unsqueeze(-1).unsqueeze(-1).repeat(1, 1, h, w)
|
| 219 |
+
x = torch.cat([observation_feats, low_dim_feats], dim=1)
|
| 220 |
+
x = self._convs(x)
|
| 221 |
+
x = self._maxp(x).squeeze(-1).squeeze(-1)
|
| 222 |
+
x = self._fcs(x)
|
| 223 |
+
return x
|
| 224 |
+
|
| 225 |
+
|
| 226 |
+
class CriticNet(nn.Module):
|
| 227 |
+
def __init__(
|
| 228 |
+
self, activation: str, low_dim_size: int, norm: str = None, q_conf: bool = True
|
| 229 |
+
):
|
| 230 |
+
super(CriticNet, self).__init__()
|
| 231 |
+
self._activation = activation
|
| 232 |
+
self._low_dim_size = low_dim_size
|
| 233 |
+
self._norm = norm
|
| 234 |
+
self._q_conf = q_conf
|
| 235 |
+
|
| 236 |
+
def build(self):
|
| 237 |
+
self._convs = nn.Sequential(
|
| 238 |
+
Conv2DBlock(
|
| 239 |
+
64 + self._low_dim_size, 128, 3, 1, self._norm, self._activation
|
| 240 |
+
),
|
| 241 |
+
Conv2DBlock(128, 128, 3, 1, self._norm, self._activation),
|
| 242 |
+
Conv2DBlock(128, 128, 3, 1, self._norm, self._activation),
|
| 243 |
+
Conv2DBlock(128, 128, 3, 1, self._norm, self._activation),
|
| 244 |
+
)
|
| 245 |
+
if self._q_conf:
|
| 246 |
+
self._final_conv = Conv2DBlock(128, 2, 3, 1)
|
| 247 |
+
else:
|
| 248 |
+
self._maxp = nn.AdaptiveMaxPool2d(1)
|
| 249 |
+
self._fcs = nn.Sequential(
|
| 250 |
+
DenseBlock(128, 64, activation=self._activation),
|
| 251 |
+
DenseBlock(64, 1),
|
| 252 |
+
)
|
| 253 |
+
|
| 254 |
+
def forward(self, observation_feats, low_dim_ins):
|
| 255 |
+
low_dim_feats = low_dim_ins
|
| 256 |
+
_, _, h, w = observation_feats.shape
|
| 257 |
+
low_dim_feats = low_dim_feats.unsqueeze(-1).unsqueeze(-1).repeat(1, 1, h, w)
|
| 258 |
+
x = torch.cat([observation_feats, low_dim_feats], dim=1)
|
| 259 |
+
x = self._convs(x)
|
| 260 |
+
if self._q_conf:
|
| 261 |
+
x = self._final_conv(x)
|
| 262 |
+
x[:, 1] = torch.sigmoid(x[:, 1])
|
| 263 |
+
else:
|
| 264 |
+
x = self._maxp(x).squeeze(-1).squeeze(-1)
|
| 265 |
+
x = self._fcs(x)
|
| 266 |
+
return x
|
| 267 |
+
|
| 268 |
+
|
| 269 |
+
class Qattention2DNet(nn.Module):
|
| 270 |
+
def __init__(
|
| 271 |
+
self,
|
| 272 |
+
siamese_net: SiameseNet,
|
| 273 |
+
filters: List[int],
|
| 274 |
+
kernel_sizes: List[int],
|
| 275 |
+
strides: List[int],
|
| 276 |
+
low_dim_state_len: int,
|
| 277 |
+
norm: str = None,
|
| 278 |
+
activation: str = "relu",
|
| 279 |
+
output_channels: int = 1,
|
| 280 |
+
skip_connections: bool = True,
|
| 281 |
+
):
|
| 282 |
+
super(Qattention2DNet, self).__init__()
|
| 283 |
+
self._siamese_net = copy.deepcopy(siamese_net)
|
| 284 |
+
self._input_channels = self._siamese_net.output_channels + low_dim_state_len
|
| 285 |
+
self._filters = filters
|
| 286 |
+
self._kernel_sizes = kernel_sizes
|
| 287 |
+
self._strides = strides
|
| 288 |
+
self._norm = norm
|
| 289 |
+
self._activation = activation
|
| 290 |
+
self._output_channels = output_channels
|
| 291 |
+
self._skip_connections = skip_connections
|
| 292 |
+
self._build_calls = 0
|
| 293 |
+
|
| 294 |
+
def build(self):
|
| 295 |
+
self._build_calls += 1
|
| 296 |
+
if self._build_calls != 1:
|
| 297 |
+
raise RuntimeError("Build needs to be called once.")
|
| 298 |
+
self._siamese_net.build()
|
| 299 |
+
self._down = []
|
| 300 |
+
ch = self._input_channels
|
| 301 |
+
for filt, ksize, stride in zip(
|
| 302 |
+
self._filters, self._kernel_sizes, self._strides
|
| 303 |
+
):
|
| 304 |
+
conv_block = Conv2DBlock(
|
| 305 |
+
ch,
|
| 306 |
+
filt,
|
| 307 |
+
ksize,
|
| 308 |
+
stride,
|
| 309 |
+
self._norm,
|
| 310 |
+
self._activation,
|
| 311 |
+
padding_mode="replicate",
|
| 312 |
+
)
|
| 313 |
+
ch = filt
|
| 314 |
+
self._down.append(conv_block)
|
| 315 |
+
self._down = nn.ModuleList(self._down)
|
| 316 |
+
|
| 317 |
+
reverse_conv_data = list(zip(self._filters, self._kernel_sizes, self._strides))
|
| 318 |
+
reverse_conv_data.reverse()
|
| 319 |
+
|
| 320 |
+
self._up = []
|
| 321 |
+
for i, (filt, ksize, stride) in enumerate(reverse_conv_data):
|
| 322 |
+
if i > 0 and self._skip_connections:
|
| 323 |
+
ch += reverse_conv_data[-i - 1][0]
|
| 324 |
+
convt_block = Conv2DUpsampleBlock(
|
| 325 |
+
ch, filt, ksize, stride, self._norm, self._activation
|
| 326 |
+
)
|
| 327 |
+
ch = filt
|
| 328 |
+
self._up.append(convt_block)
|
| 329 |
+
self._up = nn.ModuleList(self._up)
|
| 330 |
+
|
| 331 |
+
self._final_conv = Conv2DBlock(
|
| 332 |
+
ch, self._output_channels, 3, 1, padding_mode="replicate"
|
| 333 |
+
)
|
| 334 |
+
|
| 335 |
+
def forward(self, observations, low_dim_ins):
|
| 336 |
+
x = self._siamese_net(observations)
|
| 337 |
+
_, _, h, w = x.shape
|
| 338 |
+
if low_dim_ins is not None:
|
| 339 |
+
low_dim_latents = low_dim_ins.unsqueeze(-1).unsqueeze(-1).repeat(1, 1, h, w)
|
| 340 |
+
x = torch.cat([x, low_dim_latents], dim=1)
|
| 341 |
+
self.ups = []
|
| 342 |
+
self.downs = []
|
| 343 |
+
layers_for_skip = []
|
| 344 |
+
for l in self._down:
|
| 345 |
+
x = l(x)
|
| 346 |
+
layers_for_skip.append(x)
|
| 347 |
+
self.downs.append(x)
|
| 348 |
+
self.latent = x
|
| 349 |
+
layers_for_skip.reverse()
|
| 350 |
+
for i, l in enumerate(self._up):
|
| 351 |
+
if i > 0 and self._skip_connections:
|
| 352 |
+
# Skip connections. Skip the first up layer.
|
| 353 |
+
x = torch.cat([layers_for_skip[i], x], 1)
|
| 354 |
+
x = l(x)
|
| 355 |
+
self.ups.append(x)
|
| 356 |
+
x = self._final_conv(x)
|
| 357 |
+
return x
|
| 358 |
+
|
| 359 |
+
|
| 360 |
+
def create_agent(
|
| 361 |
+
camera_name: str,
|
| 362 |
+
activation: str,
|
| 363 |
+
q_conf: bool,
|
| 364 |
+
action_min_max,
|
| 365 |
+
alpha,
|
| 366 |
+
alpha_lr,
|
| 367 |
+
alpha_auto_tune,
|
| 368 |
+
critic_lr,
|
| 369 |
+
actor_lr,
|
| 370 |
+
next_best_pose_critic_weight_decay,
|
| 371 |
+
next_best_pose_actor_weight_decay,
|
| 372 |
+
crop_shape,
|
| 373 |
+
next_best_pose_tau,
|
| 374 |
+
next_best_pose_critic_grad_clip,
|
| 375 |
+
next_best_pose_actor_grad_clip,
|
| 376 |
+
qattention_tau,
|
| 377 |
+
qattention_lr,
|
| 378 |
+
qattention_weight_decay,
|
| 379 |
+
qattention_lambda_qreg,
|
| 380 |
+
low_dim_state_len,
|
| 381 |
+
qattention_grad_clip,
|
| 382 |
+
):
|
| 383 |
+
siamese_net = SiameseNet(
|
| 384 |
+
input_channels=[3, 3],
|
| 385 |
+
filters=[8],
|
| 386 |
+
kernel_sizes=[5],
|
| 387 |
+
strides=[1],
|
| 388 |
+
activation=activation,
|
| 389 |
+
norm=None,
|
| 390 |
+
)
|
| 391 |
+
qattention_net = Qattention2DNet(
|
| 392 |
+
siamese_net=siamese_net,
|
| 393 |
+
filters=[16, 16],
|
| 394 |
+
kernel_sizes=[5, 5],
|
| 395 |
+
strides=[2, 2],
|
| 396 |
+
output_channels=1,
|
| 397 |
+
norm=None,
|
| 398 |
+
activation=activation,
|
| 399 |
+
skip_connections=True,
|
| 400 |
+
low_dim_state_len=0,
|
| 401 |
+
)
|
| 402 |
+
|
| 403 |
+
qattention_agent = QAttentionAgent(
|
| 404 |
+
pixel_unet=qattention_net,
|
| 405 |
+
tau=qattention_tau,
|
| 406 |
+
camera_name=camera_name,
|
| 407 |
+
lr=qattention_lr,
|
| 408 |
+
weight_decay=qattention_weight_decay,
|
| 409 |
+
lambda_qreg=qattention_lambda_qreg,
|
| 410 |
+
include_low_dim_state=False,
|
| 411 |
+
grad_clip=qattention_grad_clip,
|
| 412 |
+
)
|
| 413 |
+
|
| 414 |
+
shared_net = SharedNet(activation, norm="layer")
|
| 415 |
+
critic_net = CriticNet(
|
| 416 |
+
activation, low_dim_state_len + 8, norm="layer", q_conf=q_conf
|
| 417 |
+
)
|
| 418 |
+
actor_net = ActorNet(activation, low_dim_state_len)
|
| 419 |
+
|
| 420 |
+
next_best_pose_agent = NextBestPoseAgent(
|
| 421 |
+
qattention_agent=qattention_agent,
|
| 422 |
+
shared_network=shared_net,
|
| 423 |
+
critic_network=critic_net,
|
| 424 |
+
actor_network=actor_net,
|
| 425 |
+
action_min_max=action_min_max,
|
| 426 |
+
camera_name=camera_name,
|
| 427 |
+
alpha=alpha,
|
| 428 |
+
alpha_lr=alpha_lr,
|
| 429 |
+
alpha_auto_tune=alpha_auto_tune,
|
| 430 |
+
critic_lr=critic_lr,
|
| 431 |
+
actor_lr=actor_lr,
|
| 432 |
+
critic_weight_decay=next_best_pose_critic_weight_decay,
|
| 433 |
+
actor_weight_decay=next_best_pose_actor_weight_decay,
|
| 434 |
+
crop_shape=crop_shape,
|
| 435 |
+
critic_tau=next_best_pose_tau,
|
| 436 |
+
critic_grad_clip=next_best_pose_critic_grad_clip,
|
| 437 |
+
actor_grad_clip=next_best_pose_actor_grad_clip,
|
| 438 |
+
q_conf=q_conf,
|
| 439 |
+
)
|
| 440 |
+
|
| 441 |
+
return PreprocessAgent(pose_agent=next_best_pose_agent)
|
external/peract_bimanual/agents/arm/next_best_pose_agent.py
ADDED
|
@@ -0,0 +1,526 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import copy
|
| 2 |
+
import logging
|
| 3 |
+
import os
|
| 4 |
+
from typing import List
|
| 5 |
+
|
| 6 |
+
import numpy as np
|
| 7 |
+
import torch
|
| 8 |
+
import torch.nn as nn
|
| 9 |
+
import torch.nn.functional as F
|
| 10 |
+
from yarr.agents.agent import (
|
| 11 |
+
Agent,
|
| 12 |
+
Summary,
|
| 13 |
+
ActResult,
|
| 14 |
+
ScalarSummary,
|
| 15 |
+
ImageSummary,
|
| 16 |
+
HistogramSummary,
|
| 17 |
+
)
|
| 18 |
+
|
| 19 |
+
from helpers import utils
|
| 20 |
+
from helpers.utils import stack_on_channel
|
| 21 |
+
from agents.arm.qattention_agent import QAttentionAgent
|
| 22 |
+
|
| 23 |
+
NAME = "NextBestPoseAgent"
|
| 24 |
+
LOG_STD_MAX = 4
|
| 25 |
+
LOG_STD_MIN = -40
|
| 26 |
+
REPLAY_ALPHA = 0.7
|
| 27 |
+
REPLAY_BETA = 0.5
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
class QFunction(nn.Module):
|
| 31 |
+
def __init__(self, critic: nn.Module, shared: nn.Module, q_conf: bool):
|
| 32 |
+
super(QFunction, self).__init__()
|
| 33 |
+
self._q_conf = q_conf
|
| 34 |
+
self._q1 = copy.deepcopy(critic)
|
| 35 |
+
self._q2 = copy.deepcopy(critic)
|
| 36 |
+
self.shared = copy.deepcopy(shared)
|
| 37 |
+
self._q1.build()
|
| 38 |
+
self._q2.build()
|
| 39 |
+
self.shared.build()
|
| 40 |
+
|
| 41 |
+
def forward(self, observations, robot_state, action):
|
| 42 |
+
obs_feats = self.shared(observations)
|
| 43 |
+
combined = torch.cat([robot_state, action.float()], dim=1)
|
| 44 |
+
q1 = self._q1(obs_feats, combined)
|
| 45 |
+
q2 = self._q2(obs_feats, combined)
|
| 46 |
+
if self._q_conf:
|
| 47 |
+
b = q1.shape[0]
|
| 48 |
+
q1 = q1.view(b, 2, -1)
|
| 49 |
+
q2 = q2.view(b, 2, -1)
|
| 50 |
+
q1v, q1c = q1[:, 0], q1[:, 1]
|
| 51 |
+
q1_best = q1v.gather(1, q1c.argmax(dim=1).unsqueeze(-1))
|
| 52 |
+
q2v, q2c = q2[:, 0], q2[:, 1]
|
| 53 |
+
q2_best = q2v.gather(1, q2c.argmax(dim=1).unsqueeze(-1))
|
| 54 |
+
return q1, q2, q1_best, q2_best
|
| 55 |
+
else:
|
| 56 |
+
q1, q2 = q1.unsqueeze(1), q2.unsqueeze(1)
|
| 57 |
+
return q1, q2, q1, q2
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
class Actor(nn.Module):
|
| 61 |
+
def __init__(self, actor_network: nn.Module, action_min_max: torch.tensor):
|
| 62 |
+
super(Actor, self).__init__()
|
| 63 |
+
self._action_min_max = action_min_max
|
| 64 |
+
self._actor_network = copy.deepcopy(actor_network)
|
| 65 |
+
self._actor_network.build()
|
| 66 |
+
|
| 67 |
+
def _rescale_actions(self, x):
|
| 68 |
+
return (
|
| 69 |
+
0.5 * (x + 1.0) * (self._action_min_max[1] - self._action_min_max[0])
|
| 70 |
+
+ self._action_min_max[0]
|
| 71 |
+
)
|
| 72 |
+
|
| 73 |
+
def _normalize(self, x):
|
| 74 |
+
return x / x.square().sum(dim=1).sqrt().unsqueeze(-1)
|
| 75 |
+
|
| 76 |
+
def _gaussian_logprob(self, noise, log_std):
|
| 77 |
+
residual = (-0.5 * noise.pow(2) - log_std).sum(-1, keepdim=True)
|
| 78 |
+
return residual - 0.5 * np.log(2 * np.pi) * noise.size(-1)
|
| 79 |
+
|
| 80 |
+
def forward(self, observations, robot_state):
|
| 81 |
+
mu_and_logstd = self._actor_network(observations, robot_state)
|
| 82 |
+
mu, log_std = torch.split(mu_and_logstd, 8, dim=1)
|
| 83 |
+
log_std = torch.clamp(log_std, LOG_STD_MIN, LOG_STD_MAX)
|
| 84 |
+
|
| 85 |
+
std = log_std.exp()
|
| 86 |
+
noise = torch.randn_like(mu)
|
| 87 |
+
pi = mu + noise * std
|
| 88 |
+
log_pi = self._gaussian_logprob(noise, log_std)
|
| 89 |
+
mu = torch.tanh(mu)
|
| 90 |
+
pi = torch.tanh(pi)
|
| 91 |
+
log_pi -= torch.log(F.relu(1 - pi.pow(2)) + 1e-6).sum(-1, keepdim=True)
|
| 92 |
+
|
| 93 |
+
pi = self._rescale_actions(pi)
|
| 94 |
+
mu = self._rescale_actions(mu)
|
| 95 |
+
|
| 96 |
+
pi = torch.cat([pi[:, :3], self._normalize(pi[:, 3:7]), pi[:, 7:]], dim=-1)
|
| 97 |
+
mu = torch.cat([mu[:, :3], self._normalize(mu[:, 3:7]), mu[:, 7:]], dim=-1)
|
| 98 |
+
return mu, pi, log_pi, log_std
|
| 99 |
+
|
| 100 |
+
|
| 101 |
+
class NextBestPoseAgent(Agent):
|
| 102 |
+
def __init__(
|
| 103 |
+
self,
|
| 104 |
+
qattention_agent: QAttentionAgent,
|
| 105 |
+
shared_network: nn.Module,
|
| 106 |
+
critic_network: nn.Module,
|
| 107 |
+
actor_network: nn.Module,
|
| 108 |
+
action_min_max: tuple,
|
| 109 |
+
camera_name: str,
|
| 110 |
+
alpha: float = 0.2,
|
| 111 |
+
alpha_auto_tune: bool = True,
|
| 112 |
+
alpha_lr: float = 0.001,
|
| 113 |
+
critic_lr: float = 0.01,
|
| 114 |
+
actor_lr: float = 0.01,
|
| 115 |
+
critic_weight_decay: float = 1e-5,
|
| 116 |
+
actor_weight_decay: float = 1e-5,
|
| 117 |
+
crop_shape: tuple = (16, 16),
|
| 118 |
+
critic_tau: float = 0.005,
|
| 119 |
+
critic_grad_clip: float = 20.0,
|
| 120 |
+
actor_grad_clip: float = 20.0,
|
| 121 |
+
gamma: float = 0.99,
|
| 122 |
+
nstep: int = 1,
|
| 123 |
+
q_conf: bool = True,
|
| 124 |
+
):
|
| 125 |
+
self._qattention_agent = qattention_agent
|
| 126 |
+
self._alpha = alpha
|
| 127 |
+
self._alpha_auto_tune = alpha_auto_tune
|
| 128 |
+
self._crop_shape = crop_shape
|
| 129 |
+
self._critic_tau = critic_tau
|
| 130 |
+
self._critic_grad_clip = critic_grad_clip
|
| 131 |
+
self._actor_grad_clip = actor_grad_clip
|
| 132 |
+
self._camera_name = camera_name
|
| 133 |
+
self._gamma = gamma
|
| 134 |
+
self._nstep = nstep
|
| 135 |
+
self._target_entropy = -8
|
| 136 |
+
self._shared_network = shared_network
|
| 137 |
+
self._critic_network = critic_network
|
| 138 |
+
self._actor_network = actor_network
|
| 139 |
+
self._action_min_max = action_min_max
|
| 140 |
+
self._critic_lr = critic_lr
|
| 141 |
+
self._actor_lr = actor_lr
|
| 142 |
+
self._alpha_lr = alpha_lr
|
| 143 |
+
self._critic_weight_decay = critic_weight_decay
|
| 144 |
+
self._actor_weight_decay = actor_weight_decay
|
| 145 |
+
self._q_conf = q_conf
|
| 146 |
+
self._crop_augmentation = False
|
| 147 |
+
|
| 148 |
+
def build(self, training: bool, device: torch.device = None):
|
| 149 |
+
if device is None:
|
| 150 |
+
device = torch.device("cpu")
|
| 151 |
+
self._qattention_agent.build(training, device)
|
| 152 |
+
action_min_max = torch.tensor(self._action_min_max).to(device)
|
| 153 |
+
self._actor = (
|
| 154 |
+
Actor(self._actor_network, action_min_max).to(device).train(training)
|
| 155 |
+
)
|
| 156 |
+
|
| 157 |
+
self._action_min_max_t = torch.tensor(self._action_min_max).to(device)
|
| 158 |
+
|
| 159 |
+
grid_for_crop = (
|
| 160 |
+
torch.arange(0, self._crop_shape[0], device=device)
|
| 161 |
+
.unsqueeze(0)
|
| 162 |
+
.repeat(self._crop_shape[0], 1)
|
| 163 |
+
.unsqueeze(-1)
|
| 164 |
+
)
|
| 165 |
+
self._grid_for_crop = torch.cat(
|
| 166 |
+
[grid_for_crop.transpose(1, 0), grid_for_crop], dim=2
|
| 167 |
+
).unsqueeze(0)
|
| 168 |
+
self._q = (
|
| 169 |
+
QFunction(self._critic_network, self._shared_network, self._q_conf)
|
| 170 |
+
.to(device)
|
| 171 |
+
.train(training)
|
| 172 |
+
)
|
| 173 |
+
if training:
|
| 174 |
+
self._q_target = (
|
| 175 |
+
QFunction(self._critic_network, self._shared_network, self._q_conf)
|
| 176 |
+
.to(device)
|
| 177 |
+
.train(False)
|
| 178 |
+
)
|
| 179 |
+
utils.soft_updates(self._q, self._q_target, 1.0)
|
| 180 |
+
|
| 181 |
+
self._crop_shape_t = torch.tensor(
|
| 182 |
+
[list(self._crop_shape)], dtype=torch.int32, device=device
|
| 183 |
+
)
|
| 184 |
+
|
| 185 |
+
# Freeze target critic.
|
| 186 |
+
for p in self._q_target.parameters():
|
| 187 |
+
p.requires_grad = False
|
| 188 |
+
|
| 189 |
+
self._log_alpha = 0
|
| 190 |
+
if self._alpha_auto_tune:
|
| 191 |
+
self._log_alpha = torch.tensor(
|
| 192 |
+
(np.log(self._alpha)),
|
| 193 |
+
dtype=torch.float,
|
| 194 |
+
requires_grad=True,
|
| 195 |
+
device=device,
|
| 196 |
+
)
|
| 197 |
+
if training:
|
| 198 |
+
self._alpha_optimizer = torch.optim.Adam(
|
| 199 |
+
[self._log_alpha], lr=self._alpha_lr
|
| 200 |
+
)
|
| 201 |
+
|
| 202 |
+
self._critic_optimizer = torch.optim.Adam(
|
| 203 |
+
self._q.parameters(),
|
| 204 |
+
lr=self._critic_lr,
|
| 205 |
+
weight_decay=self._critic_weight_decay,
|
| 206 |
+
)
|
| 207 |
+
self._actor_optimizer = torch.optim.Adam(
|
| 208 |
+
self._actor.parameters(),
|
| 209 |
+
lr=self._actor_lr,
|
| 210 |
+
weight_decay=self._actor_weight_decay,
|
| 211 |
+
)
|
| 212 |
+
|
| 213 |
+
logging.info(
|
| 214 |
+
"# NBP Critic Params: %d"
|
| 215 |
+
% sum(p.numel() for p in self._q.parameters() if p.requires_grad)
|
| 216 |
+
)
|
| 217 |
+
logging.info(
|
| 218 |
+
"# NBP Actor Params: %d"
|
| 219 |
+
% sum(p.numel() for p in self._actor.parameters() if p.requires_grad)
|
| 220 |
+
)
|
| 221 |
+
else:
|
| 222 |
+
for p in self._actor.parameters():
|
| 223 |
+
p.requires_grad = False
|
| 224 |
+
|
| 225 |
+
self._device = device
|
| 226 |
+
|
| 227 |
+
@property
|
| 228 |
+
def alpha(self):
|
| 229 |
+
return self._log_alpha.exp() if self._alpha_auto_tune else self._alpha
|
| 230 |
+
|
| 231 |
+
def _extract_crop(self, pixel_action, observation):
|
| 232 |
+
# Pixel action will now be (B, 2)
|
| 233 |
+
observation = stack_on_channel(observation)
|
| 234 |
+
h = observation.shape[-1]
|
| 235 |
+
top_left_corner = torch.clamp(
|
| 236 |
+
pixel_action - self._crop_shape[0] // 2, 0, h - self._crop_shape[1]
|
| 237 |
+
)
|
| 238 |
+
grid = self._grid_for_crop + top_left_corner.unsqueeze(1).unsqueeze(1)
|
| 239 |
+
grid = ((grid / float(h)) * 2.0) - 1.0
|
| 240 |
+
grid = torch.cat((grid[:, :, :, 1:2], grid[:, :, :, 0:1]), dim=-1)
|
| 241 |
+
crop = F.grid_sample(observation, grid, mode="nearest", align_corners=True)
|
| 242 |
+
return crop
|
| 243 |
+
|
| 244 |
+
def _preprocess_inputs(self, replay_sample, pixel_action, pixel_action_tp1):
|
| 245 |
+
observations = [
|
| 246 |
+
self._extract_crop(
|
| 247 |
+
pixel_action, replay_sample["%s_rgb" % self._camera_name]
|
| 248 |
+
),
|
| 249 |
+
self._extract_crop(
|
| 250 |
+
pixel_action, replay_sample["%s_point_cloud" % self._camera_name]
|
| 251 |
+
),
|
| 252 |
+
]
|
| 253 |
+
tp1_observations = [
|
| 254 |
+
self._extract_crop(
|
| 255 |
+
pixel_action_tp1, replay_sample["%s_rgb_tp1" % self._camera_name]
|
| 256 |
+
),
|
| 257 |
+
self._extract_crop(
|
| 258 |
+
pixel_action_tp1,
|
| 259 |
+
replay_sample["%s_point_cloud_tp1" % self._camera_name],
|
| 260 |
+
),
|
| 261 |
+
]
|
| 262 |
+
return observations, tp1_observations
|
| 263 |
+
|
| 264 |
+
def _clip_action(self, a):
|
| 265 |
+
return torch.min(
|
| 266 |
+
torch.max(a, self._action_min_max_t[0:1]), self._action_min_max_t[1:2]
|
| 267 |
+
)
|
| 268 |
+
|
| 269 |
+
def _update_critic(self, replay_sample: dict) -> None:
|
| 270 |
+
action = replay_sample["action"]
|
| 271 |
+
reward = replay_sample["reward"]
|
| 272 |
+
|
| 273 |
+
robot_state = stack_on_channel(replay_sample["low_dim_state"][:, -1:])
|
| 274 |
+
robot_state_tp1 = stack_on_channel(replay_sample["low_dim_state_tp1"][:, -1:])
|
| 275 |
+
|
| 276 |
+
# Get last of time stack and first of plan stack
|
| 277 |
+
pixel_action = replay_sample["%s_pixel_coord" % self._camera_name][:, -1]
|
| 278 |
+
pixel_action_tp1 = replay_sample["%s_pixel_coord_tp1" % self._camera_name][
|
| 279 |
+
:, -1
|
| 280 |
+
]
|
| 281 |
+
|
| 282 |
+
if self._crop_augmentation:
|
| 283 |
+
shifted = (
|
| 284 |
+
(torch.rand_like(pixel_action.float()) * self._crop_shape_t).int()
|
| 285 |
+
- self._crop_shape_t // 2
|
| 286 |
+
) * replay_sample["demo"].int().unsqueeze(1)
|
| 287 |
+
pixel_action += shifted
|
| 288 |
+
pixel_action_tp1 += shifted
|
| 289 |
+
|
| 290 |
+
# Don't want timeouts to be classed as terminals
|
| 291 |
+
terminal = replay_sample["terminal"].float() - replay_sample["timeout"].float()
|
| 292 |
+
|
| 293 |
+
observations, tp1_observations = self._preprocess_inputs(
|
| 294 |
+
replay_sample, pixel_action, pixel_action_tp1
|
| 295 |
+
)
|
| 296 |
+
|
| 297 |
+
q1, q2, _, _ = self._q(observations, robot_state, action)
|
| 298 |
+
|
| 299 |
+
with torch.no_grad():
|
| 300 |
+
obs_feats = self._q.shared(tp1_observations)
|
| 301 |
+
_, pi_tp1, logp_pi_tp1, _ = self._actor(obs_feats, robot_state_tp1)
|
| 302 |
+
|
| 303 |
+
q1_pi_tp1_targ, q2_pi_tp1_targ, _, _ = self._q_target(
|
| 304 |
+
tp1_observations, robot_state_tp1, pi_tp1
|
| 305 |
+
)
|
| 306 |
+
|
| 307 |
+
min_q_pi_targ = torch.min(q1_pi_tp1_targ[:, 0], q2_pi_tp1_targ[:, 0])
|
| 308 |
+
next_value = min_q_pi_targ - self.alpha * logp_pi_tp1
|
| 309 |
+
q_backup = (
|
| 310 |
+
reward.unsqueeze(-1)
|
| 311 |
+
+ (self._gamma**self._nstep)
|
| 312 |
+
* (1.0 - terminal.unsqueeze(-1))
|
| 313 |
+
* next_value
|
| 314 |
+
)
|
| 315 |
+
|
| 316 |
+
loss_weights = utils.loss_weights(replay_sample, REPLAY_BETA)
|
| 317 |
+
|
| 318 |
+
self._critic_summaries = {}
|
| 319 |
+
if self._q_conf:
|
| 320 |
+
w = 1.0
|
| 321 |
+
q1_delta = (
|
| 322 |
+
F.smooth_l1_loss(q1[:, 0], q_backup, reduction="none") * q1[:, 1]
|
| 323 |
+
- w * q1[:, 1].log()
|
| 324 |
+
)
|
| 325 |
+
q2_delta = (
|
| 326 |
+
F.smooth_l1_loss(q2[:, 0], q_backup, reduction="none") * q2[:, 1]
|
| 327 |
+
- w * q2[:, 1].log()
|
| 328 |
+
)
|
| 329 |
+
self._critic_summaries = {
|
| 330 |
+
"q_conf_loss": -(w * q1[:, 1].log()).mean(),
|
| 331 |
+
"q_conf_mean": q1[:, 1].mean(),
|
| 332 |
+
}
|
| 333 |
+
else:
|
| 334 |
+
q1_delta = F.smooth_l1_loss(q1[:, 0], q_backup, reduction="none")
|
| 335 |
+
q2_delta = F.smooth_l1_loss(q2[:, 0], q_backup, reduction="none")
|
| 336 |
+
|
| 337 |
+
q1_delta, q2_delta = q1_delta.mean(1), q2_delta.mean(1)
|
| 338 |
+
q1_bellman_loss = (q1_delta * loss_weights).mean()
|
| 339 |
+
q2_bellman_loss = (q2_delta * loss_weights).mean()
|
| 340 |
+
|
| 341 |
+
critic_loss = q1_bellman_loss + q2_bellman_loss
|
| 342 |
+
|
| 343 |
+
self._critic_summaries.update(
|
| 344 |
+
{
|
| 345 |
+
"q1_bellman_loss": q1_bellman_loss,
|
| 346 |
+
"q2_bellman_loss": q2_bellman_loss,
|
| 347 |
+
"q1_mean": q1[:, 0].mean().item(),
|
| 348 |
+
"q2_mean": q2[:, 0].mean().item(),
|
| 349 |
+
"alpha": self.alpha,
|
| 350 |
+
}
|
| 351 |
+
)
|
| 352 |
+
self._crop_summary = observations
|
| 353 |
+
self._crop_summary_tp1 = tp1_observations
|
| 354 |
+
|
| 355 |
+
new_pri = torch.sqrt((q1_delta + q2_delta) / 2.0 + 1e-10)
|
| 356 |
+
self._new_priority = (new_pri / torch.max(new_pri)).detach()
|
| 357 |
+
self._grad_step(
|
| 358 |
+
critic_loss,
|
| 359 |
+
self._critic_optimizer,
|
| 360 |
+
self._q.parameters(),
|
| 361 |
+
self._critic_grad_clip,
|
| 362 |
+
)
|
| 363 |
+
|
| 364 |
+
def _update_actor(self, replay_sample: dict) -> None:
|
| 365 |
+
robot_state = stack_on_channel(replay_sample["low_dim_state"][:, -1:])
|
| 366 |
+
pixel_action = replay_sample["%s_pixel_coord" % self._camera_name][:, -1]
|
| 367 |
+
|
| 368 |
+
if self._crop_augmentation:
|
| 369 |
+
shifted = (
|
| 370 |
+
(torch.rand_like(pixel_action.float()) * self._crop_shape_t).int()
|
| 371 |
+
- self._crop_shape_t // 2
|
| 372 |
+
) * replay_sample["demo"].int().unsqueeze(1)
|
| 373 |
+
pixel_action += shifted
|
| 374 |
+
|
| 375 |
+
# Crop the observations
|
| 376 |
+
observations = [
|
| 377 |
+
self._extract_crop(
|
| 378 |
+
pixel_action, replay_sample["%s_rgb" % self._camera_name]
|
| 379 |
+
),
|
| 380 |
+
self._extract_crop(
|
| 381 |
+
pixel_action, replay_sample["%s_point_cloud" % self._camera_name]
|
| 382 |
+
),
|
| 383 |
+
]
|
| 384 |
+
|
| 385 |
+
with torch.no_grad():
|
| 386 |
+
obs_feats = self._q.shared(observations)
|
| 387 |
+
|
| 388 |
+
mu, pi, self._logp_pi, log_scale_diag = self._actor(obs_feats, robot_state)
|
| 389 |
+
|
| 390 |
+
_, _, q1_pi, q2_pi = self._q(observations, robot_state, pi)
|
| 391 |
+
|
| 392 |
+
min_q_pi = torch.min(q1_pi, q2_pi)[:, 0]
|
| 393 |
+
pi_loss = self.alpha * self._logp_pi - min_q_pi
|
| 394 |
+
|
| 395 |
+
loss_weights = utils.loss_weights(replay_sample, REPLAY_BETA)
|
| 396 |
+
pi_loss = (pi_loss * loss_weights).mean()
|
| 397 |
+
|
| 398 |
+
self._actor_summaries = {
|
| 399 |
+
"pi/loss": pi_loss,
|
| 400 |
+
"pi/q1_pi_mean": q1_pi.mean(),
|
| 401 |
+
"pi/q2_pi_mean": q2_pi.mean(),
|
| 402 |
+
"pi/mu": mu.mean(),
|
| 403 |
+
"pi/pi": pi.mean(),
|
| 404 |
+
"pi/log_pi": self._logp_pi.mean(),
|
| 405 |
+
"pi/log_scale_diag": log_scale_diag.mean(),
|
| 406 |
+
}
|
| 407 |
+
self._grad_step(
|
| 408 |
+
pi_loss,
|
| 409 |
+
self._actor_optimizer,
|
| 410 |
+
self._actor.parameters(),
|
| 411 |
+
self._actor_grad_clip,
|
| 412 |
+
)
|
| 413 |
+
|
| 414 |
+
def _update_alpha(self):
|
| 415 |
+
alpha_loss = -(
|
| 416 |
+
self.alpha * (self._logp_pi + self._target_entropy).detach()
|
| 417 |
+
).mean()
|
| 418 |
+
self._grad_step(alpha_loss, self._alpha_optimizer)
|
| 419 |
+
|
| 420 |
+
def _grad_step(self, loss, opt, model_params=None, clip=None):
|
| 421 |
+
opt.zero_grad()
|
| 422 |
+
loss.backward()
|
| 423 |
+
if clip is not None and model_params is not None:
|
| 424 |
+
nn.utils.clip_grad_value_(model_params, clip)
|
| 425 |
+
opt.step()
|
| 426 |
+
|
| 427 |
+
def update(self, step: int, replay_sample: dict) -> dict:
|
| 428 |
+
info = self._qattention_agent.update(step, replay_sample)
|
| 429 |
+
|
| 430 |
+
self._update_critic(replay_sample)
|
| 431 |
+
|
| 432 |
+
# Freeze critic so you don't waste computational effort
|
| 433 |
+
# computing gradients for them during the policy learning step.
|
| 434 |
+
for p in self._q.parameters():
|
| 435 |
+
p.requires_grad = False
|
| 436 |
+
|
| 437 |
+
self._update_actor(replay_sample)
|
| 438 |
+
if self._alpha_auto_tune:
|
| 439 |
+
self._update_alpha()
|
| 440 |
+
|
| 441 |
+
# UnFreeze critic.
|
| 442 |
+
for p in self._q.parameters():
|
| 443 |
+
p.requires_grad = True
|
| 444 |
+
|
| 445 |
+
utils.soft_updates(self._q, self._q_target, self._critic_tau)
|
| 446 |
+
pixel_agent_priority = info["priority"]
|
| 447 |
+
return {
|
| 448 |
+
"priority": ((self._new_priority + pixel_agent_priority) / 2.0)
|
| 449 |
+
** REPLAY_ALPHA
|
| 450 |
+
}
|
| 451 |
+
|
| 452 |
+
def act(self, step: int, observation: dict, deterministic=False) -> ActResult:
|
| 453 |
+
with torch.no_grad():
|
| 454 |
+
act_res = self._qattention_agent.act(step, observation, deterministic)
|
| 455 |
+
observations = [
|
| 456 |
+
self._extract_crop(
|
| 457 |
+
act_res.action.unsqueeze(0),
|
| 458 |
+
observation["%s_rgb" % self._camera_name],
|
| 459 |
+
),
|
| 460 |
+
self._extract_crop(
|
| 461 |
+
act_res.action.unsqueeze(0),
|
| 462 |
+
observation["%s_point_cloud" % self._camera_name],
|
| 463 |
+
),
|
| 464 |
+
]
|
| 465 |
+
self._act_crop_summaries = observations
|
| 466 |
+
robot_state = stack_on_channel(observation["low_dim_state"][:, -1:])
|
| 467 |
+
obs_feats = self._q.shared(observations)
|
| 468 |
+
mu, pi, _, _ = self._actor(obs_feats, robot_state)
|
| 469 |
+
act_res.action = (mu if deterministic else pi)[0]
|
| 470 |
+
act_res.info.update({"rgb_crop": observations[0]})
|
| 471 |
+
return act_res
|
| 472 |
+
|
| 473 |
+
def update_summaries(self) -> List[Summary]:
|
| 474 |
+
summaries = [
|
| 475 |
+
ImageSummary("%s/crops/rgb" % NAME, (self._crop_summary[0] + 1.0) / 2.0),
|
| 476 |
+
ImageSummary("%s/crops/point_cloud" % NAME, self._crop_summary[1]),
|
| 477 |
+
ImageSummary(
|
| 478 |
+
"%s/crops_tp1/rgb" % NAME, (self._crop_summary_tp1[0] + 1.0) / 2.0
|
| 479 |
+
),
|
| 480 |
+
ImageSummary("%s/crops_tp1/point_cloud" % NAME, self._crop_summary_tp1[1]),
|
| 481 |
+
]
|
| 482 |
+
|
| 483 |
+
for n, v in list(self._critic_summaries.items()) + list(
|
| 484 |
+
self._actor_summaries.items()
|
| 485 |
+
):
|
| 486 |
+
summaries.append(ScalarSummary("%s/%s" % (NAME, n), v))
|
| 487 |
+
|
| 488 |
+
for tag, param in list(self._q.named_parameters()) + list(
|
| 489 |
+
self._actor.named_parameters()
|
| 490 |
+
):
|
| 491 |
+
summaries.append(
|
| 492 |
+
HistogramSummary("%s/gradient/%s" % (NAME, tag), param.grad)
|
| 493 |
+
)
|
| 494 |
+
summaries.append(HistogramSummary("%s/weight/%s" % (NAME, tag), param.data))
|
| 495 |
+
|
| 496 |
+
pixel_summaries = self._qattention_agent.update_summaries()
|
| 497 |
+
return pixel_summaries + summaries
|
| 498 |
+
|
| 499 |
+
def act_summaries(self) -> List[Summary]:
|
| 500 |
+
summaries = [
|
| 501 |
+
ImageSummary(
|
| 502 |
+
"%s/crops/act/rgb" % NAME, (self._act_crop_summaries[0] + 1.0) / 2.0
|
| 503 |
+
),
|
| 504 |
+
ImageSummary(
|
| 505 |
+
"%s/crops/act/point_cloud" % NAME, self._act_crop_summaries[1]
|
| 506 |
+
),
|
| 507 |
+
]
|
| 508 |
+
return summaries + self._qattention_agent.act_summaries()
|
| 509 |
+
|
| 510 |
+
def load_weights(self, savedir: str):
|
| 511 |
+
self._qattention_agent.load_weights(savedir)
|
| 512 |
+
self._actor.load_state_dict(
|
| 513 |
+
torch.load(
|
| 514 |
+
os.path.join(savedir, "pose_actor.pt"), map_location=torch.device("cpu")
|
| 515 |
+
)
|
| 516 |
+
)
|
| 517 |
+
self._q.load_state_dict(
|
| 518 |
+
torch.load(
|
| 519 |
+
os.path.join(savedir, "pose_q.pt"), map_location=torch.device("cpu")
|
| 520 |
+
)
|
| 521 |
+
)
|
| 522 |
+
|
| 523 |
+
def save_weights(self, savedir: str):
|
| 524 |
+
self._qattention_agent.save_weights(savedir)
|
| 525 |
+
torch.save(self._actor.state_dict(), os.path.join(savedir, "pose_actor.pt"))
|
| 526 |
+
torch.save(self._q.state_dict(), os.path.join(savedir, "pose_q.pt"))
|
external/peract_bimanual/agents/arm/qattention_agent.py
ADDED
|
@@ -0,0 +1,247 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import copy
|
| 2 |
+
import logging
|
| 3 |
+
import os
|
| 4 |
+
from typing import List
|
| 5 |
+
|
| 6 |
+
import PIL
|
| 7 |
+
import torch
|
| 8 |
+
import torch.nn as nn
|
| 9 |
+
import torch.nn.functional as F
|
| 10 |
+
from torchvision import transforms
|
| 11 |
+
|
| 12 |
+
from yarr.agents.agent import (
|
| 13 |
+
Agent,
|
| 14 |
+
ActResult,
|
| 15 |
+
ScalarSummary,
|
| 16 |
+
HistogramSummary,
|
| 17 |
+
ImageSummary,
|
| 18 |
+
Summary,
|
| 19 |
+
)
|
| 20 |
+
|
| 21 |
+
from helpers import utils
|
| 22 |
+
from helpers.utils import stack_on_channel
|
| 23 |
+
|
| 24 |
+
NAME = "QAttentionAgent"
|
| 25 |
+
REPLAY_BETA = 1.0
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
class QFunction(nn.Module):
|
| 29 |
+
def __init__(self, unet: nn.Module):
|
| 30 |
+
super(QFunction, self).__init__()
|
| 31 |
+
self._qnet = copy.deepcopy(unet)
|
| 32 |
+
self._qnet2 = copy.deepcopy(unet)
|
| 33 |
+
self._qnet.build()
|
| 34 |
+
self._qnet2.build()
|
| 35 |
+
|
| 36 |
+
def _argmax_2d(self, tensor):
|
| 37 |
+
t_shape = tensor.shape
|
| 38 |
+
m = tensor.view(t_shape[0], -1).argmax(1).view(-1, 1)
|
| 39 |
+
indices = torch.cat((m // t_shape[-1], m % t_shape[-1]), dim=1)
|
| 40 |
+
return indices
|
| 41 |
+
|
| 42 |
+
def forward(self, x, robot_state):
|
| 43 |
+
q = self._qnet(x, robot_state)[:, 0]
|
| 44 |
+
q2 = self._qnet2(x, robot_state)[:, 0]
|
| 45 |
+
coords = self._argmax_2d(torch.min(q, q2))
|
| 46 |
+
return q, q2, coords
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
class QAttentionAgent(Agent):
|
| 50 |
+
def __init__(
|
| 51 |
+
self,
|
| 52 |
+
pixel_unet: nn.Module,
|
| 53 |
+
camera_name: str,
|
| 54 |
+
tau: float = 0.005,
|
| 55 |
+
gamma: float = 0.99,
|
| 56 |
+
nstep: int = 1,
|
| 57 |
+
lr: float = 0.0001,
|
| 58 |
+
weight_decay: float = 1e-5,
|
| 59 |
+
lambda_qreg: float = 1e-6,
|
| 60 |
+
grad_clip: float = 20.0,
|
| 61 |
+
include_low_dim_state: bool = False,
|
| 62 |
+
):
|
| 63 |
+
self._pixel_unet = pixel_unet
|
| 64 |
+
self._camera_name = camera_name
|
| 65 |
+
self._tau = tau
|
| 66 |
+
self._gamma = gamma
|
| 67 |
+
self._nstep = nstep
|
| 68 |
+
self._lr = lr
|
| 69 |
+
self._weight_decay = weight_decay
|
| 70 |
+
self._lambda_qreg = lambda_qreg
|
| 71 |
+
self._grad_clip = grad_clip
|
| 72 |
+
self._include_low_dim_state = include_low_dim_state
|
| 73 |
+
|
| 74 |
+
def build(self, training: bool, device: torch.device = None):
|
| 75 |
+
if device is None:
|
| 76 |
+
device = torch.device("cpu")
|
| 77 |
+
self._q = QFunction(self._pixel_unet).to(device).train(training)
|
| 78 |
+
self._q_target = None
|
| 79 |
+
if training:
|
| 80 |
+
self._q_target = QFunction(self._pixel_unet).to(device).train(False)
|
| 81 |
+
for p in self._q_target.parameters():
|
| 82 |
+
p.requires_grad = False
|
| 83 |
+
utils.soft_updates(self._q, self._q_target, 1.0)
|
| 84 |
+
self._optimizer = torch.optim.Adam(
|
| 85 |
+
self._q.parameters(), lr=self._lr, weight_decay=self._weight_decay
|
| 86 |
+
)
|
| 87 |
+
logging.info(
|
| 88 |
+
"# Q-attention Params: %d"
|
| 89 |
+
% sum(p.numel() for p in self._q.parameters() if p.requires_grad)
|
| 90 |
+
)
|
| 91 |
+
else:
|
| 92 |
+
for p in self._q.parameters():
|
| 93 |
+
p.requires_grad = False
|
| 94 |
+
self._device = device
|
| 95 |
+
|
| 96 |
+
def _get_q_from_pixel_coord(self, q, coord):
|
| 97 |
+
b, h, w = q.shape
|
| 98 |
+
flat_indicies = (coord[:, 0] * w + coord[:, 1])[:, None].long()
|
| 99 |
+
return q.view(b, h * w).gather(1, flat_indicies)
|
| 100 |
+
|
| 101 |
+
def _preprocess_inputs(self, replay_sample):
|
| 102 |
+
observations = [
|
| 103 |
+
stack_on_channel(replay_sample["%s_rgb" % self._camera_name]),
|
| 104 |
+
stack_on_channel(replay_sample["%s_point_cloud" % self._camera_name]),
|
| 105 |
+
]
|
| 106 |
+
tp1_observations = [
|
| 107 |
+
stack_on_channel(replay_sample["%s_rgb_tp1" % self._camera_name]),
|
| 108 |
+
stack_on_channel(replay_sample["%s_point_cloud_tp1" % self._camera_name]),
|
| 109 |
+
]
|
| 110 |
+
return observations, tp1_observations
|
| 111 |
+
|
| 112 |
+
def update(self, step: int, replay_sample: dict) -> dict:
|
| 113 |
+
pixel_action = replay_sample["%s_pixel_coord" % self._camera_name][:, -1].int()
|
| 114 |
+
reward = replay_sample["reward"]
|
| 115 |
+
reward = torch.where(reward > 0, reward, torch.zeros_like(reward))
|
| 116 |
+
|
| 117 |
+
robot_state = robot_state_tp1 = None
|
| 118 |
+
if self._include_low_dim_state:
|
| 119 |
+
robot_state = stack_on_channel(replay_sample["low_dim_state"])
|
| 120 |
+
robot_state_tp1 = stack_on_channel(replay_sample["low_dim_state_tp1"])
|
| 121 |
+
|
| 122 |
+
# Don't want timeouts to be classed as terminals
|
| 123 |
+
terminal = replay_sample["terminal"].float() - replay_sample["timeout"].float()
|
| 124 |
+
|
| 125 |
+
obs, obs_tp1 = self._preprocess_inputs(replay_sample)
|
| 126 |
+
q, q2, coords = self._q(obs, robot_state)
|
| 127 |
+
|
| 128 |
+
with torch.no_grad():
|
| 129 |
+
# (B, h, w)
|
| 130 |
+
_, _, coords_tp1 = self._q(obs_tp1, robot_state_tp1)
|
| 131 |
+
q_tp1_targ, q2_tp1_targ, _ = self._q_target(obs_tp1, robot_state_tp1)
|
| 132 |
+
q_tp1_targ = torch.min(q_tp1_targ, q2_tp1_targ)
|
| 133 |
+
q_tp1_targ = self._get_q_from_pixel_coord(q_tp1_targ, coords_tp1)
|
| 134 |
+
target = (
|
| 135 |
+
reward.unsqueeze(1)
|
| 136 |
+
+ (self._gamma**self._nstep)
|
| 137 |
+
* (1 - terminal.unsqueeze(1))
|
| 138 |
+
* q_tp1_targ
|
| 139 |
+
)
|
| 140 |
+
target = torch.clamp(target, 0.0, 100.0)
|
| 141 |
+
|
| 142 |
+
q_pred = self._get_q_from_pixel_coord(q, pixel_action)
|
| 143 |
+
delta = F.smooth_l1_loss(q_pred, target, reduction="none").mean(1)
|
| 144 |
+
|
| 145 |
+
delta += F.smooth_l1_loss(
|
| 146 |
+
self._get_q_from_pixel_coord(q2, pixel_action), target, reduction="none"
|
| 147 |
+
).mean(1)
|
| 148 |
+
q_reg = (
|
| 149 |
+
(0.5 * torch.sum(q**2)) + (0.5 * torch.sum(q2**2))
|
| 150 |
+
) * self._lambda_qreg
|
| 151 |
+
|
| 152 |
+
loss_weights = utils.loss_weights(replay_sample, REPLAY_BETA)
|
| 153 |
+
total_loss = ((delta) * loss_weights).mean() + q_reg
|
| 154 |
+
new_priority = ((delta) + 1e-10).sqrt()
|
| 155 |
+
new_priority /= new_priority.max()
|
| 156 |
+
|
| 157 |
+
self._summaries = {
|
| 158 |
+
"losses/bellman": delta.mean(),
|
| 159 |
+
"losses/qreg": q_reg.mean(),
|
| 160 |
+
"q/mean": q.mean(),
|
| 161 |
+
"q/action_q": q_pred.mean(),
|
| 162 |
+
}
|
| 163 |
+
self._qvalues = q[:1]
|
| 164 |
+
self._rgb_observation = replay_sample["front_rgb"][0, -1]
|
| 165 |
+
self._optimizer.zero_grad()
|
| 166 |
+
total_loss.backward()
|
| 167 |
+
if self._grad_clip is not None:
|
| 168 |
+
nn.utils.clip_grad_value_(self._q.parameters(), self._grad_clip)
|
| 169 |
+
self._optimizer.step()
|
| 170 |
+
utils.soft_updates(self._q, self._q_target, self._tau)
|
| 171 |
+
|
| 172 |
+
return {
|
| 173 |
+
"priority": new_priority,
|
| 174 |
+
}
|
| 175 |
+
|
| 176 |
+
def act(self, step: int, observation: dict, deterministic=False) -> ActResult:
|
| 177 |
+
with torch.no_grad():
|
| 178 |
+
observations = [
|
| 179 |
+
stack_on_channel(observation["%s_rgb" % self._camera_name]),
|
| 180 |
+
stack_on_channel(observation["%s_point_cloud" % self._camera_name]),
|
| 181 |
+
]
|
| 182 |
+
robot_state = None
|
| 183 |
+
if self._include_low_dim_state:
|
| 184 |
+
robot_state = stack_on_channel(observation["low_dim_state"])
|
| 185 |
+
# Coords are stored as (y, x)
|
| 186 |
+
q, q2, coords = self._q(observations, robot_state)
|
| 187 |
+
self._act_qvalues = torch.min(q, q2)[:1]
|
| 188 |
+
self._rgb_observation = observation["front_rgb"][0, -1]
|
| 189 |
+
return ActResult(
|
| 190 |
+
coords[0],
|
| 191 |
+
observation_elements={
|
| 192 |
+
"%s_pixel_coord" % self._camera_name: coords[0],
|
| 193 |
+
},
|
| 194 |
+
info={"q_values": self._act_qvalues},
|
| 195 |
+
)
|
| 196 |
+
|
| 197 |
+
@staticmethod
|
| 198 |
+
def generate_heatmap(q_values, rgb_obs):
|
| 199 |
+
norm_q = torch.clamp(q_values / 100.0, 0, 1)
|
| 200 |
+
heatmap = torch.cat(
|
| 201 |
+
[norm_q, torch.zeros_like(norm_q), torch.zeros_like(norm_q)]
|
| 202 |
+
)
|
| 203 |
+
img = transforms.functional.to_pil_image(rgb_obs)
|
| 204 |
+
h_img = transforms.functional.to_pil_image(heatmap).convert("RGB")
|
| 205 |
+
ret = PIL.Image.blend(img, h_img, 0.75)
|
| 206 |
+
return transforms.ToTensor()(ret).unsqueeze_(0)
|
| 207 |
+
|
| 208 |
+
def update_summaries(self) -> List[Summary]:
|
| 209 |
+
summaries = [
|
| 210 |
+
ImageSummary(
|
| 211 |
+
"%s/Q" % NAME,
|
| 212 |
+
QAttentionAgent.generate_heatmap(
|
| 213 |
+
self._qvalues.cpu(), ((self._rgb_observation + 1) / 2.0).cpu()
|
| 214 |
+
),
|
| 215 |
+
)
|
| 216 |
+
]
|
| 217 |
+
for n, v in self._summaries.items():
|
| 218 |
+
summaries.append(ScalarSummary("%s/%s" % (NAME, n), v))
|
| 219 |
+
|
| 220 |
+
for tag, param in self._q.named_parameters():
|
| 221 |
+
assert not torch.isnan(param.grad.abs() <= 1.0).all()
|
| 222 |
+
summaries.append(
|
| 223 |
+
HistogramSummary("%s/gradient/%s" % (NAME, tag), param.grad)
|
| 224 |
+
)
|
| 225 |
+
summaries.append(HistogramSummary("%s/weight/%s" % (NAME, tag), param.data))
|
| 226 |
+
return summaries
|
| 227 |
+
|
| 228 |
+
def act_summaries(self) -> List[Summary]:
|
| 229 |
+
return [
|
| 230 |
+
ImageSummary(
|
| 231 |
+
"%s/Q_act" % NAME,
|
| 232 |
+
QAttentionAgent.generate_heatmap(
|
| 233 |
+
self._act_qvalues.cpu(), ((self._rgb_observation + 1) / 2.0).cpu()
|
| 234 |
+
),
|
| 235 |
+
)
|
| 236 |
+
]
|
| 237 |
+
|
| 238 |
+
def load_weights(self, savedir: str):
|
| 239 |
+
self._q.load_state_dict(
|
| 240 |
+
torch.load(
|
| 241 |
+
os.path.join(savedir, "pixel_agent_q.pt"),
|
| 242 |
+
map_location=torch.device("cpu"),
|
| 243 |
+
)
|
| 244 |
+
)
|
| 245 |
+
|
| 246 |
+
def save_weights(self, savedir: str):
|
| 247 |
+
torch.save(self._q.state_dict(), os.path.join(savedir, "pixel_agent_q.pt"))
|
external/peract_bimanual/agents/baselines/__init__.py
ADDED
|
File without changes
|
external/peract_bimanual/agents/baselines/bc_lang/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
import agents.baselines.bc_lang.launch_utils
|
external/peract_bimanual/agents/baselines/bc_lang/bc_lang_agent.py
ADDED
|
@@ -0,0 +1,148 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import copy
|
| 2 |
+
import logging
|
| 3 |
+
import os
|
| 4 |
+
from typing import List
|
| 5 |
+
|
| 6 |
+
import torch
|
| 7 |
+
import torch.nn as nn
|
| 8 |
+
import torch.nn.functional as F
|
| 9 |
+
from yarr.agents.agent import Agent, Summary, ActResult, ScalarSummary, HistogramSummary
|
| 10 |
+
|
| 11 |
+
from helpers import utils
|
| 12 |
+
from helpers.utils import stack_on_channel
|
| 13 |
+
|
| 14 |
+
from helpers.clip.core.clip import build_model, load_clip
|
| 15 |
+
|
| 16 |
+
NAME = "BCLangAgent"
|
| 17 |
+
REPLAY_ALPHA = 0.7
|
| 18 |
+
REPLAY_BETA = 1.0
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
class Actor(nn.Module):
|
| 22 |
+
def __init__(self, actor_network: nn.Module):
|
| 23 |
+
super(Actor, self).__init__()
|
| 24 |
+
self._actor_network = copy.deepcopy(actor_network)
|
| 25 |
+
self._actor_network.build()
|
| 26 |
+
|
| 27 |
+
def forward(self, observations, robot_state, lang_goal_emb):
|
| 28 |
+
mu = self._actor_network(observations, robot_state, lang_goal_emb)
|
| 29 |
+
return mu
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
class BCLangAgent(Agent):
|
| 33 |
+
def __init__(
|
| 34 |
+
self,
|
| 35 |
+
actor_network: nn.Module,
|
| 36 |
+
camera_name: str,
|
| 37 |
+
lr: float = 0.01,
|
| 38 |
+
weight_decay: float = 1e-5,
|
| 39 |
+
grad_clip: float = 20.0,
|
| 40 |
+
):
|
| 41 |
+
self._camera_name = camera_name
|
| 42 |
+
self._actor_network = actor_network
|
| 43 |
+
self._lr = lr
|
| 44 |
+
self._weight_decay = weight_decay
|
| 45 |
+
self._grad_clip = grad_clip
|
| 46 |
+
|
| 47 |
+
def build(self, training: bool, device: torch.device = None):
|
| 48 |
+
if device is None:
|
| 49 |
+
device = torch.device("cpu")
|
| 50 |
+
self._actor = Actor(self._actor_network).to(device).train(training)
|
| 51 |
+
if training:
|
| 52 |
+
self._actor_optimizer = torch.optim.Adam(
|
| 53 |
+
self._actor.parameters(), lr=self._lr, weight_decay=self._weight_decay
|
| 54 |
+
)
|
| 55 |
+
logging.info(
|
| 56 |
+
"# Actor Params: %d"
|
| 57 |
+
% sum(p.numel() for p in self._actor.parameters() if p.requires_grad)
|
| 58 |
+
)
|
| 59 |
+
else:
|
| 60 |
+
for p in self._actor.parameters():
|
| 61 |
+
p.requires_grad = False
|
| 62 |
+
|
| 63 |
+
model, _ = load_clip("RN50", jit=False)
|
| 64 |
+
self._clip_rn50 = build_model(model.state_dict())
|
| 65 |
+
self._clip_rn50 = self._clip_rn50.float().to(device)
|
| 66 |
+
self._clip_rn50.eval()
|
| 67 |
+
del model
|
| 68 |
+
|
| 69 |
+
self._device = device
|
| 70 |
+
|
| 71 |
+
def _grad_step(self, loss, opt, model_params=None, clip=None):
|
| 72 |
+
opt.zero_grad()
|
| 73 |
+
loss.backward()
|
| 74 |
+
if clip is not None and model_params is not None:
|
| 75 |
+
nn.utils.clip_grad_value_(model_params, clip)
|
| 76 |
+
opt.step()
|
| 77 |
+
|
| 78 |
+
def update(self, step: int, replay_sample: dict) -> dict:
|
| 79 |
+
lang_goal_emb = replay_sample["lang_goal_emb"]
|
| 80 |
+
robot_state = replay_sample["low_dim_state"]
|
| 81 |
+
observations = [
|
| 82 |
+
replay_sample["%s_rgb" % self._camera_name],
|
| 83 |
+
replay_sample["%s_point_cloud" % self._camera_name],
|
| 84 |
+
]
|
| 85 |
+
mu = self._actor(observations, robot_state, lang_goal_emb)
|
| 86 |
+
loss_weights = utils.loss_weights(replay_sample, REPLAY_BETA)
|
| 87 |
+
delta = F.mse_loss(mu, replay_sample["action"], reduction="none").mean(1)
|
| 88 |
+
loss = (delta * loss_weights).mean()
|
| 89 |
+
self._grad_step(
|
| 90 |
+
loss, self._actor_optimizer, self._actor.parameters(), self._grad_clip
|
| 91 |
+
)
|
| 92 |
+
self._summaries = {
|
| 93 |
+
"pi/loss": loss,
|
| 94 |
+
"pi/mu": mu.mean(),
|
| 95 |
+
}
|
| 96 |
+
return {"total_losses": loss}
|
| 97 |
+
|
| 98 |
+
def _normalize_quat(self, x):
|
| 99 |
+
return x / x.square().sum(dim=1).sqrt().unsqueeze(-1)
|
| 100 |
+
|
| 101 |
+
def act(self, step: int, observation: dict, deterministic=False) -> ActResult:
|
| 102 |
+
lang_goal_tokens = observation.get("lang_goal_tokens", None).long()
|
| 103 |
+
|
| 104 |
+
with torch.no_grad():
|
| 105 |
+
lang_goal_tokens = lang_goal_tokens.to(device=self._device)
|
| 106 |
+
lang_goal_emb, _ = self._clip_rn50.encode_text_with_embeddings(
|
| 107 |
+
lang_goal_tokens[0]
|
| 108 |
+
)
|
| 109 |
+
lang_goal_emb = lang_goal_emb.to(device=self._device)
|
| 110 |
+
|
| 111 |
+
observations = [
|
| 112 |
+
observation["%s_rgb" % self._camera_name][0].to(self._device),
|
| 113 |
+
observation["%s_point_cloud" % self._camera_name][0].to(self._device),
|
| 114 |
+
]
|
| 115 |
+
robot_state = observation["low_dim_state"][0].to(self._device)
|
| 116 |
+
|
| 117 |
+
mu = self._actor(observations, robot_state, lang_goal_emb)
|
| 118 |
+
mu = torch.cat([mu[:, :3], self._normalize_quat(mu[:, 3:7]), mu[:, 7:]], dim=-1)
|
| 119 |
+
ignore_collisions = torch.Tensor([1.0]).to(mu.device)
|
| 120 |
+
mu0 = torch.cat([mu[0], ignore_collisions])
|
| 121 |
+
return ActResult(mu0.detach().cpu())
|
| 122 |
+
|
| 123 |
+
def update_summaries(self) -> List[Summary]:
|
| 124 |
+
summaries = []
|
| 125 |
+
for n, v in self._summaries.items():
|
| 126 |
+
summaries.append(ScalarSummary("%s/%s" % (NAME, n), v))
|
| 127 |
+
|
| 128 |
+
for tag, param in self._actor.named_parameters():
|
| 129 |
+
summaries.append(
|
| 130 |
+
HistogramSummary("%s/gradient/%s" % (NAME, tag), param.grad)
|
| 131 |
+
)
|
| 132 |
+
summaries.append(HistogramSummary("%s/weight/%s" % (NAME, tag), param.data))
|
| 133 |
+
|
| 134 |
+
return summaries
|
| 135 |
+
|
| 136 |
+
def act_summaries(self) -> List[Summary]:
|
| 137 |
+
return []
|
| 138 |
+
|
| 139 |
+
def load_weights(self, savedir: str):
|
| 140 |
+
self._actor.load_state_dict(
|
| 141 |
+
torch.load(
|
| 142 |
+
os.path.join(savedir, "bc_actor.pt"), map_location=torch.device("cpu")
|
| 143 |
+
)
|
| 144 |
+
)
|
| 145 |
+
print("Loaded weights from %s" % savedir)
|
| 146 |
+
|
| 147 |
+
def save_weights(self, savedir: str):
|
| 148 |
+
torch.save(self._actor.state_dict(), os.path.join(savedir, "bc_actor.pt"))
|
external/peract_bimanual/agents/baselines/bc_lang/launch_utils.py
ADDED
|
@@ -0,0 +1,368 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Adapted from ARM
|
| 2 |
+
# Source: https://github.com/stepjam/ARM
|
| 3 |
+
# License: https://github.com/stepjam/ARM/LICENSE
|
| 4 |
+
|
| 5 |
+
import logging
|
| 6 |
+
from typing import List
|
| 7 |
+
|
| 8 |
+
import numpy as np
|
| 9 |
+
from omegaconf import DictConfig
|
| 10 |
+
from rlbench.backend.observation import Observation
|
| 11 |
+
from rlbench.observation_config import ObservationConfig
|
| 12 |
+
import rlbench.utils as rlbench_utils
|
| 13 |
+
from rlbench.demo import Demo
|
| 14 |
+
from yarr.replay_buffer.prioritized_replay_buffer import (
|
| 15 |
+
PrioritizedReplayBuffer,
|
| 16 |
+
ObservationElement,
|
| 17 |
+
)
|
| 18 |
+
from yarr.replay_buffer.replay_buffer import ReplayElement, ReplayBuffer
|
| 19 |
+
from yarr.replay_buffer.uniform_replay_buffer import UniformReplayBuffer
|
| 20 |
+
from yarr.replay_buffer.task_uniform_replay_buffer import TaskUniformReplayBuffer
|
| 21 |
+
|
| 22 |
+
from helpers import demo_loading_utils, utils
|
| 23 |
+
from helpers import observation_utils
|
| 24 |
+
from agents.baselines.bc_lang.bc_lang_agent import BCLangAgent
|
| 25 |
+
from helpers.custom_rlbench_env import CustomRLBenchEnv
|
| 26 |
+
from helpers.network_utils import SiameseNet, CNNLangAndFcsNet
|
| 27 |
+
from helpers.preprocess_agent import PreprocessAgent
|
| 28 |
+
|
| 29 |
+
import torch
|
| 30 |
+
from torch.multiprocessing import Process, Value, Manager
|
| 31 |
+
from helpers.clip.core.clip import build_model, load_clip, tokenize
|
| 32 |
+
|
| 33 |
+
LOW_DIM_SIZE = 4
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
def create_replay(
|
| 37 |
+
batch_size: int,
|
| 38 |
+
timesteps: int,
|
| 39 |
+
prioritisation: bool,
|
| 40 |
+
task_uniform: bool,
|
| 41 |
+
save_dir: str,
|
| 42 |
+
cameras: list,
|
| 43 |
+
image_size=[128, 128],
|
| 44 |
+
replay_size=3e5,
|
| 45 |
+
):
|
| 46 |
+
lang_feat_dim = 1024
|
| 47 |
+
|
| 48 |
+
# low_dim_state
|
| 49 |
+
observation_elements = []
|
| 50 |
+
observation_elements.append(
|
| 51 |
+
ObservationElement("low_dim_state", (LOW_DIM_SIZE,), np.float32)
|
| 52 |
+
)
|
| 53 |
+
|
| 54 |
+
# rgb, depth, point cloud, intrinsics, extrinsics
|
| 55 |
+
for cname in cameras:
|
| 56 |
+
observation_elements.append(
|
| 57 |
+
ObservationElement(
|
| 58 |
+
"%s_rgb" % cname,
|
| 59 |
+
(
|
| 60 |
+
3,
|
| 61 |
+
*image_size,
|
| 62 |
+
),
|
| 63 |
+
np.float32,
|
| 64 |
+
)
|
| 65 |
+
)
|
| 66 |
+
observation_elements.append(
|
| 67 |
+
ObservationElement("%s_point_cloud" % cname, (3, *image_size), np.float32)
|
| 68 |
+
) # see pyrep/objects/vision_sensor.py on how pointclouds are extracted from depth frames
|
| 69 |
+
observation_elements.append(
|
| 70 |
+
ObservationElement(
|
| 71 |
+
"%s_camera_extrinsics" % cname,
|
| 72 |
+
(
|
| 73 |
+
4,
|
| 74 |
+
4,
|
| 75 |
+
),
|
| 76 |
+
np.float32,
|
| 77 |
+
)
|
| 78 |
+
)
|
| 79 |
+
observation_elements.append(
|
| 80 |
+
ObservationElement(
|
| 81 |
+
"%s_camera_intrinsics" % cname,
|
| 82 |
+
(
|
| 83 |
+
3,
|
| 84 |
+
3,
|
| 85 |
+
),
|
| 86 |
+
np.float32,
|
| 87 |
+
)
|
| 88 |
+
)
|
| 89 |
+
|
| 90 |
+
observation_elements.extend(
|
| 91 |
+
[
|
| 92 |
+
ReplayElement("lang_goal_emb", (lang_feat_dim,), np.float32),
|
| 93 |
+
ReplayElement("task", (), str),
|
| 94 |
+
ReplayElement(
|
| 95 |
+
"lang_goal", (1,), object
|
| 96 |
+
), # language goal string for debugging and visualization
|
| 97 |
+
]
|
| 98 |
+
)
|
| 99 |
+
|
| 100 |
+
extra_replay_elements = [
|
| 101 |
+
ReplayElement("demo", (), np.bool),
|
| 102 |
+
]
|
| 103 |
+
|
| 104 |
+
replay_buffer = TaskUniformReplayBuffer(
|
| 105 |
+
save_dir=save_dir,
|
| 106 |
+
batch_size=batch_size,
|
| 107 |
+
timesteps=timesteps,
|
| 108 |
+
replay_capacity=int(replay_size),
|
| 109 |
+
action_shape=(8,),
|
| 110 |
+
action_dtype=np.float32,
|
| 111 |
+
reward_shape=(),
|
| 112 |
+
reward_dtype=np.float32,
|
| 113 |
+
update_horizon=1,
|
| 114 |
+
observation_elements=observation_elements,
|
| 115 |
+
extra_replay_elements=extra_replay_elements,
|
| 116 |
+
)
|
| 117 |
+
return replay_buffer
|
| 118 |
+
|
| 119 |
+
|
| 120 |
+
def _get_action(obs_tp1: Observation):
|
| 121 |
+
quat = utils.normalize_quaternion(obs_tp1.gripper_pose[3:])
|
| 122 |
+
if quat[-1] < 0:
|
| 123 |
+
quat = -quat
|
| 124 |
+
return np.concatenate(
|
| 125 |
+
[obs_tp1.gripper_pose[:3], quat, [float(obs_tp1.gripper_open)]]
|
| 126 |
+
)
|
| 127 |
+
|
| 128 |
+
|
| 129 |
+
def _add_keypoints_to_replay(
|
| 130 |
+
cfg: DictConfig,
|
| 131 |
+
task: str,
|
| 132 |
+
replay: ReplayBuffer,
|
| 133 |
+
inital_obs: Observation,
|
| 134 |
+
demo: Demo,
|
| 135 |
+
episode_keypoints: List[int],
|
| 136 |
+
cameras: List[str],
|
| 137 |
+
description: str = "",
|
| 138 |
+
clip_model=None,
|
| 139 |
+
device="cpu",
|
| 140 |
+
):
|
| 141 |
+
prev_action = None
|
| 142 |
+
obs = inital_obs
|
| 143 |
+
all_actions = []
|
| 144 |
+
for k, keypoint in enumerate(episode_keypoints):
|
| 145 |
+
obs_tp1 = demo[keypoint]
|
| 146 |
+
action = _get_action(obs_tp1)
|
| 147 |
+
all_actions.append(action)
|
| 148 |
+
terminal = k == len(episode_keypoints) - 1
|
| 149 |
+
reward = float(terminal) if terminal else 0
|
| 150 |
+
|
| 151 |
+
obs_dict = observation_utils.extract_obs(
|
| 152 |
+
obs,
|
| 153 |
+
t=k,
|
| 154 |
+
prev_action=prev_action,
|
| 155 |
+
cameras=cameras,
|
| 156 |
+
episode_length=cfg.rlbench.episode_length,
|
| 157 |
+
robot_name=cfg.method.robot_name,
|
| 158 |
+
)
|
| 159 |
+
del obs_dict["ignore_collisions"]
|
| 160 |
+
tokens = tokenize([description]).numpy()
|
| 161 |
+
token_tensor = torch.from_numpy(tokens).to(device)
|
| 162 |
+
lang_feats, lang_embs = clip_model.encode_text_with_embeddings(token_tensor)
|
| 163 |
+
obs_dict["lang_goal_emb"] = lang_feats[0].float().detach().cpu().numpy()
|
| 164 |
+
|
| 165 |
+
final_obs = {
|
| 166 |
+
"task": task,
|
| 167 |
+
"lang_goal": np.array([description], dtype=object),
|
| 168 |
+
}
|
| 169 |
+
|
| 170 |
+
prev_action = np.copy(action)
|
| 171 |
+
others = {"demo": True}
|
| 172 |
+
others.update(final_obs)
|
| 173 |
+
others.update(obs_dict)
|
| 174 |
+
timeout = False
|
| 175 |
+
replay.add(action, reward, terminal, timeout, **others)
|
| 176 |
+
obs = obs_tp1 # Set the next obs
|
| 177 |
+
# Final step
|
| 178 |
+
obs_dict_tp1 = observation_utils.extract_obs(
|
| 179 |
+
obs_tp1,
|
| 180 |
+
t=k + 1,
|
| 181 |
+
prev_action=prev_action,
|
| 182 |
+
cameras=cameras,
|
| 183 |
+
episode_length=cfg.rlbench.episode_length,
|
| 184 |
+
robot_name=cfg.method.robot_name,
|
| 185 |
+
)
|
| 186 |
+
obs_dict_tp1["lang_goal_emb"] = lang_feats[0].float().detach().cpu().numpy()
|
| 187 |
+
# del obs_dict_tp1['lang_goal_tokens']
|
| 188 |
+
del obs_dict_tp1["ignore_collisions"]
|
| 189 |
+
# obs_dict_tp1['task'] = task
|
| 190 |
+
obs_dict_tp1.update(final_obs)
|
| 191 |
+
replay.add_final(**obs_dict_tp1)
|
| 192 |
+
return all_actions
|
| 193 |
+
|
| 194 |
+
|
| 195 |
+
def fill_replay(
|
| 196 |
+
cfg: DictConfig,
|
| 197 |
+
obs_config: ObservationConfig,
|
| 198 |
+
rank: int,
|
| 199 |
+
replay: ReplayBuffer,
|
| 200 |
+
task: str,
|
| 201 |
+
num_demos: int,
|
| 202 |
+
demo_augmentation: bool,
|
| 203 |
+
demo_augmentation_every_n: int,
|
| 204 |
+
cameras: List[str],
|
| 205 |
+
clip_model=None,
|
| 206 |
+
device="cpu",
|
| 207 |
+
):
|
| 208 |
+
if clip_model is None:
|
| 209 |
+
model, _ = load_clip("RN50", jit=False, device=device)
|
| 210 |
+
clip_model = build_model(model.state_dict())
|
| 211 |
+
clip_model.to(device)
|
| 212 |
+
del model
|
| 213 |
+
|
| 214 |
+
logging.debug("Filling %s replay ..." % task)
|
| 215 |
+
all_actions = []
|
| 216 |
+
for d_idx in range(num_demos):
|
| 217 |
+
# load demo from disk
|
| 218 |
+
demo = rlbench_utils.get_stored_demos(
|
| 219 |
+
amount=1,
|
| 220 |
+
image_paths=False,
|
| 221 |
+
dataset_root=cfg.rlbench.demo_path,
|
| 222 |
+
variation_number=-1,
|
| 223 |
+
task_name=task,
|
| 224 |
+
obs_config=obs_config,
|
| 225 |
+
random_selection=False,
|
| 226 |
+
from_episode_number=d_idx,
|
| 227 |
+
)[0]
|
| 228 |
+
|
| 229 |
+
descs = demo._observations[0].misc["descriptions"]
|
| 230 |
+
|
| 231 |
+
# extract keypoints (a.k.a keyframes)
|
| 232 |
+
episode_keypoints = demo_loading_utils.keypoint_discovery(demo)
|
| 233 |
+
|
| 234 |
+
if rank == 0:
|
| 235 |
+
logging.info(
|
| 236 |
+
f"Loading Demo({d_idx}) - found {len(episode_keypoints)} keypoints - {task}"
|
| 237 |
+
)
|
| 238 |
+
|
| 239 |
+
for i in range(len(demo) - 1):
|
| 240 |
+
if not demo_augmentation and i > 0:
|
| 241 |
+
break
|
| 242 |
+
if i % demo_augmentation_every_n != 0:
|
| 243 |
+
continue
|
| 244 |
+
|
| 245 |
+
obs = demo[i]
|
| 246 |
+
desc = descs[0]
|
| 247 |
+
# if our starting point is past one of the keypoints, then remove it
|
| 248 |
+
while len(episode_keypoints) > 0 and i >= episode_keypoints[0]:
|
| 249 |
+
episode_keypoints = episode_keypoints[1:]
|
| 250 |
+
if len(episode_keypoints) == 0:
|
| 251 |
+
break
|
| 252 |
+
all_actions.extend(
|
| 253 |
+
_add_keypoints_to_replay(
|
| 254 |
+
cfg,
|
| 255 |
+
task,
|
| 256 |
+
replay,
|
| 257 |
+
obs,
|
| 258 |
+
demo,
|
| 259 |
+
episode_keypoints,
|
| 260 |
+
cameras,
|
| 261 |
+
description=desc,
|
| 262 |
+
clip_model=clip_model,
|
| 263 |
+
device=device,
|
| 264 |
+
)
|
| 265 |
+
)
|
| 266 |
+
logging.debug("Replay filled with demos.")
|
| 267 |
+
return all_actions
|
| 268 |
+
|
| 269 |
+
|
| 270 |
+
def fill_multi_task_replay(
|
| 271 |
+
cfg: DictConfig,
|
| 272 |
+
obs_config: ObservationConfig,
|
| 273 |
+
rank: int,
|
| 274 |
+
replay: ReplayBuffer,
|
| 275 |
+
tasks: List[str],
|
| 276 |
+
num_demos: int,
|
| 277 |
+
demo_augmentation: bool,
|
| 278 |
+
demo_augmentation_every_n: int,
|
| 279 |
+
cameras: List[str],
|
| 280 |
+
clip_model=None,
|
| 281 |
+
):
|
| 282 |
+
manager = Manager()
|
| 283 |
+
store = manager.dict()
|
| 284 |
+
|
| 285 |
+
# create a MP dict for storing indicies
|
| 286 |
+
# TODO(mohit): this shouldn't be initialized here
|
| 287 |
+
del replay._task_idxs
|
| 288 |
+
task_idxs = manager.dict()
|
| 289 |
+
replay._task_idxs = task_idxs
|
| 290 |
+
replay._create_storage(store)
|
| 291 |
+
replay.add_count = Value("i", 0)
|
| 292 |
+
|
| 293 |
+
# fill replay buffer in parallel across tasks
|
| 294 |
+
max_parallel_processes = cfg.replay.max_parallel_processes
|
| 295 |
+
processes = []
|
| 296 |
+
n = np.arange(len(tasks))
|
| 297 |
+
split_n = utils.split_list(n, max_parallel_processes)
|
| 298 |
+
for split in split_n:
|
| 299 |
+
for e_idx, task_idx in enumerate(split):
|
| 300 |
+
task = tasks[int(task_idx)]
|
| 301 |
+
model_device = torch.device(
|
| 302 |
+
"cuda:%s" % (e_idx % torch.cuda.device_count())
|
| 303 |
+
if torch.cuda.is_available()
|
| 304 |
+
else "cpu"
|
| 305 |
+
)
|
| 306 |
+
p = Process(
|
| 307 |
+
target=fill_replay,
|
| 308 |
+
args=(
|
| 309 |
+
cfg,
|
| 310 |
+
obs_config,
|
| 311 |
+
rank,
|
| 312 |
+
replay,
|
| 313 |
+
task,
|
| 314 |
+
num_demos,
|
| 315 |
+
demo_augmentation,
|
| 316 |
+
demo_augmentation_every_n,
|
| 317 |
+
cameras,
|
| 318 |
+
clip_model,
|
| 319 |
+
model_device,
|
| 320 |
+
),
|
| 321 |
+
)
|
| 322 |
+
p.start()
|
| 323 |
+
processes.append(p)
|
| 324 |
+
|
| 325 |
+
for p in processes:
|
| 326 |
+
p.join()
|
| 327 |
+
|
| 328 |
+
logging.debug("Replay filled with multi demos.")
|
| 329 |
+
|
| 330 |
+
|
| 331 |
+
def create_agent(cfg: DictConfig):
|
| 332 |
+
camera_name = cfg.rlbench.cameras
|
| 333 |
+
activation = cfg.method.activation
|
| 334 |
+
lr = cfg.method.lr
|
| 335 |
+
weight_decay = cfg.method.weight_decay
|
| 336 |
+
image_resolution = cfg.rlbench.camera_resolution
|
| 337 |
+
grad_clip = cfg.method.grad_clip
|
| 338 |
+
|
| 339 |
+
siamese_net = SiameseNet(
|
| 340 |
+
input_channels=[3, 3],
|
| 341 |
+
filters=[16],
|
| 342 |
+
kernel_sizes=[5],
|
| 343 |
+
strides=[1],
|
| 344 |
+
activation=activation,
|
| 345 |
+
norm=None,
|
| 346 |
+
)
|
| 347 |
+
|
| 348 |
+
actor_net = CNNLangAndFcsNet(
|
| 349 |
+
siamese_net=siamese_net,
|
| 350 |
+
input_resolution=image_resolution,
|
| 351 |
+
filters=[32, 64, 64],
|
| 352 |
+
kernel_sizes=[3, 3, 3],
|
| 353 |
+
strides=[2, 2, 2],
|
| 354 |
+
norm=None,
|
| 355 |
+
activation=activation,
|
| 356 |
+
fc_layers=[128, 64, 3 + 4 + 1],
|
| 357 |
+
low_dim_state_len=LOW_DIM_SIZE,
|
| 358 |
+
)
|
| 359 |
+
|
| 360 |
+
bc_agent = BCLangAgent(
|
| 361 |
+
actor_network=actor_net,
|
| 362 |
+
camera_name=camera_name,
|
| 363 |
+
lr=lr,
|
| 364 |
+
weight_decay=weight_decay,
|
| 365 |
+
grad_clip=grad_clip,
|
| 366 |
+
)
|
| 367 |
+
|
| 368 |
+
return PreprocessAgent(pose_agent=bc_agent)
|
external/peract_bimanual/agents/baselines/vit_bc_lang/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
import agents.baselines.vit_bc_lang.launch_utils
|
external/peract_bimanual/agents/baselines/vit_bc_lang/launch_utils.py
ADDED
|
@@ -0,0 +1,372 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Adapted from ARM
|
| 2 |
+
# Source: https://github.com/stepjam/ARM
|
| 3 |
+
# License: https://github.com/stepjam/ARM/LICENSE
|
| 4 |
+
|
| 5 |
+
import logging
|
| 6 |
+
from typing import List
|
| 7 |
+
|
| 8 |
+
import numpy as np
|
| 9 |
+
from omegaconf import DictConfig
|
| 10 |
+
from rlbench.backend.observation import Observation
|
| 11 |
+
from rlbench.observation_config import ObservationConfig
|
| 12 |
+
import rlbench.utils as rlbench_utils
|
| 13 |
+
from rlbench.demo import Demo
|
| 14 |
+
from yarr.replay_buffer.prioritized_replay_buffer import (
|
| 15 |
+
PrioritizedReplayBuffer,
|
| 16 |
+
ObservationElement,
|
| 17 |
+
)
|
| 18 |
+
from yarr.replay_buffer.replay_buffer import ReplayElement, ReplayBuffer
|
| 19 |
+
from yarr.replay_buffer.uniform_replay_buffer import UniformReplayBuffer
|
| 20 |
+
from yarr.replay_buffer.task_uniform_replay_buffer import TaskUniformReplayBuffer
|
| 21 |
+
|
| 22 |
+
from helpers import demo_loading_utils, utils
|
| 23 |
+
from helpers import observation_utils
|
| 24 |
+
from agents.baselines.vit_bc_lang.vit_bc_lang_agent import ViTBCLangAgent
|
| 25 |
+
from helpers.custom_rlbench_env import CustomRLBenchEnv
|
| 26 |
+
from helpers.network_utils import ViTLangAndFcsNet, ViT
|
| 27 |
+
from helpers.preprocess_agent import PreprocessAgent
|
| 28 |
+
|
| 29 |
+
import torch
|
| 30 |
+
from torch.multiprocessing import Process, Value, Manager
|
| 31 |
+
from helpers.clip.core.clip import build_model, load_clip, tokenize
|
| 32 |
+
|
| 33 |
+
LOW_DIM_SIZE = 4
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
def create_replay(
|
| 37 |
+
batch_size: int,
|
| 38 |
+
timesteps: int,
|
| 39 |
+
prioritisation: bool,
|
| 40 |
+
task_uniform: bool,
|
| 41 |
+
save_dir: str,
|
| 42 |
+
cameras: list,
|
| 43 |
+
image_size=[128, 128],
|
| 44 |
+
replay_size=3e5,
|
| 45 |
+
):
|
| 46 |
+
lang_feat_dim = 1024
|
| 47 |
+
|
| 48 |
+
# low_dim_state
|
| 49 |
+
observation_elements = []
|
| 50 |
+
observation_elements.append(
|
| 51 |
+
ObservationElement("low_dim_state", (LOW_DIM_SIZE,), np.float32)
|
| 52 |
+
)
|
| 53 |
+
|
| 54 |
+
# rgb, depth, point cloud, intrinsics, extrinsics
|
| 55 |
+
for cname in cameras:
|
| 56 |
+
observation_elements.append(
|
| 57 |
+
ObservationElement(
|
| 58 |
+
"%s_rgb" % cname,
|
| 59 |
+
(
|
| 60 |
+
3,
|
| 61 |
+
*image_size,
|
| 62 |
+
),
|
| 63 |
+
np.float32,
|
| 64 |
+
)
|
| 65 |
+
)
|
| 66 |
+
observation_elements.append(
|
| 67 |
+
ObservationElement("%s_point_cloud" % cname, (3, *image_size), np.float32)
|
| 68 |
+
) # see pyrep/objects/vision_sensor.py on how pointclouds are extracted from depth frames
|
| 69 |
+
observation_elements.append(
|
| 70 |
+
ObservationElement(
|
| 71 |
+
"%s_camera_extrinsics" % cname,
|
| 72 |
+
(
|
| 73 |
+
4,
|
| 74 |
+
4,
|
| 75 |
+
),
|
| 76 |
+
np.float32,
|
| 77 |
+
)
|
| 78 |
+
)
|
| 79 |
+
observation_elements.append(
|
| 80 |
+
ObservationElement(
|
| 81 |
+
"%s_camera_intrinsics" % cname,
|
| 82 |
+
(
|
| 83 |
+
3,
|
| 84 |
+
3,
|
| 85 |
+
),
|
| 86 |
+
np.float32,
|
| 87 |
+
)
|
| 88 |
+
)
|
| 89 |
+
|
| 90 |
+
observation_elements.extend(
|
| 91 |
+
[
|
| 92 |
+
ReplayElement("lang_goal_emb", (lang_feat_dim,), np.float32),
|
| 93 |
+
ReplayElement("task", (), str),
|
| 94 |
+
ReplayElement(
|
| 95 |
+
"lang_goal", (1,), object
|
| 96 |
+
), # language goal string for debugging and visualization
|
| 97 |
+
]
|
| 98 |
+
)
|
| 99 |
+
|
| 100 |
+
extra_replay_elements = [
|
| 101 |
+
ReplayElement("demo", (), np.bool),
|
| 102 |
+
]
|
| 103 |
+
|
| 104 |
+
replay_buffer = TaskUniformReplayBuffer(
|
| 105 |
+
save_dir=save_dir,
|
| 106 |
+
batch_size=batch_size,
|
| 107 |
+
timesteps=timesteps,
|
| 108 |
+
replay_capacity=int(replay_size),
|
| 109 |
+
action_shape=(8,),
|
| 110 |
+
action_dtype=np.float32,
|
| 111 |
+
reward_shape=(),
|
| 112 |
+
reward_dtype=np.float32,
|
| 113 |
+
update_horizon=1,
|
| 114 |
+
observation_elements=observation_elements,
|
| 115 |
+
extra_replay_elements=extra_replay_elements,
|
| 116 |
+
)
|
| 117 |
+
return replay_buffer
|
| 118 |
+
|
| 119 |
+
|
| 120 |
+
def _get_action(obs_tp1: Observation):
|
| 121 |
+
quat = utils.normalize_quaternion(obs_tp1.gripper_pose[3:])
|
| 122 |
+
if quat[-1] < 0:
|
| 123 |
+
quat = -quat
|
| 124 |
+
return np.concatenate(
|
| 125 |
+
[obs_tp1.gripper_pose[:3], quat, [float(obs_tp1.gripper_open)]]
|
| 126 |
+
)
|
| 127 |
+
|
| 128 |
+
|
| 129 |
+
def _add_keypoints_to_replay(
|
| 130 |
+
cfg: DictConfig,
|
| 131 |
+
task: str,
|
| 132 |
+
replay: ReplayBuffer,
|
| 133 |
+
inital_obs: Observation,
|
| 134 |
+
demo: Demo,
|
| 135 |
+
episode_keypoints: List[int],
|
| 136 |
+
cameras: List[str],
|
| 137 |
+
description: str = "",
|
| 138 |
+
clip_model=None,
|
| 139 |
+
device="cpu",
|
| 140 |
+
):
|
| 141 |
+
prev_action = None
|
| 142 |
+
obs = inital_obs
|
| 143 |
+
all_actions = []
|
| 144 |
+
for k, keypoint in enumerate(episode_keypoints):
|
| 145 |
+
obs_tp1 = demo[keypoint]
|
| 146 |
+
action = _get_action(obs_tp1)
|
| 147 |
+
all_actions.append(action)
|
| 148 |
+
terminal = k == len(episode_keypoints) - 1
|
| 149 |
+
reward = float(terminal) if terminal else 0
|
| 150 |
+
|
| 151 |
+
obs_dict = observation_utils.extract_obs(
|
| 152 |
+
obs,
|
| 153 |
+
t=k,
|
| 154 |
+
prev_action=prev_action,
|
| 155 |
+
cameras=cameras,
|
| 156 |
+
episode_length=cfg.rlbench.episode_length,
|
| 157 |
+
robot_name=cfg.method.robot_name,
|
| 158 |
+
)
|
| 159 |
+
del obs_dict["ignore_collisions"]
|
| 160 |
+
tokens = tokenize([description]).numpy()
|
| 161 |
+
token_tensor = torch.from_numpy(tokens).to(device)
|
| 162 |
+
lang_feats, lang_embs = clip_model.encode_text_with_embeddings(token_tensor)
|
| 163 |
+
obs_dict["lang_goal_emb"] = lang_feats[0].float().detach().cpu().numpy()
|
| 164 |
+
|
| 165 |
+
final_obs = {
|
| 166 |
+
"task": task,
|
| 167 |
+
"lang_goal": np.array([description], dtype=object),
|
| 168 |
+
}
|
| 169 |
+
|
| 170 |
+
prev_action = np.copy(action)
|
| 171 |
+
others = {"demo": True}
|
| 172 |
+
others.update(final_obs)
|
| 173 |
+
others.update(obs_dict)
|
| 174 |
+
timeout = False
|
| 175 |
+
replay.add(action, reward, terminal, timeout, **others)
|
| 176 |
+
obs = obs_tp1 # Set the next obs
|
| 177 |
+
# Final step
|
| 178 |
+
obs_dict_tp1 = observation_utils.extract_obs(
|
| 179 |
+
obs_tp1,
|
| 180 |
+
t=k + 1,
|
| 181 |
+
prev_action=prev_action,
|
| 182 |
+
cameras=cameras,
|
| 183 |
+
episode_length=cfg.rlbench.episode_length,
|
| 184 |
+
robot_name=cfg.method.robot_name,
|
| 185 |
+
)
|
| 186 |
+
obs_dict_tp1["lang_goal_emb"] = lang_feats[0].float().detach().cpu().numpy()
|
| 187 |
+
# del obs_dict_tp1['lang_goal_tokens']
|
| 188 |
+
del obs_dict_tp1["ignore_collisions"]
|
| 189 |
+
# obs_dict_tp1['task'] = task
|
| 190 |
+
obs_dict_tp1.update(final_obs)
|
| 191 |
+
replay.add_final(**obs_dict_tp1)
|
| 192 |
+
return all_actions
|
| 193 |
+
|
| 194 |
+
|
| 195 |
+
def fill_replay(
|
| 196 |
+
cfg: DictConfig,
|
| 197 |
+
obs_config: ObservationConfig,
|
| 198 |
+
rank: int,
|
| 199 |
+
replay: ReplayBuffer,
|
| 200 |
+
task: str,
|
| 201 |
+
num_demos: int,
|
| 202 |
+
demo_augmentation: bool,
|
| 203 |
+
demo_augmentation_every_n: int,
|
| 204 |
+
cameras: List[str],
|
| 205 |
+
clip_model=None,
|
| 206 |
+
device="cpu",
|
| 207 |
+
):
|
| 208 |
+
if clip_model is None:
|
| 209 |
+
model, _ = load_clip("RN50", jit=False, device=device)
|
| 210 |
+
clip_model = build_model(model.state_dict())
|
| 211 |
+
clip_model.to(device)
|
| 212 |
+
del model
|
| 213 |
+
|
| 214 |
+
logging.debug("Filling %s replay ..." % task)
|
| 215 |
+
all_actions = []
|
| 216 |
+
for d_idx in range(num_demos):
|
| 217 |
+
# load demo from disk
|
| 218 |
+
demo = rlbench_utils.get_stored_demos(
|
| 219 |
+
amount=1,
|
| 220 |
+
image_paths=False,
|
| 221 |
+
dataset_root=cfg.rlbench.demo_path,
|
| 222 |
+
variation_number=-1,
|
| 223 |
+
task_name=task,
|
| 224 |
+
obs_config=obs_config,
|
| 225 |
+
random_selection=False,
|
| 226 |
+
from_episode_number=d_idx,
|
| 227 |
+
)[0]
|
| 228 |
+
|
| 229 |
+
descs = demo._observations[0].misc["descriptions"]
|
| 230 |
+
|
| 231 |
+
# extract keypoints (a.k.a keyframes)
|
| 232 |
+
episode_keypoints = demo_loading_utils.keypoint_discovery(demo)
|
| 233 |
+
|
| 234 |
+
if rank == 0:
|
| 235 |
+
logging.info(
|
| 236 |
+
f"Loading Demo({d_idx}) - found {len(episode_keypoints)} keypoints - {task}"
|
| 237 |
+
)
|
| 238 |
+
|
| 239 |
+
for i in range(len(demo) - 1):
|
| 240 |
+
if not demo_augmentation and i > 0:
|
| 241 |
+
break
|
| 242 |
+
if i % demo_augmentation_every_n != 0:
|
| 243 |
+
continue
|
| 244 |
+
|
| 245 |
+
obs = demo[i]
|
| 246 |
+
desc = descs[0]
|
| 247 |
+
# if our starting point is past one of the keypoints, then remove it
|
| 248 |
+
while len(episode_keypoints) > 0 and i >= episode_keypoints[0]:
|
| 249 |
+
episode_keypoints = episode_keypoints[1:]
|
| 250 |
+
if len(episode_keypoints) == 0:
|
| 251 |
+
break
|
| 252 |
+
all_actions.extend(
|
| 253 |
+
_add_keypoints_to_replay(
|
| 254 |
+
cfg,
|
| 255 |
+
task,
|
| 256 |
+
replay,
|
| 257 |
+
obs,
|
| 258 |
+
demo,
|
| 259 |
+
episode_keypoints,
|
| 260 |
+
cameras,
|
| 261 |
+
description=desc,
|
| 262 |
+
clip_model=clip_model,
|
| 263 |
+
device=device,
|
| 264 |
+
)
|
| 265 |
+
)
|
| 266 |
+
logging.debug("Replay filled with demos.")
|
| 267 |
+
return all_actions
|
| 268 |
+
|
| 269 |
+
|
| 270 |
+
def fill_multi_task_replay(
|
| 271 |
+
cfg: DictConfig,
|
| 272 |
+
obs_config: ObservationConfig,
|
| 273 |
+
rank: int,
|
| 274 |
+
replay: ReplayBuffer,
|
| 275 |
+
tasks: List[str],
|
| 276 |
+
num_demos: int,
|
| 277 |
+
demo_augmentation: bool,
|
| 278 |
+
demo_augmentation_every_n: int,
|
| 279 |
+
cameras: List[str],
|
| 280 |
+
clip_model=None,
|
| 281 |
+
):
|
| 282 |
+
manager = Manager()
|
| 283 |
+
store = manager.dict()
|
| 284 |
+
|
| 285 |
+
# create a MP dict for storing indicies
|
| 286 |
+
# TODO(mohit): this shouldn't be initialized here
|
| 287 |
+
del replay._task_idxs
|
| 288 |
+
task_idxs = manager.dict()
|
| 289 |
+
replay._task_idxs = task_idxs
|
| 290 |
+
replay._create_storage(store)
|
| 291 |
+
replay.add_count = Value("i", 0)
|
| 292 |
+
|
| 293 |
+
# fill replay buffer in parallel across tasks
|
| 294 |
+
max_parallel_processes = cfg.replay.max_parallel_processes
|
| 295 |
+
processes = []
|
| 296 |
+
n = np.arange(len(tasks))
|
| 297 |
+
split_n = utils.split_list(n, max_parallel_processes)
|
| 298 |
+
for split in split_n:
|
| 299 |
+
for e_idx, task_idx in enumerate(split):
|
| 300 |
+
task = tasks[int(task_idx)]
|
| 301 |
+
model_device = torch.device(
|
| 302 |
+
"cuda:%s" % (e_idx % torch.cuda.device_count())
|
| 303 |
+
if torch.cuda.is_available()
|
| 304 |
+
else "cpu"
|
| 305 |
+
)
|
| 306 |
+
p = Process(
|
| 307 |
+
target=fill_replay,
|
| 308 |
+
args=(
|
| 309 |
+
cfg,
|
| 310 |
+
obs_config,
|
| 311 |
+
rank,
|
| 312 |
+
replay,
|
| 313 |
+
task,
|
| 314 |
+
num_demos,
|
| 315 |
+
demo_augmentation,
|
| 316 |
+
demo_augmentation_every_n,
|
| 317 |
+
cameras,
|
| 318 |
+
clip_model,
|
| 319 |
+
model_device,
|
| 320 |
+
),
|
| 321 |
+
)
|
| 322 |
+
p.start()
|
| 323 |
+
processes.append(p)
|
| 324 |
+
|
| 325 |
+
for p in processes:
|
| 326 |
+
p.join()
|
| 327 |
+
|
| 328 |
+
logging.debug("Replay filled with multi demos.")
|
| 329 |
+
|
| 330 |
+
|
| 331 |
+
def create_agent(cfg: DictConfig):
|
| 332 |
+
camera_name = cfg.rlbench.cameras
|
| 333 |
+
activation = cfg.method.activation
|
| 334 |
+
lr = cfg.method.lr
|
| 335 |
+
weight_decay = cfg.method.weight_decay
|
| 336 |
+
image_resolution = cfg.rlbench.camera_resolution
|
| 337 |
+
grad_clip = cfg.method.grad_clip
|
| 338 |
+
|
| 339 |
+
vit = ViT(
|
| 340 |
+
image_size=128,
|
| 341 |
+
patch_size=8,
|
| 342 |
+
num_classes=16,
|
| 343 |
+
dim=64,
|
| 344 |
+
depth=6,
|
| 345 |
+
heads=8,
|
| 346 |
+
mlp_dim=64,
|
| 347 |
+
dropout=0.1,
|
| 348 |
+
emb_dropout=0.1,
|
| 349 |
+
channels=6,
|
| 350 |
+
)
|
| 351 |
+
|
| 352 |
+
actor_net = ViTLangAndFcsNet(
|
| 353 |
+
vit=vit,
|
| 354 |
+
input_resolution=image_resolution,
|
| 355 |
+
filters=[64, 96, 128],
|
| 356 |
+
kernel_sizes=[1, 1, 1],
|
| 357 |
+
strides=[1, 1, 1],
|
| 358 |
+
norm=None,
|
| 359 |
+
activation=activation,
|
| 360 |
+
fc_layers=[128, 64, 3 + 4 + 1],
|
| 361 |
+
low_dim_state_len=LOW_DIM_SIZE,
|
| 362 |
+
)
|
| 363 |
+
|
| 364 |
+
bc_agent = ViTBCLangAgent(
|
| 365 |
+
actor_network=actor_net,
|
| 366 |
+
camera_name=camera_name,
|
| 367 |
+
lr=lr,
|
| 368 |
+
weight_decay=weight_decay,
|
| 369 |
+
grad_clip=grad_clip,
|
| 370 |
+
)
|
| 371 |
+
|
| 372 |
+
return PreprocessAgent(pose_agent=bc_agent)
|
external/peract_bimanual/agents/baselines/vit_bc_lang/vit_bc_lang_agent.py
ADDED
|
@@ -0,0 +1,148 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import copy
|
| 2 |
+
import logging
|
| 3 |
+
import os
|
| 4 |
+
from typing import List
|
| 5 |
+
|
| 6 |
+
import torch
|
| 7 |
+
import torch.nn as nn
|
| 8 |
+
import torch.nn.functional as F
|
| 9 |
+
from yarr.agents.agent import Agent, Summary, ActResult, ScalarSummary, HistogramSummary
|
| 10 |
+
|
| 11 |
+
from helpers import utils
|
| 12 |
+
from helpers.utils import stack_on_channel
|
| 13 |
+
|
| 14 |
+
from helpers.clip.core.clip import build_model, load_clip
|
| 15 |
+
|
| 16 |
+
NAME = "ViTBCLangAgent"
|
| 17 |
+
REPLAY_ALPHA = 0.7
|
| 18 |
+
REPLAY_BETA = 1.0
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
class Actor(nn.Module):
|
| 22 |
+
def __init__(self, actor_network: nn.Module):
|
| 23 |
+
super(Actor, self).__init__()
|
| 24 |
+
self._actor_network = copy.deepcopy(actor_network)
|
| 25 |
+
self._actor_network.build()
|
| 26 |
+
|
| 27 |
+
def forward(self, observations, robot_state, lang_goal_emb):
|
| 28 |
+
mu = self._actor_network(observations, robot_state, lang_goal_emb)
|
| 29 |
+
return mu
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
class ViTBCLangAgent(Agent):
|
| 33 |
+
def __init__(
|
| 34 |
+
self,
|
| 35 |
+
actor_network: nn.Module,
|
| 36 |
+
camera_name: str,
|
| 37 |
+
lr: float = 0.01,
|
| 38 |
+
weight_decay: float = 1e-5,
|
| 39 |
+
grad_clip: float = 20.0,
|
| 40 |
+
):
|
| 41 |
+
self._camera_name = camera_name
|
| 42 |
+
self._actor_network = actor_network
|
| 43 |
+
self._lr = lr
|
| 44 |
+
self._weight_decay = weight_decay
|
| 45 |
+
self._grad_clip = grad_clip
|
| 46 |
+
|
| 47 |
+
def build(self, training: bool, device: torch.device = None):
|
| 48 |
+
if device is None:
|
| 49 |
+
device = torch.device("cpu")
|
| 50 |
+
self._actor = Actor(self._actor_network).to(device).train(training)
|
| 51 |
+
if training:
|
| 52 |
+
self._actor_optimizer = torch.optim.Adam(
|
| 53 |
+
self._actor.parameters(), lr=self._lr, weight_decay=self._weight_decay
|
| 54 |
+
)
|
| 55 |
+
logging.info(
|
| 56 |
+
"# Actor Params: %d"
|
| 57 |
+
% sum(p.numel() for p in self._actor.parameters() if p.requires_grad)
|
| 58 |
+
)
|
| 59 |
+
else:
|
| 60 |
+
for p in self._actor.parameters():
|
| 61 |
+
p.requires_grad = False
|
| 62 |
+
|
| 63 |
+
model, _ = load_clip("RN50", jit=False)
|
| 64 |
+
self._clip_rn50 = build_model(model.state_dict())
|
| 65 |
+
self._clip_rn50 = self._clip_rn50.float().to(device)
|
| 66 |
+
self._clip_rn50.eval()
|
| 67 |
+
del model
|
| 68 |
+
|
| 69 |
+
self._device = device
|
| 70 |
+
|
| 71 |
+
def _grad_step(self, loss, opt, model_params=None, clip=None):
|
| 72 |
+
opt.zero_grad()
|
| 73 |
+
loss.backward()
|
| 74 |
+
if clip is not None and model_params is not None:
|
| 75 |
+
nn.utils.clip_grad_value_(model_params, clip)
|
| 76 |
+
opt.step()
|
| 77 |
+
|
| 78 |
+
def update(self, step: int, replay_sample: dict) -> dict:
|
| 79 |
+
lang_goal_emb = replay_sample["lang_goal_emb"]
|
| 80 |
+
robot_state = replay_sample["low_dim_state"]
|
| 81 |
+
observations = [
|
| 82 |
+
replay_sample["%s_rgb" % self._camera_name],
|
| 83 |
+
replay_sample["%s_point_cloud" % self._camera_name],
|
| 84 |
+
]
|
| 85 |
+
mu = self._actor(observations, robot_state, lang_goal_emb)
|
| 86 |
+
loss_weights = utils.loss_weights(replay_sample, REPLAY_BETA)
|
| 87 |
+
delta = F.mse_loss(mu, replay_sample["action"], reduction="none").mean(1)
|
| 88 |
+
loss = (delta * loss_weights).mean()
|
| 89 |
+
self._grad_step(
|
| 90 |
+
loss, self._actor_optimizer, self._actor.parameters(), self._grad_clip
|
| 91 |
+
)
|
| 92 |
+
self._summaries = {
|
| 93 |
+
"pi/loss": loss,
|
| 94 |
+
"pi/mu": mu.mean(),
|
| 95 |
+
}
|
| 96 |
+
return {"total_losses": loss}
|
| 97 |
+
|
| 98 |
+
def _normalize_quat(self, x):
|
| 99 |
+
return x / x.square().sum(dim=1).sqrt().unsqueeze(-1)
|
| 100 |
+
|
| 101 |
+
def act(self, step: int, observation: dict, deterministic=False) -> ActResult:
|
| 102 |
+
lang_goal_tokens = observation.get("lang_goal_tokens", None).long()
|
| 103 |
+
|
| 104 |
+
with torch.no_grad():
|
| 105 |
+
lang_goal_tokens = lang_goal_tokens.to(device=self._device)
|
| 106 |
+
lang_goal_emb, _ = self._clip_rn50.encode_text_with_embeddings(
|
| 107 |
+
lang_goal_tokens[0]
|
| 108 |
+
)
|
| 109 |
+
lang_goal_emb = lang_goal_emb.to(device=self._device)
|
| 110 |
+
|
| 111 |
+
observations = [
|
| 112 |
+
observation["%s_rgb" % self._camera_name][0].to(self._device),
|
| 113 |
+
observation["%s_point_cloud" % self._camera_name][0].to(self._device),
|
| 114 |
+
]
|
| 115 |
+
robot_state = observation["low_dim_state"][0].to(self._device)
|
| 116 |
+
|
| 117 |
+
mu = self._actor(observations, robot_state, lang_goal_emb)
|
| 118 |
+
mu = torch.cat([mu[:, :3], self._normalize_quat(mu[:, 3:7]), mu[:, 7:]], dim=-1)
|
| 119 |
+
ignore_collisions = torch.Tensor([1.0]).to(mu.device)
|
| 120 |
+
mu0 = torch.cat([mu[0], ignore_collisions])
|
| 121 |
+
return ActResult(mu0.detach().cpu())
|
| 122 |
+
|
| 123 |
+
def update_summaries(self) -> List[Summary]:
|
| 124 |
+
summaries = []
|
| 125 |
+
for n, v in self._summaries.items():
|
| 126 |
+
summaries.append(ScalarSummary("%s/%s" % (NAME, n), v))
|
| 127 |
+
|
| 128 |
+
for tag, param in self._actor.named_parameters():
|
| 129 |
+
summaries.append(
|
| 130 |
+
HistogramSummary("%s/gradient/%s" % (NAME, tag), param.grad)
|
| 131 |
+
)
|
| 132 |
+
summaries.append(HistogramSummary("%s/weight/%s" % (NAME, tag), param.data))
|
| 133 |
+
|
| 134 |
+
return summaries
|
| 135 |
+
|
| 136 |
+
def act_summaries(self) -> List[Summary]:
|
| 137 |
+
return []
|
| 138 |
+
|
| 139 |
+
def load_weights(self, savedir: str):
|
| 140 |
+
self._actor.load_state_dict(
|
| 141 |
+
torch.load(
|
| 142 |
+
os.path.join(savedir, "bc_actor.pt"), map_location=torch.device("cpu")
|
| 143 |
+
)
|
| 144 |
+
)
|
| 145 |
+
print("Loaded weights from %s" % savedir)
|
| 146 |
+
|
| 147 |
+
def save_weights(self, savedir: str):
|
| 148 |
+
torch.save(self._actor.state_dict(), os.path.join(savedir, "bc_actor.pt"))
|
external/peract_bimanual/agents/bimanual_peract/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
import agents.bimanual_peract.launch_utils
|
external/peract_bimanual/agents/bimanual_peract/launch_utils.py
ADDED
|
@@ -0,0 +1,93 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Adapted from ARM
|
| 2 |
+
# Source: https://github.com/stepjam/ARM
|
| 3 |
+
# License: https://github.com/stepjam/ARM/LICENSE
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
from helpers.preprocess_agent import PreprocessAgent
|
| 7 |
+
|
| 8 |
+
from agents.bimanual_peract.perceiver_lang_io import PerceiverVoxelLangEncoder
|
| 9 |
+
from agents.bimanual_peract.qattention_peract_bc_agent import QAttentionPerActBCAgent
|
| 10 |
+
from agents.bimanual_peract.qattention_stack_agent import QAttentionStackAgent
|
| 11 |
+
|
| 12 |
+
from omegaconf import DictConfig
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
def create_agent(cfg: DictConfig):
|
| 16 |
+
depth_0bounds = cfg.rlbench.scene_bounds
|
| 17 |
+
cam_resolution = cfg.rlbench.camera_resolution
|
| 18 |
+
|
| 19 |
+
num_rotation_classes = int(360.0 // cfg.method.rotation_resolution)
|
| 20 |
+
qattention_agents = []
|
| 21 |
+
for depth, vox_size in enumerate(cfg.method.voxel_sizes):
|
| 22 |
+
last = depth == len(cfg.method.voxel_sizes) - 1
|
| 23 |
+
perceiver_encoder = PerceiverVoxelLangEncoder(
|
| 24 |
+
depth=cfg.method.transformer_depth,
|
| 25 |
+
iterations=cfg.method.transformer_iterations,
|
| 26 |
+
voxel_size=vox_size,
|
| 27 |
+
initial_dim=3 + 3 + 1 + 3,
|
| 28 |
+
low_dim_size=cfg.method.low_dim_size,
|
| 29 |
+
layer=depth,
|
| 30 |
+
num_rotation_classes=num_rotation_classes if last else 0,
|
| 31 |
+
num_grip_classes=2 if last else 0,
|
| 32 |
+
num_collision_classes=2 if last else 0,
|
| 33 |
+
input_axis=3,
|
| 34 |
+
num_latents=cfg.method.num_latents,
|
| 35 |
+
latent_dim=cfg.method.latent_dim,
|
| 36 |
+
cross_heads=cfg.method.cross_heads,
|
| 37 |
+
latent_heads=cfg.method.latent_heads,
|
| 38 |
+
cross_dim_head=cfg.method.cross_dim_head,
|
| 39 |
+
latent_dim_head=cfg.method.latent_dim_head,
|
| 40 |
+
weight_tie_layers=False,
|
| 41 |
+
activation=cfg.method.activation,
|
| 42 |
+
pos_encoding_with_lang=cfg.method.pos_encoding_with_lang,
|
| 43 |
+
input_dropout=cfg.method.input_dropout,
|
| 44 |
+
attn_dropout=cfg.method.attn_dropout,
|
| 45 |
+
decoder_dropout=cfg.method.decoder_dropout,
|
| 46 |
+
lang_fusion_type=cfg.method.lang_fusion_type,
|
| 47 |
+
voxel_patch_size=cfg.method.voxel_patch_size,
|
| 48 |
+
voxel_patch_stride=cfg.method.voxel_patch_stride,
|
| 49 |
+
no_skip_connection=cfg.method.no_skip_connection,
|
| 50 |
+
no_perceiver=cfg.method.no_perceiver,
|
| 51 |
+
no_language=cfg.method.no_language,
|
| 52 |
+
final_dim=cfg.method.final_dim,
|
| 53 |
+
)
|
| 54 |
+
|
| 55 |
+
qattention_agent = QAttentionPerActBCAgent(
|
| 56 |
+
layer=depth,
|
| 57 |
+
coordinate_bounds=depth_0bounds,
|
| 58 |
+
perceiver_encoder=perceiver_encoder,
|
| 59 |
+
camera_names=cfg.rlbench.cameras,
|
| 60 |
+
voxel_size=vox_size,
|
| 61 |
+
bounds_offset=cfg.method.bounds_offset[depth - 1] if depth > 0 else None,
|
| 62 |
+
image_crop_size=cfg.method.image_crop_size,
|
| 63 |
+
lr=cfg.method.lr,
|
| 64 |
+
training_iterations=cfg.framework.training_iterations,
|
| 65 |
+
lr_scheduler=cfg.method.lr_scheduler,
|
| 66 |
+
num_warmup_steps=cfg.method.num_warmup_steps,
|
| 67 |
+
trans_loss_weight=cfg.method.trans_loss_weight,
|
| 68 |
+
rot_loss_weight=cfg.method.rot_loss_weight,
|
| 69 |
+
grip_loss_weight=cfg.method.grip_loss_weight,
|
| 70 |
+
collision_loss_weight=cfg.method.collision_loss_weight,
|
| 71 |
+
include_low_dim_state=True,
|
| 72 |
+
image_resolution=cam_resolution,
|
| 73 |
+
batch_size=cfg.replay.batch_size,
|
| 74 |
+
voxel_feature_size=3,
|
| 75 |
+
lambda_weight_l2=cfg.method.lambda_weight_l2,
|
| 76 |
+
num_rotation_classes=num_rotation_classes,
|
| 77 |
+
rotation_resolution=cfg.method.rotation_resolution,
|
| 78 |
+
transform_augmentation=cfg.method.transform_augmentation.apply_se3,
|
| 79 |
+
transform_augmentation_xyz=cfg.method.transform_augmentation.aug_xyz,
|
| 80 |
+
transform_augmentation_rpy=cfg.method.transform_augmentation.aug_rpy,
|
| 81 |
+
transform_augmentation_rot_resolution=cfg.method.transform_augmentation.aug_rot_resolution,
|
| 82 |
+
optimizer_type=cfg.method.optimizer,
|
| 83 |
+
num_devices=cfg.ddp.num_devices,
|
| 84 |
+
)
|
| 85 |
+
qattention_agents.append(qattention_agent)
|
| 86 |
+
|
| 87 |
+
rotation_agent = QAttentionStackAgent(
|
| 88 |
+
qattention_agents=qattention_agents,
|
| 89 |
+
rotation_resolution=cfg.method.rotation_resolution,
|
| 90 |
+
camera_names=cfg.rlbench.cameras,
|
| 91 |
+
)
|
| 92 |
+
preprocess_agent = PreprocessAgent(pose_agent=rotation_agent)
|
| 93 |
+
return preprocess_agent
|
external/peract_bimanual/agents/bimanual_peract/perceiver_lang_io.py
ADDED
|
@@ -0,0 +1,549 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Perceiver IO implementation adpated for manipulation
|
| 2 |
+
# Source: https://github.com/lucidrains/perceiver-pytorch
|
| 3 |
+
# License: https://github.com/lucidrains/perceiver-pytorch/blob/main/LICENSE
|
| 4 |
+
|
| 5 |
+
import torch
|
| 6 |
+
from torch import nn
|
| 7 |
+
|
| 8 |
+
from einops import rearrange
|
| 9 |
+
from einops import repeat
|
| 10 |
+
|
| 11 |
+
from perceiver_pytorch.perceiver_pytorch import cache_fn
|
| 12 |
+
from perceiver_pytorch.perceiver_pytorch import PreNorm, FeedForward, Attention
|
| 13 |
+
|
| 14 |
+
from helpers.network_utils import (
|
| 15 |
+
DenseBlock,
|
| 16 |
+
SpatialSoftmax3D,
|
| 17 |
+
Conv3DBlock,
|
| 18 |
+
Conv3DUpsampleBlock,
|
| 19 |
+
)
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
# PerceiverIO adapted for 6-DoF manipulation
|
| 23 |
+
class PerceiverVoxelLangEncoder(nn.Module):
|
| 24 |
+
def __init__(
|
| 25 |
+
self,
|
| 26 |
+
depth, # number of self-attention layers
|
| 27 |
+
iterations, # number cross-attention iterations (PerceiverIO uses just 1)
|
| 28 |
+
voxel_size, # N voxels per side (size: N*N*N)
|
| 29 |
+
initial_dim, # 10 dimensions - dimension of the input sequence to be encoded
|
| 30 |
+
low_dim_size, # 4 dimensions - proprioception: {gripper_open, left_finger, right_finger, timestep}
|
| 31 |
+
layer=0,
|
| 32 |
+
num_rotation_classes=72, # 5 degree increments (5*72=360) for each of the 3-axis
|
| 33 |
+
num_grip_classes=2, # open or not open
|
| 34 |
+
num_collision_classes=2, # collisions allowed or not allowed
|
| 35 |
+
input_axis=3, # 3D tensors have 3 axes
|
| 36 |
+
num_latents=512, # number of latent vectors
|
| 37 |
+
im_channels=64, # intermediate channel size
|
| 38 |
+
latent_dim=512, # dimensions of latent vectors
|
| 39 |
+
cross_heads=1, # number of cross-attention heads
|
| 40 |
+
latent_heads=8, # number of latent heads
|
| 41 |
+
cross_dim_head=64,
|
| 42 |
+
latent_dim_head=64,
|
| 43 |
+
activation="relu",
|
| 44 |
+
weight_tie_layers=False,
|
| 45 |
+
pos_encoding_with_lang=True,
|
| 46 |
+
input_dropout=0.1,
|
| 47 |
+
attn_dropout=0.1,
|
| 48 |
+
decoder_dropout=0.0,
|
| 49 |
+
lang_fusion_type="seq",
|
| 50 |
+
voxel_patch_size=9,
|
| 51 |
+
voxel_patch_stride=8,
|
| 52 |
+
no_skip_connection=False,
|
| 53 |
+
no_perceiver=False,
|
| 54 |
+
no_language=False,
|
| 55 |
+
final_dim=64,
|
| 56 |
+
):
|
| 57 |
+
super().__init__()
|
| 58 |
+
self.depth = depth
|
| 59 |
+
self.layer = layer
|
| 60 |
+
self.init_dim = int(initial_dim)
|
| 61 |
+
self.iterations = iterations
|
| 62 |
+
self.input_axis = input_axis
|
| 63 |
+
self.voxel_size = voxel_size
|
| 64 |
+
self.low_dim_size = low_dim_size
|
| 65 |
+
self.im_channels = im_channels
|
| 66 |
+
self.pos_encoding_with_lang = pos_encoding_with_lang
|
| 67 |
+
self.lang_fusion_type = lang_fusion_type
|
| 68 |
+
self.voxel_patch_size = voxel_patch_size
|
| 69 |
+
self.voxel_patch_stride = voxel_patch_stride
|
| 70 |
+
self.num_rotation_classes = num_rotation_classes
|
| 71 |
+
self.num_grip_classes = num_grip_classes
|
| 72 |
+
self.num_collision_classes = num_collision_classes
|
| 73 |
+
self.final_dim = final_dim
|
| 74 |
+
self.input_dropout = input_dropout
|
| 75 |
+
self.attn_dropout = attn_dropout
|
| 76 |
+
self.decoder_dropout = decoder_dropout
|
| 77 |
+
self.no_skip_connection = no_skip_connection
|
| 78 |
+
self.no_perceiver = no_perceiver
|
| 79 |
+
self.no_language = no_language
|
| 80 |
+
|
| 81 |
+
# patchified input dimensions
|
| 82 |
+
spatial_size = voxel_size // self.voxel_patch_stride # 100/5 = 20
|
| 83 |
+
|
| 84 |
+
# 64 voxel features + 64 proprio features (+ 64 lang goal features if concattenated)
|
| 85 |
+
self.input_dim_before_seq = (
|
| 86 |
+
self.im_channels * 3
|
| 87 |
+
if self.lang_fusion_type == "concat"
|
| 88 |
+
else self.im_channels * 2
|
| 89 |
+
)
|
| 90 |
+
|
| 91 |
+
# CLIP language feature dimensions
|
| 92 |
+
lang_feat_dim, lang_emb_dim, lang_max_seq_len = 1024, 512, 77
|
| 93 |
+
|
| 94 |
+
# learnable positional encoding
|
| 95 |
+
if self.pos_encoding_with_lang:
|
| 96 |
+
self.pos_encoding = nn.Parameter(
|
| 97 |
+
torch.randn(
|
| 98 |
+
1, lang_max_seq_len + spatial_size**3, self.input_dim_before_seq
|
| 99 |
+
)
|
| 100 |
+
)
|
| 101 |
+
else:
|
| 102 |
+
# assert self.lang_fusion_type == 'concat', 'Only concat is supported for pos encoding without lang.'
|
| 103 |
+
self.pos_encoding = nn.Parameter(
|
| 104 |
+
torch.randn(
|
| 105 |
+
1,
|
| 106 |
+
spatial_size,
|
| 107 |
+
spatial_size,
|
| 108 |
+
spatial_size,
|
| 109 |
+
self.input_dim_before_seq,
|
| 110 |
+
)
|
| 111 |
+
)
|
| 112 |
+
|
| 113 |
+
# voxel input preprocessing 1x1 conv encoder
|
| 114 |
+
self.input_preprocess = Conv3DBlock(
|
| 115 |
+
self.init_dim,
|
| 116 |
+
self.im_channels,
|
| 117 |
+
kernel_sizes=1,
|
| 118 |
+
strides=1,
|
| 119 |
+
norm=None,
|
| 120 |
+
activation=activation,
|
| 121 |
+
)
|
| 122 |
+
|
| 123 |
+
# patchify conv
|
| 124 |
+
self.patchify = Conv3DBlock(
|
| 125 |
+
self.input_preprocess.out_channels,
|
| 126 |
+
self.im_channels,
|
| 127 |
+
kernel_sizes=self.voxel_patch_size,
|
| 128 |
+
strides=self.voxel_patch_stride,
|
| 129 |
+
norm=None,
|
| 130 |
+
activation=activation,
|
| 131 |
+
)
|
| 132 |
+
|
| 133 |
+
# language preprocess
|
| 134 |
+
if self.lang_fusion_type == "concat":
|
| 135 |
+
self.lang_preprocess = nn.Linear(lang_feat_dim, self.im_channels)
|
| 136 |
+
elif self.lang_fusion_type == "seq":
|
| 137 |
+
self.lang_preprocess = nn.Linear(lang_emb_dim, self.im_channels * 2)
|
| 138 |
+
|
| 139 |
+
# proprioception
|
| 140 |
+
if self.low_dim_size > 0:
|
| 141 |
+
self.proprio_preprocess = DenseBlock(
|
| 142 |
+
self.low_dim_size,
|
| 143 |
+
self.im_channels,
|
| 144 |
+
norm=None,
|
| 145 |
+
activation=activation,
|
| 146 |
+
)
|
| 147 |
+
|
| 148 |
+
# pooling functions
|
| 149 |
+
self.local_maxp = nn.MaxPool3d(3, 2, padding=1)
|
| 150 |
+
self.global_maxp = nn.AdaptiveMaxPool3d(1)
|
| 151 |
+
|
| 152 |
+
# 1st 3D softmax
|
| 153 |
+
self.ss0 = SpatialSoftmax3D(
|
| 154 |
+
self.voxel_size, self.voxel_size, self.voxel_size, self.im_channels
|
| 155 |
+
)
|
| 156 |
+
flat_size = self.im_channels * 4
|
| 157 |
+
|
| 158 |
+
# latent vectors (that are randomly initialized)
|
| 159 |
+
self.latents = nn.Parameter(torch.randn(num_latents, latent_dim))
|
| 160 |
+
|
| 161 |
+
# encoder cross attention
|
| 162 |
+
self.cross_attend_blocks = nn.ModuleList(
|
| 163 |
+
[
|
| 164 |
+
PreNorm(
|
| 165 |
+
latent_dim,
|
| 166 |
+
Attention(
|
| 167 |
+
latent_dim,
|
| 168 |
+
self.input_dim_before_seq,
|
| 169 |
+
heads=cross_heads,
|
| 170 |
+
dim_head=cross_dim_head,
|
| 171 |
+
dropout=input_dropout,
|
| 172 |
+
),
|
| 173 |
+
context_dim=self.input_dim_before_seq,
|
| 174 |
+
),
|
| 175 |
+
PreNorm(latent_dim, FeedForward(latent_dim)),
|
| 176 |
+
PreNorm(latent_dim, FeedForward(latent_dim)),
|
| 177 |
+
]
|
| 178 |
+
)
|
| 179 |
+
|
| 180 |
+
get_latent_attn = lambda: PreNorm(
|
| 181 |
+
latent_dim,
|
| 182 |
+
Attention(
|
| 183 |
+
latent_dim,
|
| 184 |
+
heads=latent_heads,
|
| 185 |
+
dim_head=latent_dim_head,
|
| 186 |
+
dropout=attn_dropout,
|
| 187 |
+
),
|
| 188 |
+
)
|
| 189 |
+
get_latent_ff = lambda: PreNorm(latent_dim, FeedForward(latent_dim))
|
| 190 |
+
get_latent_attn, get_latent_ff = map(cache_fn, (get_latent_attn, get_latent_ff))
|
| 191 |
+
|
| 192 |
+
# self attention layers
|
| 193 |
+
self.layers = nn.ModuleList([])
|
| 194 |
+
cache_args = {"_cache": weight_tie_layers}
|
| 195 |
+
|
| 196 |
+
for i in range(depth):
|
| 197 |
+
self.layers.append(
|
| 198 |
+
nn.ModuleList(
|
| 199 |
+
[
|
| 200 |
+
get_latent_attn(**cache_args),
|
| 201 |
+
get_latent_ff(**cache_args),
|
| 202 |
+
get_latent_attn(**cache_args),
|
| 203 |
+
get_latent_ff(**cache_args),
|
| 204 |
+
]
|
| 205 |
+
)
|
| 206 |
+
)
|
| 207 |
+
|
| 208 |
+
self.combined_latent_attn = get_latent_attn(**cache_args)
|
| 209 |
+
self.combined_latent_ff = get_latent_ff(**cache_args)
|
| 210 |
+
|
| 211 |
+
# decoder cross attention
|
| 212 |
+
self.decoder_cross_attn_right = PreNorm(
|
| 213 |
+
self.input_dim_before_seq,
|
| 214 |
+
Attention(
|
| 215 |
+
self.input_dim_before_seq,
|
| 216 |
+
latent_dim,
|
| 217 |
+
heads=cross_heads,
|
| 218 |
+
dim_head=cross_dim_head,
|
| 219 |
+
dropout=decoder_dropout,
|
| 220 |
+
),
|
| 221 |
+
context_dim=latent_dim,
|
| 222 |
+
)
|
| 223 |
+
|
| 224 |
+
self.decoder_cross_attn_left = PreNorm(
|
| 225 |
+
self.input_dim_before_seq,
|
| 226 |
+
Attention(
|
| 227 |
+
self.input_dim_before_seq,
|
| 228 |
+
latent_dim,
|
| 229 |
+
heads=cross_heads,
|
| 230 |
+
dim_head=cross_dim_head,
|
| 231 |
+
dropout=decoder_dropout,
|
| 232 |
+
),
|
| 233 |
+
context_dim=latent_dim,
|
| 234 |
+
)
|
| 235 |
+
|
| 236 |
+
# upsample conv
|
| 237 |
+
self.up0 = Conv3DUpsampleBlock(
|
| 238 |
+
self.input_dim_before_seq,
|
| 239 |
+
self.final_dim,
|
| 240 |
+
kernel_sizes=self.voxel_patch_size,
|
| 241 |
+
strides=self.voxel_patch_stride,
|
| 242 |
+
norm=None,
|
| 243 |
+
activation=activation,
|
| 244 |
+
)
|
| 245 |
+
|
| 246 |
+
# 2nd 3D softmax
|
| 247 |
+
self.ss1 = SpatialSoftmax3D(
|
| 248 |
+
spatial_size, spatial_size, spatial_size, self.input_dim_before_seq
|
| 249 |
+
)
|
| 250 |
+
|
| 251 |
+
flat_size += self.input_dim_before_seq * 4
|
| 252 |
+
|
| 253 |
+
# final 3D softmax
|
| 254 |
+
self.final = Conv3DBlock(
|
| 255 |
+
self.im_channels
|
| 256 |
+
if (self.no_perceiver or self.no_skip_connection)
|
| 257 |
+
else self.im_channels * 2,
|
| 258 |
+
self.im_channels,
|
| 259 |
+
kernel_sizes=3,
|
| 260 |
+
strides=1,
|
| 261 |
+
norm=None,
|
| 262 |
+
activation=activation,
|
| 263 |
+
)
|
| 264 |
+
|
| 265 |
+
self.right_trans_decoder = Conv3DBlock(
|
| 266 |
+
self.final_dim,
|
| 267 |
+
1,
|
| 268 |
+
kernel_sizes=3,
|
| 269 |
+
strides=1,
|
| 270 |
+
norm=None,
|
| 271 |
+
activation=None,
|
| 272 |
+
)
|
| 273 |
+
|
| 274 |
+
self.left_trans_decoder = Conv3DBlock(
|
| 275 |
+
self.final_dim,
|
| 276 |
+
1,
|
| 277 |
+
kernel_sizes=3,
|
| 278 |
+
strides=1,
|
| 279 |
+
norm=None,
|
| 280 |
+
activation=None,
|
| 281 |
+
)
|
| 282 |
+
|
| 283 |
+
# rotation, gripper, and collision MLP layers
|
| 284 |
+
if self.num_rotation_classes > 0:
|
| 285 |
+
self.ss_final = SpatialSoftmax3D(
|
| 286 |
+
self.voxel_size, self.voxel_size, self.voxel_size, self.im_channels
|
| 287 |
+
)
|
| 288 |
+
|
| 289 |
+
flat_size += self.im_channels * 4
|
| 290 |
+
|
| 291 |
+
self.right_dense0 = DenseBlock(flat_size, 256, None, activation)
|
| 292 |
+
self.right_dense1 = DenseBlock(256, self.final_dim, None, activation)
|
| 293 |
+
|
| 294 |
+
self.left_dense0 = DenseBlock(flat_size, 256, None, activation)
|
| 295 |
+
self.left_dense1 = DenseBlock(256, self.final_dim, None, activation)
|
| 296 |
+
|
| 297 |
+
self.right_rot_grip_collision_ff = DenseBlock(
|
| 298 |
+
self.final_dim,
|
| 299 |
+
self.num_rotation_classes * 3
|
| 300 |
+
+ self.num_grip_classes
|
| 301 |
+
+ self.num_collision_classes,
|
| 302 |
+
None,
|
| 303 |
+
None,
|
| 304 |
+
)
|
| 305 |
+
|
| 306 |
+
self.left_rot_grip_collision_ff = DenseBlock(
|
| 307 |
+
self.final_dim,
|
| 308 |
+
self.num_rotation_classes * 3
|
| 309 |
+
+ self.num_grip_classes
|
| 310 |
+
+ self.num_collision_classes,
|
| 311 |
+
None,
|
| 312 |
+
None,
|
| 313 |
+
)
|
| 314 |
+
|
| 315 |
+
def encode_text(self, x):
|
| 316 |
+
with torch.no_grad():
|
| 317 |
+
text_feat, text_emb = self._clip_rn50.encode_text_with_embeddings(x)
|
| 318 |
+
|
| 319 |
+
text_feat = text_feat.detach()
|
| 320 |
+
text_emb = text_emb.detach()
|
| 321 |
+
text_mask = torch.where(x == 0, x, 1) # [1, max_token_len]
|
| 322 |
+
return text_feat, text_emb
|
| 323 |
+
|
| 324 |
+
def forward(
|
| 325 |
+
self,
|
| 326 |
+
ins,
|
| 327 |
+
proprio,
|
| 328 |
+
lang_goal_emb,
|
| 329 |
+
lang_token_embs,
|
| 330 |
+
prev_layer_voxel_grid,
|
| 331 |
+
bounds,
|
| 332 |
+
prev_layer_bounds,
|
| 333 |
+
mask=None,
|
| 334 |
+
):
|
| 335 |
+
# preprocess input
|
| 336 |
+
d0 = self.input_preprocess(ins) # [B,10,100,100,100] -> [B,64,100,100,100]
|
| 337 |
+
|
| 338 |
+
# aggregated features from 1st softmax and maxpool for MLP decoders
|
| 339 |
+
feats = [self.ss0(d0.contiguous()), self.global_maxp(d0).view(ins.shape[0], -1)]
|
| 340 |
+
|
| 341 |
+
# patchify input (5x5x5 patches)
|
| 342 |
+
ins = self.patchify(d0) # [B,64,100,100,100] -> [B,64,20,20,20]
|
| 343 |
+
|
| 344 |
+
b, c, d, h, w, device = *ins.shape, ins.device
|
| 345 |
+
axis = [d, h, w]
|
| 346 |
+
assert (
|
| 347 |
+
len(axis) == self.input_axis
|
| 348 |
+
), "input must have the same number of axis as input_axis"
|
| 349 |
+
|
| 350 |
+
# concat proprio
|
| 351 |
+
if self.low_dim_size > 0:
|
| 352 |
+
p = self.proprio_preprocess(proprio) # [B,4] -> [B,64]
|
| 353 |
+
p = p.unsqueeze(-1).unsqueeze(-1).unsqueeze(-1).repeat(1, 1, d, h, w)
|
| 354 |
+
ins = torch.cat([ins, p], dim=1) # [B,128,20,20,20]
|
| 355 |
+
|
| 356 |
+
# language ablation
|
| 357 |
+
if self.no_language:
|
| 358 |
+
lang_goal_emb = torch.zeros_like(lang_goal_emb)
|
| 359 |
+
lang_token_embs = torch.zeros_like(lang_token_embs)
|
| 360 |
+
|
| 361 |
+
# option 1: tile and concat lang goal to input
|
| 362 |
+
if self.lang_fusion_type == "concat":
|
| 363 |
+
lang_emb = lang_goal_emb
|
| 364 |
+
lang_emb = lang_emb.to(dtype=ins.dtype)
|
| 365 |
+
l = self.lang_preprocess(lang_emb)
|
| 366 |
+
l = l.unsqueeze(-1).unsqueeze(-1).unsqueeze(-1).repeat(1, 1, d, h, w)
|
| 367 |
+
ins = torch.cat([ins, l], dim=1)
|
| 368 |
+
|
| 369 |
+
# channel last
|
| 370 |
+
ins = rearrange(ins, "b d ... -> b ... d") # [B,20,20,20,128]
|
| 371 |
+
|
| 372 |
+
# add pos encoding to grid
|
| 373 |
+
if not self.pos_encoding_with_lang:
|
| 374 |
+
ins = ins + self.pos_encoding
|
| 375 |
+
|
| 376 |
+
######################## NOTE #############################
|
| 377 |
+
# NOTE: If you add positional encodings ^here the lang embs
|
| 378 |
+
# won't have positional encodings. I accidently forgot
|
| 379 |
+
# to turn this off for all the experiments in the paper.
|
| 380 |
+
# So I guess those models were using language embs
|
| 381 |
+
# as a bag of words :( But it doesn't matter much for
|
| 382 |
+
# RLBench tasks since we don't test for novel instructions
|
| 383 |
+
# at test time anyway. The recommend way is to add
|
| 384 |
+
# positional encodings to the final input sequence
|
| 385 |
+
# fed into the Perceiver Transformer, as done below
|
| 386 |
+
# (and also in the Colab tutorial).
|
| 387 |
+
###########################################################
|
| 388 |
+
|
| 389 |
+
# concat to channels of and flatten axis
|
| 390 |
+
queries_orig_shape = ins.shape
|
| 391 |
+
|
| 392 |
+
# rearrange input to be channel last
|
| 393 |
+
ins = rearrange(ins, "b ... d -> b (...) d") # [B,8000,128]
|
| 394 |
+
ins_wo_prev_layers = ins
|
| 395 |
+
|
| 396 |
+
# option 2: add lang token embs as a sequence
|
| 397 |
+
if self.lang_fusion_type == "seq":
|
| 398 |
+
l = self.lang_preprocess(lang_token_embs) # [B,77,1024] -> [B,77,128]
|
| 399 |
+
ins = torch.cat((l, ins), dim=1) # [B,8077,128]
|
| 400 |
+
|
| 401 |
+
# add pos encoding to language + flattened grid (the recommended way)
|
| 402 |
+
if self.pos_encoding_with_lang:
|
| 403 |
+
ins = ins + self.pos_encoding
|
| 404 |
+
|
| 405 |
+
# batchify latents
|
| 406 |
+
x = repeat(self.latents, "n d -> b n d", b=b)
|
| 407 |
+
|
| 408 |
+
cross_attn, cross_ff_right, cross_ff_left = self.cross_attend_blocks
|
| 409 |
+
|
| 410 |
+
for it in range(self.iterations):
|
| 411 |
+
# encoder cross attention
|
| 412 |
+
x = cross_attn(x, context=ins, mask=mask) + x
|
| 413 |
+
|
| 414 |
+
# x.size() = [1, num_latents, latent_dim]
|
| 415 |
+
x_right, x_left = x.chunk(2, dim=1)
|
| 416 |
+
|
| 417 |
+
x_right = cross_ff_right(x_right) + x_right
|
| 418 |
+
x_left = cross_ff_left(x_left) + x_left
|
| 419 |
+
|
| 420 |
+
# self-attention layers
|
| 421 |
+
for (
|
| 422 |
+
self_attn_right,
|
| 423 |
+
self_ff_right,
|
| 424 |
+
self_attn_left,
|
| 425 |
+
self_ff_left,
|
| 426 |
+
) in self.layers:
|
| 427 |
+
x_right = self_attn_right(x_right) + x_right
|
| 428 |
+
x_right = self_ff_right(x_right) + x_right
|
| 429 |
+
|
| 430 |
+
x_left = self_attn_left(x_left) + x_left
|
| 431 |
+
x_left = self_ff_left(x_left) + x_left
|
| 432 |
+
|
| 433 |
+
x = torch.concat([x_right, x_left], dim=1)
|
| 434 |
+
x = self.combined_latent_attn(x) + x
|
| 435 |
+
x = self.combined_latent_ff(x) + x
|
| 436 |
+
|
| 437 |
+
x_right, x_left = x.chunk(2, dim=1)
|
| 438 |
+
|
| 439 |
+
# decoder cross attention
|
| 440 |
+
latents_right = self.decoder_cross_attn_right(ins, context=x_right)
|
| 441 |
+
latents_left = self.decoder_cross_attn_left(ins, context=x_left)
|
| 442 |
+
|
| 443 |
+
# crop out the language part of the output sequence
|
| 444 |
+
if self.lang_fusion_type == "seq":
|
| 445 |
+
latents_right = latents_right[:, l.shape[1] :]
|
| 446 |
+
latents_left = latents_left[:, l.shape[1] :]
|
| 447 |
+
|
| 448 |
+
# reshape back to voxel grid
|
| 449 |
+
latents_right = latents_right.view(
|
| 450 |
+
b, *queries_orig_shape[1:-1], latents_right.shape[-1]
|
| 451 |
+
) # [B,20,20,20,64]
|
| 452 |
+
latents_right = rearrange(
|
| 453 |
+
latents_right, "b ... d -> b d ..."
|
| 454 |
+
) # [B,64,20,20,20]
|
| 455 |
+
|
| 456 |
+
# reshape back to voxel grid
|
| 457 |
+
latents_left = latents_left.view(
|
| 458 |
+
b, *queries_orig_shape[1:-1], latents_left.shape[-1]
|
| 459 |
+
) # [B,20,20,20,64]
|
| 460 |
+
latents_left = rearrange(latents_left, "b ... d -> b d ...") # [B,64,20,20,20]
|
| 461 |
+
|
| 462 |
+
# aggregated features from 2nd softmax and maxpool for MLP decoders
|
| 463 |
+
|
| 464 |
+
feats_right = feats.copy()
|
| 465 |
+
feats_left = feats
|
| 466 |
+
|
| 467 |
+
feats_right.extend(
|
| 468 |
+
[
|
| 469 |
+
self.ss1(latents_right.contiguous()),
|
| 470 |
+
self.global_maxp(latents_right).view(b, -1),
|
| 471 |
+
]
|
| 472 |
+
)
|
| 473 |
+
feats_left.extend(
|
| 474 |
+
[
|
| 475 |
+
self.ss1(latents_left.contiguous()),
|
| 476 |
+
self.global_maxp(latents_left).view(b, -1),
|
| 477 |
+
]
|
| 478 |
+
)
|
| 479 |
+
|
| 480 |
+
# upsample
|
| 481 |
+
u0_right = self.up0(latents_right)
|
| 482 |
+
u0_left = self.up0(latents_left)
|
| 483 |
+
|
| 484 |
+
# ablations
|
| 485 |
+
if self.no_skip_connection:
|
| 486 |
+
u_right = self.final(u0_right)
|
| 487 |
+
u_left = self.final(u0_left)
|
| 488 |
+
elif self.no_perceiver:
|
| 489 |
+
u_right = self.final(d0)
|
| 490 |
+
u_left = self.final(d0)
|
| 491 |
+
else:
|
| 492 |
+
u_right = self.final(torch.cat([d0, u0_right], dim=1))
|
| 493 |
+
u_left = self.final(torch.cat([d0, u0_left], dim=1))
|
| 494 |
+
|
| 495 |
+
# translation decoder
|
| 496 |
+
right_trans = self.right_trans_decoder(u_right)
|
| 497 |
+
left_trans = self.left_trans_decoder(u_left)
|
| 498 |
+
|
| 499 |
+
# rotation, gripper, and collision MLPs
|
| 500 |
+
rot_and_grip_out = None
|
| 501 |
+
if self.num_rotation_classes > 0:
|
| 502 |
+
feats_right.extend(
|
| 503 |
+
[
|
| 504 |
+
self.ss_final(u_right.contiguous()),
|
| 505 |
+
self.global_maxp(u_right).view(b, -1),
|
| 506 |
+
]
|
| 507 |
+
)
|
| 508 |
+
|
| 509 |
+
right_dense0 = self.right_dense0(torch.cat(feats_right, dim=1))
|
| 510 |
+
right_dense1 = self.right_dense1(right_dense0) # [B,72*3+2+2]
|
| 511 |
+
|
| 512 |
+
right_rot_and_grip_collision_out = self.right_rot_grip_collision_ff(
|
| 513 |
+
right_dense1
|
| 514 |
+
)
|
| 515 |
+
right_rot_and_grip_out = right_rot_and_grip_collision_out[
|
| 516 |
+
:, : -self.num_collision_classes
|
| 517 |
+
]
|
| 518 |
+
right_collision_out = right_rot_and_grip_collision_out[
|
| 519 |
+
:, -self.num_collision_classes :
|
| 520 |
+
]
|
| 521 |
+
|
| 522 |
+
feats_left.extend(
|
| 523 |
+
[
|
| 524 |
+
self.ss_final(u_left.contiguous()),
|
| 525 |
+
self.global_maxp(u_left).view(b, -1),
|
| 526 |
+
]
|
| 527 |
+
)
|
| 528 |
+
|
| 529 |
+
left_dense0 = self.left_dense0(torch.cat(feats_left, dim=1))
|
| 530 |
+
left_dense1 = self.left_dense1(left_dense0) # [B,72*3+2+2]
|
| 531 |
+
|
| 532 |
+
left_rot_and_grip_collision_out = self.left_rot_grip_collision_ff(
|
| 533 |
+
left_dense1
|
| 534 |
+
)
|
| 535 |
+
left_rot_and_grip_out = left_rot_and_grip_collision_out[
|
| 536 |
+
:, : -self.num_collision_classes
|
| 537 |
+
]
|
| 538 |
+
left_collision_out = left_rot_and_grip_collision_out[
|
| 539 |
+
:, -self.num_collision_classes :
|
| 540 |
+
]
|
| 541 |
+
|
| 542 |
+
return (
|
| 543 |
+
right_trans,
|
| 544 |
+
right_rot_and_grip_out,
|
| 545 |
+
right_collision_out,
|
| 546 |
+
left_trans,
|
| 547 |
+
left_rot_and_grip_out,
|
| 548 |
+
left_collision_out,
|
| 549 |
+
)
|
external/peract_bimanual/agents/bimanual_peract/qattention_peract_bc_agent.py
ADDED
|
@@ -0,0 +1,1063 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import copy
|
| 2 |
+
import logging
|
| 3 |
+
import os
|
| 4 |
+
from typing import List
|
| 5 |
+
|
| 6 |
+
import numpy as np
|
| 7 |
+
import torch
|
| 8 |
+
import torch.nn as nn
|
| 9 |
+
import torch.nn.functional as F
|
| 10 |
+
from torchvision import transforms
|
| 11 |
+
from pytorch3d import transforms as torch3d_tf
|
| 12 |
+
from yarr.agents.agent import (
|
| 13 |
+
Agent,
|
| 14 |
+
ActResult,
|
| 15 |
+
ScalarSummary,
|
| 16 |
+
HistogramSummary,
|
| 17 |
+
ImageSummary,
|
| 18 |
+
Summary,
|
| 19 |
+
)
|
| 20 |
+
|
| 21 |
+
from helpers import utils
|
| 22 |
+
from helpers.utils import visualise_voxel, stack_on_channel
|
| 23 |
+
from voxel.voxel_grid import VoxelGrid
|
| 24 |
+
from voxel.augmentation import apply_se3_augmentation
|
| 25 |
+
from einops import rearrange
|
| 26 |
+
from helpers.clip.core.clip import build_model, load_clip
|
| 27 |
+
|
| 28 |
+
import transformers
|
| 29 |
+
from helpers.optim.lamb import Lamb
|
| 30 |
+
|
| 31 |
+
from torch.nn.parallel import DistributedDataParallel as DDP
|
| 32 |
+
|
| 33 |
+
NAME = "QAttentionAgent"
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
class QFunction(nn.Module):
|
| 37 |
+
def __init__(
|
| 38 |
+
self,
|
| 39 |
+
perceiver_encoder: nn.Module,
|
| 40 |
+
voxelizer: VoxelGrid,
|
| 41 |
+
bounds_offset: float,
|
| 42 |
+
rotation_resolution: float,
|
| 43 |
+
device,
|
| 44 |
+
training,
|
| 45 |
+
):
|
| 46 |
+
super(QFunction, self).__init__()
|
| 47 |
+
self._rotation_resolution = rotation_resolution
|
| 48 |
+
self._voxelizer = voxelizer
|
| 49 |
+
self._bounds_offset = bounds_offset
|
| 50 |
+
self._qnet = perceiver_encoder.to(device)
|
| 51 |
+
|
| 52 |
+
# distributed training
|
| 53 |
+
if training:
|
| 54 |
+
self._qnet = DDP(self._qnet, device_ids=[device])
|
| 55 |
+
|
| 56 |
+
def _argmax_3d(self, tensor_orig):
|
| 57 |
+
b, c, d, h, w = tensor_orig.shape # c will be one
|
| 58 |
+
idxs = tensor_orig.view(b, c, -1).argmax(-1)
|
| 59 |
+
indices = torch.cat([((idxs // h) // d), (idxs // h) % w, idxs % w], 1)
|
| 60 |
+
return indices
|
| 61 |
+
|
| 62 |
+
def choose_highest_action(self, q_trans, q_rot_grip, q_collision):
|
| 63 |
+
coords = self._argmax_3d(q_trans)
|
| 64 |
+
rot_and_grip_indicies = None
|
| 65 |
+
ignore_collision = None
|
| 66 |
+
if q_rot_grip is not None:
|
| 67 |
+
q_rot = torch.stack(
|
| 68 |
+
torch.split(
|
| 69 |
+
q_rot_grip[:, :-2], int(360 // self._rotation_resolution), dim=1
|
| 70 |
+
),
|
| 71 |
+
dim=1,
|
| 72 |
+
)
|
| 73 |
+
rot_and_grip_indicies = torch.cat(
|
| 74 |
+
[
|
| 75 |
+
q_rot[:, 0:1].argmax(-1),
|
| 76 |
+
q_rot[:, 1:2].argmax(-1),
|
| 77 |
+
q_rot[:, 2:3].argmax(-1),
|
| 78 |
+
q_rot_grip[:, -2:].argmax(-1, keepdim=True),
|
| 79 |
+
],
|
| 80 |
+
-1,
|
| 81 |
+
)
|
| 82 |
+
ignore_collision = q_collision[:, -2:].argmax(-1, keepdim=True)
|
| 83 |
+
return coords, rot_and_grip_indicies, ignore_collision
|
| 84 |
+
|
| 85 |
+
def forward(
|
| 86 |
+
self,
|
| 87 |
+
rgb_pcd,
|
| 88 |
+
proprio,
|
| 89 |
+
pcd,
|
| 90 |
+
lang_goal_emb,
|
| 91 |
+
lang_token_embs,
|
| 92 |
+
bounds=None,
|
| 93 |
+
prev_bounds=None,
|
| 94 |
+
prev_layer_voxel_grid=None,
|
| 95 |
+
):
|
| 96 |
+
# rgb_pcd will be list of list (list of [rgb, pcd])
|
| 97 |
+
b = rgb_pcd[0][0].shape[0]
|
| 98 |
+
pcd_flat = torch.cat([p.permute(0, 2, 3, 1).reshape(b, -1, 3) for p in pcd], 1)
|
| 99 |
+
|
| 100 |
+
# flatten RGBs and Pointclouds
|
| 101 |
+
rgb = [rp[0] for rp in rgb_pcd]
|
| 102 |
+
feat_size = rgb[0].shape[1]
|
| 103 |
+
flat_imag_features = torch.cat(
|
| 104 |
+
[p.permute(0, 2, 3, 1).reshape(b, -1, feat_size) for p in rgb], 1
|
| 105 |
+
)
|
| 106 |
+
|
| 107 |
+
# construct voxel grid
|
| 108 |
+
voxel_grid = self._voxelizer.coords_to_bounding_voxel_grid(
|
| 109 |
+
pcd_flat, coord_features=flat_imag_features, coord_bounds=bounds
|
| 110 |
+
)
|
| 111 |
+
|
| 112 |
+
# swap to channels fist
|
| 113 |
+
voxel_grid = voxel_grid.permute(0, 4, 1, 2, 3).detach()
|
| 114 |
+
|
| 115 |
+
# batch bounds if necessary
|
| 116 |
+
if bounds.shape[0] != b:
|
| 117 |
+
bounds = bounds.repeat(b, 1)
|
| 118 |
+
|
| 119 |
+
# forward pass
|
| 120 |
+
split_pred = self._qnet(
|
| 121 |
+
voxel_grid,
|
| 122 |
+
proprio,
|
| 123 |
+
lang_goal_emb,
|
| 124 |
+
lang_token_embs,
|
| 125 |
+
prev_layer_voxel_grid,
|
| 126 |
+
bounds,
|
| 127 |
+
prev_bounds,
|
| 128 |
+
)
|
| 129 |
+
|
| 130 |
+
return split_pred, voxel_grid
|
| 131 |
+
|
| 132 |
+
|
| 133 |
+
class QAttentionPerActBCAgent(Agent):
|
| 134 |
+
def __init__(
|
| 135 |
+
self,
|
| 136 |
+
layer: int,
|
| 137 |
+
coordinate_bounds: list,
|
| 138 |
+
perceiver_encoder: nn.Module,
|
| 139 |
+
camera_names: list,
|
| 140 |
+
batch_size: int,
|
| 141 |
+
voxel_size: int,
|
| 142 |
+
bounds_offset: float,
|
| 143 |
+
voxel_feature_size: int,
|
| 144 |
+
image_crop_size: int,
|
| 145 |
+
num_rotation_classes: int,
|
| 146 |
+
rotation_resolution: float,
|
| 147 |
+
lr: float = 0.0001,
|
| 148 |
+
lr_scheduler: bool = False,
|
| 149 |
+
training_iterations: int = 100000,
|
| 150 |
+
num_warmup_steps: int = 20000,
|
| 151 |
+
trans_loss_weight: float = 1.0,
|
| 152 |
+
rot_loss_weight: float = 1.0,
|
| 153 |
+
grip_loss_weight: float = 1.0,
|
| 154 |
+
collision_loss_weight: float = 1.0,
|
| 155 |
+
include_low_dim_state: bool = False,
|
| 156 |
+
image_resolution: list = None,
|
| 157 |
+
lambda_weight_l2: float = 0.0,
|
| 158 |
+
transform_augmentation: bool = True,
|
| 159 |
+
transform_augmentation_xyz: list = [0.0, 0.0, 0.0],
|
| 160 |
+
transform_augmentation_rpy: list = [0.0, 0.0, 180.0],
|
| 161 |
+
transform_augmentation_rot_resolution: int = 5,
|
| 162 |
+
optimizer_type: str = "adam",
|
| 163 |
+
num_devices: int = 1,
|
| 164 |
+
):
|
| 165 |
+
self._layer = layer
|
| 166 |
+
self._coordinate_bounds = coordinate_bounds
|
| 167 |
+
self._perceiver_encoder = perceiver_encoder
|
| 168 |
+
self._voxel_feature_size = voxel_feature_size
|
| 169 |
+
self._bounds_offset = bounds_offset
|
| 170 |
+
self._image_crop_size = image_crop_size
|
| 171 |
+
self._lr = lr
|
| 172 |
+
self._lr_scheduler = lr_scheduler
|
| 173 |
+
self._training_iterations = training_iterations
|
| 174 |
+
self._num_warmup_steps = num_warmup_steps
|
| 175 |
+
self._trans_loss_weight = trans_loss_weight
|
| 176 |
+
self._rot_loss_weight = rot_loss_weight
|
| 177 |
+
self._grip_loss_weight = grip_loss_weight
|
| 178 |
+
self._collision_loss_weight = collision_loss_weight
|
| 179 |
+
self._include_low_dim_state = include_low_dim_state
|
| 180 |
+
self._image_resolution = image_resolution or [128, 128]
|
| 181 |
+
self._voxel_size = voxel_size
|
| 182 |
+
self._camera_names = camera_names
|
| 183 |
+
self._num_cameras = len(camera_names)
|
| 184 |
+
self._batch_size = batch_size
|
| 185 |
+
self._lambda_weight_l2 = lambda_weight_l2
|
| 186 |
+
self._transform_augmentation = transform_augmentation
|
| 187 |
+
self._transform_augmentation_xyz = torch.from_numpy(
|
| 188 |
+
np.array(transform_augmentation_xyz)
|
| 189 |
+
)
|
| 190 |
+
self._transform_augmentation_rpy = transform_augmentation_rpy
|
| 191 |
+
self._transform_augmentation_rot_resolution = (
|
| 192 |
+
transform_augmentation_rot_resolution
|
| 193 |
+
)
|
| 194 |
+
self._optimizer_type = optimizer_type
|
| 195 |
+
self._num_devices = num_devices
|
| 196 |
+
self._num_rotation_classes = num_rotation_classes
|
| 197 |
+
self._rotation_resolution = rotation_resolution
|
| 198 |
+
|
| 199 |
+
self._cross_entropy_loss = nn.CrossEntropyLoss(reduction="none")
|
| 200 |
+
self._name = NAME + "_layer" + str(self._layer)
|
| 201 |
+
|
| 202 |
+
def build(self, training: bool, device: torch.device = None):
|
| 203 |
+
self._training = training
|
| 204 |
+
|
| 205 |
+
if device is None:
|
| 206 |
+
device = torch.device("cpu")
|
| 207 |
+
|
| 208 |
+
self._device = device
|
| 209 |
+
|
| 210 |
+
self._voxelizer = VoxelGrid(
|
| 211 |
+
coord_bounds=self._coordinate_bounds,
|
| 212 |
+
voxel_size=self._voxel_size,
|
| 213 |
+
device=device,
|
| 214 |
+
batch_size=self._batch_size if training else 1,
|
| 215 |
+
feature_size=self._voxel_feature_size,
|
| 216 |
+
max_num_coords=np.prod(self._image_resolution) * self._num_cameras,
|
| 217 |
+
)
|
| 218 |
+
|
| 219 |
+
self._q = (
|
| 220 |
+
QFunction(
|
| 221 |
+
self._perceiver_encoder,
|
| 222 |
+
self._voxelizer,
|
| 223 |
+
self._bounds_offset,
|
| 224 |
+
self._rotation_resolution,
|
| 225 |
+
device,
|
| 226 |
+
training,
|
| 227 |
+
)
|
| 228 |
+
.to(device)
|
| 229 |
+
.train(training)
|
| 230 |
+
)
|
| 231 |
+
|
| 232 |
+
grid_for_crop = (
|
| 233 |
+
torch.arange(0, self._image_crop_size, device=device)
|
| 234 |
+
.unsqueeze(0)
|
| 235 |
+
.repeat(self._image_crop_size, 1)
|
| 236 |
+
.unsqueeze(-1)
|
| 237 |
+
)
|
| 238 |
+
self._grid_for_crop = torch.cat(
|
| 239 |
+
[grid_for_crop.transpose(1, 0), grid_for_crop], dim=2
|
| 240 |
+
).unsqueeze(0)
|
| 241 |
+
|
| 242 |
+
self._coordinate_bounds = torch.tensor(
|
| 243 |
+
self._coordinate_bounds, device=device
|
| 244 |
+
).unsqueeze(0)
|
| 245 |
+
|
| 246 |
+
if self._training:
|
| 247 |
+
# optimizer
|
| 248 |
+
if self._optimizer_type == "lamb":
|
| 249 |
+
self._optimizer = Lamb(
|
| 250 |
+
self._q.parameters(),
|
| 251 |
+
lr=self._lr,
|
| 252 |
+
weight_decay=self._lambda_weight_l2,
|
| 253 |
+
betas=(0.9, 0.999),
|
| 254 |
+
adam=False,
|
| 255 |
+
)
|
| 256 |
+
elif self._optimizer_type == "adam":
|
| 257 |
+
self._optimizer = torch.optim.Adam(
|
| 258 |
+
self._q.parameters(),
|
| 259 |
+
lr=self._lr,
|
| 260 |
+
weight_decay=self._lambda_weight_l2,
|
| 261 |
+
)
|
| 262 |
+
else:
|
| 263 |
+
raise Exception("Unknown optimizer type")
|
| 264 |
+
|
| 265 |
+
# learning rate scheduler
|
| 266 |
+
if self._lr_scheduler:
|
| 267 |
+
self._scheduler = (
|
| 268 |
+
transformers.get_cosine_with_hard_restarts_schedule_with_warmup(
|
| 269 |
+
self._optimizer,
|
| 270 |
+
num_warmup_steps=self._num_warmup_steps,
|
| 271 |
+
num_training_steps=self._training_iterations,
|
| 272 |
+
num_cycles=self._training_iterations // 10000,
|
| 273 |
+
)
|
| 274 |
+
)
|
| 275 |
+
|
| 276 |
+
# one-hot zero tensors
|
| 277 |
+
self._action_trans_one_hot_zeros = torch.zeros(
|
| 278 |
+
(
|
| 279 |
+
self._batch_size,
|
| 280 |
+
1,
|
| 281 |
+
self._voxel_size,
|
| 282 |
+
self._voxel_size,
|
| 283 |
+
self._voxel_size,
|
| 284 |
+
),
|
| 285 |
+
dtype=int,
|
| 286 |
+
device=device,
|
| 287 |
+
)
|
| 288 |
+
self._action_rot_x_one_hot_zeros = torch.zeros(
|
| 289 |
+
(self._batch_size, self._num_rotation_classes), dtype=int, device=device
|
| 290 |
+
)
|
| 291 |
+
self._action_rot_y_one_hot_zeros = torch.zeros(
|
| 292 |
+
(self._batch_size, self._num_rotation_classes), dtype=int, device=device
|
| 293 |
+
)
|
| 294 |
+
self._action_rot_z_one_hot_zeros = torch.zeros(
|
| 295 |
+
(self._batch_size, self._num_rotation_classes), dtype=int, device=device
|
| 296 |
+
)
|
| 297 |
+
self._action_grip_one_hot_zeros = torch.zeros(
|
| 298 |
+
(self._batch_size, 2), dtype=int, device=device
|
| 299 |
+
)
|
| 300 |
+
self._action_ignore_collisions_one_hot_zeros = torch.zeros(
|
| 301 |
+
(self._batch_size, 2), dtype=int, device=device
|
| 302 |
+
)
|
| 303 |
+
|
| 304 |
+
# print total params
|
| 305 |
+
logging.info(
|
| 306 |
+
"# Q Params: %d"
|
| 307 |
+
% sum(
|
| 308 |
+
p.numel()
|
| 309 |
+
for name, p in self._q.named_parameters()
|
| 310 |
+
if p.requires_grad and "clip" not in name
|
| 311 |
+
)
|
| 312 |
+
)
|
| 313 |
+
else:
|
| 314 |
+
for param in self._q.parameters():
|
| 315 |
+
param.requires_grad = False
|
| 316 |
+
|
| 317 |
+
# load CLIP for encoding language goals during evaluation
|
| 318 |
+
model, _ = load_clip("RN50", jit=False)
|
| 319 |
+
self._clip_rn50 = build_model(model.state_dict())
|
| 320 |
+
self._clip_rn50 = self._clip_rn50.float().to(device)
|
| 321 |
+
self._clip_rn50.eval()
|
| 322 |
+
del model
|
| 323 |
+
|
| 324 |
+
self._voxelizer.to(device)
|
| 325 |
+
self._q.to(device)
|
| 326 |
+
|
| 327 |
+
def _extract_crop(self, pixel_action, observation):
|
| 328 |
+
# Pixel action will now be (B, 2)
|
| 329 |
+
# observation = stack_on_channel(observation)
|
| 330 |
+
h = observation.shape[-1]
|
| 331 |
+
top_left_corner = torch.clamp(
|
| 332 |
+
pixel_action - self._image_crop_size // 2, 0, h - self._image_crop_size
|
| 333 |
+
)
|
| 334 |
+
grid = self._grid_for_crop + top_left_corner.unsqueeze(1)
|
| 335 |
+
grid = ((grid / float(h)) * 2.0) - 1.0 # between -1 and 1
|
| 336 |
+
# Used for cropping the images across a batch
|
| 337 |
+
# swap fro y x, to x, y
|
| 338 |
+
grid = torch.cat((grid[:, :, :, 1:2], grid[:, :, :, 0:1]), dim=-1)
|
| 339 |
+
crop = F.grid_sample(observation, grid, mode="nearest", align_corners=True)
|
| 340 |
+
return crop
|
| 341 |
+
|
| 342 |
+
def _preprocess_inputs(self, replay_sample):
|
| 343 |
+
obs = []
|
| 344 |
+
pcds = []
|
| 345 |
+
self._crop_summary = []
|
| 346 |
+
for n in self._camera_names:
|
| 347 |
+
rgb = replay_sample["%s_rgb" % n]
|
| 348 |
+
pcd = replay_sample["%s_point_cloud" % n]
|
| 349 |
+
|
| 350 |
+
obs.append([rgb, pcd])
|
| 351 |
+
pcds.append(pcd)
|
| 352 |
+
return obs, pcds
|
| 353 |
+
|
| 354 |
+
def _act_preprocess_inputs(self, observation):
|
| 355 |
+
obs, pcds = [], []
|
| 356 |
+
for n in self._camera_names:
|
| 357 |
+
rgb = observation["%s_rgb" % n]
|
| 358 |
+
pcd = observation["%s_point_cloud" % n]
|
| 359 |
+
|
| 360 |
+
obs.append([rgb, pcd])
|
| 361 |
+
pcds.append(pcd)
|
| 362 |
+
return obs, pcds
|
| 363 |
+
|
| 364 |
+
def _get_value_from_voxel_index(self, q, voxel_idx):
|
| 365 |
+
b, c, d, h, w = q.shape
|
| 366 |
+
q_trans_flat = q.view(b, c, d * h * w)
|
| 367 |
+
flat_indicies = (
|
| 368 |
+
voxel_idx[:, 0] * d * h + voxel_idx[:, 1] * h + voxel_idx[:, 2]
|
| 369 |
+
)[:, None].int()
|
| 370 |
+
highest_idxs = flat_indicies.unsqueeze(-1).repeat(1, c, 1)
|
| 371 |
+
chosen_voxel_values = q_trans_flat.gather(2, highest_idxs)[
|
| 372 |
+
..., 0
|
| 373 |
+
] # (B, trans + rot + grip)
|
| 374 |
+
return chosen_voxel_values
|
| 375 |
+
|
| 376 |
+
def _get_value_from_rot_and_grip(self, rot_grip_q, rot_and_grip_idx):
|
| 377 |
+
q_rot = torch.stack(
|
| 378 |
+
torch.split(
|
| 379 |
+
rot_grip_q[:, :-2], int(360 // self._rotation_resolution), dim=1
|
| 380 |
+
),
|
| 381 |
+
dim=1,
|
| 382 |
+
) # B, 3, 72
|
| 383 |
+
q_grip = rot_grip_q[:, -2:]
|
| 384 |
+
rot_and_grip_values = torch.cat(
|
| 385 |
+
[
|
| 386 |
+
q_rot[:, 0].gather(1, rot_and_grip_idx[:, 0:1]),
|
| 387 |
+
q_rot[:, 1].gather(1, rot_and_grip_idx[:, 1:2]),
|
| 388 |
+
q_rot[:, 2].gather(1, rot_and_grip_idx[:, 2:3]),
|
| 389 |
+
q_grip.gather(1, rot_and_grip_idx[:, 3:4]),
|
| 390 |
+
],
|
| 391 |
+
-1,
|
| 392 |
+
)
|
| 393 |
+
return rot_and_grip_values
|
| 394 |
+
|
| 395 |
+
def _celoss(self, pred, labels):
|
| 396 |
+
return self._cross_entropy_loss(pred, labels.argmax(-1))
|
| 397 |
+
|
| 398 |
+
def _softmax_q_trans(self, q):
|
| 399 |
+
q_shape = q.shape
|
| 400 |
+
return F.softmax(q.reshape(q_shape[0], -1), dim=1).reshape(q_shape)
|
| 401 |
+
|
| 402 |
+
def _softmax_q_rot_grip(self, q_rot_grip):
|
| 403 |
+
q_rot_x_flat = q_rot_grip[
|
| 404 |
+
:, 0 * self._num_rotation_classes : 1 * self._num_rotation_classes
|
| 405 |
+
]
|
| 406 |
+
q_rot_y_flat = q_rot_grip[
|
| 407 |
+
:, 1 * self._num_rotation_classes : 2 * self._num_rotation_classes
|
| 408 |
+
]
|
| 409 |
+
q_rot_z_flat = q_rot_grip[
|
| 410 |
+
:, 2 * self._num_rotation_classes : 3 * self._num_rotation_classes
|
| 411 |
+
]
|
| 412 |
+
q_grip_flat = q_rot_grip[:, 3 * self._num_rotation_classes :]
|
| 413 |
+
|
| 414 |
+
q_rot_x_flat_softmax = F.softmax(q_rot_x_flat, dim=1)
|
| 415 |
+
q_rot_y_flat_softmax = F.softmax(q_rot_y_flat, dim=1)
|
| 416 |
+
q_rot_z_flat_softmax = F.softmax(q_rot_z_flat, dim=1)
|
| 417 |
+
q_grip_flat_softmax = F.softmax(q_grip_flat, dim=1)
|
| 418 |
+
|
| 419 |
+
return torch.cat(
|
| 420 |
+
[
|
| 421 |
+
q_rot_x_flat_softmax,
|
| 422 |
+
q_rot_y_flat_softmax,
|
| 423 |
+
q_rot_z_flat_softmax,
|
| 424 |
+
q_grip_flat_softmax,
|
| 425 |
+
],
|
| 426 |
+
dim=1,
|
| 427 |
+
)
|
| 428 |
+
|
| 429 |
+
def _softmax_ignore_collision(self, q_collision):
|
| 430 |
+
q_collision_softmax = F.softmax(q_collision, dim=1)
|
| 431 |
+
return q_collision_softmax
|
| 432 |
+
|
| 433 |
+
def update(self, step: int, replay_sample: dict) -> dict:
|
| 434 |
+
right_action_trans = replay_sample["right_trans_action_indicies"][
|
| 435 |
+
:, self._layer * 3 : self._layer * 3 + 3
|
| 436 |
+
].int()
|
| 437 |
+
right_action_rot_grip = replay_sample["right_rot_grip_action_indicies"].int()
|
| 438 |
+
right_action_gripper_pose = replay_sample["right_gripper_pose"]
|
| 439 |
+
right_action_ignore_collisions = replay_sample["right_ignore_collisions"].int()
|
| 440 |
+
|
| 441 |
+
left_action_trans = replay_sample["left_trans_action_indicies"][
|
| 442 |
+
:, self._layer * 3 : self._layer * 3 + 3
|
| 443 |
+
].int()
|
| 444 |
+
left_action_rot_grip = replay_sample["left_rot_grip_action_indicies"].int()
|
| 445 |
+
left_action_gripper_pose = replay_sample["left_gripper_pose"]
|
| 446 |
+
left_action_ignore_collisions = replay_sample["left_ignore_collisions"].int()
|
| 447 |
+
|
| 448 |
+
lang_goal_emb = replay_sample["lang_goal_emb"].float()
|
| 449 |
+
lang_token_embs = replay_sample["lang_token_embs"].float()
|
| 450 |
+
prev_layer_voxel_grid = replay_sample.get("prev_layer_voxel_grid", None)
|
| 451 |
+
prev_layer_bounds = replay_sample.get("prev_layer_bounds", None)
|
| 452 |
+
device = self._device
|
| 453 |
+
|
| 454 |
+
bounds = self._coordinate_bounds.to(device)
|
| 455 |
+
if self._layer > 0:
|
| 456 |
+
right_cp = replay_sample[
|
| 457 |
+
"right_attention_coordinate_layer_%d" % (self._layer - 1)
|
| 458 |
+
]
|
| 459 |
+
|
| 460 |
+
left_cp = replay_sample[
|
| 461 |
+
"left_attention_coordinate_layer_%d" % (self._layer - 1)
|
| 462 |
+
]
|
| 463 |
+
|
| 464 |
+
right_bounds = torch.cat(
|
| 465 |
+
[right_cp - self._bounds_offset, right_cp + self._bounds_offset], dim=1
|
| 466 |
+
)
|
| 467 |
+
left_bounds = torch.cat(
|
| 468 |
+
[left_cp - self._bounds_offset, left_cp + self._bounds_offset], dim=1
|
| 469 |
+
)
|
| 470 |
+
|
| 471 |
+
else:
|
| 472 |
+
right_bounds = bounds
|
| 473 |
+
left_bounds = bounds
|
| 474 |
+
|
| 475 |
+
right_proprio = None
|
| 476 |
+
left_proprio = None
|
| 477 |
+
if self._include_low_dim_state:
|
| 478 |
+
right_proprio = replay_sample["right_low_dim_state"]
|
| 479 |
+
left_proprio = replay_sample["left_low_dim_state"]
|
| 480 |
+
|
| 481 |
+
# ..TODO::
|
| 482 |
+
# Can we add the coordinates of both robots?
|
| 483 |
+
#
|
| 484 |
+
|
| 485 |
+
obs, pcd = self._preprocess_inputs(replay_sample)
|
| 486 |
+
|
| 487 |
+
# batch size
|
| 488 |
+
bs = pcd[0].shape[0]
|
| 489 |
+
|
| 490 |
+
# We can move the point cloud w.r.t to the other robot's cooridinate system
|
| 491 |
+
# similar to apply_se3_augmentation
|
| 492 |
+
#
|
| 493 |
+
|
| 494 |
+
# SE(3) augmentation of point clouds and actions
|
| 495 |
+
if self._transform_augmentation:
|
| 496 |
+
from voxel import augmentation
|
| 497 |
+
|
| 498 |
+
(
|
| 499 |
+
right_action_trans,
|
| 500 |
+
right_action_rot_grip,
|
| 501 |
+
left_action_trans,
|
| 502 |
+
left_action_rot_grip,
|
| 503 |
+
pcd,
|
| 504 |
+
) = augmentation.bimanual_apply_se3_augmentation(
|
| 505 |
+
pcd,
|
| 506 |
+
right_action_gripper_pose,
|
| 507 |
+
right_action_trans,
|
| 508 |
+
right_action_rot_grip,
|
| 509 |
+
left_action_gripper_pose,
|
| 510 |
+
left_action_trans,
|
| 511 |
+
left_action_rot_grip,
|
| 512 |
+
bounds,
|
| 513 |
+
self._layer,
|
| 514 |
+
self._transform_augmentation_xyz,
|
| 515 |
+
self._transform_augmentation_rpy,
|
| 516 |
+
self._transform_augmentation_rot_resolution,
|
| 517 |
+
self._voxel_size,
|
| 518 |
+
self._rotation_resolution,
|
| 519 |
+
self._device,
|
| 520 |
+
)
|
| 521 |
+
else:
|
| 522 |
+
right_action_trans = right_action_trans.int()
|
| 523 |
+
left_action_trans = left_action_trans.int()
|
| 524 |
+
|
| 525 |
+
proprio = torch.cat((right_proprio, left_proprio), dim=1)
|
| 526 |
+
|
| 527 |
+
right_action = (
|
| 528 |
+
right_action_trans,
|
| 529 |
+
right_action_rot_grip,
|
| 530 |
+
right_action_ignore_collisions,
|
| 531 |
+
)
|
| 532 |
+
left_action = (
|
| 533 |
+
left_action_trans,
|
| 534 |
+
left_action_rot_grip,
|
| 535 |
+
left_action_ignore_collisions,
|
| 536 |
+
)
|
| 537 |
+
# forward pass
|
| 538 |
+
q, voxel_grid = self._q(
|
| 539 |
+
obs,
|
| 540 |
+
proprio,
|
| 541 |
+
pcd,
|
| 542 |
+
lang_goal_emb,
|
| 543 |
+
lang_token_embs,
|
| 544 |
+
bounds,
|
| 545 |
+
prev_layer_bounds,
|
| 546 |
+
prev_layer_voxel_grid,
|
| 547 |
+
)
|
| 548 |
+
|
| 549 |
+
(
|
| 550 |
+
right_q_trans,
|
| 551 |
+
right_q_rot_grip,
|
| 552 |
+
right_q_collision,
|
| 553 |
+
left_q_trans,
|
| 554 |
+
left_q_rot_grip,
|
| 555 |
+
left_q_collision,
|
| 556 |
+
) = q
|
| 557 |
+
|
| 558 |
+
# argmax to choose best action
|
| 559 |
+
(
|
| 560 |
+
right_coords,
|
| 561 |
+
right_rot_and_grip_indicies,
|
| 562 |
+
right_ignore_collision_indicies,
|
| 563 |
+
) = self._q.choose_highest_action(
|
| 564 |
+
right_q_trans, right_q_rot_grip, right_q_collision
|
| 565 |
+
)
|
| 566 |
+
|
| 567 |
+
(
|
| 568 |
+
left_coords,
|
| 569 |
+
left_rot_and_grip_indicies,
|
| 570 |
+
left_ignore_collision_indicies,
|
| 571 |
+
) = self._q.choose_highest_action(
|
| 572 |
+
left_q_trans, left_q_rot_grip, left_q_collision
|
| 573 |
+
)
|
| 574 |
+
|
| 575 |
+
(
|
| 576 |
+
right_q_trans_loss,
|
| 577 |
+
right_q_rot_loss,
|
| 578 |
+
right_q_grip_loss,
|
| 579 |
+
right_q_collision_loss,
|
| 580 |
+
) = (0.0, 0.0, 0.0, 0.0)
|
| 581 |
+
left_q_trans_loss, left_q_rot_loss, left_q_grip_loss, left_q_collision_loss = (
|
| 582 |
+
0.0,
|
| 583 |
+
0.0,
|
| 584 |
+
0.0,
|
| 585 |
+
0.0,
|
| 586 |
+
)
|
| 587 |
+
|
| 588 |
+
# translation one-hot
|
| 589 |
+
right_action_trans_one_hot = self._action_trans_one_hot_zeros.clone().detach()
|
| 590 |
+
left_action_trans_one_hot = self._action_trans_one_hot_zeros.clone().detach()
|
| 591 |
+
for b in range(bs):
|
| 592 |
+
right_gt_coord = right_action_trans[b, :].int()
|
| 593 |
+
right_action_trans_one_hot[
|
| 594 |
+
b, :, right_gt_coord[0], right_gt_coord[1], right_gt_coord[2]
|
| 595 |
+
] = 1
|
| 596 |
+
left_gt_coord = left_action_trans[b, :].int()
|
| 597 |
+
left_action_trans_one_hot[
|
| 598 |
+
b, :, left_gt_coord[0], left_gt_coord[1], left_gt_coord[2]
|
| 599 |
+
] = 1
|
| 600 |
+
|
| 601 |
+
# translation loss
|
| 602 |
+
right_q_trans_flat = right_q_trans.view(bs, -1)
|
| 603 |
+
right_action_trans_one_hot_flat = right_action_trans_one_hot.view(bs, -1)
|
| 604 |
+
right_q_trans_loss = self._celoss(
|
| 605 |
+
right_q_trans_flat, right_action_trans_one_hot_flat
|
| 606 |
+
)
|
| 607 |
+
left_q_trans_flat = left_q_trans.view(bs, -1)
|
| 608 |
+
left_action_trans_one_hot_flat = left_action_trans_one_hot.view(bs, -1)
|
| 609 |
+
left_q_trans_loss = self._celoss(
|
| 610 |
+
left_q_trans_flat, left_action_trans_one_hot_flat
|
| 611 |
+
)
|
| 612 |
+
|
| 613 |
+
q_trans_loss = right_q_trans_loss + left_q_trans_loss
|
| 614 |
+
|
| 615 |
+
with_rot_and_grip = (
|
| 616 |
+
len(right_rot_and_grip_indicies) > 0 and len(left_rot_and_grip_indicies) > 0
|
| 617 |
+
)
|
| 618 |
+
if with_rot_and_grip:
|
| 619 |
+
# rotation, gripper, and collision one-hots
|
| 620 |
+
right_action_rot_x_one_hot = self._action_rot_x_one_hot_zeros.clone()
|
| 621 |
+
right_action_rot_y_one_hot = self._action_rot_y_one_hot_zeros.clone()
|
| 622 |
+
right_action_rot_z_one_hot = self._action_rot_z_one_hot_zeros.clone()
|
| 623 |
+
right_action_grip_one_hot = self._action_grip_one_hot_zeros.clone()
|
| 624 |
+
right_action_ignore_collisions_one_hot = (
|
| 625 |
+
self._action_ignore_collisions_one_hot_zeros.clone()
|
| 626 |
+
)
|
| 627 |
+
|
| 628 |
+
left_action_rot_x_one_hot = self._action_rot_x_one_hot_zeros.clone()
|
| 629 |
+
left_action_rot_y_one_hot = self._action_rot_y_one_hot_zeros.clone()
|
| 630 |
+
left_action_rot_z_one_hot = self._action_rot_z_one_hot_zeros.clone()
|
| 631 |
+
left_action_grip_one_hot = self._action_grip_one_hot_zeros.clone()
|
| 632 |
+
left_action_ignore_collisions_one_hot = (
|
| 633 |
+
self._action_ignore_collisions_one_hot_zeros.clone()
|
| 634 |
+
)
|
| 635 |
+
|
| 636 |
+
for b in range(bs):
|
| 637 |
+
right_gt_rot_grip = right_action_rot_grip[b, :].int()
|
| 638 |
+
right_action_rot_x_one_hot[b, right_gt_rot_grip[0]] = 1
|
| 639 |
+
right_action_rot_y_one_hot[b, right_gt_rot_grip[1]] = 1
|
| 640 |
+
right_action_rot_z_one_hot[b, right_gt_rot_grip[2]] = 1
|
| 641 |
+
right_action_grip_one_hot[b, right_gt_rot_grip[3]] = 1
|
| 642 |
+
|
| 643 |
+
right_gt_ignore_collisions = right_action_ignore_collisions[b, :].int()
|
| 644 |
+
right_action_ignore_collisions_one_hot[
|
| 645 |
+
b, right_gt_ignore_collisions[0]
|
| 646 |
+
] = 1
|
| 647 |
+
|
| 648 |
+
left_gt_rot_grip = left_action_rot_grip[b, :].int()
|
| 649 |
+
left_action_rot_x_one_hot[b, left_gt_rot_grip[0]] = 1
|
| 650 |
+
left_action_rot_y_one_hot[b, left_gt_rot_grip[1]] = 1
|
| 651 |
+
left_action_rot_z_one_hot[b, left_gt_rot_grip[2]] = 1
|
| 652 |
+
left_action_grip_one_hot[b, left_gt_rot_grip[3]] = 1
|
| 653 |
+
|
| 654 |
+
left_gt_ignore_collisions = left_action_ignore_collisions[b, :].int()
|
| 655 |
+
left_action_ignore_collisions_one_hot[
|
| 656 |
+
b, left_gt_ignore_collisions[0]
|
| 657 |
+
] = 1
|
| 658 |
+
|
| 659 |
+
# flatten predictions
|
| 660 |
+
right_q_rot_x_flat = right_q_rot_grip[
|
| 661 |
+
:, 0 * self._num_rotation_classes : 1 * self._num_rotation_classes
|
| 662 |
+
]
|
| 663 |
+
right_q_rot_y_flat = right_q_rot_grip[
|
| 664 |
+
:, 1 * self._num_rotation_classes : 2 * self._num_rotation_classes
|
| 665 |
+
]
|
| 666 |
+
right_q_rot_z_flat = right_q_rot_grip[
|
| 667 |
+
:, 2 * self._num_rotation_classes : 3 * self._num_rotation_classes
|
| 668 |
+
]
|
| 669 |
+
right_q_grip_flat = right_q_rot_grip[:, 3 * self._num_rotation_classes :]
|
| 670 |
+
right_q_ignore_collisions_flat = right_q_collision
|
| 671 |
+
|
| 672 |
+
left_q_rot_x_flat = left_q_rot_grip[
|
| 673 |
+
:, 0 * self._num_rotation_classes : 1 * self._num_rotation_classes
|
| 674 |
+
]
|
| 675 |
+
left_q_rot_y_flat = left_q_rot_grip[
|
| 676 |
+
:, 1 * self._num_rotation_classes : 2 * self._num_rotation_classes
|
| 677 |
+
]
|
| 678 |
+
left_q_rot_z_flat = left_q_rot_grip[
|
| 679 |
+
:, 2 * self._num_rotation_classes : 3 * self._num_rotation_classes
|
| 680 |
+
]
|
| 681 |
+
left_q_grip_flat = left_q_rot_grip[:, 3 * self._num_rotation_classes :]
|
| 682 |
+
left_q_ignore_collisions_flat = left_q_collision
|
| 683 |
+
|
| 684 |
+
# rotation loss
|
| 685 |
+
right_q_rot_loss += self._celoss(
|
| 686 |
+
right_q_rot_x_flat, right_action_rot_x_one_hot
|
| 687 |
+
)
|
| 688 |
+
right_q_rot_loss += self._celoss(
|
| 689 |
+
right_q_rot_y_flat, right_action_rot_y_one_hot
|
| 690 |
+
)
|
| 691 |
+
right_q_rot_loss += self._celoss(
|
| 692 |
+
right_q_rot_z_flat, right_action_rot_z_one_hot
|
| 693 |
+
)
|
| 694 |
+
|
| 695 |
+
left_q_rot_loss += self._celoss(
|
| 696 |
+
left_q_rot_x_flat, left_action_rot_x_one_hot
|
| 697 |
+
)
|
| 698 |
+
left_q_rot_loss += self._celoss(
|
| 699 |
+
left_q_rot_y_flat, left_action_rot_y_one_hot
|
| 700 |
+
)
|
| 701 |
+
left_q_rot_loss += self._celoss(
|
| 702 |
+
left_q_rot_z_flat, left_action_rot_z_one_hot
|
| 703 |
+
)
|
| 704 |
+
|
| 705 |
+
# gripper loss
|
| 706 |
+
right_q_grip_loss += self._celoss(
|
| 707 |
+
right_q_grip_flat, right_action_grip_one_hot
|
| 708 |
+
)
|
| 709 |
+
left_q_grip_loss += self._celoss(left_q_grip_flat, left_action_grip_one_hot)
|
| 710 |
+
|
| 711 |
+
# collision loss
|
| 712 |
+
right_q_collision_loss += self._celoss(
|
| 713 |
+
right_q_ignore_collisions_flat, right_action_ignore_collisions_one_hot
|
| 714 |
+
)
|
| 715 |
+
left_q_collision_loss += self._celoss(
|
| 716 |
+
left_q_ignore_collisions_flat, left_action_ignore_collisions_one_hot
|
| 717 |
+
)
|
| 718 |
+
|
| 719 |
+
q_trans_loss = right_q_trans_loss + left_q_trans_loss
|
| 720 |
+
q_rot_loss = right_q_rot_loss + left_q_rot_loss
|
| 721 |
+
q_grip_loss = right_q_grip_loss + left_q_grip_loss
|
| 722 |
+
q_collision_loss = right_q_collision_loss + left_q_collision_loss
|
| 723 |
+
|
| 724 |
+
combined_losses = (
|
| 725 |
+
(q_trans_loss * self._trans_loss_weight)
|
| 726 |
+
+ (q_rot_loss * self._rot_loss_weight)
|
| 727 |
+
+ (q_grip_loss * self._grip_loss_weight)
|
| 728 |
+
+ (q_collision_loss * self._collision_loss_weight)
|
| 729 |
+
)
|
| 730 |
+
total_loss = combined_losses.mean()
|
| 731 |
+
|
| 732 |
+
self._optimizer.zero_grad()
|
| 733 |
+
total_loss.backward()
|
| 734 |
+
self._optimizer.step()
|
| 735 |
+
|
| 736 |
+
self._summaries = {
|
| 737 |
+
"losses/total_loss": total_loss,
|
| 738 |
+
"losses/trans_loss": q_trans_loss.mean(),
|
| 739 |
+
"losses/rot_loss": q_rot_loss.mean() if with_rot_and_grip else 0.0,
|
| 740 |
+
"losses/grip_loss": q_grip_loss.mean() if with_rot_and_grip else 0.0,
|
| 741 |
+
"losses/right/trans_loss": q_trans_loss.mean(),
|
| 742 |
+
"losses/right/rot_loss": q_rot_loss.mean() if with_rot_and_grip else 0.0,
|
| 743 |
+
"losses/right/grip_loss": q_grip_loss.mean() if with_rot_and_grip else 0.0,
|
| 744 |
+
"losses/right/collision_loss": q_collision_loss.mean()
|
| 745 |
+
if with_rot_and_grip
|
| 746 |
+
else 0.0,
|
| 747 |
+
"losses/left/trans_loss": q_trans_loss.mean(),
|
| 748 |
+
"losses/left/rot_loss": q_rot_loss.mean() if with_rot_and_grip else 0.0,
|
| 749 |
+
"losses/left/grip_loss": q_grip_loss.mean() if with_rot_and_grip else 0.0,
|
| 750 |
+
"losses/left/collision_loss": q_collision_loss.mean()
|
| 751 |
+
if with_rot_and_grip
|
| 752 |
+
else 0.0,
|
| 753 |
+
"losses/collision_loss": q_collision_loss.mean()
|
| 754 |
+
if with_rot_and_grip
|
| 755 |
+
else 0.0,
|
| 756 |
+
}
|
| 757 |
+
|
| 758 |
+
if self._lr_scheduler:
|
| 759 |
+
self._scheduler.step()
|
| 760 |
+
self._summaries["learning_rate"] = self._scheduler.get_last_lr()[0]
|
| 761 |
+
|
| 762 |
+
self._vis_voxel_grid = voxel_grid[0]
|
| 763 |
+
self._right_vis_translation_qvalue = self._softmax_q_trans(right_q_trans[0])
|
| 764 |
+
self._right_vis_max_coordinate = right_coords[0]
|
| 765 |
+
self._right_vis_gt_coordinate = right_action_trans[0]
|
| 766 |
+
|
| 767 |
+
self._left_vis_translation_qvalue = self._softmax_q_trans(left_q_trans[0])
|
| 768 |
+
self._left_vis_max_coordinate = left_coords[0]
|
| 769 |
+
self._left_vis_gt_coordinate = left_action_trans[0]
|
| 770 |
+
|
| 771 |
+
# Note: PerAct doesn't use multi-layer voxel grids like C2FARM
|
| 772 |
+
# stack prev_layer_voxel_grid(s) from previous layers into a list
|
| 773 |
+
if prev_layer_voxel_grid is None:
|
| 774 |
+
prev_layer_voxel_grid = [voxel_grid]
|
| 775 |
+
else:
|
| 776 |
+
prev_layer_voxel_grid = prev_layer_voxel_grid + [voxel_grid]
|
| 777 |
+
|
| 778 |
+
# stack prev_layer_bound(s) from previous layers into a list
|
| 779 |
+
if prev_layer_bounds is None:
|
| 780 |
+
prev_layer_bounds = [self._coordinate_bounds.repeat(bs, 1)]
|
| 781 |
+
else:
|
| 782 |
+
prev_layer_bounds = prev_layer_bounds + [bounds]
|
| 783 |
+
|
| 784 |
+
return {
|
| 785 |
+
"total_loss": total_loss,
|
| 786 |
+
"prev_layer_voxel_grid": prev_layer_voxel_grid,
|
| 787 |
+
"prev_layer_bounds": prev_layer_bounds,
|
| 788 |
+
}
|
| 789 |
+
|
| 790 |
+
def act(self, step: int, observation: dict, deterministic=False) -> ActResult:
|
| 791 |
+
deterministic = True
|
| 792 |
+
bounds = self._coordinate_bounds
|
| 793 |
+
prev_layer_voxel_grid = observation.get("prev_layer_voxel_grid", None)
|
| 794 |
+
prev_layer_bounds = observation.get("prev_layer_bounds", None)
|
| 795 |
+
lang_goal_tokens = observation.get("lang_goal_tokens", None).long()
|
| 796 |
+
|
| 797 |
+
# extract CLIP language embs
|
| 798 |
+
with torch.no_grad():
|
| 799 |
+
lang_goal_tokens = lang_goal_tokens.to(device=self._device)
|
| 800 |
+
(
|
| 801 |
+
lang_goal_emb,
|
| 802 |
+
lang_token_embs,
|
| 803 |
+
) = self._clip_rn50.encode_text_with_embeddings(lang_goal_tokens[0])
|
| 804 |
+
|
| 805 |
+
# voxelization resolution
|
| 806 |
+
res = (bounds[:, 3:] - bounds[:, :3]) / self._voxel_size
|
| 807 |
+
max_rot_index = int(360 // self._rotation_resolution)
|
| 808 |
+
right_proprio = None
|
| 809 |
+
left_proprio = None
|
| 810 |
+
|
| 811 |
+
if self._include_low_dim_state:
|
| 812 |
+
right_proprio = observation["right_low_dim_state"]
|
| 813 |
+
left_proprio = observation["left_low_dim_state"]
|
| 814 |
+
right_proprio = right_proprio[0].to(self._device)
|
| 815 |
+
left_proprio = left_proprio[0].to(self._device)
|
| 816 |
+
|
| 817 |
+
obs, pcd = self._act_preprocess_inputs(observation)
|
| 818 |
+
|
| 819 |
+
# correct batch size and device
|
| 820 |
+
obs = [[o[0][0].to(self._device), o[1][0].to(self._device)] for o in obs]
|
| 821 |
+
|
| 822 |
+
pcd = [p[0].to(self._device) for p in pcd]
|
| 823 |
+
lang_goal_emb = lang_goal_emb.to(self._device)
|
| 824 |
+
lang_token_embs = lang_token_embs.to(self._device)
|
| 825 |
+
bounds = torch.as_tensor(bounds, device=self._device)
|
| 826 |
+
prev_layer_voxel_grid = (
|
| 827 |
+
prev_layer_voxel_grid.to(self._device)
|
| 828 |
+
if prev_layer_voxel_grid is not None
|
| 829 |
+
else None
|
| 830 |
+
)
|
| 831 |
+
prev_layer_bounds = (
|
| 832 |
+
prev_layer_bounds.to(self._device)
|
| 833 |
+
if prev_layer_bounds is not None
|
| 834 |
+
else None
|
| 835 |
+
)
|
| 836 |
+
|
| 837 |
+
proprio = torch.cat((right_proprio, left_proprio), dim=1)
|
| 838 |
+
|
| 839 |
+
# inference
|
| 840 |
+
(
|
| 841 |
+
right_q_trans,
|
| 842 |
+
right_q_rot_grip,
|
| 843 |
+
right_q_ignore_collisions,
|
| 844 |
+
left_q_trans,
|
| 845 |
+
left_q_rot_grip,
|
| 846 |
+
left_q_ignore_collisions,
|
| 847 |
+
), vox_grid = self._q(
|
| 848 |
+
obs,
|
| 849 |
+
proprio,
|
| 850 |
+
pcd,
|
| 851 |
+
lang_goal_emb,
|
| 852 |
+
lang_token_embs,
|
| 853 |
+
bounds,
|
| 854 |
+
prev_layer_bounds,
|
| 855 |
+
prev_layer_voxel_grid,
|
| 856 |
+
)
|
| 857 |
+
|
| 858 |
+
# softmax Q predictions
|
| 859 |
+
right_q_trans = self._softmax_q_trans(right_q_trans)
|
| 860 |
+
left_q_trans = self._softmax_q_trans(left_q_trans)
|
| 861 |
+
|
| 862 |
+
if right_q_rot_grip is not None:
|
| 863 |
+
right_q_rot_grip = self._softmax_q_rot_grip(right_q_rot_grip)
|
| 864 |
+
|
| 865 |
+
if left_q_rot_grip is not None:
|
| 866 |
+
left_q_rot_grip = self._softmax_q_rot_grip(left_q_rot_grip)
|
| 867 |
+
|
| 868 |
+
if right_q_ignore_collisions is not None:
|
| 869 |
+
right_q_ignore_collisions = self._softmax_ignore_collision(
|
| 870 |
+
right_q_ignore_collisions
|
| 871 |
+
)
|
| 872 |
+
|
| 873 |
+
if left_q_ignore_collisions is not None:
|
| 874 |
+
left_q_ignore_collisions = self._softmax_ignore_collision(
|
| 875 |
+
left_q_ignore_collisions
|
| 876 |
+
)
|
| 877 |
+
|
| 878 |
+
# argmax Q predictions
|
| 879 |
+
(
|
| 880 |
+
right_coords,
|
| 881 |
+
right_rot_and_grip_indicies,
|
| 882 |
+
right_ignore_collisions,
|
| 883 |
+
) = self._q.choose_highest_action(
|
| 884 |
+
right_q_trans, right_q_rot_grip, right_q_ignore_collisions
|
| 885 |
+
)
|
| 886 |
+
(
|
| 887 |
+
left_coords,
|
| 888 |
+
left_rot_and_grip_indicies,
|
| 889 |
+
left_ignore_collisions,
|
| 890 |
+
) = self._q.choose_highest_action(
|
| 891 |
+
left_q_trans, left_q_rot_grip, left_q_ignore_collisions
|
| 892 |
+
)
|
| 893 |
+
|
| 894 |
+
if right_q_rot_grip is not None:
|
| 895 |
+
right_rot_grip_action = right_rot_and_grip_indicies
|
| 896 |
+
if right_q_ignore_collisions is not None:
|
| 897 |
+
right_ignore_collisions_action = right_ignore_collisions.int()
|
| 898 |
+
|
| 899 |
+
if left_q_rot_grip is not None:
|
| 900 |
+
left_rot_grip_action = left_rot_and_grip_indicies
|
| 901 |
+
if left_q_ignore_collisions is not None:
|
| 902 |
+
left_ignore_collisions_action = left_ignore_collisions.int()
|
| 903 |
+
|
| 904 |
+
right_coords = right_coords.int()
|
| 905 |
+
left_coords = left_coords.int()
|
| 906 |
+
|
| 907 |
+
right_attention_coordinate = bounds[:, :3] + res * right_coords + res / 2
|
| 908 |
+
left_attention_coordinate = bounds[:, :3] + res * left_coords + res / 2
|
| 909 |
+
|
| 910 |
+
# stack prev_layer_voxel_grid(s) into a list
|
| 911 |
+
# NOTE: PerAct doesn't used multi-layer voxel grids like C2FARM
|
| 912 |
+
if prev_layer_voxel_grid is None:
|
| 913 |
+
prev_layer_voxel_grid = [vox_grid]
|
| 914 |
+
else:
|
| 915 |
+
prev_layer_voxel_grid = prev_layer_voxel_grid + [vox_grid]
|
| 916 |
+
|
| 917 |
+
if prev_layer_bounds is None:
|
| 918 |
+
prev_layer_bounds = [bounds]
|
| 919 |
+
else:
|
| 920 |
+
prev_layer_bounds = prev_layer_bounds + [bounds]
|
| 921 |
+
|
| 922 |
+
observation_elements = {
|
| 923 |
+
"right_attention_coordinate": right_attention_coordinate,
|
| 924 |
+
"left_attention_coordinate": left_attention_coordinate,
|
| 925 |
+
"prev_layer_voxel_grid": prev_layer_voxel_grid,
|
| 926 |
+
"prev_layer_bounds": prev_layer_bounds,
|
| 927 |
+
}
|
| 928 |
+
info = {
|
| 929 |
+
"voxel_grid_depth%d" % self._layer: vox_grid,
|
| 930 |
+
"right_q_depth%d" % self._layer: right_q_trans,
|
| 931 |
+
"right_voxel_idx_depth%d" % self._layer: right_coords,
|
| 932 |
+
"left_q_depth%d" % self._layer: left_q_trans,
|
| 933 |
+
"left_voxel_idx_depth%d" % self._layer: left_coords,
|
| 934 |
+
}
|
| 935 |
+
self._act_voxel_grid = vox_grid[0]
|
| 936 |
+
self._right_act_max_coordinate = right_coords[0]
|
| 937 |
+
self._right_act_qvalues = right_q_trans[0].detach()
|
| 938 |
+
self._left_act_max_coordinate = left_coords[0]
|
| 939 |
+
self._left_act_qvalues = left_q_trans[0].detach()
|
| 940 |
+
|
| 941 |
+
action = (
|
| 942 |
+
right_coords,
|
| 943 |
+
right_rot_grip_action,
|
| 944 |
+
right_ignore_collisions,
|
| 945 |
+
left_coords,
|
| 946 |
+
left_rot_grip_action,
|
| 947 |
+
left_ignore_collisions,
|
| 948 |
+
)
|
| 949 |
+
|
| 950 |
+
return ActResult(action, observation_elements=observation_elements, info=info)
|
| 951 |
+
|
| 952 |
+
def update_summaries(self) -> List[Summary]:
|
| 953 |
+
voxel_grid = self._vis_voxel_grid.detach().cpu().numpy()
|
| 954 |
+
summaries = []
|
| 955 |
+
summaries.append(
|
| 956 |
+
ImageSummary(
|
| 957 |
+
"%s/right_update_qattention" % self._name,
|
| 958 |
+
transforms.ToTensor()(
|
| 959 |
+
visualise_voxel(
|
| 960 |
+
voxel_grid,
|
| 961 |
+
self._right_vis_translation_qvalue.detach().cpu().numpy(),
|
| 962 |
+
self._right_vis_max_coordinate.detach().cpu().numpy(),
|
| 963 |
+
self._right_vis_gt_coordinate.detach().cpu().numpy(),
|
| 964 |
+
)
|
| 965 |
+
),
|
| 966 |
+
)
|
| 967 |
+
)
|
| 968 |
+
summaries.append(
|
| 969 |
+
ImageSummary(
|
| 970 |
+
"%s/left_update_qattention" % self._name,
|
| 971 |
+
transforms.ToTensor()(
|
| 972 |
+
visualise_voxel(
|
| 973 |
+
voxel_grid,
|
| 974 |
+
self._left_vis_translation_qvalue.detach().cpu().numpy(),
|
| 975 |
+
self._left_vis_max_coordinate.detach().cpu().numpy(),
|
| 976 |
+
self._left_vis_gt_coordinate.detach().cpu().numpy(),
|
| 977 |
+
)
|
| 978 |
+
),
|
| 979 |
+
)
|
| 980 |
+
)
|
| 981 |
+
for n, v in self._summaries.items():
|
| 982 |
+
summaries.append(ScalarSummary("%s/%s" % (self._name, n), v))
|
| 983 |
+
|
| 984 |
+
for name, crop in self._crop_summary:
|
| 985 |
+
crops = (torch.cat(torch.split(crop, 3, dim=1), dim=3) + 1.0) / 2.0
|
| 986 |
+
summaries.extend([ImageSummary("%s/crops/%s" % (self._name, name), crops)])
|
| 987 |
+
|
| 988 |
+
for tag, param in self._q.named_parameters():
|
| 989 |
+
# assert not torch.isnan(param.grad.abs() <= 1.0).all()
|
| 990 |
+
summaries.append(
|
| 991 |
+
HistogramSummary("%s/gradient/%s" % (self._name, tag), param.grad)
|
| 992 |
+
)
|
| 993 |
+
summaries.append(
|
| 994 |
+
HistogramSummary("%s/weight/%s" % (self._name, tag), param.data)
|
| 995 |
+
)
|
| 996 |
+
|
| 997 |
+
return summaries
|
| 998 |
+
|
| 999 |
+
def act_summaries(self) -> List[Summary]:
|
| 1000 |
+
voxel_grid = self._act_voxel_grid.cpu().numpy()
|
| 1001 |
+
right_q_attention = self._right_act_qvalues.cpu().numpy()
|
| 1002 |
+
right_highlight_coordinate = self._right_act_max_coordinate.cpu().numpy()
|
| 1003 |
+
right_visualization = visualise_voxel(
|
| 1004 |
+
voxel_grid, right_q_attention, right_highlight_coordinate
|
| 1005 |
+
)
|
| 1006 |
+
|
| 1007 |
+
left_q_attention = self._left_act_qvalues.cpu().numpy()
|
| 1008 |
+
left_highlight_coordinate = self._left_act_max_coordinate.cpu().numpy()
|
| 1009 |
+
left_visualization = visualise_voxel(
|
| 1010 |
+
voxel_grid, left_q_attention, left_highlight_coordinate
|
| 1011 |
+
)
|
| 1012 |
+
|
| 1013 |
+
return [
|
| 1014 |
+
ImageSummary(
|
| 1015 |
+
f"{self._name}/right_act_Qattention",
|
| 1016 |
+
transforms.ToTensor()(right_visualization),
|
| 1017 |
+
),
|
| 1018 |
+
ImageSummary(
|
| 1019 |
+
f"{self._name}/left_act_Qattention",
|
| 1020 |
+
transforms.ToTensor()(left_visualization),
|
| 1021 |
+
),
|
| 1022 |
+
]
|
| 1023 |
+
|
| 1024 |
+
def load_weights(self, savedir: str):
|
| 1025 |
+
device = (
|
| 1026 |
+
self._device
|
| 1027 |
+
if not self._training
|
| 1028 |
+
else torch.device("cuda:%d" % self._device)
|
| 1029 |
+
)
|
| 1030 |
+
weight_file = os.path.join(savedir, "%s.pt" % self._name)
|
| 1031 |
+
state_dict = torch.load(weight_file, map_location=device)
|
| 1032 |
+
|
| 1033 |
+
# load only keys that are in the current model
|
| 1034 |
+
merged_state_dict = self._q.state_dict()
|
| 1035 |
+
for k, v in state_dict.items():
|
| 1036 |
+
if not self._training:
|
| 1037 |
+
k = k.replace("_qnet.module", "_qnet")
|
| 1038 |
+
if k in merged_state_dict:
|
| 1039 |
+
merged_state_dict[k] = v
|
| 1040 |
+
else:
|
| 1041 |
+
if "_voxelizer" not in k:
|
| 1042 |
+
logging.warning("key %s not found in checkpoint" % k)
|
| 1043 |
+
if not self._training:
|
| 1044 |
+
# reshape voxelizer weights
|
| 1045 |
+
b = merged_state_dict["_voxelizer._ones_max_coords"].shape[0]
|
| 1046 |
+
merged_state_dict["_voxelizer._ones_max_coords"] = merged_state_dict[
|
| 1047 |
+
"_voxelizer._ones_max_coords"
|
| 1048 |
+
][0:1]
|
| 1049 |
+
flat_shape = merged_state_dict["_voxelizer._flat_output"].shape[0]
|
| 1050 |
+
merged_state_dict["_voxelizer._flat_output"] = merged_state_dict[
|
| 1051 |
+
"_voxelizer._flat_output"
|
| 1052 |
+
][0 : flat_shape // b]
|
| 1053 |
+
merged_state_dict["_voxelizer._tiled_batch_indices"] = merged_state_dict[
|
| 1054 |
+
"_voxelizer._tiled_batch_indices"
|
| 1055 |
+
][0:1]
|
| 1056 |
+
merged_state_dict["_voxelizer._index_grid"] = merged_state_dict[
|
| 1057 |
+
"_voxelizer._index_grid"
|
| 1058 |
+
][0:1]
|
| 1059 |
+
self._q.load_state_dict(merged_state_dict)
|
| 1060 |
+
print("loaded weights from %s" % weight_file)
|
| 1061 |
+
|
| 1062 |
+
def save_weights(self, savedir: str):
|
| 1063 |
+
torch.save(self._q.state_dict(), os.path.join(savedir, "%s.pt" % self._name))
|
external/peract_bimanual/agents/bimanual_peract/qattention_stack_agent.py
ADDED
|
@@ -0,0 +1,202 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import List
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
from yarr.agents.agent import Agent, ActResult, Summary
|
| 5 |
+
|
| 6 |
+
import numpy as np
|
| 7 |
+
|
| 8 |
+
from helpers import utils
|
| 9 |
+
from agents.bimanual_peract.qattention_peract_bc_agent import QAttentionPerActBCAgent
|
| 10 |
+
|
| 11 |
+
NAME = "QAttentionStackAgent"
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
class QAttentionStackAgent(Agent):
|
| 15 |
+
def __init__(
|
| 16 |
+
self,
|
| 17 |
+
qattention_agents: List[QAttentionPerActBCAgent],
|
| 18 |
+
rotation_resolution: float,
|
| 19 |
+
camera_names: List[str],
|
| 20 |
+
rotation_prediction_depth: int = 0,
|
| 21 |
+
):
|
| 22 |
+
super(QAttentionStackAgent, self).__init__()
|
| 23 |
+
self._qattention_agents = qattention_agents
|
| 24 |
+
self._rotation_resolution = rotation_resolution
|
| 25 |
+
self._camera_names = camera_names
|
| 26 |
+
self._rotation_prediction_depth = rotation_prediction_depth
|
| 27 |
+
|
| 28 |
+
def build(self, training: bool, device=None) -> None:
|
| 29 |
+
self._device = device
|
| 30 |
+
if self._device is None:
|
| 31 |
+
self._device = torch.device("cpu")
|
| 32 |
+
for qa in self._qattention_agents:
|
| 33 |
+
qa.build(training, device)
|
| 34 |
+
|
| 35 |
+
def update(self, step: int, replay_sample: dict) -> dict:
|
| 36 |
+
priorities = 0
|
| 37 |
+
total_losses = 0.0
|
| 38 |
+
for qa in self._qattention_agents:
|
| 39 |
+
update_dict = qa.update(step, replay_sample)
|
| 40 |
+
replay_sample.update(update_dict)
|
| 41 |
+
total_losses += update_dict["total_loss"]
|
| 42 |
+
return {
|
| 43 |
+
"total_losses": total_losses,
|
| 44 |
+
}
|
| 45 |
+
|
| 46 |
+
def act(self, step: int, observation: dict, deterministic=False) -> ActResult:
|
| 47 |
+
observation_elements = {}
|
| 48 |
+
(
|
| 49 |
+
right_translation_results,
|
| 50 |
+
right_rot_grip_results,
|
| 51 |
+
right_ignore_collisions_results,
|
| 52 |
+
) = ([], [], [])
|
| 53 |
+
(
|
| 54 |
+
left_translation_results,
|
| 55 |
+
left_rot_grip_results,
|
| 56 |
+
left_ignore_collisions_results,
|
| 57 |
+
) = ([], [], [])
|
| 58 |
+
|
| 59 |
+
infos = {}
|
| 60 |
+
for depth, qagent in enumerate(self._qattention_agents):
|
| 61 |
+
act_results = qagent.act(step, observation, deterministic)
|
| 62 |
+
right_attention_coordinate = (
|
| 63 |
+
act_results.observation_elements["right_attention_coordinate"]
|
| 64 |
+
.cpu()
|
| 65 |
+
.numpy()
|
| 66 |
+
)
|
| 67 |
+
left_attention_coordinate = (
|
| 68 |
+
act_results.observation_elements["left_attention_coordinate"]
|
| 69 |
+
.cpu()
|
| 70 |
+
.numpy()
|
| 71 |
+
)
|
| 72 |
+
observation_elements[
|
| 73 |
+
"right_attention_coordinate_layer_%d" % depth
|
| 74 |
+
] = right_attention_coordinate[0]
|
| 75 |
+
observation_elements[
|
| 76 |
+
"left_attention_coordinate_layer_%d" % depth
|
| 77 |
+
] = left_attention_coordinate[0]
|
| 78 |
+
|
| 79 |
+
(
|
| 80 |
+
right_translation_idxs,
|
| 81 |
+
right_rot_grip_idxs,
|
| 82 |
+
right_ignore_collisions_idxs,
|
| 83 |
+
left_translation_idxs,
|
| 84 |
+
left_rot_grip_idxs,
|
| 85 |
+
left_ignore_collisions_idxs,
|
| 86 |
+
) = act_results.action
|
| 87 |
+
|
| 88 |
+
right_translation_results.append(right_translation_idxs)
|
| 89 |
+
if right_rot_grip_idxs is not None:
|
| 90 |
+
right_rot_grip_results.append(right_rot_grip_idxs)
|
| 91 |
+
if right_ignore_collisions_idxs is not None:
|
| 92 |
+
right_ignore_collisions_results.append(right_ignore_collisions_idxs)
|
| 93 |
+
|
| 94 |
+
left_translation_results.append(left_translation_idxs)
|
| 95 |
+
if left_rot_grip_idxs is not None:
|
| 96 |
+
left_rot_grip_results.append(left_rot_grip_idxs)
|
| 97 |
+
if left_ignore_collisions_idxs is not None:
|
| 98 |
+
left_ignore_collisions_results.append(left_ignore_collisions_idxs)
|
| 99 |
+
|
| 100 |
+
observation[
|
| 101 |
+
"right_attention_coordinate"
|
| 102 |
+
] = act_results.observation_elements["right_attention_coordinate"]
|
| 103 |
+
observation["left_attention_coordinate"] = act_results.observation_elements[
|
| 104 |
+
"left_attention_coordinate"
|
| 105 |
+
]
|
| 106 |
+
|
| 107 |
+
observation["prev_layer_voxel_grid"] = act_results.observation_elements[
|
| 108 |
+
"prev_layer_voxel_grid"
|
| 109 |
+
]
|
| 110 |
+
observation["prev_layer_bounds"] = act_results.observation_elements[
|
| 111 |
+
"prev_layer_bounds"
|
| 112 |
+
]
|
| 113 |
+
|
| 114 |
+
for n in self._camera_names:
|
| 115 |
+
extrinsics = observation["%s_camera_extrinsics" % n][0, 0].cpu().numpy()
|
| 116 |
+
intrinsics = observation["%s_camera_intrinsics" % n][0, 0].cpu().numpy()
|
| 117 |
+
px, py = utils.point_to_pixel_index(
|
| 118 |
+
right_attention_coordinate[0], extrinsics, intrinsics
|
| 119 |
+
)
|
| 120 |
+
pc_t = torch.tensor(
|
| 121 |
+
[[[py, px]]], dtype=torch.float32, device=self._device
|
| 122 |
+
)
|
| 123 |
+
observation[f"right_{n}_pixel_coord"] = pc_t
|
| 124 |
+
observation_elements[f"right_{n}_pixel_coord"] = [py, px]
|
| 125 |
+
|
| 126 |
+
px, py = utils.point_to_pixel_index(
|
| 127 |
+
left_attention_coordinate[0], extrinsics, intrinsics
|
| 128 |
+
)
|
| 129 |
+
pc_t = torch.tensor(
|
| 130 |
+
[[[py, px]]], dtype=torch.float32, device=self._device
|
| 131 |
+
)
|
| 132 |
+
observation[f"left_{n}_pixel_coord"] = pc_t
|
| 133 |
+
observation_elements[f"left_{n}_pixel_coord"] = [py, px]
|
| 134 |
+
infos.update(act_results.info)
|
| 135 |
+
|
| 136 |
+
right_rgai = torch.cat(right_rot_grip_results, 1)[0].cpu().numpy()
|
| 137 |
+
# ..todo:: utils.correct_rotation_instability does nothing so we can ignore it
|
| 138 |
+
# right_rgai = utils.correct_rotation_instability(right_rgai, self._rotation_resolution)
|
| 139 |
+
right_ignore_collisions = (
|
| 140 |
+
torch.cat(right_ignore_collisions_results, 1)[0].cpu().numpy()
|
| 141 |
+
)
|
| 142 |
+
right_trans_action_indicies = (
|
| 143 |
+
torch.cat(right_translation_results, 1)[0].cpu().numpy()
|
| 144 |
+
)
|
| 145 |
+
|
| 146 |
+
observation_elements[
|
| 147 |
+
"right_trans_action_indicies"
|
| 148 |
+
] = right_trans_action_indicies[:3]
|
| 149 |
+
observation_elements["right_rot_grip_action_indicies"] = right_rgai[:4]
|
| 150 |
+
|
| 151 |
+
left_rgai = torch.cat(left_rot_grip_results, 1)[0].cpu().numpy()
|
| 152 |
+
left_ignore_collisions = (
|
| 153 |
+
torch.cat(left_ignore_collisions_results, 1)[0].cpu().numpy()
|
| 154 |
+
)
|
| 155 |
+
left_trans_action_indicies = (
|
| 156 |
+
torch.cat(left_translation_results, 1)[0].cpu().numpy()
|
| 157 |
+
)
|
| 158 |
+
|
| 159 |
+
observation_elements["left_trans_action_indicies"] = left_trans_action_indicies[
|
| 160 |
+
3:
|
| 161 |
+
]
|
| 162 |
+
observation_elements["left_rot_grip_action_indicies"] = left_rgai[4:]
|
| 163 |
+
|
| 164 |
+
continuous_action = np.concatenate(
|
| 165 |
+
[
|
| 166 |
+
right_attention_coordinate[0],
|
| 167 |
+
utils.discrete_euler_to_quaternion(
|
| 168 |
+
right_rgai[-4:-1], self._rotation_resolution
|
| 169 |
+
),
|
| 170 |
+
right_rgai[-1:],
|
| 171 |
+
right_ignore_collisions,
|
| 172 |
+
left_attention_coordinate[0],
|
| 173 |
+
utils.discrete_euler_to_quaternion(
|
| 174 |
+
left_rgai[-4:-1], self._rotation_resolution
|
| 175 |
+
),
|
| 176 |
+
left_rgai[-1:],
|
| 177 |
+
left_ignore_collisions,
|
| 178 |
+
]
|
| 179 |
+
)
|
| 180 |
+
return ActResult(
|
| 181 |
+
continuous_action, observation_elements=observation_elements, info=infos
|
| 182 |
+
)
|
| 183 |
+
|
| 184 |
+
def update_summaries(self) -> List[Summary]:
|
| 185 |
+
summaries = []
|
| 186 |
+
for qa in self._qattention_agents:
|
| 187 |
+
summaries.extend(qa.update_summaries())
|
| 188 |
+
return summaries
|
| 189 |
+
|
| 190 |
+
def act_summaries(self) -> List[Summary]:
|
| 191 |
+
s = []
|
| 192 |
+
for qa in self._qattention_agents:
|
| 193 |
+
s.extend(qa.act_summaries())
|
| 194 |
+
return s
|
| 195 |
+
|
| 196 |
+
def load_weights(self, savedir: str):
|
| 197 |
+
for qa in self._qattention_agents:
|
| 198 |
+
qa.load_weights(savedir)
|
| 199 |
+
|
| 200 |
+
def save_weights(self, savedir: str):
|
| 201 |
+
for qa in self._qattention_agents:
|
| 202 |
+
qa.save_weights(savedir)
|
external/peract_bimanual/agents/c2farm_lingunet_bc/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
import agents.c2farm_lingunet_bc.launch_utils
|
external/peract_bimanual/agents/c2farm_lingunet_bc/launch_utils.py
ADDED
|
@@ -0,0 +1,519 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Adapted from ARM
|
| 2 |
+
# Source: https://github.com/stepjam/ARM
|
| 3 |
+
# License: https://github.com/stepjam/ARM/LICENSE
|
| 4 |
+
|
| 5 |
+
import logging
|
| 6 |
+
from typing import List
|
| 7 |
+
|
| 8 |
+
import numpy as np
|
| 9 |
+
from omegaconf import DictConfig
|
| 10 |
+
from rlbench.backend.observation import Observation
|
| 11 |
+
from rlbench.observation_config import ObservationConfig
|
| 12 |
+
import rlbench.utils as rlbench_utils
|
| 13 |
+
from rlbench.demo import Demo
|
| 14 |
+
from yarr.replay_buffer.prioritized_replay_buffer import ObservationElement
|
| 15 |
+
from yarr.replay_buffer.replay_buffer import ReplayElement, ReplayBuffer
|
| 16 |
+
from yarr.replay_buffer.uniform_replay_buffer import UniformReplayBuffer
|
| 17 |
+
from yarr.replay_buffer.task_uniform_replay_buffer import TaskUniformReplayBuffer
|
| 18 |
+
|
| 19 |
+
from helpers import demo_loading_utils, utils
|
| 20 |
+
from helpers import observation_utils
|
| 21 |
+
from helpers.preprocess_agent import PreprocessAgent
|
| 22 |
+
from helpers.clip.core.clip import tokenize
|
| 23 |
+
from agents.c2farm_lingunet_bc.networks import QattentionLingU3DNet
|
| 24 |
+
from agents.c2farm_lingunet_bc.qattention_lingunet_bc_agent import (
|
| 25 |
+
QAttentionLingUNetBCAgent,
|
| 26 |
+
)
|
| 27 |
+
from agents.c2farm_lingunet_bc.qattention_stack_agent import QAttentionStackAgent
|
| 28 |
+
|
| 29 |
+
import torch
|
| 30 |
+
from torch.multiprocessing import Process, Value, Manager
|
| 31 |
+
from helpers.clip.core.clip import build_model, load_clip, tokenize
|
| 32 |
+
from omegaconf import DictConfig
|
| 33 |
+
|
| 34 |
+
REWARD_SCALE = 100.0
|
| 35 |
+
LOW_DIM_SIZE = 4
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
def create_replay(
|
| 39 |
+
batch_size: int,
|
| 40 |
+
timesteps: int,
|
| 41 |
+
prioritisation: bool,
|
| 42 |
+
task_uniform: bool,
|
| 43 |
+
save_dir: str,
|
| 44 |
+
cameras: list,
|
| 45 |
+
voxel_sizes,
|
| 46 |
+
image_size=[128, 128],
|
| 47 |
+
replay_size=3e5,
|
| 48 |
+
):
|
| 49 |
+
trans_indicies_size = 3 * len(voxel_sizes)
|
| 50 |
+
rot_and_grip_indicies_size = 3 + 1
|
| 51 |
+
gripper_pose_size = 7
|
| 52 |
+
ignore_collisions_size = 1
|
| 53 |
+
max_token_seq_len = 77
|
| 54 |
+
lang_feat_dim = 1024
|
| 55 |
+
lang_emb_dim = 512
|
| 56 |
+
|
| 57 |
+
# low_dim_state
|
| 58 |
+
observation_elements = []
|
| 59 |
+
observation_elements.append(
|
| 60 |
+
ObservationElement("low_dim_state", (LOW_DIM_SIZE,), np.float32)
|
| 61 |
+
)
|
| 62 |
+
|
| 63 |
+
# rgb, depth, point cloud, intrinsics, extrinsics
|
| 64 |
+
for cname in cameras:
|
| 65 |
+
observation_elements.append(
|
| 66 |
+
ObservationElement(
|
| 67 |
+
"%s_rgb" % cname,
|
| 68 |
+
(
|
| 69 |
+
3,
|
| 70 |
+
*image_size,
|
| 71 |
+
),
|
| 72 |
+
np.float32,
|
| 73 |
+
)
|
| 74 |
+
)
|
| 75 |
+
observation_elements.append(
|
| 76 |
+
ObservationElement("%s_point_cloud" % cname, (3, *image_size), np.float32)
|
| 77 |
+
) # see pyrep/objects/vision_sensor.py on how pointclouds are extracted from depth frames
|
| 78 |
+
observation_elements.append(
|
| 79 |
+
ObservationElement(
|
| 80 |
+
"%s_camera_extrinsics" % cname,
|
| 81 |
+
(
|
| 82 |
+
4,
|
| 83 |
+
4,
|
| 84 |
+
),
|
| 85 |
+
np.float32,
|
| 86 |
+
)
|
| 87 |
+
)
|
| 88 |
+
observation_elements.append(
|
| 89 |
+
ObservationElement(
|
| 90 |
+
"%s_camera_intrinsics" % cname,
|
| 91 |
+
(
|
| 92 |
+
3,
|
| 93 |
+
3,
|
| 94 |
+
),
|
| 95 |
+
np.float32,
|
| 96 |
+
)
|
| 97 |
+
)
|
| 98 |
+
observation_elements.append(
|
| 99 |
+
ObservationElement("%s_pixel_coord" % cname, (2,), np.int32)
|
| 100 |
+
)
|
| 101 |
+
|
| 102 |
+
# discretized translation, discretized rotation, discrete ignore collision, 6-DoF gripper pose, and pre-trained language embeddings
|
| 103 |
+
observation_elements.extend(
|
| 104 |
+
[
|
| 105 |
+
ReplayElement("trans_action_indicies", (trans_indicies_size,), np.int32),
|
| 106 |
+
ReplayElement(
|
| 107 |
+
"rot_grip_action_indicies", (rot_and_grip_indicies_size,), np.int32
|
| 108 |
+
),
|
| 109 |
+
ReplayElement("ignore_collisions", (ignore_collisions_size,), np.int32),
|
| 110 |
+
ReplayElement("gripper_pose", (gripper_pose_size,), np.float32),
|
| 111 |
+
ReplayElement("lang_goal_emb", (lang_feat_dim,), np.float32),
|
| 112 |
+
ReplayElement(
|
| 113 |
+
"lang_token_embs",
|
| 114 |
+
(
|
| 115 |
+
max_token_seq_len,
|
| 116 |
+
lang_emb_dim,
|
| 117 |
+
),
|
| 118 |
+
np.float32,
|
| 119 |
+
), # extracted from CLIP's language encoder
|
| 120 |
+
ReplayElement("task", (), str),
|
| 121 |
+
ReplayElement(
|
| 122 |
+
"lang_goal", (1,), object
|
| 123 |
+
), # language goal string for debugging and visualization
|
| 124 |
+
]
|
| 125 |
+
)
|
| 126 |
+
|
| 127 |
+
for depth in range(len(voxel_sizes)):
|
| 128 |
+
observation_elements.append(
|
| 129 |
+
ReplayElement("attention_coordinate_layer_%d" % depth, (3,), np.float32)
|
| 130 |
+
)
|
| 131 |
+
|
| 132 |
+
extra_replay_elements = [
|
| 133 |
+
ReplayElement("demo", (), np.bool),
|
| 134 |
+
]
|
| 135 |
+
|
| 136 |
+
replay_buffer = TaskUniformReplayBuffer(
|
| 137 |
+
save_dir=save_dir,
|
| 138 |
+
batch_size=batch_size,
|
| 139 |
+
timesteps=timesteps,
|
| 140 |
+
replay_capacity=int(replay_size),
|
| 141 |
+
action_shape=(8,),
|
| 142 |
+
action_dtype=np.float32,
|
| 143 |
+
reward_shape=(),
|
| 144 |
+
reward_dtype=np.float32,
|
| 145 |
+
update_horizon=1,
|
| 146 |
+
observation_elements=observation_elements,
|
| 147 |
+
extra_replay_elements=extra_replay_elements,
|
| 148 |
+
)
|
| 149 |
+
return replay_buffer
|
| 150 |
+
|
| 151 |
+
|
| 152 |
+
def _get_action(
|
| 153 |
+
obs_tp1: Observation,
|
| 154 |
+
obs_tm1: Observation,
|
| 155 |
+
rlbench_scene_bounds: List[float], # metric 3D bounds of the scene
|
| 156 |
+
voxel_sizes: List[int],
|
| 157 |
+
bounds_offset: List[float],
|
| 158 |
+
rotation_resolution: int,
|
| 159 |
+
crop_augmentation: bool,
|
| 160 |
+
):
|
| 161 |
+
quat = utils.normalize_quaternion(obs_tp1.gripper_pose[3:])
|
| 162 |
+
if quat[-1] < 0:
|
| 163 |
+
quat = -quat
|
| 164 |
+
disc_rot = utils.quaternion_to_discrete_euler(quat, rotation_resolution)
|
| 165 |
+
disc_rot = utils.correct_rotation_instability(disc_rot, rotation_resolution)
|
| 166 |
+
|
| 167 |
+
attention_coordinate = obs_tp1.gripper_pose[:3]
|
| 168 |
+
trans_indicies, attention_coordinates = [], []
|
| 169 |
+
bounds = np.array(rlbench_scene_bounds)
|
| 170 |
+
ignore_collisions = int(obs_tm1.ignore_collisions)
|
| 171 |
+
for depth, vox_size in enumerate(
|
| 172 |
+
voxel_sizes
|
| 173 |
+
): # only single voxelization-level is used in PerAct
|
| 174 |
+
if depth > 0:
|
| 175 |
+
if crop_augmentation:
|
| 176 |
+
shift = bounds_offset[depth - 1] * 0.75
|
| 177 |
+
attention_coordinate += np.random.uniform(-shift, shift, size=(3,))
|
| 178 |
+
bounds = np.concatenate(
|
| 179 |
+
[
|
| 180 |
+
attention_coordinate - bounds_offset[depth - 1],
|
| 181 |
+
attention_coordinate + bounds_offset[depth - 1],
|
| 182 |
+
]
|
| 183 |
+
)
|
| 184 |
+
index = utils.point_to_voxel_index(obs_tp1.gripper_pose[:3], vox_size, bounds)
|
| 185 |
+
trans_indicies.extend(index.tolist())
|
| 186 |
+
res = (bounds[3:] - bounds[:3]) / vox_size
|
| 187 |
+
attention_coordinate = bounds[:3] + res * index
|
| 188 |
+
attention_coordinates.append(attention_coordinate)
|
| 189 |
+
|
| 190 |
+
rot_and_grip_indicies = disc_rot.tolist()
|
| 191 |
+
grip = float(obs_tp1.gripper_open)
|
| 192 |
+
rot_and_grip_indicies.extend([int(obs_tp1.gripper_open)])
|
| 193 |
+
return (
|
| 194 |
+
trans_indicies,
|
| 195 |
+
rot_and_grip_indicies,
|
| 196 |
+
ignore_collisions,
|
| 197 |
+
np.concatenate([obs_tp1.gripper_pose, np.array([grip])]),
|
| 198 |
+
attention_coordinates,
|
| 199 |
+
)
|
| 200 |
+
|
| 201 |
+
|
| 202 |
+
def _add_keypoints_to_replay(
|
| 203 |
+
cfg: DictConfig,
|
| 204 |
+
task: str,
|
| 205 |
+
replay: ReplayBuffer,
|
| 206 |
+
inital_obs: Observation,
|
| 207 |
+
demo: Demo,
|
| 208 |
+
episode_keypoints: List[int],
|
| 209 |
+
cameras: List[str],
|
| 210 |
+
rlbench_scene_bounds: List[float],
|
| 211 |
+
voxel_sizes: List[int],
|
| 212 |
+
bounds_offset: List[float],
|
| 213 |
+
rotation_resolution: int,
|
| 214 |
+
crop_augmentation: bool,
|
| 215 |
+
description: str = "",
|
| 216 |
+
clip_model=None,
|
| 217 |
+
device="cpu",
|
| 218 |
+
):
|
| 219 |
+
prev_action = None
|
| 220 |
+
obs = inital_obs
|
| 221 |
+
for k, keypoint in enumerate(episode_keypoints):
|
| 222 |
+
obs_tp1 = demo[keypoint]
|
| 223 |
+
obs_tm1 = demo[max(0, keypoint - 1)]
|
| 224 |
+
(
|
| 225 |
+
trans_indicies,
|
| 226 |
+
rot_grip_indicies,
|
| 227 |
+
ignore_collisions,
|
| 228 |
+
action,
|
| 229 |
+
attention_coordinates,
|
| 230 |
+
) = _get_action(
|
| 231 |
+
obs_tp1,
|
| 232 |
+
obs_tm1,
|
| 233 |
+
rlbench_scene_bounds,
|
| 234 |
+
voxel_sizes,
|
| 235 |
+
bounds_offset,
|
| 236 |
+
rotation_resolution,
|
| 237 |
+
crop_augmentation,
|
| 238 |
+
)
|
| 239 |
+
|
| 240 |
+
terminal = k == len(episode_keypoints) - 1
|
| 241 |
+
reward = float(terminal) * REWARD_SCALE if terminal else 0
|
| 242 |
+
|
| 243 |
+
obs_dict = observation_utils.extract_obs(
|
| 244 |
+
obs,
|
| 245 |
+
t=k,
|
| 246 |
+
prev_action=prev_action,
|
| 247 |
+
cameras=cameras,
|
| 248 |
+
episode_length=cfg.rlbench.episode_length,
|
| 249 |
+
robot_name=cfg.method.robot_name,
|
| 250 |
+
)
|
| 251 |
+
tokens = tokenize([description]).numpy()
|
| 252 |
+
token_tensor = torch.from_numpy(tokens).to(device)
|
| 253 |
+
sentence_emb, token_embs = clip_model.encode_text_with_embeddings(token_tensor)
|
| 254 |
+
obs_dict["lang_goal_emb"] = sentence_emb[0].float().detach().cpu().numpy()
|
| 255 |
+
obs_dict["lang_token_embs"] = token_embs[0].float().detach().cpu().numpy()
|
| 256 |
+
|
| 257 |
+
prev_action = np.copy(action)
|
| 258 |
+
|
| 259 |
+
others = {"demo": True}
|
| 260 |
+
final_obs = {
|
| 261 |
+
"trans_action_indicies": trans_indicies,
|
| 262 |
+
"rot_grip_action_indicies": rot_grip_indicies,
|
| 263 |
+
"gripper_pose": obs_tp1.gripper_pose,
|
| 264 |
+
"task": task,
|
| 265 |
+
"lang_goal": np.array([description], dtype=object),
|
| 266 |
+
}
|
| 267 |
+
|
| 268 |
+
for depth in range(len(voxel_sizes)):
|
| 269 |
+
final_obs["attention_coordinate_layer_%d" % depth] = attention_coordinates[
|
| 270 |
+
depth
|
| 271 |
+
]
|
| 272 |
+
for name in cameras:
|
| 273 |
+
px, py = utils.point_to_pixel_index(
|
| 274 |
+
obs_tp1.gripper_pose[:3],
|
| 275 |
+
obs_tp1.misc["%s_camera_extrinsics" % name],
|
| 276 |
+
obs_tp1.misc["%s_camera_intrinsics" % name],
|
| 277 |
+
)
|
| 278 |
+
final_obs["%s_pixel_coord" % name] = [py, px]
|
| 279 |
+
|
| 280 |
+
others.update(final_obs)
|
| 281 |
+
others.update(obs_dict)
|
| 282 |
+
|
| 283 |
+
timeout = False
|
| 284 |
+
replay.add(action, reward, terminal, timeout, **others)
|
| 285 |
+
obs = obs_tp1
|
| 286 |
+
|
| 287 |
+
# final step
|
| 288 |
+
obs_dict_tp1 = observation_utils.extract_obs(
|
| 289 |
+
obs_tp1,
|
| 290 |
+
t=k + 1,
|
| 291 |
+
prev_action=prev_action,
|
| 292 |
+
cameras=cameras,
|
| 293 |
+
episode_length=cfg.rlbench.episode_length,
|
| 294 |
+
robot_name=cfg.method.robot_name,
|
| 295 |
+
)
|
| 296 |
+
obs_dict_tp1["lang_goal_emb"] = sentence_emb[0].float().detach().cpu().numpy()
|
| 297 |
+
obs_dict_tp1["lang_token_embs"] = token_embs[0].float().detach().cpu().numpy()
|
| 298 |
+
|
| 299 |
+
obs_dict_tp1.pop("wrist_world_to_cam", None)
|
| 300 |
+
obs_dict_tp1.update(final_obs)
|
| 301 |
+
replay.add_final(**obs_dict_tp1)
|
| 302 |
+
|
| 303 |
+
|
| 304 |
+
def fill_replay(
|
| 305 |
+
cfg: DictConfig,
|
| 306 |
+
obs_config: ObservationConfig,
|
| 307 |
+
rank: int,
|
| 308 |
+
replay: ReplayBuffer,
|
| 309 |
+
task: str,
|
| 310 |
+
num_demos: int,
|
| 311 |
+
demo_augmentation: bool,
|
| 312 |
+
demo_augmentation_every_n: int,
|
| 313 |
+
cameras: List[str],
|
| 314 |
+
rlbench_scene_bounds: List[float], # AKA: DEPTH0_BOUNDS
|
| 315 |
+
voxel_sizes: List[int],
|
| 316 |
+
bounds_offset: List[float],
|
| 317 |
+
rotation_resolution: int,
|
| 318 |
+
crop_augmentation: bool,
|
| 319 |
+
clip_model=None,
|
| 320 |
+
device="cpu",
|
| 321 |
+
keypoint_method="heuristic",
|
| 322 |
+
):
|
| 323 |
+
if clip_model is None:
|
| 324 |
+
model, _ = load_clip("RN50", jit=False, device=device)
|
| 325 |
+
clip_model = build_model(model.state_dict())
|
| 326 |
+
clip_model.to(device)
|
| 327 |
+
del model
|
| 328 |
+
|
| 329 |
+
logging.debug("Filling %s replay ..." % task)
|
| 330 |
+
for d_idx in range(num_demos):
|
| 331 |
+
# load demo from disk
|
| 332 |
+
demo = rlbench_utils.get_stored_demos(
|
| 333 |
+
amount=1,
|
| 334 |
+
image_paths=False,
|
| 335 |
+
dataset_root=cfg.rlbench.demo_path,
|
| 336 |
+
variation_number=-1,
|
| 337 |
+
task_name=task,
|
| 338 |
+
obs_config=obs_config,
|
| 339 |
+
random_selection=False,
|
| 340 |
+
from_episode_number=d_idx,
|
| 341 |
+
)[0]
|
| 342 |
+
|
| 343 |
+
descs = demo._observations[0].misc["descriptions"]
|
| 344 |
+
|
| 345 |
+
# extract keypoints (a.k.a keyframes)
|
| 346 |
+
episode_keypoints = demo_loading_utils.keypoint_discovery(
|
| 347 |
+
demo, method=keypoint_method
|
| 348 |
+
)
|
| 349 |
+
|
| 350 |
+
if rank == 0:
|
| 351 |
+
logging.info(
|
| 352 |
+
f"Loading Demo({d_idx}) - found {len(episode_keypoints)} keypoints - {task}"
|
| 353 |
+
)
|
| 354 |
+
|
| 355 |
+
for i in range(len(demo) - 1):
|
| 356 |
+
if not demo_augmentation and i > 0:
|
| 357 |
+
break
|
| 358 |
+
if i % demo_augmentation_every_n != 0:
|
| 359 |
+
continue
|
| 360 |
+
|
| 361 |
+
obs = demo[i]
|
| 362 |
+
desc = descs[0]
|
| 363 |
+
# if our starting point is past one of the keypoints, then remove it
|
| 364 |
+
while len(episode_keypoints) > 0 and i >= episode_keypoints[0]:
|
| 365 |
+
episode_keypoints = episode_keypoints[1:]
|
| 366 |
+
if len(episode_keypoints) == 0:
|
| 367 |
+
break
|
| 368 |
+
_add_keypoints_to_replay(
|
| 369 |
+
cfg,
|
| 370 |
+
task,
|
| 371 |
+
replay,
|
| 372 |
+
obs,
|
| 373 |
+
demo,
|
| 374 |
+
episode_keypoints,
|
| 375 |
+
cameras,
|
| 376 |
+
rlbench_scene_bounds,
|
| 377 |
+
voxel_sizes,
|
| 378 |
+
bounds_offset,
|
| 379 |
+
rotation_resolution,
|
| 380 |
+
crop_augmentation,
|
| 381 |
+
description=desc,
|
| 382 |
+
clip_model=clip_model,
|
| 383 |
+
device=device,
|
| 384 |
+
)
|
| 385 |
+
logging.debug("Replay %s filled with demos." % task)
|
| 386 |
+
|
| 387 |
+
|
| 388 |
+
def fill_multi_task_replay(
|
| 389 |
+
cfg: DictConfig,
|
| 390 |
+
obs_config: ObservationConfig,
|
| 391 |
+
rank: int,
|
| 392 |
+
replay: ReplayBuffer,
|
| 393 |
+
tasks: List[str],
|
| 394 |
+
num_demos: int,
|
| 395 |
+
demo_augmentation: bool,
|
| 396 |
+
demo_augmentation_every_n: int,
|
| 397 |
+
cameras: List[str],
|
| 398 |
+
rlbench_scene_bounds: List[float],
|
| 399 |
+
voxel_sizes: List[int],
|
| 400 |
+
bounds_offset: List[float],
|
| 401 |
+
rotation_resolution: int,
|
| 402 |
+
crop_augmentation: bool,
|
| 403 |
+
clip_model=None,
|
| 404 |
+
keypoint_method="heuristic",
|
| 405 |
+
):
|
| 406 |
+
manager = Manager()
|
| 407 |
+
store = manager.dict()
|
| 408 |
+
|
| 409 |
+
# create a MP dict for storing indicies
|
| 410 |
+
# TODO(mohit): this shouldn't be initialized here
|
| 411 |
+
del replay._task_idxs
|
| 412 |
+
task_idxs = manager.dict()
|
| 413 |
+
replay._task_idxs = task_idxs
|
| 414 |
+
replay._create_storage(store)
|
| 415 |
+
replay.add_count = Value("i", 0)
|
| 416 |
+
|
| 417 |
+
# fill replay buffer in parallel across tasks
|
| 418 |
+
max_parallel_processes = cfg.replay.max_parallel_processes
|
| 419 |
+
processes = []
|
| 420 |
+
n = np.arange(len(tasks))
|
| 421 |
+
split_n = utils.split_list(n, max_parallel_processes)
|
| 422 |
+
for split in split_n:
|
| 423 |
+
for e_idx, task_idx in enumerate(split):
|
| 424 |
+
task = tasks[int(task_idx)]
|
| 425 |
+
model_device = torch.device(
|
| 426 |
+
"cuda:%s" % (e_idx % torch.cuda.device_count())
|
| 427 |
+
if torch.cuda.is_available()
|
| 428 |
+
else "cpu"
|
| 429 |
+
)
|
| 430 |
+
p = Process(
|
| 431 |
+
target=fill_replay,
|
| 432 |
+
args=(
|
| 433 |
+
cfg,
|
| 434 |
+
obs_config,
|
| 435 |
+
rank,
|
| 436 |
+
replay,
|
| 437 |
+
task,
|
| 438 |
+
num_demos,
|
| 439 |
+
demo_augmentation,
|
| 440 |
+
demo_augmentation_every_n,
|
| 441 |
+
cameras,
|
| 442 |
+
rlbench_scene_bounds,
|
| 443 |
+
voxel_sizes,
|
| 444 |
+
bounds_offset,
|
| 445 |
+
rotation_resolution,
|
| 446 |
+
crop_augmentation,
|
| 447 |
+
clip_model,
|
| 448 |
+
model_device,
|
| 449 |
+
keypoint_method,
|
| 450 |
+
),
|
| 451 |
+
)
|
| 452 |
+
p.start()
|
| 453 |
+
processes.append(p)
|
| 454 |
+
|
| 455 |
+
for p in processes:
|
| 456 |
+
p.join()
|
| 457 |
+
|
| 458 |
+
|
| 459 |
+
def create_agent(cfg: DictConfig):
|
| 460 |
+
LATENT_SIZE = 64
|
| 461 |
+
depth_0bounds = cfg.rlbench.scene_bounds
|
| 462 |
+
cam_resolution = cfg.rlbench.camera_resolution
|
| 463 |
+
|
| 464 |
+
num_rotation_classes = int(360.0 // cfg.method.rotation_resolution)
|
| 465 |
+
qattention_agents = []
|
| 466 |
+
for depth, vox_size in enumerate(cfg.method.voxel_sizes):
|
| 467 |
+
last = depth == len(cfg.method.voxel_sizes) - 1
|
| 468 |
+
unet3d = QattentionLingU3DNet(
|
| 469 |
+
in_channels=3 + 3 + 1 + 3,
|
| 470 |
+
out_channels=1,
|
| 471 |
+
voxel_size=vox_size,
|
| 472 |
+
out_dense=((num_rotation_classes * 3) + 4) if last else 0,
|
| 473 |
+
kernels=LATENT_SIZE,
|
| 474 |
+
norm=None if "None" in cfg.method.norm else cfg.method.norm,
|
| 475 |
+
dense_feats=128,
|
| 476 |
+
activation=cfg.method.activation,
|
| 477 |
+
low_dim_size=4,
|
| 478 |
+
include_prev_layer=cfg.method.include_prev_layer and depth > 0,
|
| 479 |
+
depth=depth,
|
| 480 |
+
)
|
| 481 |
+
|
| 482 |
+
qattention_agent = QAttentionLingUNetBCAgent(
|
| 483 |
+
layer=depth,
|
| 484 |
+
coordinate_bounds=depth_0bounds,
|
| 485 |
+
unet3d=unet3d,
|
| 486 |
+
camera_names=cfg.rlbench.cameras,
|
| 487 |
+
batch_size=cfg.replay.batch_size,
|
| 488 |
+
voxel_size=vox_size,
|
| 489 |
+
bounds_offset=cfg.method.bounds_offset[depth - 1] if depth > 0 else None,
|
| 490 |
+
voxel_feature_size=3,
|
| 491 |
+
image_crop_size=cfg.method.image_crop_size,
|
| 492 |
+
lr=cfg.method.lr,
|
| 493 |
+
training_iterations=cfg.framework.training_iterations,
|
| 494 |
+
lr_scheduler=cfg.method.lr_scheduler,
|
| 495 |
+
num_warmup_steps=cfg.method.num_warmup_steps,
|
| 496 |
+
trans_loss_weight=cfg.method.trans_loss_weight,
|
| 497 |
+
rot_loss_weight=cfg.method.rot_loss_weight,
|
| 498 |
+
grip_loss_weight=cfg.method.grip_loss_weight,
|
| 499 |
+
collision_loss_weight=cfg.method.collision_loss_weight,
|
| 500 |
+
include_low_dim_state=True,
|
| 501 |
+
image_resolution=cam_resolution,
|
| 502 |
+
lambda_weight_l2=cfg.method.lambda_weight_l2,
|
| 503 |
+
num_rotation_classes=num_rotation_classes,
|
| 504 |
+
rotation_resolution=cfg.method.rotation_resolution,
|
| 505 |
+
transform_augmentation=cfg.method.transform_augmentation.apply_se3,
|
| 506 |
+
transform_augmentation_xyz=cfg.method.transform_augmentation.aug_xyz,
|
| 507 |
+
transform_augmentation_rpy=cfg.method.transform_augmentation.aug_rpy,
|
| 508 |
+
transform_augmentation_rot_resolution=cfg.method.transform_augmentation.aug_rot_resolution,
|
| 509 |
+
num_devices=cfg.ddp.num_devices,
|
| 510 |
+
)
|
| 511 |
+
qattention_agents.append(qattention_agent)
|
| 512 |
+
|
| 513 |
+
rotation_agent = QAttentionStackAgent(
|
| 514 |
+
qattention_agents=qattention_agents,
|
| 515 |
+
rotation_resolution=cfg.method.rotation_resolution,
|
| 516 |
+
camera_names=cfg.rlbench.cameras,
|
| 517 |
+
)
|
| 518 |
+
preprocess_agent = PreprocessAgent(pose_agent=rotation_agent)
|
| 519 |
+
return preprocess_agent
|
external/peract_bimanual/agents/c2farm_lingunet_bc/networks.py
ADDED
|
@@ -0,0 +1,301 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
|
| 4 |
+
from helpers.network_utils import (
|
| 5 |
+
Conv3DInceptionBlock,
|
| 6 |
+
DenseBlock,
|
| 7 |
+
SpatialSoftmax3D,
|
| 8 |
+
Conv3DInceptionBlockUpsampleBlock,
|
| 9 |
+
Conv3DBlock,
|
| 10 |
+
)
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
class QattentionLingU3DNet(nn.Module):
|
| 14 |
+
def __init__(
|
| 15 |
+
self,
|
| 16 |
+
in_channels: int,
|
| 17 |
+
out_channels: int,
|
| 18 |
+
out_dense: int,
|
| 19 |
+
voxel_size: int,
|
| 20 |
+
low_dim_size: int,
|
| 21 |
+
kernels: int,
|
| 22 |
+
norm: str = None,
|
| 23 |
+
activation: str = "relu",
|
| 24 |
+
dense_feats: int = 32,
|
| 25 |
+
include_prev_layer=False,
|
| 26 |
+
depth=0,
|
| 27 |
+
lingunet_dropout=0.0,
|
| 28 |
+
):
|
| 29 |
+
super(QattentionLingU3DNet, self).__init__()
|
| 30 |
+
self._in_channels = in_channels
|
| 31 |
+
self._out_channels = out_channels
|
| 32 |
+
self._norm = norm
|
| 33 |
+
self._activation = activation
|
| 34 |
+
self._kernels = kernels
|
| 35 |
+
self._low_dim_size = low_dim_size
|
| 36 |
+
self._build_calls = 0
|
| 37 |
+
self._voxel_size = voxel_size
|
| 38 |
+
self._dense_feats = dense_feats
|
| 39 |
+
self._out_dense = out_dense
|
| 40 |
+
self._include_prev_layer = include_prev_layer
|
| 41 |
+
self._depth = depth
|
| 42 |
+
|
| 43 |
+
self._lingunet_dropout = lingunet_dropout
|
| 44 |
+
self._clip_lang_feat_dim = 1024
|
| 45 |
+
|
| 46 |
+
if self._voxel_size < 16:
|
| 47 |
+
raise Exception(
|
| 48 |
+
"Voxel size for C2FARM_LINGUNET_BC should be at least 16 or higher"
|
| 49 |
+
)
|
| 50 |
+
|
| 51 |
+
def build(self):
|
| 52 |
+
use_residual = False
|
| 53 |
+
self._build_calls += 1
|
| 54 |
+
if self._build_calls != 1:
|
| 55 |
+
raise RuntimeError("Build needs to be called once.")
|
| 56 |
+
|
| 57 |
+
spatial_size = self._voxel_size
|
| 58 |
+
self._input_preprocess = Conv3DInceptionBlock(
|
| 59 |
+
self._in_channels,
|
| 60 |
+
self._kernels,
|
| 61 |
+
norm=self._norm,
|
| 62 |
+
activation=self._activation,
|
| 63 |
+
)
|
| 64 |
+
|
| 65 |
+
d0_ins = self._input_preprocess.out_channels
|
| 66 |
+
if self._include_prev_layer:
|
| 67 |
+
PREV_VOXEL_CHANNELS = 0
|
| 68 |
+
d0_ins += self._input_preprocess.out_channels * self._depth
|
| 69 |
+
|
| 70 |
+
if self._low_dim_size > 0:
|
| 71 |
+
self._proprio_preprocess = DenseBlock(
|
| 72 |
+
self._low_dim_size, self._kernels, None, self._activation
|
| 73 |
+
)
|
| 74 |
+
d0_ins += self._kernels
|
| 75 |
+
|
| 76 |
+
self._down0 = Conv3DInceptionBlock(
|
| 77 |
+
d0_ins,
|
| 78 |
+
self._kernels,
|
| 79 |
+
norm=self._norm,
|
| 80 |
+
activation=self._activation,
|
| 81 |
+
residual=use_residual,
|
| 82 |
+
)
|
| 83 |
+
self._ss0 = SpatialSoftmax3D(
|
| 84 |
+
spatial_size, spatial_size, spatial_size, self._down0.out_channels
|
| 85 |
+
)
|
| 86 |
+
spatial_size //= 2
|
| 87 |
+
self._down1 = Conv3DInceptionBlock(
|
| 88 |
+
self._down0.out_channels,
|
| 89 |
+
self._kernels * 2,
|
| 90 |
+
norm=self._norm,
|
| 91 |
+
activation=self._activation,
|
| 92 |
+
residual=use_residual,
|
| 93 |
+
)
|
| 94 |
+
self._ss1 = SpatialSoftmax3D(
|
| 95 |
+
spatial_size, spatial_size, spatial_size, self._down1.out_channels
|
| 96 |
+
)
|
| 97 |
+
spatial_size //= 2
|
| 98 |
+
|
| 99 |
+
flat_size = self._down0.out_channels * 4 + self._down1.out_channels * 4
|
| 100 |
+
|
| 101 |
+
k1 = self._down1.out_channels
|
| 102 |
+
if self._voxel_size > 8:
|
| 103 |
+
k1 += self._kernels
|
| 104 |
+
self._down2 = Conv3DInceptionBlock(
|
| 105 |
+
self._down1.out_channels,
|
| 106 |
+
self._kernels * 4,
|
| 107 |
+
norm=self._norm,
|
| 108 |
+
activation=self._activation,
|
| 109 |
+
residual=use_residual,
|
| 110 |
+
)
|
| 111 |
+
self._lang_proj2 = DenseBlock(
|
| 112 |
+
self._clip_lang_feat_dim, self._down2.out_channels, None, None
|
| 113 |
+
)
|
| 114 |
+
self._dropout2 = nn.Dropout(self._lingunet_dropout)
|
| 115 |
+
flat_size += self._down2.out_channels * 4
|
| 116 |
+
self._ss2 = SpatialSoftmax3D(
|
| 117 |
+
spatial_size, spatial_size, spatial_size, self._down2.out_channels
|
| 118 |
+
)
|
| 119 |
+
spatial_size //= 2
|
| 120 |
+
k2 = self._down2.out_channels
|
| 121 |
+
if self._voxel_size > 16:
|
| 122 |
+
k2 *= 2
|
| 123 |
+
self._down3 = Conv3DInceptionBlock(
|
| 124 |
+
self._down2.out_channels,
|
| 125 |
+
self._kernels,
|
| 126 |
+
norm=self._norm,
|
| 127 |
+
activation=self._activation,
|
| 128 |
+
residual=use_residual,
|
| 129 |
+
)
|
| 130 |
+
self._lang_proj3 = DenseBlock(
|
| 131 |
+
self._clip_lang_feat_dim, self._down3.out_channels, None, None
|
| 132 |
+
)
|
| 133 |
+
self._dropout3 = nn.Dropout(self._lingunet_dropout)
|
| 134 |
+
flat_size += self._down3.out_channels * 4
|
| 135 |
+
self._ss3 = SpatialSoftmax3D(
|
| 136 |
+
spatial_size, spatial_size, spatial_size, self._down3.out_channels
|
| 137 |
+
)
|
| 138 |
+
self._up3 = Conv3DInceptionBlockUpsampleBlock(
|
| 139 |
+
self._kernels,
|
| 140 |
+
self._kernels * 4,
|
| 141 |
+
2,
|
| 142 |
+
norm=self._norm,
|
| 143 |
+
activation=self._activation,
|
| 144 |
+
residual=use_residual,
|
| 145 |
+
)
|
| 146 |
+
self._up2 = Conv3DInceptionBlockUpsampleBlock(
|
| 147 |
+
k2,
|
| 148 |
+
self._kernels,
|
| 149 |
+
2,
|
| 150 |
+
norm=self._norm,
|
| 151 |
+
activation=self._activation,
|
| 152 |
+
residual=use_residual,
|
| 153 |
+
)
|
| 154 |
+
|
| 155 |
+
self._up1 = Conv3DInceptionBlockUpsampleBlock(
|
| 156 |
+
k1,
|
| 157 |
+
self._kernels,
|
| 158 |
+
2,
|
| 159 |
+
norm=self._norm,
|
| 160 |
+
activation=self._activation,
|
| 161 |
+
residual=use_residual,
|
| 162 |
+
)
|
| 163 |
+
|
| 164 |
+
self._global_maxp = nn.AdaptiveMaxPool3d(1)
|
| 165 |
+
self._local_maxp = nn.MaxPool3d(3, 2, padding=1)
|
| 166 |
+
self._final = Conv3DBlock(
|
| 167 |
+
self._kernels * 2,
|
| 168 |
+
self._kernels,
|
| 169 |
+
kernel_sizes=3,
|
| 170 |
+
strides=1,
|
| 171 |
+
norm=self._norm,
|
| 172 |
+
activation=self._activation,
|
| 173 |
+
)
|
| 174 |
+
self._final2 = Conv3DBlock(
|
| 175 |
+
self._kernels,
|
| 176 |
+
self._out_channels,
|
| 177 |
+
kernel_sizes=3,
|
| 178 |
+
strides=1,
|
| 179 |
+
norm=None,
|
| 180 |
+
activation=None,
|
| 181 |
+
)
|
| 182 |
+
|
| 183 |
+
self._ss_final = SpatialSoftmax3D(
|
| 184 |
+
self._voxel_size, self._voxel_size, self._voxel_size, self._kernels
|
| 185 |
+
)
|
| 186 |
+
flat_size += self._kernels * 4
|
| 187 |
+
|
| 188 |
+
if self._out_dense > 0:
|
| 189 |
+
self._dense0 = DenseBlock(
|
| 190 |
+
flat_size, self._dense_feats, None, self._activation
|
| 191 |
+
)
|
| 192 |
+
self._dense1 = DenseBlock(
|
| 193 |
+
self._dense_feats, self._dense_feats, None, self._activation
|
| 194 |
+
)
|
| 195 |
+
self._dense2 = DenseBlock(self._dense_feats, self._out_dense, None, None)
|
| 196 |
+
|
| 197 |
+
def _proj_feature(self, x, spatial_size, proj_fn):
|
| 198 |
+
x = proj_fn(x)
|
| 199 |
+
x = x.unsqueeze(2).unsqueeze(3).unsqueeze(4)
|
| 200 |
+
x = x.repeat(1, 1, spatial_size, spatial_size, spatial_size)
|
| 201 |
+
return x
|
| 202 |
+
|
| 203 |
+
def forward(
|
| 204 |
+
self,
|
| 205 |
+
ins,
|
| 206 |
+
proprio,
|
| 207 |
+
lang_goal_embs,
|
| 208 |
+
lang_token_embs,
|
| 209 |
+
bounds,
|
| 210 |
+
prev_bounds,
|
| 211 |
+
prev_layer_voxel_grid,
|
| 212 |
+
):
|
| 213 |
+
b, _, d, h, w = ins.shape
|
| 214 |
+
x = self._input_preprocess(ins)
|
| 215 |
+
|
| 216 |
+
if self._include_prev_layer:
|
| 217 |
+
for voxel_grid in prev_layer_voxel_grid:
|
| 218 |
+
y = self._input_preprocess(voxel_grid)
|
| 219 |
+
x = torch.cat([x, y], dim=1)
|
| 220 |
+
|
| 221 |
+
if self._low_dim_size > 0:
|
| 222 |
+
p = self._proprio_preprocess(proprio)
|
| 223 |
+
p = p.unsqueeze(-1).unsqueeze(-1).unsqueeze(-1).repeat(1, 1, d, h, w)
|
| 224 |
+
x = torch.cat([x, p], dim=1)
|
| 225 |
+
|
| 226 |
+
l_feat = lang_goal_embs
|
| 227 |
+
l_feat = l_feat.to(dtype=x.dtype)
|
| 228 |
+
|
| 229 |
+
d0 = self._down0(x)
|
| 230 |
+
# l0 = self._proj_feature(l_feat, d0.shape[-1], self._lang_proj0)
|
| 231 |
+
# d0 = self._dropout0(d0 * l0)
|
| 232 |
+
ss0 = self._ss0(d0)
|
| 233 |
+
maxp0 = self._global_maxp(d0).view(b, -1)
|
| 234 |
+
|
| 235 |
+
d1 = u = self._down1(self._local_maxp(d0))
|
| 236 |
+
# l1 = self._proj_feature(l_feat, d1.shape[-1], self._lang_proj1)
|
| 237 |
+
# d1 = self._dropout1(d1 * l1)
|
| 238 |
+
ss1 = self._ss1(d1)
|
| 239 |
+
maxp1 = self._global_maxp(d1).view(b, -1)
|
| 240 |
+
|
| 241 |
+
feats = [ss0, maxp0, ss1, maxp1]
|
| 242 |
+
|
| 243 |
+
if self._voxel_size > 8:
|
| 244 |
+
d2 = u = self._down2(self._local_maxp(d1))
|
| 245 |
+
l2 = self._proj_feature(l_feat, d2.shape[-1], self._lang_proj2)
|
| 246 |
+
d2 = self._dropout2(d2 * l2)
|
| 247 |
+
feats.extend([self._ss2(d2), self._global_maxp(d2).view(b, -1)])
|
| 248 |
+
if self._voxel_size > 16:
|
| 249 |
+
d3 = self._down3(self._local_maxp(d2))
|
| 250 |
+
l3 = self._proj_feature(l_feat, d3.shape[-1], self._lang_proj3)
|
| 251 |
+
d3 = self._dropout3(d3 * l3)
|
| 252 |
+
feats.extend([self._ss3(d3), self._global_maxp(d3).view(b, -1)])
|
| 253 |
+
u3 = self._up3(d3)
|
| 254 |
+
u = torch.cat([d2, u3], dim=1)
|
| 255 |
+
u2 = self._up2(u)
|
| 256 |
+
u = torch.cat([d1, u2], dim=1)
|
| 257 |
+
|
| 258 |
+
u1 = self._up1(u)
|
| 259 |
+
f1 = self._final(torch.cat([d0, u1], dim=1))
|
| 260 |
+
trans = self._final2(f1)
|
| 261 |
+
|
| 262 |
+
feats.extend([self._ss_final(f1), self._global_maxp(f1).view(b, -1)])
|
| 263 |
+
|
| 264 |
+
self.latent_dict = {
|
| 265 |
+
"d0": d0.mean(-1).mean(-1).mean(-1),
|
| 266 |
+
"d1": d1.mean(-1).mean(-1).mean(-1),
|
| 267 |
+
"u1": u1.mean(-1).mean(-1).mean(-1),
|
| 268 |
+
"trans_out": trans,
|
| 269 |
+
}
|
| 270 |
+
|
| 271 |
+
rot_and_grip_out, collision_out = None, None
|
| 272 |
+
if self._out_dense > 0:
|
| 273 |
+
dense0 = self._dense0(torch.cat(feats, 1))
|
| 274 |
+
dense1 = self._dense1(dense0)
|
| 275 |
+
rot_and_grip_collision_out = self._dense2(dense1)
|
| 276 |
+
rot_and_grip_out = rot_and_grip_collision_out[:, :-2]
|
| 277 |
+
collision_out = rot_and_grip_collision_out[:, -2:]
|
| 278 |
+
self.latent_dict.update(
|
| 279 |
+
{
|
| 280 |
+
"dense0": dense0,
|
| 281 |
+
"dense1": dense1,
|
| 282 |
+
"dense2": rot_and_grip_collision_out,
|
| 283 |
+
}
|
| 284 |
+
)
|
| 285 |
+
|
| 286 |
+
if self._voxel_size > 8:
|
| 287 |
+
self.latent_dict.update(
|
| 288 |
+
{
|
| 289 |
+
"d2": d2.mean(-1).mean(-1).mean(-1),
|
| 290 |
+
"u2": u2.mean(-1).mean(-1).mean(-1),
|
| 291 |
+
}
|
| 292 |
+
)
|
| 293 |
+
if self._voxel_size > 16:
|
| 294 |
+
self.latent_dict.update(
|
| 295 |
+
{
|
| 296 |
+
"d3": d3.mean(-1).mean(-1).mean(-1),
|
| 297 |
+
"u3": u3.mean(-1).mean(-1).mean(-1),
|
| 298 |
+
}
|
| 299 |
+
)
|
| 300 |
+
|
| 301 |
+
return trans, rot_and_grip_out, collision_out
|
external/peract_bimanual/agents/c2farm_lingunet_bc/qattention_lingunet_bc_agent.py
ADDED
|
@@ -0,0 +1,790 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import copy
|
| 2 |
+
import logging
|
| 3 |
+
import os
|
| 4 |
+
from typing import List
|
| 5 |
+
|
| 6 |
+
import numpy as np
|
| 7 |
+
import torch
|
| 8 |
+
import torch.nn as nn
|
| 9 |
+
import torch.nn.functional as F
|
| 10 |
+
from torchvision import transforms
|
| 11 |
+
from pytorch3d import transforms as torch3d_tf
|
| 12 |
+
from yarr.agents.agent import (
|
| 13 |
+
Agent,
|
| 14 |
+
ActResult,
|
| 15 |
+
ScalarSummary,
|
| 16 |
+
HistogramSummary,
|
| 17 |
+
ImageSummary,
|
| 18 |
+
Summary,
|
| 19 |
+
)
|
| 20 |
+
|
| 21 |
+
from helpers import utils
|
| 22 |
+
from helpers.utils import visualise_voxel, stack_on_channel
|
| 23 |
+
from voxel.voxel_grid import VoxelGrid
|
| 24 |
+
from voxel.augmentation import apply_se3_augmentation
|
| 25 |
+
from einops import rearrange
|
| 26 |
+
from helpers.clip.core.clip import build_model, load_clip
|
| 27 |
+
|
| 28 |
+
import transformers
|
| 29 |
+
from torch.nn.parallel import DistributedDataParallel as DDP
|
| 30 |
+
|
| 31 |
+
NAME = "QAttentionAgent"
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
class QFunction(nn.Module):
|
| 35 |
+
def __init__(
|
| 36 |
+
self,
|
| 37 |
+
unet_3d: nn.Module,
|
| 38 |
+
voxelizer: VoxelGrid,
|
| 39 |
+
bounds_offset: float,
|
| 40 |
+
rotation_resolution: float,
|
| 41 |
+
device,
|
| 42 |
+
training,
|
| 43 |
+
):
|
| 44 |
+
super(QFunction, self).__init__()
|
| 45 |
+
self._rotation_resolution = rotation_resolution
|
| 46 |
+
self._voxelizer = voxelizer
|
| 47 |
+
self._bounds_offset = bounds_offset
|
| 48 |
+
self._qnet = unet_3d.to(device)
|
| 49 |
+
|
| 50 |
+
# distributed training
|
| 51 |
+
if training:
|
| 52 |
+
self._qnet = DDP(self._qnet, device_ids=[device])
|
| 53 |
+
|
| 54 |
+
def _argmax_3d(self, tensor_orig):
|
| 55 |
+
b, c, d, h, w = tensor_orig.shape # c will be one
|
| 56 |
+
idxs = tensor_orig.view(b, c, -1).argmax(-1)
|
| 57 |
+
indices = torch.cat([((idxs // h) // d), (idxs // h) % w, idxs % w], 1)
|
| 58 |
+
return indices
|
| 59 |
+
|
| 60 |
+
def choose_highest_action(self, q_trans, q_rot_grip, q_collision):
|
| 61 |
+
coords = self._argmax_3d(q_trans)
|
| 62 |
+
rot_and_grip_indicies = None
|
| 63 |
+
ignore_collision = None
|
| 64 |
+
if q_rot_grip is not None:
|
| 65 |
+
q_rot = torch.stack(
|
| 66 |
+
torch.split(
|
| 67 |
+
q_rot_grip[:, :-2], int(360 // self._rotation_resolution), dim=1
|
| 68 |
+
),
|
| 69 |
+
dim=1,
|
| 70 |
+
)
|
| 71 |
+
rot_and_grip_indicies = torch.cat(
|
| 72 |
+
[
|
| 73 |
+
q_rot[:, 0:1].argmax(-1),
|
| 74 |
+
q_rot[:, 1:2].argmax(-1),
|
| 75 |
+
q_rot[:, 2:3].argmax(-1),
|
| 76 |
+
q_rot_grip[:, -2:].argmax(-1, keepdim=True),
|
| 77 |
+
],
|
| 78 |
+
-1,
|
| 79 |
+
)
|
| 80 |
+
ignore_collision = q_collision[:, -2:].argmax(-1, keepdim=True)
|
| 81 |
+
return coords, rot_and_grip_indicies, ignore_collision
|
| 82 |
+
|
| 83 |
+
def forward(
|
| 84 |
+
self,
|
| 85 |
+
rgb_pcd,
|
| 86 |
+
proprio,
|
| 87 |
+
pcd,
|
| 88 |
+
lang_goal_emb,
|
| 89 |
+
lang_token_embs,
|
| 90 |
+
bounds=None,
|
| 91 |
+
prev_bounds=None,
|
| 92 |
+
prev_layer_voxel_grid=None,
|
| 93 |
+
):
|
| 94 |
+
# rgb_pcd will be list of list (list of [rgb, pcd])
|
| 95 |
+
b = rgb_pcd[0][0].shape[0]
|
| 96 |
+
pcd_flat = torch.cat([p.permute(0, 2, 3, 1).reshape(b, -1, 3) for p in pcd], 1)
|
| 97 |
+
|
| 98 |
+
# flatten RGBs and Pointclouds
|
| 99 |
+
rgb = [rp[0] for rp in rgb_pcd]
|
| 100 |
+
feat_size = rgb[0].shape[1]
|
| 101 |
+
flat_imag_features = torch.cat(
|
| 102 |
+
[p.permute(0, 2, 3, 1).reshape(b, -1, feat_size) for p in rgb], 1
|
| 103 |
+
)
|
| 104 |
+
|
| 105 |
+
# construct voxel grid
|
| 106 |
+
voxel_grid = self._voxelizer.coords_to_bounding_voxel_grid(
|
| 107 |
+
pcd_flat, coord_features=flat_imag_features, coord_bounds=bounds
|
| 108 |
+
)
|
| 109 |
+
|
| 110 |
+
# swap to channels fist
|
| 111 |
+
voxel_grid = voxel_grid.permute(0, 4, 1, 2, 3).detach()
|
| 112 |
+
|
| 113 |
+
# batch bounds if necessary
|
| 114 |
+
if bounds.shape[0] != b:
|
| 115 |
+
bounds = bounds.repeat(b, 1)
|
| 116 |
+
|
| 117 |
+
# forward pass
|
| 118 |
+
q_trans, q_rot_and_grip, q_ignore_collisions = self._qnet(
|
| 119 |
+
voxel_grid,
|
| 120 |
+
proprio,
|
| 121 |
+
lang_goal_emb,
|
| 122 |
+
lang_token_embs,
|
| 123 |
+
prev_layer_voxel_grid,
|
| 124 |
+
bounds,
|
| 125 |
+
prev_bounds,
|
| 126 |
+
)
|
| 127 |
+
|
| 128 |
+
return q_trans, q_rot_and_grip, q_ignore_collisions, voxel_grid
|
| 129 |
+
|
| 130 |
+
|
| 131 |
+
class QAttentionLingUNetBCAgent(Agent):
|
| 132 |
+
def __init__(
|
| 133 |
+
self,
|
| 134 |
+
layer: int,
|
| 135 |
+
coordinate_bounds: list,
|
| 136 |
+
unet3d: nn.Module,
|
| 137 |
+
camera_names: list,
|
| 138 |
+
batch_size: int,
|
| 139 |
+
voxel_size: int,
|
| 140 |
+
bounds_offset: float,
|
| 141 |
+
voxel_feature_size: int,
|
| 142 |
+
image_crop_size: int,
|
| 143 |
+
num_rotation_classes: int,
|
| 144 |
+
rotation_resolution: float,
|
| 145 |
+
lr: float = 0.0001,
|
| 146 |
+
lr_scheduler: bool = False,
|
| 147 |
+
training_iterations: int = 100000,
|
| 148 |
+
num_warmup_steps: int = 20000,
|
| 149 |
+
trans_loss_weight: float = 1.0,
|
| 150 |
+
rot_loss_weight: float = 1.0,
|
| 151 |
+
grip_loss_weight: float = 1.0,
|
| 152 |
+
collision_loss_weight: float = 1.0,
|
| 153 |
+
include_low_dim_state: bool = False,
|
| 154 |
+
image_resolution: list = None,
|
| 155 |
+
lambda_weight_l2: float = 0.0,
|
| 156 |
+
transform_augmentation: bool = True,
|
| 157 |
+
transform_augmentation_xyz: list = [0.0, 0.0, 0.0],
|
| 158 |
+
transform_augmentation_rpy: list = [0.0, 0.0, 180.0],
|
| 159 |
+
transform_augmentation_rot_resolution: int = 5,
|
| 160 |
+
num_devices: int = 1,
|
| 161 |
+
):
|
| 162 |
+
self._layer = layer
|
| 163 |
+
self._coordinate_bounds = coordinate_bounds
|
| 164 |
+
self._unet3d = unet3d
|
| 165 |
+
self._voxel_feature_size = voxel_feature_size
|
| 166 |
+
self._bounds_offset = bounds_offset
|
| 167 |
+
self._image_crop_size = image_crop_size
|
| 168 |
+
self._lr = lr
|
| 169 |
+
self._lr_scheduler = lr_scheduler
|
| 170 |
+
self._training_iterations = training_iterations
|
| 171 |
+
self._num_warmup_steps = num_warmup_steps
|
| 172 |
+
self._trans_loss_weight = trans_loss_weight
|
| 173 |
+
self._rot_loss_weight = rot_loss_weight
|
| 174 |
+
self._grip_loss_weight = grip_loss_weight
|
| 175 |
+
self._collision_loss_weight = collision_loss_weight
|
| 176 |
+
self._include_low_dim_state = include_low_dim_state
|
| 177 |
+
self._image_resolution = image_resolution or [128, 128]
|
| 178 |
+
self._voxel_size = voxel_size
|
| 179 |
+
self._camera_names = camera_names
|
| 180 |
+
self._num_cameras = len(camera_names)
|
| 181 |
+
self._batch_size = batch_size
|
| 182 |
+
self._lambda_weight_l2 = lambda_weight_l2
|
| 183 |
+
self._transform_augmentation = transform_augmentation
|
| 184 |
+
self._transform_augmentation_xyz = torch.from_numpy(
|
| 185 |
+
np.array(transform_augmentation_xyz)
|
| 186 |
+
)
|
| 187 |
+
self._transform_augmentation_rpy = transform_augmentation_rpy
|
| 188 |
+
self._transform_augmentation_rot_resolution = (
|
| 189 |
+
transform_augmentation_rot_resolution
|
| 190 |
+
)
|
| 191 |
+
self._num_devices = num_devices
|
| 192 |
+
self._num_rotation_classes = num_rotation_classes
|
| 193 |
+
self._rotation_resolution = rotation_resolution
|
| 194 |
+
|
| 195 |
+
self._cross_entropy_loss = nn.CrossEntropyLoss(reduction="none")
|
| 196 |
+
self._name = NAME + "_layer" + str(self._layer)
|
| 197 |
+
|
| 198 |
+
def build(self, training: bool, device: torch.device = None):
|
| 199 |
+
self._training = training
|
| 200 |
+
self._device = device
|
| 201 |
+
|
| 202 |
+
if device is None:
|
| 203 |
+
device = torch.device("cpu")
|
| 204 |
+
|
| 205 |
+
self._voxelizer = VoxelGrid(
|
| 206 |
+
coord_bounds=self._coordinate_bounds,
|
| 207 |
+
voxel_size=self._voxel_size,
|
| 208 |
+
device=device,
|
| 209 |
+
batch_size=self._batch_size if training else 1,
|
| 210 |
+
feature_size=self._voxel_feature_size,
|
| 211 |
+
max_num_coords=np.prod(self._image_resolution) * self._num_cameras,
|
| 212 |
+
)
|
| 213 |
+
|
| 214 |
+
self._unet3d.build()
|
| 215 |
+
|
| 216 |
+
self._q = (
|
| 217 |
+
QFunction(
|
| 218 |
+
self._unet3d,
|
| 219 |
+
self._voxelizer,
|
| 220 |
+
self._bounds_offset,
|
| 221 |
+
self._rotation_resolution,
|
| 222 |
+
device,
|
| 223 |
+
training,
|
| 224 |
+
)
|
| 225 |
+
.to(device)
|
| 226 |
+
.train(training)
|
| 227 |
+
)
|
| 228 |
+
|
| 229 |
+
grid_for_crop = (
|
| 230 |
+
torch.arange(0, self._image_crop_size, device=device)
|
| 231 |
+
.unsqueeze(0)
|
| 232 |
+
.repeat(self._image_crop_size, 1)
|
| 233 |
+
.unsqueeze(-1)
|
| 234 |
+
)
|
| 235 |
+
self._grid_for_crop = torch.cat(
|
| 236 |
+
[grid_for_crop.transpose(1, 0), grid_for_crop], dim=2
|
| 237 |
+
).unsqueeze(0)
|
| 238 |
+
|
| 239 |
+
self._coordinate_bounds = torch.tensor(
|
| 240 |
+
self._coordinate_bounds, device=device
|
| 241 |
+
).unsqueeze(0)
|
| 242 |
+
|
| 243 |
+
if self._training:
|
| 244 |
+
# optimizer
|
| 245 |
+
self._optimizer = torch.optim.Adam(
|
| 246 |
+
self._q.parameters(),
|
| 247 |
+
lr=self._lr,
|
| 248 |
+
weight_decay=self._lambda_weight_l2,
|
| 249 |
+
)
|
| 250 |
+
|
| 251 |
+
# learning rate scheduler
|
| 252 |
+
if self._lr_scheduler:
|
| 253 |
+
self._scheduler = (
|
| 254 |
+
transformers.get_cosine_with_hard_restarts_schedule_with_warmup(
|
| 255 |
+
self._optimizer,
|
| 256 |
+
num_warmup_steps=self._num_warmup_steps,
|
| 257 |
+
num_training_steps=self._training_iterations,
|
| 258 |
+
num_cycles=self._training_iterations // 10000,
|
| 259 |
+
)
|
| 260 |
+
)
|
| 261 |
+
|
| 262 |
+
# one-hot zero tensors
|
| 263 |
+
self._action_trans_one_hot_zeros = torch.zeros(
|
| 264 |
+
(
|
| 265 |
+
self._batch_size,
|
| 266 |
+
1,
|
| 267 |
+
self._voxel_size,
|
| 268 |
+
self._voxel_size,
|
| 269 |
+
self._voxel_size,
|
| 270 |
+
),
|
| 271 |
+
dtype=int,
|
| 272 |
+
device=device,
|
| 273 |
+
)
|
| 274 |
+
self._action_rot_x_one_hot_zeros = torch.zeros(
|
| 275 |
+
(self._batch_size, self._num_rotation_classes), dtype=int, device=device
|
| 276 |
+
)
|
| 277 |
+
self._action_rot_y_one_hot_zeros = torch.zeros(
|
| 278 |
+
(self._batch_size, self._num_rotation_classes), dtype=int, device=device
|
| 279 |
+
)
|
| 280 |
+
self._action_rot_z_one_hot_zeros = torch.zeros(
|
| 281 |
+
(self._batch_size, self._num_rotation_classes), dtype=int, device=device
|
| 282 |
+
)
|
| 283 |
+
self._action_grip_one_hot_zeros = torch.zeros(
|
| 284 |
+
(self._batch_size, 2), dtype=int, device=device
|
| 285 |
+
)
|
| 286 |
+
self._action_ignore_collisions_one_hot_zeros = torch.zeros(
|
| 287 |
+
(self._batch_size, 2), dtype=int, device=device
|
| 288 |
+
)
|
| 289 |
+
|
| 290 |
+
# print total params
|
| 291 |
+
logging.info(
|
| 292 |
+
"# Q Params: %d"
|
| 293 |
+
% sum(
|
| 294 |
+
p.numel()
|
| 295 |
+
for name, p in self._q.named_parameters()
|
| 296 |
+
if p.requires_grad and "clip" not in name
|
| 297 |
+
)
|
| 298 |
+
)
|
| 299 |
+
else:
|
| 300 |
+
for param in self._q.parameters():
|
| 301 |
+
param.requires_grad = False
|
| 302 |
+
|
| 303 |
+
# load CLIP for encoding language goals during evaluation
|
| 304 |
+
model, _ = load_clip("RN50", jit=False)
|
| 305 |
+
self._clip_rn50 = build_model(model.state_dict())
|
| 306 |
+
self._clip_rn50 = self._clip_rn50.float().to(device)
|
| 307 |
+
self._clip_rn50.eval()
|
| 308 |
+
del model
|
| 309 |
+
|
| 310 |
+
self._voxelizer.to(device)
|
| 311 |
+
self._q.to(device)
|
| 312 |
+
|
| 313 |
+
def _extract_crop(self, pixel_action, observation):
|
| 314 |
+
# Pixel action will now be (B, 2)
|
| 315 |
+
# observation = stack_on_channel(observation)
|
| 316 |
+
h = observation.shape[-1]
|
| 317 |
+
top_left_corner = torch.clamp(
|
| 318 |
+
pixel_action - self._image_crop_size // 2, 0, h - self._image_crop_size
|
| 319 |
+
)
|
| 320 |
+
grid = self._grid_for_crop + top_left_corner.unsqueeze(1).unsqueeze(1)
|
| 321 |
+
grid = ((grid / float(h)) * 2.0) - 1.0 # between -1 and 1
|
| 322 |
+
# Used for cropping the images across a batch
|
| 323 |
+
# swap fro y x, to x, y
|
| 324 |
+
grid = torch.cat((grid[:, :, :, 1:2], grid[:, :, :, 0:1]), dim=-1)
|
| 325 |
+
crop = F.grid_sample(observation, grid, mode="nearest", align_corners=True)
|
| 326 |
+
return crop
|
| 327 |
+
|
| 328 |
+
def _preprocess_inputs(self, replay_sample):
|
| 329 |
+
obs, pcds = [], []
|
| 330 |
+
self._crop_summary = []
|
| 331 |
+
for n in self._camera_names:
|
| 332 |
+
if self._layer > 0:
|
| 333 |
+
pc_t = replay_sample["%s_pixel_coord" % n]
|
| 334 |
+
rgb = self._extract_crop(pc_t, replay_sample["%s_rgb" % n])
|
| 335 |
+
pcd = self._extract_crop(pc_t, replay_sample["%s_point_cloud" % n])
|
| 336 |
+
self._crop_summary.append((n, rgb))
|
| 337 |
+
else:
|
| 338 |
+
rgb = replay_sample["%s_rgb" % n]
|
| 339 |
+
pcd = replay_sample["%s_point_cloud" % n]
|
| 340 |
+
|
| 341 |
+
obs.append([rgb, pcd])
|
| 342 |
+
pcds.append(pcd)
|
| 343 |
+
return obs, pcds
|
| 344 |
+
|
| 345 |
+
def _act_preprocess_inputs(self, observation):
|
| 346 |
+
obs, pcds = [], []
|
| 347 |
+
for n in self._camera_names:
|
| 348 |
+
if self._layer > 0:
|
| 349 |
+
pc_t = observation["%s_pixel_coord" % n][0]
|
| 350 |
+
rgb = self._extract_crop(pc_t, observation["%s_rgb" % n][0])
|
| 351 |
+
pcd = self._extract_crop(pc_t, observation["%s_point_cloud" % n][0])
|
| 352 |
+
else:
|
| 353 |
+
rgb = observation["%s_rgb" % n][0]
|
| 354 |
+
pcd = observation["%s_point_cloud" % n][0]
|
| 355 |
+
|
| 356 |
+
obs.append([rgb, pcd])
|
| 357 |
+
pcds.append(pcd)
|
| 358 |
+
return obs, pcds
|
| 359 |
+
|
| 360 |
+
def _get_value_from_voxel_index(self, q, voxel_idx):
|
| 361 |
+
b, c, d, h, w = q.shape
|
| 362 |
+
q_trans_flat = q.view(b, c, d * h * w)
|
| 363 |
+
flat_indicies = (
|
| 364 |
+
voxel_idx[:, 0] * d * h + voxel_idx[:, 1] * h + voxel_idx[:, 2]
|
| 365 |
+
)[:, None].int()
|
| 366 |
+
highest_idxs = flat_indicies.unsqueeze(-1).repeat(1, c, 1)
|
| 367 |
+
chosen_voxel_values = q_trans_flat.gather(2, highest_idxs)[
|
| 368 |
+
..., 0
|
| 369 |
+
] # (B, trans + rot + grip)
|
| 370 |
+
return chosen_voxel_values
|
| 371 |
+
|
| 372 |
+
def _get_value_from_rot_and_grip(self, rot_grip_q, rot_and_grip_idx):
|
| 373 |
+
q_rot = torch.stack(
|
| 374 |
+
torch.split(
|
| 375 |
+
rot_grip_q[:, :-2], int(360 // self._rotation_resolution), dim=1
|
| 376 |
+
),
|
| 377 |
+
dim=1,
|
| 378 |
+
) # B, 3, 72
|
| 379 |
+
q_grip = rot_grip_q[:, -2:]
|
| 380 |
+
rot_and_grip_values = torch.cat(
|
| 381 |
+
[
|
| 382 |
+
q_rot[:, 0].gather(1, rot_and_grip_idx[:, 0:1]),
|
| 383 |
+
q_rot[:, 1].gather(1, rot_and_grip_idx[:, 1:2]),
|
| 384 |
+
q_rot[:, 2].gather(1, rot_and_grip_idx[:, 2:3]),
|
| 385 |
+
q_grip.gather(1, rot_and_grip_idx[:, 3:4]),
|
| 386 |
+
],
|
| 387 |
+
-1,
|
| 388 |
+
)
|
| 389 |
+
return rot_and_grip_values
|
| 390 |
+
|
| 391 |
+
def _celoss(self, pred, labels):
|
| 392 |
+
return self._cross_entropy_loss(pred, labels.argmax(-1))
|
| 393 |
+
|
| 394 |
+
def _softmax_q_trans(self, q):
|
| 395 |
+
q_shape = q.shape
|
| 396 |
+
return F.softmax(q.reshape(q_shape[0], -1), dim=1).reshape(q_shape)
|
| 397 |
+
|
| 398 |
+
def _softmax_q_rot_grip(self, q_rot_grip):
|
| 399 |
+
q_rot_x_flat = q_rot_grip[
|
| 400 |
+
:, 0 * self._num_rotation_classes : 1 * self._num_rotation_classes
|
| 401 |
+
]
|
| 402 |
+
q_rot_y_flat = q_rot_grip[
|
| 403 |
+
:, 1 * self._num_rotation_classes : 2 * self._num_rotation_classes
|
| 404 |
+
]
|
| 405 |
+
q_rot_z_flat = q_rot_grip[
|
| 406 |
+
:, 2 * self._num_rotation_classes : 3 * self._num_rotation_classes
|
| 407 |
+
]
|
| 408 |
+
q_grip_flat = q_rot_grip[:, 3 * self._num_rotation_classes :]
|
| 409 |
+
|
| 410 |
+
q_rot_x_flat_softmax = F.softmax(q_rot_x_flat, dim=1)
|
| 411 |
+
q_rot_y_flat_softmax = F.softmax(q_rot_y_flat, dim=1)
|
| 412 |
+
q_rot_z_flat_softmax = F.softmax(q_rot_z_flat, dim=1)
|
| 413 |
+
q_grip_flat_softmax = F.softmax(q_grip_flat, dim=1)
|
| 414 |
+
|
| 415 |
+
return torch.cat(
|
| 416 |
+
[
|
| 417 |
+
q_rot_x_flat_softmax,
|
| 418 |
+
q_rot_y_flat_softmax,
|
| 419 |
+
q_rot_z_flat_softmax,
|
| 420 |
+
q_grip_flat_softmax,
|
| 421 |
+
],
|
| 422 |
+
dim=1,
|
| 423 |
+
)
|
| 424 |
+
|
| 425 |
+
def _softmax_ignore_collision(self, q_collision):
|
| 426 |
+
q_collision_softmax = F.softmax(q_collision, dim=1)
|
| 427 |
+
return q_collision_softmax
|
| 428 |
+
|
| 429 |
+
def update(self, step: int, replay_sample: dict) -> dict:
|
| 430 |
+
action_trans = replay_sample["trans_action_indicies"][
|
| 431 |
+
:, self._layer * 3 : self._layer * 3 + 3
|
| 432 |
+
].int()
|
| 433 |
+
action_rot_grip = replay_sample["rot_grip_action_indicies"].int()
|
| 434 |
+
action_gripper_pose = replay_sample["gripper_pose"]
|
| 435 |
+
action_ignore_collisions = replay_sample["ignore_collisions"].int()
|
| 436 |
+
lang_goal_emb = replay_sample["lang_goal_emb"].float()
|
| 437 |
+
lang_token_embs = replay_sample["lang_token_embs"].float()
|
| 438 |
+
prev_layer_voxel_grid = replay_sample.get("prev_layer_voxel_grid", None)
|
| 439 |
+
prev_layer_bounds = replay_sample.get("prev_layer_bounds", None)
|
| 440 |
+
device = self._device
|
| 441 |
+
|
| 442 |
+
bounds = bounds_tp1 = self._coordinate_bounds
|
| 443 |
+
if self._layer > 0:
|
| 444 |
+
cp = replay_sample["attention_coordinate_layer_%d" % (self._layer - 1)]
|
| 445 |
+
bounds = torch.cat(
|
| 446 |
+
[cp - self._bounds_offset, cp + self._bounds_offset], dim=1
|
| 447 |
+
)
|
| 448 |
+
|
| 449 |
+
proprio = None
|
| 450 |
+
if self._include_low_dim_state:
|
| 451 |
+
proprio = replay_sample["low_dim_state"]
|
| 452 |
+
|
| 453 |
+
obs, pcd = self._preprocess_inputs(replay_sample)
|
| 454 |
+
|
| 455 |
+
# batch size
|
| 456 |
+
bs = pcd[0].shape[0]
|
| 457 |
+
|
| 458 |
+
# SE(3) augmentation of point clouds and actions
|
| 459 |
+
if self._transform_augmentation:
|
| 460 |
+
action_trans, action_rot_grip, pcd = apply_se3_augmentation(
|
| 461 |
+
pcd,
|
| 462 |
+
action_gripper_pose,
|
| 463 |
+
action_trans,
|
| 464 |
+
action_rot_grip,
|
| 465 |
+
bounds,
|
| 466 |
+
self._layer,
|
| 467 |
+
self._transform_augmentation_xyz,
|
| 468 |
+
self._transform_augmentation_rpy,
|
| 469 |
+
self._transform_augmentation_rot_resolution,
|
| 470 |
+
self._voxel_size,
|
| 471 |
+
self._rotation_resolution,
|
| 472 |
+
self._device,
|
| 473 |
+
)
|
| 474 |
+
|
| 475 |
+
# forward pass
|
| 476 |
+
q_trans, q_rot_grip, q_collision, voxel_grid = self._q(
|
| 477 |
+
obs,
|
| 478 |
+
proprio,
|
| 479 |
+
pcd,
|
| 480 |
+
lang_goal_emb,
|
| 481 |
+
lang_token_embs,
|
| 482 |
+
bounds,
|
| 483 |
+
prev_layer_bounds,
|
| 484 |
+
prev_layer_voxel_grid,
|
| 485 |
+
)
|
| 486 |
+
|
| 487 |
+
# argmax to choose best action
|
| 488 |
+
(
|
| 489 |
+
coords,
|
| 490 |
+
rot_and_grip_indicies,
|
| 491 |
+
ignore_collision_indicies,
|
| 492 |
+
) = self._q.choose_highest_action(q_trans, q_rot_grip, q_collision)
|
| 493 |
+
|
| 494 |
+
q_trans_loss, q_rot_loss, q_grip_loss, q_collision_loss = 0.0, 0.0, 0.0, 0.0
|
| 495 |
+
|
| 496 |
+
# translation one-hot
|
| 497 |
+
action_trans_one_hot = self._action_trans_one_hot_zeros.clone()
|
| 498 |
+
for b in range(bs):
|
| 499 |
+
gt_coord = action_trans[b, :].int()
|
| 500 |
+
action_trans_one_hot[b, :, gt_coord[0], gt_coord[1], gt_coord[2]] = 1
|
| 501 |
+
|
| 502 |
+
# translation loss
|
| 503 |
+
q_trans_flat = q_trans.view(bs, -1)
|
| 504 |
+
action_trans_one_hot_flat = action_trans_one_hot.view(bs, -1)
|
| 505 |
+
q_trans_loss = self._celoss(q_trans_flat, action_trans_one_hot_flat)
|
| 506 |
+
|
| 507 |
+
with_rot_and_grip = rot_and_grip_indicies is not None
|
| 508 |
+
if with_rot_and_grip:
|
| 509 |
+
# rotation, gripper, and collision one-hots
|
| 510 |
+
action_rot_x_one_hot = self._action_rot_x_one_hot_zeros.clone()
|
| 511 |
+
action_rot_y_one_hot = self._action_rot_y_one_hot_zeros.clone()
|
| 512 |
+
action_rot_z_one_hot = self._action_rot_z_one_hot_zeros.clone()
|
| 513 |
+
action_grip_one_hot = self._action_grip_one_hot_zeros.clone()
|
| 514 |
+
action_ignore_collisions_one_hot = (
|
| 515 |
+
self._action_ignore_collisions_one_hot_zeros.clone()
|
| 516 |
+
)
|
| 517 |
+
|
| 518 |
+
for b in range(bs):
|
| 519 |
+
gt_rot_grip = action_rot_grip[b, :].int()
|
| 520 |
+
action_rot_x_one_hot[b, gt_rot_grip[0]] = 1
|
| 521 |
+
action_rot_y_one_hot[b, gt_rot_grip[1]] = 1
|
| 522 |
+
action_rot_z_one_hot[b, gt_rot_grip[2]] = 1
|
| 523 |
+
action_grip_one_hot[b, gt_rot_grip[3]] = 1
|
| 524 |
+
|
| 525 |
+
gt_ignore_collisions = action_ignore_collisions[b, :].int()
|
| 526 |
+
action_ignore_collisions_one_hot[b, gt_ignore_collisions[0]] = 1
|
| 527 |
+
|
| 528 |
+
# flatten predictions
|
| 529 |
+
q_rot_x_flat = q_rot_grip[
|
| 530 |
+
:, 0 * self._num_rotation_classes : 1 * self._num_rotation_classes
|
| 531 |
+
]
|
| 532 |
+
q_rot_y_flat = q_rot_grip[
|
| 533 |
+
:, 1 * self._num_rotation_classes : 2 * self._num_rotation_classes
|
| 534 |
+
]
|
| 535 |
+
q_rot_z_flat = q_rot_grip[
|
| 536 |
+
:, 2 * self._num_rotation_classes : 3 * self._num_rotation_classes
|
| 537 |
+
]
|
| 538 |
+
q_grip_flat = q_rot_grip[:, 3 * self._num_rotation_classes :]
|
| 539 |
+
q_ignore_collisions_flat = q_collision
|
| 540 |
+
|
| 541 |
+
# rotation loss
|
| 542 |
+
q_rot_loss += self._celoss(q_rot_x_flat, action_rot_x_one_hot)
|
| 543 |
+
q_rot_loss += self._celoss(q_rot_y_flat, action_rot_y_one_hot)
|
| 544 |
+
q_rot_loss += self._celoss(q_rot_z_flat, action_rot_z_one_hot)
|
| 545 |
+
|
| 546 |
+
# gripper loss
|
| 547 |
+
q_grip_loss += self._celoss(q_grip_flat, action_grip_one_hot)
|
| 548 |
+
|
| 549 |
+
# collision loss
|
| 550 |
+
q_collision_loss += self._celoss(
|
| 551 |
+
q_ignore_collisions_flat, action_ignore_collisions_one_hot
|
| 552 |
+
)
|
| 553 |
+
|
| 554 |
+
combined_losses = (
|
| 555 |
+
(q_trans_loss * self._trans_loss_weight)
|
| 556 |
+
+ (q_rot_loss * self._rot_loss_weight)
|
| 557 |
+
+ (q_grip_loss * self._grip_loss_weight)
|
| 558 |
+
+ (q_collision_loss * self._collision_loss_weight)
|
| 559 |
+
)
|
| 560 |
+
total_loss = combined_losses.mean()
|
| 561 |
+
|
| 562 |
+
self._optimizer.zero_grad()
|
| 563 |
+
total_loss.backward()
|
| 564 |
+
self._optimizer.step()
|
| 565 |
+
|
| 566 |
+
self._summaries = {
|
| 567 |
+
"losses/total_loss": total_loss,
|
| 568 |
+
"losses/trans_loss": q_trans_loss.mean(),
|
| 569 |
+
"losses/rot_loss": q_rot_loss.mean() if with_rot_and_grip else 0.0,
|
| 570 |
+
"losses/grip_loss": q_grip_loss.mean() if with_rot_and_grip else 0.0,
|
| 571 |
+
"losses/collision_loss": q_collision_loss.mean()
|
| 572 |
+
if with_rot_and_grip
|
| 573 |
+
else 0.0,
|
| 574 |
+
}
|
| 575 |
+
|
| 576 |
+
if self._lr_scheduler:
|
| 577 |
+
self._scheduler.step()
|
| 578 |
+
self._summaries["learning_rate"] = self._scheduler.get_last_lr()[0]
|
| 579 |
+
|
| 580 |
+
self._vis_voxel_grid = voxel_grid[0]
|
| 581 |
+
self._vis_translation_qvalue = self._softmax_q_trans(q_trans[0])
|
| 582 |
+
self._vis_max_coordinate = coords[0]
|
| 583 |
+
self._vis_gt_coordinate = action_trans[0]
|
| 584 |
+
|
| 585 |
+
# Note: PerAct doesn't use multi-layer voxel grids like C2FARM
|
| 586 |
+
# stack prev_layer_voxel_grid(s) from previous layers into a list
|
| 587 |
+
if prev_layer_voxel_grid is None:
|
| 588 |
+
prev_layer_voxel_grid = [voxel_grid]
|
| 589 |
+
else:
|
| 590 |
+
prev_layer_voxel_grid = prev_layer_voxel_grid + [voxel_grid]
|
| 591 |
+
|
| 592 |
+
# stack prev_layer_bound(s) from previous layers into a list
|
| 593 |
+
if prev_layer_bounds is None:
|
| 594 |
+
prev_layer_bounds = [self._coordinate_bounds.repeat(bs, 1)]
|
| 595 |
+
else:
|
| 596 |
+
prev_layer_bounds = prev_layer_bounds + [bounds]
|
| 597 |
+
|
| 598 |
+
return {
|
| 599 |
+
"total_loss": total_loss,
|
| 600 |
+
"prev_layer_voxel_grid": prev_layer_voxel_grid,
|
| 601 |
+
"prev_layer_bounds": prev_layer_bounds,
|
| 602 |
+
}
|
| 603 |
+
|
| 604 |
+
def act(self, step: int, observation: dict, deterministic=False) -> ActResult:
|
| 605 |
+
deterministic = True
|
| 606 |
+
bounds = self._coordinate_bounds
|
| 607 |
+
prev_layer_voxel_grid = observation.get("prev_layer_voxel_grid", None)
|
| 608 |
+
prev_layer_bounds = observation.get("prev_layer_bounds", None)
|
| 609 |
+
lang_goal_tokens = observation.get("lang_goal_tokens", None).long()
|
| 610 |
+
|
| 611 |
+
# extract CLIP language embs
|
| 612 |
+
with torch.no_grad():
|
| 613 |
+
lang_goal_tokens = lang_goal_tokens.to(device=self._device)
|
| 614 |
+
(
|
| 615 |
+
lang_goal_emb,
|
| 616 |
+
lang_token_embs,
|
| 617 |
+
) = self._clip_rn50.encode_text_with_embeddings(lang_goal_tokens[0])
|
| 618 |
+
|
| 619 |
+
if self._layer > 0:
|
| 620 |
+
cp = observation["attention_coordinate"]
|
| 621 |
+
bounds = torch.cat(
|
| 622 |
+
[cp - self._bounds_offset, cp + self._bounds_offset], dim=1
|
| 623 |
+
)
|
| 624 |
+
|
| 625 |
+
# voxelization resolution
|
| 626 |
+
res = (bounds[:, 3:] - bounds[:, :3]) / self._voxel_size
|
| 627 |
+
max_rot_index = int(360 // self._rotation_resolution)
|
| 628 |
+
proprio = None
|
| 629 |
+
|
| 630 |
+
if self._include_low_dim_state:
|
| 631 |
+
proprio = observation["low_dim_state"]
|
| 632 |
+
|
| 633 |
+
obs, pcd = self._act_preprocess_inputs(observation)
|
| 634 |
+
|
| 635 |
+
# correct batch size and device
|
| 636 |
+
obs = [[o[0].to(self._device), o[1].to(self._device)] for o in obs]
|
| 637 |
+
proprio = proprio[0].to(self._device)
|
| 638 |
+
pcd = [p.to(self._device) for p in pcd]
|
| 639 |
+
lang_goal_emb = lang_goal_emb.to(self._device)
|
| 640 |
+
lang_token_embs = lang_token_embs.to(self._device)
|
| 641 |
+
bounds = torch.as_tensor(bounds, device=self._device)
|
| 642 |
+
if prev_layer_voxel_grid is not None:
|
| 643 |
+
prev_layer_voxel_grid = [
|
| 644 |
+
pvg.to(self._device) for pvg in prev_layer_voxel_grid
|
| 645 |
+
]
|
| 646 |
+
if prev_layer_bounds is not None:
|
| 647 |
+
prev_layer_bounds = [pb.to(self._device) for pb in prev_layer_bounds]
|
| 648 |
+
|
| 649 |
+
# inference
|
| 650 |
+
q_trans, q_rot_grip, q_ignore_collisions, vox_grid = self._q(
|
| 651 |
+
obs,
|
| 652 |
+
proprio,
|
| 653 |
+
pcd,
|
| 654 |
+
lang_goal_emb,
|
| 655 |
+
lang_token_embs,
|
| 656 |
+
bounds,
|
| 657 |
+
prev_layer_bounds,
|
| 658 |
+
prev_layer_voxel_grid,
|
| 659 |
+
)
|
| 660 |
+
|
| 661 |
+
# softmax Q predictions
|
| 662 |
+
q_trans = self._softmax_q_trans(q_trans)
|
| 663 |
+
q_rot_grip = (
|
| 664 |
+
self._softmax_q_rot_grip(q_rot_grip) if q_rot_grip is not None else None
|
| 665 |
+
)
|
| 666 |
+
q_ignore_collisions = (
|
| 667 |
+
self._softmax_ignore_collision(q_ignore_collisions)
|
| 668 |
+
if q_ignore_collisions is not None
|
| 669 |
+
else None
|
| 670 |
+
)
|
| 671 |
+
|
| 672 |
+
# argmax Q predictions
|
| 673 |
+
(
|
| 674 |
+
coords,
|
| 675 |
+
rot_and_grip_indicies,
|
| 676 |
+
ignore_collisions,
|
| 677 |
+
) = self._q.choose_highest_action(q_trans, q_rot_grip, q_ignore_collisions)
|
| 678 |
+
|
| 679 |
+
rot_grip_action = rot_and_grip_indicies if q_rot_grip is not None else None
|
| 680 |
+
ignore_collisions_action = (
|
| 681 |
+
ignore_collisions.int() if ignore_collisions is not None else None
|
| 682 |
+
)
|
| 683 |
+
|
| 684 |
+
coords = coords.int()
|
| 685 |
+
attention_coordinate = bounds[:, :3] + res * coords + res / 2
|
| 686 |
+
|
| 687 |
+
# stack prev_layer_voxel_grid(s) into a list
|
| 688 |
+
# NOTE: PerAct doesn't used multi-layer voxel grids like C2FARM
|
| 689 |
+
if prev_layer_voxel_grid is None:
|
| 690 |
+
prev_layer_voxel_grid = [vox_grid]
|
| 691 |
+
else:
|
| 692 |
+
prev_layer_voxel_grid = prev_layer_voxel_grid + [vox_grid]
|
| 693 |
+
|
| 694 |
+
if prev_layer_bounds is None:
|
| 695 |
+
prev_layer_bounds = [bounds]
|
| 696 |
+
else:
|
| 697 |
+
prev_layer_bounds = prev_layer_bounds + [bounds]
|
| 698 |
+
|
| 699 |
+
observation_elements = {
|
| 700 |
+
"attention_coordinate": attention_coordinate,
|
| 701 |
+
"prev_layer_voxel_grid": prev_layer_voxel_grid,
|
| 702 |
+
"prev_layer_bounds": prev_layer_bounds,
|
| 703 |
+
}
|
| 704 |
+
info = {
|
| 705 |
+
"voxel_grid_depth%d" % self._layer: vox_grid,
|
| 706 |
+
"q_depth%d" % self._layer: q_trans,
|
| 707 |
+
"voxel_idx_depth%d" % self._layer: coords,
|
| 708 |
+
}
|
| 709 |
+
self._act_voxel_grid = vox_grid[0]
|
| 710 |
+
self._act_max_coordinate = coords[0]
|
| 711 |
+
self._act_qvalues = q_trans[0].detach()
|
| 712 |
+
return ActResult(
|
| 713 |
+
(coords, rot_grip_action, ignore_collisions_action),
|
| 714 |
+
observation_elements=observation_elements,
|
| 715 |
+
info=info,
|
| 716 |
+
)
|
| 717 |
+
|
| 718 |
+
def update_summaries(self) -> List[Summary]:
|
| 719 |
+
summaries = [
|
| 720 |
+
ImageSummary(
|
| 721 |
+
"%s/update_qattention" % self._name,
|
| 722 |
+
transforms.ToTensor()(
|
| 723 |
+
visualise_voxel(
|
| 724 |
+
self._vis_voxel_grid.detach().cpu().numpy(),
|
| 725 |
+
self._vis_translation_qvalue.detach().cpu().numpy(),
|
| 726 |
+
self._vis_max_coordinate.detach().cpu().numpy(),
|
| 727 |
+
self._vis_gt_coordinate.detach().cpu().numpy(),
|
| 728 |
+
)
|
| 729 |
+
),
|
| 730 |
+
)
|
| 731 |
+
]
|
| 732 |
+
|
| 733 |
+
for n, v in self._summaries.items():
|
| 734 |
+
summaries.append(ScalarSummary("%s/%s" % (self._name, n), v))
|
| 735 |
+
|
| 736 |
+
for name, crop in self._crop_summary:
|
| 737 |
+
crops = (torch.cat(torch.split(crop, 3, dim=1), dim=3) + 1.0) / 2.0
|
| 738 |
+
summaries.extend([ImageSummary("%s/crops/%s" % (self._name, name), crops)])
|
| 739 |
+
|
| 740 |
+
for tag, param in self._q.named_parameters():
|
| 741 |
+
# assert not torch.isnan(param.grad.abs() <= 1.0).all()
|
| 742 |
+
summaries.append(
|
| 743 |
+
HistogramSummary("%s/gradient/%s" % (self._name, tag), param.grad)
|
| 744 |
+
)
|
| 745 |
+
summaries.append(
|
| 746 |
+
HistogramSummary("%s/weight/%s" % (self._name, tag), param.data)
|
| 747 |
+
)
|
| 748 |
+
|
| 749 |
+
return summaries
|
| 750 |
+
|
| 751 |
+
def act_summaries(self) -> List[Summary]:
|
| 752 |
+
return [
|
| 753 |
+
ImageSummary(
|
| 754 |
+
"%s/act_Qattention" % self._name,
|
| 755 |
+
transforms.ToTensor()(
|
| 756 |
+
visualise_voxel(
|
| 757 |
+
self._act_voxel_grid.cpu().numpy(),
|
| 758 |
+
self._act_qvalues.cpu().numpy(),
|
| 759 |
+
self._act_max_coordinate.cpu().numpy(),
|
| 760 |
+
)
|
| 761 |
+
),
|
| 762 |
+
)
|
| 763 |
+
]
|
| 764 |
+
|
| 765 |
+
def load_weights(self, savedir: str):
|
| 766 |
+
device = (
|
| 767 |
+
self._device
|
| 768 |
+
if not self._training
|
| 769 |
+
else torch.device("cuda:%d" % self._device)
|
| 770 |
+
)
|
| 771 |
+
state_dict = torch.load(
|
| 772 |
+
os.path.join(savedir, "%s.pt" % self._name), map_location=device
|
| 773 |
+
)
|
| 774 |
+
|
| 775 |
+
# load only keys that are in the current model
|
| 776 |
+
merged_state_dict = self._q.state_dict()
|
| 777 |
+
for k, v in state_dict.items():
|
| 778 |
+
if "_voxelizer" not in k:
|
| 779 |
+
if not self._training:
|
| 780 |
+
k = k.replace("_qnet.module", "_qnet")
|
| 781 |
+
|
| 782 |
+
if k in merged_state_dict:
|
| 783 |
+
merged_state_dict[k] = v
|
| 784 |
+
else:
|
| 785 |
+
logging.warning("key %s not found in checkpoint" % k)
|
| 786 |
+
self._q.load_state_dict(merged_state_dict)
|
| 787 |
+
print("loaded weights from %s" % savedir)
|
| 788 |
+
|
| 789 |
+
def save_weights(self, savedir: str):
|
| 790 |
+
torch.save(self._q.state_dict(), os.path.join(savedir, "%s.pt" % self._name))
|
external/peract_bimanual/agents/c2farm_lingunet_bc/qattention_stack_agent.py
ADDED
|
@@ -0,0 +1,136 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import List
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
from yarr.agents.agent import Agent, ActResult, Summary
|
| 5 |
+
|
| 6 |
+
import numpy as np
|
| 7 |
+
|
| 8 |
+
from helpers import utils
|
| 9 |
+
from agents.c2farm_lingunet_bc.qattention_lingunet_bc_agent import (
|
| 10 |
+
QAttentionLingUNetBCAgent,
|
| 11 |
+
)
|
| 12 |
+
|
| 13 |
+
from scipy.spatial.transform import Rotation
|
| 14 |
+
|
| 15 |
+
NAME = "QAttentionStackAgent"
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
class QAttentionStackAgent(Agent):
|
| 19 |
+
def __init__(
|
| 20 |
+
self,
|
| 21 |
+
qattention_agents: List[QAttentionLingUNetBCAgent],
|
| 22 |
+
rotation_resolution: float,
|
| 23 |
+
camera_names: List[str],
|
| 24 |
+
rotation_prediction_depth: int = 0,
|
| 25 |
+
):
|
| 26 |
+
super(QAttentionStackAgent, self).__init__()
|
| 27 |
+
self._qattention_agents = qattention_agents
|
| 28 |
+
self._rotation_resolution = rotation_resolution
|
| 29 |
+
self._camera_names = camera_names
|
| 30 |
+
self._rotation_prediction_depth = rotation_prediction_depth
|
| 31 |
+
|
| 32 |
+
def build(self, training: bool, device=None) -> None:
|
| 33 |
+
self._device = device
|
| 34 |
+
if self._device is None:
|
| 35 |
+
self._device = torch.device("cpu")
|
| 36 |
+
for qa in self._qattention_agents:
|
| 37 |
+
qa.build(training, device)
|
| 38 |
+
|
| 39 |
+
def update(self, step: int, replay_sample: dict) -> dict:
|
| 40 |
+
priorities = 0
|
| 41 |
+
total_losses = 0.0
|
| 42 |
+
for qa in self._qattention_agents:
|
| 43 |
+
update_dict = qa.update(step, replay_sample)
|
| 44 |
+
replay_sample.update(update_dict)
|
| 45 |
+
total_losses += update_dict["total_loss"]
|
| 46 |
+
return {
|
| 47 |
+
"total_losses": total_losses,
|
| 48 |
+
}
|
| 49 |
+
|
| 50 |
+
def act(self, step: int, observation: dict, deterministic=False) -> ActResult:
|
| 51 |
+
observation_elements = {}
|
| 52 |
+
translation_results, rot_grip_results, ignore_collisions_results = [], [], []
|
| 53 |
+
infos = {}
|
| 54 |
+
for depth, qagent in enumerate(self._qattention_agents):
|
| 55 |
+
act_results = qagent.act(step, observation, deterministic)
|
| 56 |
+
attention_coordinate = (
|
| 57 |
+
act_results.observation_elements["attention_coordinate"].cpu().numpy()
|
| 58 |
+
)
|
| 59 |
+
observation_elements[
|
| 60 |
+
"attention_coordinate_layer_%d" % depth
|
| 61 |
+
] = attention_coordinate[0]
|
| 62 |
+
|
| 63 |
+
translation_idxs, rot_grip_idxs, ignore_collisions_idxs = act_results.action
|
| 64 |
+
translation_results.append(translation_idxs)
|
| 65 |
+
if rot_grip_idxs is not None:
|
| 66 |
+
rot_grip_results.append(rot_grip_idxs)
|
| 67 |
+
if ignore_collisions_idxs is not None:
|
| 68 |
+
ignore_collisions_results.append(ignore_collisions_idxs)
|
| 69 |
+
|
| 70 |
+
observation["attention_coordinate"] = act_results.observation_elements[
|
| 71 |
+
"attention_coordinate"
|
| 72 |
+
]
|
| 73 |
+
observation["prev_layer_voxel_grid"] = act_results.observation_elements[
|
| 74 |
+
"prev_layer_voxel_grid"
|
| 75 |
+
]
|
| 76 |
+
observation["prev_layer_bounds"] = act_results.observation_elements[
|
| 77 |
+
"prev_layer_bounds"
|
| 78 |
+
]
|
| 79 |
+
|
| 80 |
+
for n in self._camera_names:
|
| 81 |
+
px, py = utils.point_to_pixel_index(
|
| 82 |
+
attention_coordinate[0],
|
| 83 |
+
observation["%s_camera_extrinsics" % n][0, 0].cpu().numpy(),
|
| 84 |
+
observation["%s_camera_intrinsics" % n][0, 0].cpu().numpy(),
|
| 85 |
+
)
|
| 86 |
+
pc_t = torch.tensor(
|
| 87 |
+
[[[py, px]]], dtype=torch.float32, device=self._device
|
| 88 |
+
)
|
| 89 |
+
observation["%s_pixel_coord" % n] = pc_t
|
| 90 |
+
observation_elements["%s_pixel_coord" % n] = [py, px]
|
| 91 |
+
|
| 92 |
+
infos.update(act_results.info)
|
| 93 |
+
|
| 94 |
+
rgai = torch.cat(rot_grip_results, 1)[0].cpu().numpy()
|
| 95 |
+
ignore_collisions = float(
|
| 96 |
+
torch.cat(ignore_collisions_results, 1)[0].cpu().numpy()
|
| 97 |
+
)
|
| 98 |
+
observation_elements["trans_action_indicies"] = (
|
| 99 |
+
torch.cat(translation_results, 1)[0].cpu().numpy()
|
| 100 |
+
)
|
| 101 |
+
observation_elements["rot_grip_action_indicies"] = rgai
|
| 102 |
+
continuous_action = np.concatenate(
|
| 103 |
+
[
|
| 104 |
+
act_results.observation_elements["attention_coordinate"]
|
| 105 |
+
.cpu()
|
| 106 |
+
.numpy()[0],
|
| 107 |
+
utils.discrete_euler_to_quaternion(
|
| 108 |
+
rgai[-4:-1], self._rotation_resolution
|
| 109 |
+
),
|
| 110 |
+
rgai[-1:],
|
| 111 |
+
[ignore_collisions],
|
| 112 |
+
]
|
| 113 |
+
)
|
| 114 |
+
return ActResult(
|
| 115 |
+
continuous_action, observation_elements=observation_elements, info=infos
|
| 116 |
+
)
|
| 117 |
+
|
| 118 |
+
def update_summaries(self) -> List[Summary]:
|
| 119 |
+
summaries = []
|
| 120 |
+
for qa in self._qattention_agents:
|
| 121 |
+
summaries.extend(qa.update_summaries())
|
| 122 |
+
return summaries
|
| 123 |
+
|
| 124 |
+
def act_summaries(self) -> List[Summary]:
|
| 125 |
+
s = []
|
| 126 |
+
for qa in self._qattention_agents:
|
| 127 |
+
s.extend(qa.act_summaries())
|
| 128 |
+
return s
|
| 129 |
+
|
| 130 |
+
def load_weights(self, savedir: str):
|
| 131 |
+
for qa in self._qattention_agents:
|
| 132 |
+
qa.load_weights(savedir)
|
| 133 |
+
|
| 134 |
+
def save_weights(self, savedir: str):
|
| 135 |
+
for qa in self._qattention_agents:
|
| 136 |
+
qa.save_weights(savedir)
|
external/peract_bimanual/agents/peract_bc/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
import agents.peract_bc.launch_utils
|
external/peract_bimanual/agents/peract_bc/launch_utils.py
ADDED
|
@@ -0,0 +1,94 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Adapted from ARM
|
| 2 |
+
# Source: https://github.com/stepjam/ARM
|
| 3 |
+
# License: https://github.com/stepjam/ARM/LICENSE
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
from helpers.preprocess_agent import PreprocessAgent
|
| 7 |
+
from agents.peract_bc.perceiver_lang_io import PerceiverVoxelLangEncoder
|
| 8 |
+
from agents.peract_bc.qattention_peract_bc_agent import QAttentionPerActBCAgent
|
| 9 |
+
from agents.peract_bc.qattention_stack_agent import QAttentionStackAgent
|
| 10 |
+
|
| 11 |
+
from omegaconf import DictConfig
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
def create_agent(cfg: DictConfig):
|
| 15 |
+
LATENT_SIZE = 64
|
| 16 |
+
depth_0bounds = cfg.rlbench.scene_bounds
|
| 17 |
+
cam_resolution = cfg.rlbench.camera_resolution
|
| 18 |
+
|
| 19 |
+
num_rotation_classes = int(360.0 // cfg.method.rotation_resolution)
|
| 20 |
+
qattention_agents = []
|
| 21 |
+
for depth, vox_size in enumerate(cfg.method.voxel_sizes):
|
| 22 |
+
last = depth == len(cfg.method.voxel_sizes) - 1
|
| 23 |
+
perceiver_encoder = PerceiverVoxelLangEncoder(
|
| 24 |
+
depth=cfg.method.transformer_depth,
|
| 25 |
+
iterations=cfg.method.transformer_iterations,
|
| 26 |
+
voxel_size=vox_size,
|
| 27 |
+
initial_dim=3 + 3 + 1 + 3,
|
| 28 |
+
low_dim_size=cfg.method.low_dim_size,
|
| 29 |
+
layer=depth,
|
| 30 |
+
num_rotation_classes=num_rotation_classes if last else 0,
|
| 31 |
+
num_grip_classes=2 if last else 0,
|
| 32 |
+
num_collision_classes=2 if last else 0,
|
| 33 |
+
input_axis=3,
|
| 34 |
+
num_latents=cfg.method.num_latents,
|
| 35 |
+
latent_dim=cfg.method.latent_dim,
|
| 36 |
+
cross_heads=cfg.method.cross_heads,
|
| 37 |
+
latent_heads=cfg.method.latent_heads,
|
| 38 |
+
cross_dim_head=cfg.method.cross_dim_head,
|
| 39 |
+
latent_dim_head=cfg.method.latent_dim_head,
|
| 40 |
+
weight_tie_layers=False,
|
| 41 |
+
activation=cfg.method.activation,
|
| 42 |
+
pos_encoding_with_lang=cfg.method.pos_encoding_with_lang,
|
| 43 |
+
input_dropout=cfg.method.input_dropout,
|
| 44 |
+
attn_dropout=cfg.method.attn_dropout,
|
| 45 |
+
decoder_dropout=cfg.method.decoder_dropout,
|
| 46 |
+
lang_fusion_type=cfg.method.lang_fusion_type,
|
| 47 |
+
voxel_patch_size=cfg.method.voxel_patch_size,
|
| 48 |
+
voxel_patch_stride=cfg.method.voxel_patch_stride,
|
| 49 |
+
no_skip_connection=cfg.method.no_skip_connection,
|
| 50 |
+
no_perceiver=cfg.method.no_perceiver,
|
| 51 |
+
no_language=cfg.method.no_language,
|
| 52 |
+
final_dim=cfg.method.final_dim,
|
| 53 |
+
)
|
| 54 |
+
|
| 55 |
+
qattention_agent = QAttentionPerActBCAgent(
|
| 56 |
+
layer=depth,
|
| 57 |
+
coordinate_bounds=depth_0bounds,
|
| 58 |
+
perceiver_encoder=perceiver_encoder,
|
| 59 |
+
camera_names=cfg.rlbench.cameras,
|
| 60 |
+
voxel_size=vox_size,
|
| 61 |
+
bounds_offset=cfg.method.bounds_offset[depth - 1] if depth > 0 else None,
|
| 62 |
+
image_crop_size=cfg.method.image_crop_size,
|
| 63 |
+
lr=cfg.method.lr,
|
| 64 |
+
training_iterations=cfg.framework.training_iterations,
|
| 65 |
+
lr_scheduler=cfg.method.lr_scheduler,
|
| 66 |
+
num_warmup_steps=cfg.method.num_warmup_steps,
|
| 67 |
+
trans_loss_weight=cfg.method.trans_loss_weight,
|
| 68 |
+
rot_loss_weight=cfg.method.rot_loss_weight,
|
| 69 |
+
grip_loss_weight=cfg.method.grip_loss_weight,
|
| 70 |
+
collision_loss_weight=cfg.method.collision_loss_weight,
|
| 71 |
+
include_low_dim_state=True,
|
| 72 |
+
image_resolution=cam_resolution,
|
| 73 |
+
batch_size=cfg.replay.batch_size,
|
| 74 |
+
voxel_feature_size=3,
|
| 75 |
+
lambda_weight_l2=cfg.method.lambda_weight_l2,
|
| 76 |
+
num_rotation_classes=num_rotation_classes,
|
| 77 |
+
rotation_resolution=cfg.method.rotation_resolution,
|
| 78 |
+
transform_augmentation=cfg.method.transform_augmentation.apply_se3,
|
| 79 |
+
transform_augmentation_xyz=cfg.method.transform_augmentation.aug_xyz,
|
| 80 |
+
transform_augmentation_rpy=cfg.method.transform_augmentation.aug_rpy,
|
| 81 |
+
transform_augmentation_rot_resolution=cfg.method.transform_augmentation.aug_rot_resolution,
|
| 82 |
+
optimizer_type=cfg.method.optimizer,
|
| 83 |
+
num_devices=cfg.ddp.num_devices,
|
| 84 |
+
checkpoint_name_prefix=cfg.framework.checkpoint_name_prefix,
|
| 85 |
+
)
|
| 86 |
+
qattention_agents.append(qattention_agent)
|
| 87 |
+
|
| 88 |
+
rotation_agent = QAttentionStackAgent(
|
| 89 |
+
qattention_agents=qattention_agents,
|
| 90 |
+
rotation_resolution=cfg.method.rotation_resolution,
|
| 91 |
+
camera_names=cfg.rlbench.cameras,
|
| 92 |
+
)
|
| 93 |
+
preprocess_agent = PreprocessAgent(pose_agent=rotation_agent)
|
| 94 |
+
return preprocess_agent
|
external/peract_bimanual/agents/peract_bc/perceiver_lang_io.py
ADDED
|
@@ -0,0 +1,426 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Perceiver IO implementation adpated for manipulation
|
| 2 |
+
# Source: https://github.com/lucidrains/perceiver-pytorch
|
| 3 |
+
# License: https://github.com/lucidrains/perceiver-pytorch/blob/main/LICENSE
|
| 4 |
+
|
| 5 |
+
import torch
|
| 6 |
+
from torch import nn
|
| 7 |
+
|
| 8 |
+
from einops import rearrange
|
| 9 |
+
from einops import repeat
|
| 10 |
+
|
| 11 |
+
from perceiver_pytorch.perceiver_pytorch import cache_fn
|
| 12 |
+
from perceiver_pytorch.perceiver_pytorch import PreNorm, FeedForward, Attention
|
| 13 |
+
|
| 14 |
+
from helpers.network_utils import (
|
| 15 |
+
DenseBlock,
|
| 16 |
+
SpatialSoftmax3D,
|
| 17 |
+
Conv3DBlock,
|
| 18 |
+
Conv3DUpsampleBlock,
|
| 19 |
+
)
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
# PerceiverIO adapted for 6-DoF manipulation
|
| 23 |
+
class PerceiverVoxelLangEncoder(nn.Module):
|
| 24 |
+
def __init__(
|
| 25 |
+
self,
|
| 26 |
+
depth, # number of self-attention layers
|
| 27 |
+
iterations, # number cross-attention iterations (PerceiverIO uses just 1)
|
| 28 |
+
voxel_size, # N voxels per side (size: N*N*N)
|
| 29 |
+
initial_dim, # 10 dimensions - dimension of the input sequence to be encoded
|
| 30 |
+
low_dim_size, # 4 dimensions - proprioception: {gripper_open, left_finger, right_finger, timestep}
|
| 31 |
+
layer=0,
|
| 32 |
+
num_rotation_classes=72, # 5 degree increments (5*72=360) for each of the 3-axis
|
| 33 |
+
num_grip_classes=2, # open or not open
|
| 34 |
+
num_collision_classes=2, # collisions allowed or not allowed
|
| 35 |
+
input_axis=3, # 3D tensors have 3 axes
|
| 36 |
+
num_latents=512, # number of latent vectors
|
| 37 |
+
im_channels=64, # intermediate channel size
|
| 38 |
+
latent_dim=512, # dimensions of latent vectors
|
| 39 |
+
cross_heads=1, # number of cross-attention heads
|
| 40 |
+
latent_heads=8, # number of latent heads
|
| 41 |
+
cross_dim_head=64,
|
| 42 |
+
latent_dim_head=64,
|
| 43 |
+
activation="relu",
|
| 44 |
+
weight_tie_layers=False,
|
| 45 |
+
pos_encoding_with_lang=True,
|
| 46 |
+
input_dropout=0.1,
|
| 47 |
+
attn_dropout=0.1,
|
| 48 |
+
decoder_dropout=0.0,
|
| 49 |
+
lang_fusion_type="seq",
|
| 50 |
+
voxel_patch_size=9,
|
| 51 |
+
voxel_patch_stride=8,
|
| 52 |
+
no_skip_connection=False,
|
| 53 |
+
no_perceiver=False,
|
| 54 |
+
no_language=False,
|
| 55 |
+
final_dim=64,
|
| 56 |
+
):
|
| 57 |
+
super().__init__()
|
| 58 |
+
self.depth = depth
|
| 59 |
+
self.layer = layer
|
| 60 |
+
self.init_dim = int(initial_dim)
|
| 61 |
+
self.iterations = iterations
|
| 62 |
+
self.input_axis = input_axis
|
| 63 |
+
self.voxel_size = voxel_size
|
| 64 |
+
self.low_dim_size = low_dim_size
|
| 65 |
+
self.im_channels = im_channels
|
| 66 |
+
self.pos_encoding_with_lang = pos_encoding_with_lang
|
| 67 |
+
self.lang_fusion_type = lang_fusion_type
|
| 68 |
+
self.voxel_patch_size = voxel_patch_size
|
| 69 |
+
self.voxel_patch_stride = voxel_patch_stride
|
| 70 |
+
self.num_rotation_classes = num_rotation_classes
|
| 71 |
+
self.num_grip_classes = num_grip_classes
|
| 72 |
+
self.num_collision_classes = num_collision_classes
|
| 73 |
+
self.final_dim = final_dim
|
| 74 |
+
self.input_dropout = input_dropout
|
| 75 |
+
self.attn_dropout = attn_dropout
|
| 76 |
+
self.decoder_dropout = decoder_dropout
|
| 77 |
+
self.no_skip_connection = no_skip_connection
|
| 78 |
+
self.no_perceiver = no_perceiver
|
| 79 |
+
self.no_language = no_language
|
| 80 |
+
|
| 81 |
+
# patchified input dimensions
|
| 82 |
+
spatial_size = voxel_size // self.voxel_patch_stride # 100/5 = 20
|
| 83 |
+
|
| 84 |
+
# 64 voxel features + 64 proprio features (+ 64 lang goal features if concattenated)
|
| 85 |
+
self.input_dim_before_seq = (
|
| 86 |
+
self.im_channels * 3
|
| 87 |
+
if self.lang_fusion_type == "concat"
|
| 88 |
+
else self.im_channels * 2
|
| 89 |
+
)
|
| 90 |
+
|
| 91 |
+
# CLIP language feature dimensions
|
| 92 |
+
lang_feat_dim, lang_emb_dim, lang_max_seq_len = 1024, 512, 77
|
| 93 |
+
|
| 94 |
+
# learnable positional encoding
|
| 95 |
+
if self.pos_encoding_with_lang:
|
| 96 |
+
self.pos_encoding = nn.Parameter(
|
| 97 |
+
torch.randn(
|
| 98 |
+
1, lang_max_seq_len + spatial_size**3, self.input_dim_before_seq
|
| 99 |
+
)
|
| 100 |
+
)
|
| 101 |
+
else:
|
| 102 |
+
# assert self.lang_fusion_type == 'concat', 'Only concat is supported for pos encoding without lang.'
|
| 103 |
+
self.pos_encoding = nn.Parameter(
|
| 104 |
+
torch.randn(
|
| 105 |
+
1,
|
| 106 |
+
spatial_size,
|
| 107 |
+
spatial_size,
|
| 108 |
+
spatial_size,
|
| 109 |
+
self.input_dim_before_seq,
|
| 110 |
+
)
|
| 111 |
+
)
|
| 112 |
+
|
| 113 |
+
# voxel input preprocessing 1x1 conv encoder
|
| 114 |
+
self.input_preprocess = Conv3DBlock(
|
| 115 |
+
self.init_dim,
|
| 116 |
+
self.im_channels,
|
| 117 |
+
kernel_sizes=1,
|
| 118 |
+
strides=1,
|
| 119 |
+
norm=None,
|
| 120 |
+
activation=activation,
|
| 121 |
+
)
|
| 122 |
+
|
| 123 |
+
# patchify conv
|
| 124 |
+
self.patchify = Conv3DBlock(
|
| 125 |
+
self.input_preprocess.out_channels,
|
| 126 |
+
self.im_channels,
|
| 127 |
+
kernel_sizes=self.voxel_patch_size,
|
| 128 |
+
strides=self.voxel_patch_stride,
|
| 129 |
+
norm=None,
|
| 130 |
+
activation=activation,
|
| 131 |
+
)
|
| 132 |
+
|
| 133 |
+
# language preprocess
|
| 134 |
+
if self.lang_fusion_type == "concat":
|
| 135 |
+
self.lang_preprocess = nn.Linear(lang_feat_dim, self.im_channels)
|
| 136 |
+
elif self.lang_fusion_type == "seq":
|
| 137 |
+
self.lang_preprocess = nn.Linear(lang_emb_dim, self.im_channels * 2)
|
| 138 |
+
|
| 139 |
+
# proprioception
|
| 140 |
+
if self.low_dim_size > 0:
|
| 141 |
+
self.proprio_preprocess = DenseBlock(
|
| 142 |
+
self.low_dim_size,
|
| 143 |
+
self.im_channels,
|
| 144 |
+
norm=None,
|
| 145 |
+
activation=activation,
|
| 146 |
+
)
|
| 147 |
+
|
| 148 |
+
# pooling functions
|
| 149 |
+
self.local_maxp = nn.MaxPool3d(3, 2, padding=1)
|
| 150 |
+
self.global_maxp = nn.AdaptiveMaxPool3d(1)
|
| 151 |
+
|
| 152 |
+
# 1st 3D softmax
|
| 153 |
+
self.ss0 = SpatialSoftmax3D(
|
| 154 |
+
self.voxel_size, self.voxel_size, self.voxel_size, self.im_channels
|
| 155 |
+
)
|
| 156 |
+
flat_size = self.im_channels * 4
|
| 157 |
+
|
| 158 |
+
# latent vectors (that are randomly initialized)
|
| 159 |
+
self.latents = nn.Parameter(torch.randn(num_latents, latent_dim))
|
| 160 |
+
|
| 161 |
+
# encoder cross attention
|
| 162 |
+
self.cross_attend_blocks = nn.ModuleList(
|
| 163 |
+
[
|
| 164 |
+
PreNorm(
|
| 165 |
+
latent_dim,
|
| 166 |
+
Attention(
|
| 167 |
+
latent_dim,
|
| 168 |
+
self.input_dim_before_seq,
|
| 169 |
+
heads=cross_heads,
|
| 170 |
+
dim_head=cross_dim_head,
|
| 171 |
+
dropout=input_dropout,
|
| 172 |
+
),
|
| 173 |
+
context_dim=self.input_dim_before_seq,
|
| 174 |
+
),
|
| 175 |
+
PreNorm(latent_dim, FeedForward(latent_dim)),
|
| 176 |
+
]
|
| 177 |
+
)
|
| 178 |
+
|
| 179 |
+
get_latent_attn = lambda: PreNorm(
|
| 180 |
+
latent_dim,
|
| 181 |
+
Attention(
|
| 182 |
+
latent_dim,
|
| 183 |
+
heads=latent_heads,
|
| 184 |
+
dim_head=latent_dim_head,
|
| 185 |
+
dropout=attn_dropout,
|
| 186 |
+
),
|
| 187 |
+
)
|
| 188 |
+
get_latent_ff = lambda: PreNorm(latent_dim, FeedForward(latent_dim))
|
| 189 |
+
get_latent_attn, get_latent_ff = map(cache_fn, (get_latent_attn, get_latent_ff))
|
| 190 |
+
|
| 191 |
+
# self attention layers
|
| 192 |
+
self.layers = nn.ModuleList([])
|
| 193 |
+
cache_args = {"_cache": weight_tie_layers}
|
| 194 |
+
|
| 195 |
+
for i in range(depth):
|
| 196 |
+
self.layers.append(
|
| 197 |
+
nn.ModuleList(
|
| 198 |
+
[get_latent_attn(**cache_args), get_latent_ff(**cache_args)]
|
| 199 |
+
)
|
| 200 |
+
)
|
| 201 |
+
|
| 202 |
+
# decoder cross attention
|
| 203 |
+
self.decoder_cross_attn = PreNorm(
|
| 204 |
+
self.input_dim_before_seq,
|
| 205 |
+
Attention(
|
| 206 |
+
self.input_dim_before_seq,
|
| 207 |
+
latent_dim,
|
| 208 |
+
heads=cross_heads,
|
| 209 |
+
dim_head=cross_dim_head,
|
| 210 |
+
dropout=decoder_dropout,
|
| 211 |
+
),
|
| 212 |
+
context_dim=latent_dim,
|
| 213 |
+
)
|
| 214 |
+
|
| 215 |
+
# upsample conv
|
| 216 |
+
self.up0 = Conv3DUpsampleBlock(
|
| 217 |
+
self.input_dim_before_seq,
|
| 218 |
+
self.final_dim,
|
| 219 |
+
kernel_sizes=self.voxel_patch_size,
|
| 220 |
+
strides=self.voxel_patch_stride,
|
| 221 |
+
norm=None,
|
| 222 |
+
activation=activation,
|
| 223 |
+
)
|
| 224 |
+
|
| 225 |
+
# 2nd 3D softmax
|
| 226 |
+
self.ss1 = SpatialSoftmax3D(
|
| 227 |
+
spatial_size, spatial_size, spatial_size, self.input_dim_before_seq
|
| 228 |
+
)
|
| 229 |
+
|
| 230 |
+
flat_size += self.input_dim_before_seq * 4
|
| 231 |
+
|
| 232 |
+
# final 3D softmax
|
| 233 |
+
self.final = Conv3DBlock(
|
| 234 |
+
self.im_channels
|
| 235 |
+
if (self.no_perceiver or self.no_skip_connection)
|
| 236 |
+
else self.im_channels * 2,
|
| 237 |
+
self.im_channels,
|
| 238 |
+
kernel_sizes=3,
|
| 239 |
+
strides=1,
|
| 240 |
+
norm=None,
|
| 241 |
+
activation=activation,
|
| 242 |
+
)
|
| 243 |
+
|
| 244 |
+
self.trans_decoder = Conv3DBlock(
|
| 245 |
+
self.final_dim,
|
| 246 |
+
1,
|
| 247 |
+
kernel_sizes=3,
|
| 248 |
+
strides=1,
|
| 249 |
+
norm=None,
|
| 250 |
+
activation=None,
|
| 251 |
+
)
|
| 252 |
+
|
| 253 |
+
# rotation, gripper, and collision MLP layers
|
| 254 |
+
if self.num_rotation_classes > 0:
|
| 255 |
+
self.ss_final = SpatialSoftmax3D(
|
| 256 |
+
self.voxel_size, self.voxel_size, self.voxel_size, self.im_channels
|
| 257 |
+
)
|
| 258 |
+
|
| 259 |
+
flat_size += self.im_channels * 4
|
| 260 |
+
|
| 261 |
+
self.dense0 = DenseBlock(flat_size, 256, None, activation)
|
| 262 |
+
self.dense1 = DenseBlock(256, self.final_dim, None, activation)
|
| 263 |
+
|
| 264 |
+
self.rot_grip_collision_ff = DenseBlock(
|
| 265 |
+
self.final_dim,
|
| 266 |
+
self.num_rotation_classes * 3
|
| 267 |
+
+ self.num_grip_classes
|
| 268 |
+
+ self.num_collision_classes,
|
| 269 |
+
None,
|
| 270 |
+
None,
|
| 271 |
+
)
|
| 272 |
+
|
| 273 |
+
def encode_text(self, x):
|
| 274 |
+
with torch.no_grad():
|
| 275 |
+
text_feat, text_emb = self._clip_rn50.encode_text_with_embeddings(x)
|
| 276 |
+
|
| 277 |
+
text_feat = text_feat.detach()
|
| 278 |
+
text_emb = text_emb.detach()
|
| 279 |
+
text_mask = torch.where(x == 0, x, 1) # [1, max_token_len]
|
| 280 |
+
return text_feat, text_emb
|
| 281 |
+
|
| 282 |
+
def forward(
|
| 283 |
+
self,
|
| 284 |
+
ins,
|
| 285 |
+
proprio,
|
| 286 |
+
lang_goal_emb,
|
| 287 |
+
lang_token_embs,
|
| 288 |
+
prev_layer_voxel_grid,
|
| 289 |
+
bounds,
|
| 290 |
+
prev_layer_bounds,
|
| 291 |
+
mask=None,
|
| 292 |
+
):
|
| 293 |
+
# preprocess input
|
| 294 |
+
d0 = self.input_preprocess(ins) # [B,10,100,100,100] -> [B,64,100,100,100]
|
| 295 |
+
|
| 296 |
+
# aggregated features from 1st softmax and maxpool for MLP decoders
|
| 297 |
+
feats = [self.ss0(d0.contiguous()), self.global_maxp(d0).view(ins.shape[0], -1)]
|
| 298 |
+
|
| 299 |
+
# patchify input (5x5x5 patches)
|
| 300 |
+
ins = self.patchify(d0) # [B,64,100,100,100] -> [B,64,20,20,20]
|
| 301 |
+
|
| 302 |
+
b, c, d, h, w, device = *ins.shape, ins.device
|
| 303 |
+
axis = [d, h, w]
|
| 304 |
+
assert (
|
| 305 |
+
len(axis) == self.input_axis
|
| 306 |
+
), "input must have the same number of axis as input_axis"
|
| 307 |
+
|
| 308 |
+
# concat proprio
|
| 309 |
+
if self.low_dim_size > 0:
|
| 310 |
+
p = self.proprio_preprocess(proprio) # [B,4] -> [B,64]
|
| 311 |
+
p = p.unsqueeze(-1).unsqueeze(-1).unsqueeze(-1).repeat(1, 1, d, h, w)
|
| 312 |
+
ins = torch.cat([ins, p], dim=1) # [B,128,20,20,20]
|
| 313 |
+
|
| 314 |
+
# language ablation
|
| 315 |
+
if self.no_language:
|
| 316 |
+
lang_goal_emb = torch.zeros_like(lang_goal_emb)
|
| 317 |
+
lang_token_embs = torch.zeros_like(lang_token_embs)
|
| 318 |
+
|
| 319 |
+
# option 1: tile and concat lang goal to input
|
| 320 |
+
if self.lang_fusion_type == "concat":
|
| 321 |
+
lang_emb = lang_goal_emb
|
| 322 |
+
lang_emb = lang_emb.to(dtype=ins.dtype)
|
| 323 |
+
l = self.lang_preprocess(lang_emb)
|
| 324 |
+
l = l.unsqueeze(-1).unsqueeze(-1).unsqueeze(-1).repeat(1, 1, d, h, w)
|
| 325 |
+
ins = torch.cat([ins, l], dim=1)
|
| 326 |
+
|
| 327 |
+
# channel last
|
| 328 |
+
ins = rearrange(ins, "b d ... -> b ... d") # [B,20,20,20,128]
|
| 329 |
+
|
| 330 |
+
# add pos encoding to grid
|
| 331 |
+
if not self.pos_encoding_with_lang:
|
| 332 |
+
ins = ins + self.pos_encoding
|
| 333 |
+
|
| 334 |
+
######################## NOTE #############################
|
| 335 |
+
# NOTE: If you add positional encodings ^here the lang embs
|
| 336 |
+
# won't have positional encodings. I accidently forgot
|
| 337 |
+
# to turn this off for all the experiments in the paper.
|
| 338 |
+
# So I guess those models were using language embs
|
| 339 |
+
# as a bag of words :( But it doesn't matter much for
|
| 340 |
+
# RLBench tasks since we don't test for novel instructions
|
| 341 |
+
# at test time anyway. The recommend way is to add
|
| 342 |
+
# positional encodings to the final input sequence
|
| 343 |
+
# fed into the Perceiver Transformer, as done below
|
| 344 |
+
# (and also in the Colab tutorial).
|
| 345 |
+
###########################################################
|
| 346 |
+
|
| 347 |
+
# concat to channels of and flatten axis
|
| 348 |
+
queries_orig_shape = ins.shape
|
| 349 |
+
|
| 350 |
+
# rearrange input to be channel last
|
| 351 |
+
ins = rearrange(ins, "b ... d -> b (...) d") # [B,8000,128]
|
| 352 |
+
ins_wo_prev_layers = ins
|
| 353 |
+
|
| 354 |
+
# option 2: add lang token embs as a sequence
|
| 355 |
+
if self.lang_fusion_type == "seq":
|
| 356 |
+
l = self.lang_preprocess(lang_token_embs) # [B,77,1024] -> [B,77,128]
|
| 357 |
+
ins = torch.cat((l, ins), dim=1) # [B,8077,128]
|
| 358 |
+
|
| 359 |
+
# add pos encoding to language + flattened grid (the recommended way)
|
| 360 |
+
if self.pos_encoding_with_lang:
|
| 361 |
+
ins = ins + self.pos_encoding
|
| 362 |
+
|
| 363 |
+
# batchify latents
|
| 364 |
+
x = repeat(self.latents, "n d -> b n d", b=b)
|
| 365 |
+
|
| 366 |
+
cross_attn, cross_ff = self.cross_attend_blocks
|
| 367 |
+
|
| 368 |
+
for it in range(self.iterations):
|
| 369 |
+
# encoder cross attention
|
| 370 |
+
x = cross_attn(x, context=ins, mask=mask) + x
|
| 371 |
+
x = cross_ff(x) + x
|
| 372 |
+
|
| 373 |
+
# self-attention layers
|
| 374 |
+
for self_attn, self_ff in self.layers:
|
| 375 |
+
x = self_attn(x) + x
|
| 376 |
+
x = self_ff(x) + x
|
| 377 |
+
|
| 378 |
+
# decoder cross attention
|
| 379 |
+
latents = self.decoder_cross_attn(ins, context=x)
|
| 380 |
+
|
| 381 |
+
# crop out the language part of the output sequence
|
| 382 |
+
if self.lang_fusion_type == "seq":
|
| 383 |
+
latents = latents[:, l.shape[1] :]
|
| 384 |
+
|
| 385 |
+
# reshape back to voxel grid
|
| 386 |
+
latents = latents.view(
|
| 387 |
+
b, *queries_orig_shape[1:-1], latents.shape[-1]
|
| 388 |
+
) # [B,20,20,20,64]
|
| 389 |
+
latents = rearrange(latents, "b ... d -> b d ...") # [B,64,20,20,20]
|
| 390 |
+
|
| 391 |
+
# aggregated features from 2nd softmax and maxpool for MLP decoders
|
| 392 |
+
feats.extend(
|
| 393 |
+
[self.ss1(latents.contiguous()), self.global_maxp(latents).view(b, -1)]
|
| 394 |
+
)
|
| 395 |
+
|
| 396 |
+
# upsample
|
| 397 |
+
u0 = self.up0(latents)
|
| 398 |
+
|
| 399 |
+
# ablations
|
| 400 |
+
if self.no_skip_connection:
|
| 401 |
+
u = self.final(u0)
|
| 402 |
+
elif self.no_perceiver:
|
| 403 |
+
u = self.final(d0)
|
| 404 |
+
else:
|
| 405 |
+
u = self.final(torch.cat([d0, u0], dim=1))
|
| 406 |
+
|
| 407 |
+
# translation decoder
|
| 408 |
+
trans = self.trans_decoder(u)
|
| 409 |
+
|
| 410 |
+
# rotation, gripper, and collision MLPs
|
| 411 |
+
rot_and_grip_out = None
|
| 412 |
+
if self.num_rotation_classes > 0:
|
| 413 |
+
feats.extend(
|
| 414 |
+
[self.ss_final(u.contiguous()), self.global_maxp(u).view(b, -1)]
|
| 415 |
+
)
|
| 416 |
+
|
| 417 |
+
dense0 = self.dense0(torch.cat(feats, dim=1))
|
| 418 |
+
dense1 = self.dense1(dense0) # [B,72*3+2+2]
|
| 419 |
+
|
| 420 |
+
rot_and_grip_collision_out = self.rot_grip_collision_ff(dense1)
|
| 421 |
+
rot_and_grip_out = rot_and_grip_collision_out[
|
| 422 |
+
:, : -self.num_collision_classes
|
| 423 |
+
]
|
| 424 |
+
collision_out = rot_and_grip_collision_out[:, -self.num_collision_classes :]
|
| 425 |
+
|
| 426 |
+
return trans, rot_and_grip_out, collision_out
|
external/peract_bimanual/agents/peract_bc/qattention_peract_bc_agent.py
ADDED
|
@@ -0,0 +1,808 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import copy
|
| 2 |
+
import logging
|
| 3 |
+
import os
|
| 4 |
+
from typing import List
|
| 5 |
+
|
| 6 |
+
import numpy as np
|
| 7 |
+
import torch
|
| 8 |
+
import torch.nn as nn
|
| 9 |
+
import torch.nn.functional as F
|
| 10 |
+
from torchvision import transforms
|
| 11 |
+
from pytorch3d import transforms as torch3d_tf
|
| 12 |
+
from yarr.agents.agent import (
|
| 13 |
+
Agent,
|
| 14 |
+
ActResult,
|
| 15 |
+
ScalarSummary,
|
| 16 |
+
HistogramSummary,
|
| 17 |
+
ImageSummary,
|
| 18 |
+
Summary,
|
| 19 |
+
)
|
| 20 |
+
|
| 21 |
+
from helpers import utils
|
| 22 |
+
from helpers.utils import visualise_voxel, stack_on_channel
|
| 23 |
+
from voxel.voxel_grid import VoxelGrid
|
| 24 |
+
from voxel.augmentation import apply_se3_augmentation
|
| 25 |
+
from einops import rearrange
|
| 26 |
+
from helpers.clip.core.clip import build_model, load_clip
|
| 27 |
+
|
| 28 |
+
import transformers
|
| 29 |
+
from helpers.optim.lamb import Lamb
|
| 30 |
+
|
| 31 |
+
from torch.nn.parallel import DistributedDataParallel as DDP
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
class QFunction(nn.Module):
|
| 35 |
+
def __init__(
|
| 36 |
+
self,
|
| 37 |
+
perceiver_encoder: nn.Module,
|
| 38 |
+
voxelizer: VoxelGrid,
|
| 39 |
+
bounds_offset: float,
|
| 40 |
+
rotation_resolution: float,
|
| 41 |
+
device,
|
| 42 |
+
training,
|
| 43 |
+
):
|
| 44 |
+
super(QFunction, self).__init__()
|
| 45 |
+
self._rotation_resolution = rotation_resolution
|
| 46 |
+
self._voxelizer = voxelizer
|
| 47 |
+
self._bounds_offset = bounds_offset
|
| 48 |
+
self._qnet = perceiver_encoder.to(device)
|
| 49 |
+
|
| 50 |
+
# distributed training
|
| 51 |
+
if training:
|
| 52 |
+
self._qnet = DDP(self._qnet, device_ids=[device])
|
| 53 |
+
|
| 54 |
+
def _argmax_3d(self, tensor_orig):
|
| 55 |
+
b, c, d, h, w = tensor_orig.shape # c will be one
|
| 56 |
+
idxs = tensor_orig.view(b, c, -1).argmax(-1)
|
| 57 |
+
indices = torch.cat([((idxs // h) // d), (idxs // h) % w, idxs % w], 1)
|
| 58 |
+
return indices
|
| 59 |
+
|
| 60 |
+
def choose_highest_action(self, q_trans, q_rot_grip, q_collision):
|
| 61 |
+
coords = self._argmax_3d(q_trans)
|
| 62 |
+
rot_and_grip_indicies = None
|
| 63 |
+
ignore_collision = None
|
| 64 |
+
if q_rot_grip is not None:
|
| 65 |
+
q_rot = torch.stack(
|
| 66 |
+
torch.split(
|
| 67 |
+
q_rot_grip[:, :-2], int(360 // self._rotation_resolution), dim=1
|
| 68 |
+
),
|
| 69 |
+
dim=1,
|
| 70 |
+
)
|
| 71 |
+
rot_and_grip_indicies = torch.cat(
|
| 72 |
+
[
|
| 73 |
+
q_rot[:, 0:1].argmax(-1),
|
| 74 |
+
q_rot[:, 1:2].argmax(-1),
|
| 75 |
+
q_rot[:, 2:3].argmax(-1),
|
| 76 |
+
q_rot_grip[:, -2:].argmax(-1, keepdim=True),
|
| 77 |
+
],
|
| 78 |
+
-1,
|
| 79 |
+
)
|
| 80 |
+
ignore_collision = q_collision[:, -2:].argmax(-1, keepdim=True)
|
| 81 |
+
return coords, rot_and_grip_indicies, ignore_collision
|
| 82 |
+
|
| 83 |
+
def forward(
|
| 84 |
+
self,
|
| 85 |
+
rgb_pcd,
|
| 86 |
+
proprio,
|
| 87 |
+
pcd,
|
| 88 |
+
lang_goal_emb,
|
| 89 |
+
lang_token_embs,
|
| 90 |
+
bounds=None,
|
| 91 |
+
prev_bounds=None,
|
| 92 |
+
prev_layer_voxel_grid=None,
|
| 93 |
+
):
|
| 94 |
+
# rgb_pcd will be list of list (list of [rgb, pcd])
|
| 95 |
+
b = rgb_pcd[0][0].shape[0]
|
| 96 |
+
pcd_flat = torch.cat([p.permute(0, 2, 3, 1).reshape(b, -1, 3) for p in pcd], 1)
|
| 97 |
+
|
| 98 |
+
# flatten RGBs and Pointclouds
|
| 99 |
+
rgb = [rp[0] for rp in rgb_pcd]
|
| 100 |
+
feat_size = rgb[0].shape[1]
|
| 101 |
+
flat_imag_features = torch.cat(
|
| 102 |
+
[p.permute(0, 2, 3, 1).reshape(b, -1, feat_size) for p in rgb], 1
|
| 103 |
+
)
|
| 104 |
+
|
| 105 |
+
# construct voxel grid
|
| 106 |
+
voxel_grid = self._voxelizer.coords_to_bounding_voxel_grid(
|
| 107 |
+
pcd_flat, coord_features=flat_imag_features, coord_bounds=bounds
|
| 108 |
+
)
|
| 109 |
+
|
| 110 |
+
# swap to channels fist
|
| 111 |
+
voxel_grid = voxel_grid.permute(0, 4, 1, 2, 3).detach()
|
| 112 |
+
|
| 113 |
+
# batch bounds if necessary
|
| 114 |
+
if bounds.shape[0] != b:
|
| 115 |
+
bounds = bounds.repeat(b, 1)
|
| 116 |
+
|
| 117 |
+
# forward pass
|
| 118 |
+
q_trans, q_rot_and_grip, q_ignore_collisions = self._qnet(
|
| 119 |
+
voxel_grid,
|
| 120 |
+
proprio,
|
| 121 |
+
lang_goal_emb,
|
| 122 |
+
lang_token_embs,
|
| 123 |
+
prev_layer_voxel_grid,
|
| 124 |
+
bounds,
|
| 125 |
+
prev_bounds,
|
| 126 |
+
)
|
| 127 |
+
|
| 128 |
+
return q_trans, q_rot_and_grip, q_ignore_collisions, voxel_grid
|
| 129 |
+
|
| 130 |
+
|
| 131 |
+
class QAttentionPerActBCAgent(Agent):
|
| 132 |
+
def __init__(
|
| 133 |
+
self,
|
| 134 |
+
layer: int,
|
| 135 |
+
coordinate_bounds: list,
|
| 136 |
+
perceiver_encoder: nn.Module,
|
| 137 |
+
camera_names: list,
|
| 138 |
+
batch_size: int,
|
| 139 |
+
voxel_size: int,
|
| 140 |
+
bounds_offset: float,
|
| 141 |
+
voxel_feature_size: int,
|
| 142 |
+
image_crop_size: int,
|
| 143 |
+
num_rotation_classes: int,
|
| 144 |
+
rotation_resolution: float,
|
| 145 |
+
lr: float = 0.0001,
|
| 146 |
+
lr_scheduler: bool = False,
|
| 147 |
+
training_iterations: int = 100000,
|
| 148 |
+
num_warmup_steps: int = 20000,
|
| 149 |
+
trans_loss_weight: float = 1.0,
|
| 150 |
+
rot_loss_weight: float = 1.0,
|
| 151 |
+
grip_loss_weight: float = 1.0,
|
| 152 |
+
collision_loss_weight: float = 1.0,
|
| 153 |
+
include_low_dim_state: bool = False,
|
| 154 |
+
image_resolution: list = None,
|
| 155 |
+
lambda_weight_l2: float = 0.0,
|
| 156 |
+
transform_augmentation: bool = True,
|
| 157 |
+
transform_augmentation_xyz: list = [0.0, 0.0, 0.0],
|
| 158 |
+
transform_augmentation_rpy: list = [0.0, 0.0, 180.0],
|
| 159 |
+
transform_augmentation_rot_resolution: int = 5,
|
| 160 |
+
optimizer_type: str = "adam",
|
| 161 |
+
num_devices: int = 1,
|
| 162 |
+
checkpoint_name_prefix=None,
|
| 163 |
+
):
|
| 164 |
+
self._layer = layer
|
| 165 |
+
self._coordinate_bounds = coordinate_bounds
|
| 166 |
+
self._perceiver_encoder = perceiver_encoder
|
| 167 |
+
self._voxel_feature_size = voxel_feature_size
|
| 168 |
+
self._bounds_offset = bounds_offset
|
| 169 |
+
self._image_crop_size = image_crop_size
|
| 170 |
+
self._lr = lr
|
| 171 |
+
self._lr_scheduler = lr_scheduler
|
| 172 |
+
self._training_iterations = training_iterations
|
| 173 |
+
self._num_warmup_steps = num_warmup_steps
|
| 174 |
+
self._trans_loss_weight = trans_loss_weight
|
| 175 |
+
self._rot_loss_weight = rot_loss_weight
|
| 176 |
+
self._grip_loss_weight = grip_loss_weight
|
| 177 |
+
self._collision_loss_weight = collision_loss_weight
|
| 178 |
+
self._include_low_dim_state = include_low_dim_state
|
| 179 |
+
self._image_resolution = image_resolution or [128, 128]
|
| 180 |
+
self._voxel_size = voxel_size
|
| 181 |
+
self._camera_names = camera_names
|
| 182 |
+
self._num_cameras = len(camera_names)
|
| 183 |
+
self._batch_size = batch_size
|
| 184 |
+
self._lambda_weight_l2 = lambda_weight_l2
|
| 185 |
+
self._transform_augmentation = transform_augmentation
|
| 186 |
+
self._transform_augmentation_xyz = torch.from_numpy(
|
| 187 |
+
np.array(transform_augmentation_xyz)
|
| 188 |
+
)
|
| 189 |
+
self._transform_augmentation_rpy = transform_augmentation_rpy
|
| 190 |
+
self._transform_augmentation_rot_resolution = (
|
| 191 |
+
transform_augmentation_rot_resolution
|
| 192 |
+
)
|
| 193 |
+
self._optimizer_type = optimizer_type
|
| 194 |
+
self._num_devices = num_devices
|
| 195 |
+
self._num_rotation_classes = num_rotation_classes
|
| 196 |
+
self._rotation_resolution = rotation_resolution
|
| 197 |
+
|
| 198 |
+
self._cross_entropy_loss = nn.CrossEntropyLoss(reduction="none")
|
| 199 |
+
checkpoint_name_prefix = checkpoint_name_prefix or "QAttentionAgent"
|
| 200 |
+
self._name = f"{checkpoint_name_prefix}_layer_{self._layer}"
|
| 201 |
+
|
| 202 |
+
def build(self, training: bool, device: torch.device = None):
|
| 203 |
+
self._training = training
|
| 204 |
+
|
| 205 |
+
if device is None:
|
| 206 |
+
device = torch.device("cpu")
|
| 207 |
+
|
| 208 |
+
self._device = device
|
| 209 |
+
|
| 210 |
+
self._voxelizer = VoxelGrid(
|
| 211 |
+
coord_bounds=self._coordinate_bounds,
|
| 212 |
+
voxel_size=self._voxel_size,
|
| 213 |
+
device=device,
|
| 214 |
+
batch_size=self._batch_size if training else 1,
|
| 215 |
+
feature_size=self._voxel_feature_size,
|
| 216 |
+
max_num_coords=np.prod(self._image_resolution) * self._num_cameras,
|
| 217 |
+
)
|
| 218 |
+
|
| 219 |
+
self._q = (
|
| 220 |
+
QFunction(
|
| 221 |
+
self._perceiver_encoder,
|
| 222 |
+
self._voxelizer,
|
| 223 |
+
self._bounds_offset,
|
| 224 |
+
self._rotation_resolution,
|
| 225 |
+
device,
|
| 226 |
+
training,
|
| 227 |
+
)
|
| 228 |
+
.to(device)
|
| 229 |
+
.train(training)
|
| 230 |
+
)
|
| 231 |
+
|
| 232 |
+
grid_for_crop = (
|
| 233 |
+
torch.arange(0, self._image_crop_size, device=device)
|
| 234 |
+
.unsqueeze(0)
|
| 235 |
+
.repeat(self._image_crop_size, 1)
|
| 236 |
+
.unsqueeze(-1)
|
| 237 |
+
)
|
| 238 |
+
self._grid_for_crop = torch.cat(
|
| 239 |
+
[grid_for_crop.transpose(1, 0), grid_for_crop], dim=2
|
| 240 |
+
).unsqueeze(0)
|
| 241 |
+
|
| 242 |
+
self._coordinate_bounds = torch.tensor(
|
| 243 |
+
self._coordinate_bounds, device=device
|
| 244 |
+
).unsqueeze(0)
|
| 245 |
+
|
| 246 |
+
if self._training:
|
| 247 |
+
# optimizer
|
| 248 |
+
if self._optimizer_type == "lamb":
|
| 249 |
+
self._optimizer = Lamb(
|
| 250 |
+
self._q.parameters(),
|
| 251 |
+
lr=self._lr,
|
| 252 |
+
weight_decay=self._lambda_weight_l2,
|
| 253 |
+
betas=(0.9, 0.999),
|
| 254 |
+
adam=False,
|
| 255 |
+
)
|
| 256 |
+
elif self._optimizer_type == "adam":
|
| 257 |
+
self._optimizer = torch.optim.Adam(
|
| 258 |
+
self._q.parameters(),
|
| 259 |
+
lr=self._lr,
|
| 260 |
+
weight_decay=self._lambda_weight_l2,
|
| 261 |
+
)
|
| 262 |
+
else:
|
| 263 |
+
raise Exception("Unknown optimizer type")
|
| 264 |
+
|
| 265 |
+
# learning rate scheduler
|
| 266 |
+
if self._lr_scheduler:
|
| 267 |
+
self._scheduler = (
|
| 268 |
+
transformers.get_cosine_with_hard_restarts_schedule_with_warmup(
|
| 269 |
+
self._optimizer,
|
| 270 |
+
num_warmup_steps=self._num_warmup_steps,
|
| 271 |
+
num_training_steps=self._training_iterations,
|
| 272 |
+
num_cycles=self._training_iterations // 10000,
|
| 273 |
+
)
|
| 274 |
+
)
|
| 275 |
+
|
| 276 |
+
# one-hot zero tensors
|
| 277 |
+
self._action_trans_one_hot_zeros = torch.zeros(
|
| 278 |
+
(
|
| 279 |
+
self._batch_size,
|
| 280 |
+
1,
|
| 281 |
+
self._voxel_size,
|
| 282 |
+
self._voxel_size,
|
| 283 |
+
self._voxel_size,
|
| 284 |
+
),
|
| 285 |
+
dtype=int,
|
| 286 |
+
device=device,
|
| 287 |
+
)
|
| 288 |
+
self._action_rot_x_one_hot_zeros = torch.zeros(
|
| 289 |
+
(self._batch_size, self._num_rotation_classes), dtype=int, device=device
|
| 290 |
+
)
|
| 291 |
+
self._action_rot_y_one_hot_zeros = torch.zeros(
|
| 292 |
+
(self._batch_size, self._num_rotation_classes), dtype=int, device=device
|
| 293 |
+
)
|
| 294 |
+
self._action_rot_z_one_hot_zeros = torch.zeros(
|
| 295 |
+
(self._batch_size, self._num_rotation_classes), dtype=int, device=device
|
| 296 |
+
)
|
| 297 |
+
self._action_grip_one_hot_zeros = torch.zeros(
|
| 298 |
+
(self._batch_size, 2), dtype=int, device=device
|
| 299 |
+
)
|
| 300 |
+
self._action_ignore_collisions_one_hot_zeros = torch.zeros(
|
| 301 |
+
(self._batch_size, 2), dtype=int, device=device
|
| 302 |
+
)
|
| 303 |
+
|
| 304 |
+
# print total params
|
| 305 |
+
logging.info(
|
| 306 |
+
"# Q Params: %d"
|
| 307 |
+
% sum(
|
| 308 |
+
p.numel()
|
| 309 |
+
for name, p in self._q.named_parameters()
|
| 310 |
+
if p.requires_grad and "clip" not in name
|
| 311 |
+
)
|
| 312 |
+
)
|
| 313 |
+
else:
|
| 314 |
+
for param in self._q.parameters():
|
| 315 |
+
param.requires_grad = False
|
| 316 |
+
|
| 317 |
+
# load CLIP for encoding language goals during evaluation
|
| 318 |
+
model, _ = load_clip("RN50", jit=False)
|
| 319 |
+
self._clip_rn50 = build_model(model.state_dict())
|
| 320 |
+
self._clip_rn50 = self._clip_rn50.float().to(device)
|
| 321 |
+
self._clip_rn50.eval()
|
| 322 |
+
del model
|
| 323 |
+
|
| 324 |
+
self._voxelizer.to(device)
|
| 325 |
+
self._q.to(device)
|
| 326 |
+
|
| 327 |
+
def _extract_crop(self, pixel_action, observation):
|
| 328 |
+
# Pixel action will now be (B, 2)
|
| 329 |
+
# observation = stack_on_channel(observation)
|
| 330 |
+
h = observation.shape[-1]
|
| 331 |
+
top_left_corner = torch.clamp(
|
| 332 |
+
pixel_action - self._image_crop_size // 2, 0, h - self._image_crop_size
|
| 333 |
+
)
|
| 334 |
+
grid = self._grid_for_crop + top_left_corner.unsqueeze(1)
|
| 335 |
+
grid = ((grid / float(h)) * 2.0) - 1.0 # between -1 and 1
|
| 336 |
+
# Used for cropping the images across a batch
|
| 337 |
+
# swap fro y x, to x, y
|
| 338 |
+
grid = torch.cat((grid[:, :, :, 1:2], grid[:, :, :, 0:1]), dim=-1)
|
| 339 |
+
crop = F.grid_sample(observation, grid, mode="nearest", align_corners=True)
|
| 340 |
+
return crop
|
| 341 |
+
|
| 342 |
+
def _preprocess_inputs(self, replay_sample):
|
| 343 |
+
obs = []
|
| 344 |
+
pcds = []
|
| 345 |
+
self._crop_summary = []
|
| 346 |
+
for n in self._camera_names:
|
| 347 |
+
rgb = replay_sample["%s_rgb" % n]
|
| 348 |
+
pcd = replay_sample["%s_point_cloud" % n]
|
| 349 |
+
|
| 350 |
+
obs.append([rgb, pcd])
|
| 351 |
+
pcds.append(pcd)
|
| 352 |
+
return obs, pcds
|
| 353 |
+
|
| 354 |
+
def _act_preprocess_inputs(self, observation):
|
| 355 |
+
obs, pcds = [], []
|
| 356 |
+
for n in self._camera_names:
|
| 357 |
+
rgb = observation["%s_rgb" % n]
|
| 358 |
+
pcd = observation["%s_point_cloud" % n]
|
| 359 |
+
|
| 360 |
+
obs.append([rgb, pcd])
|
| 361 |
+
pcds.append(pcd)
|
| 362 |
+
return obs, pcds
|
| 363 |
+
|
| 364 |
+
def _get_value_from_voxel_index(self, q, voxel_idx):
|
| 365 |
+
b, c, d, h, w = q.shape
|
| 366 |
+
q_trans_flat = q.view(b, c, d * h * w)
|
| 367 |
+
flat_indicies = (
|
| 368 |
+
voxel_idx[:, 0] * d * h + voxel_idx[:, 1] * h + voxel_idx[:, 2]
|
| 369 |
+
)[:, None].int()
|
| 370 |
+
highest_idxs = flat_indicies.unsqueeze(-1).repeat(1, c, 1)
|
| 371 |
+
chosen_voxel_values = q_trans_flat.gather(2, highest_idxs)[
|
| 372 |
+
..., 0
|
| 373 |
+
] # (B, trans + rot + grip)
|
| 374 |
+
return chosen_voxel_values
|
| 375 |
+
|
| 376 |
+
def _get_value_from_rot_and_grip(self, rot_grip_q, rot_and_grip_idx):
|
| 377 |
+
q_rot = torch.stack(
|
| 378 |
+
torch.split(
|
| 379 |
+
rot_grip_q[:, :-2], int(360 // self._rotation_resolution), dim=1
|
| 380 |
+
),
|
| 381 |
+
dim=1,
|
| 382 |
+
) # B, 3, 72
|
| 383 |
+
q_grip = rot_grip_q[:, -2:]
|
| 384 |
+
rot_and_grip_values = torch.cat(
|
| 385 |
+
[
|
| 386 |
+
q_rot[:, 0].gather(1, rot_and_grip_idx[:, 0:1]),
|
| 387 |
+
q_rot[:, 1].gather(1, rot_and_grip_idx[:, 1:2]),
|
| 388 |
+
q_rot[:, 2].gather(1, rot_and_grip_idx[:, 2:3]),
|
| 389 |
+
q_grip.gather(1, rot_and_grip_idx[:, 3:4]),
|
| 390 |
+
],
|
| 391 |
+
-1,
|
| 392 |
+
)
|
| 393 |
+
return rot_and_grip_values
|
| 394 |
+
|
| 395 |
+
def _celoss(self, pred, labels):
|
| 396 |
+
return self._cross_entropy_loss(pred, labels.argmax(-1))
|
| 397 |
+
|
| 398 |
+
def _softmax_q_trans(self, q):
|
| 399 |
+
q_shape = q.shape
|
| 400 |
+
return F.softmax(q.reshape(q_shape[0], -1), dim=1).reshape(q_shape)
|
| 401 |
+
|
| 402 |
+
def _softmax_q_rot_grip(self, q_rot_grip):
|
| 403 |
+
q_rot_x_flat = q_rot_grip[
|
| 404 |
+
:, 0 * self._num_rotation_classes : 1 * self._num_rotation_classes
|
| 405 |
+
]
|
| 406 |
+
q_rot_y_flat = q_rot_grip[
|
| 407 |
+
:, 1 * self._num_rotation_classes : 2 * self._num_rotation_classes
|
| 408 |
+
]
|
| 409 |
+
q_rot_z_flat = q_rot_grip[
|
| 410 |
+
:, 2 * self._num_rotation_classes : 3 * self._num_rotation_classes
|
| 411 |
+
]
|
| 412 |
+
q_grip_flat = q_rot_grip[:, 3 * self._num_rotation_classes :]
|
| 413 |
+
|
| 414 |
+
q_rot_x_flat_softmax = F.softmax(q_rot_x_flat, dim=1)
|
| 415 |
+
q_rot_y_flat_softmax = F.softmax(q_rot_y_flat, dim=1)
|
| 416 |
+
q_rot_z_flat_softmax = F.softmax(q_rot_z_flat, dim=1)
|
| 417 |
+
q_grip_flat_softmax = F.softmax(q_grip_flat, dim=1)
|
| 418 |
+
|
| 419 |
+
return torch.cat(
|
| 420 |
+
[
|
| 421 |
+
q_rot_x_flat_softmax,
|
| 422 |
+
q_rot_y_flat_softmax,
|
| 423 |
+
q_rot_z_flat_softmax,
|
| 424 |
+
q_grip_flat_softmax,
|
| 425 |
+
],
|
| 426 |
+
dim=1,
|
| 427 |
+
)
|
| 428 |
+
|
| 429 |
+
def _softmax_ignore_collision(self, q_collision):
|
| 430 |
+
q_collision_softmax = F.softmax(q_collision, dim=1)
|
| 431 |
+
return q_collision_softmax
|
| 432 |
+
|
| 433 |
+
def update(self, step: int, replay_sample: dict) -> dict:
|
| 434 |
+
action_trans = replay_sample["trans_action_indicies"][
|
| 435 |
+
:, self._layer * 3 : self._layer * 3 + 3
|
| 436 |
+
].int()
|
| 437 |
+
action_rot_grip = replay_sample["rot_grip_action_indicies"].int()
|
| 438 |
+
action_gripper_pose = replay_sample["gripper_pose"]
|
| 439 |
+
action_ignore_collisions = replay_sample["ignore_collisions"].int()
|
| 440 |
+
lang_goal_emb = replay_sample["lang_goal_emb"].float()
|
| 441 |
+
lang_token_embs = replay_sample["lang_token_embs"].float()
|
| 442 |
+
prev_layer_voxel_grid = replay_sample.get("prev_layer_voxel_grid", None)
|
| 443 |
+
prev_layer_bounds = replay_sample.get("prev_layer_bounds", None)
|
| 444 |
+
device = self._device
|
| 445 |
+
|
| 446 |
+
bounds = self._coordinate_bounds.to(device)
|
| 447 |
+
if self._layer > 0:
|
| 448 |
+
cp = replay_sample["attention_coordinate_layer_%d" % (self._layer - 1)]
|
| 449 |
+
bounds = torch.cat(
|
| 450 |
+
[cp - self._bounds_offset, cp + self._bounds_offset], dim=1
|
| 451 |
+
)
|
| 452 |
+
|
| 453 |
+
proprio = None
|
| 454 |
+
if self._include_low_dim_state:
|
| 455 |
+
proprio = replay_sample["low_dim_state"]
|
| 456 |
+
|
| 457 |
+
obs, pcd = self._preprocess_inputs(replay_sample)
|
| 458 |
+
|
| 459 |
+
# batch size
|
| 460 |
+
bs = pcd[0].shape[0]
|
| 461 |
+
|
| 462 |
+
# SE(3) augmentation of point clouds and actions
|
| 463 |
+
if self._transform_augmentation:
|
| 464 |
+
action_trans, action_rot_grip, pcd = apply_se3_augmentation(
|
| 465 |
+
pcd,
|
| 466 |
+
action_gripper_pose,
|
| 467 |
+
action_trans,
|
| 468 |
+
action_rot_grip,
|
| 469 |
+
bounds,
|
| 470 |
+
self._layer,
|
| 471 |
+
self._transform_augmentation_xyz,
|
| 472 |
+
self._transform_augmentation_rpy,
|
| 473 |
+
self._transform_augmentation_rot_resolution,
|
| 474 |
+
self._voxel_size,
|
| 475 |
+
self._rotation_resolution,
|
| 476 |
+
self._device,
|
| 477 |
+
)
|
| 478 |
+
|
| 479 |
+
# forward pass
|
| 480 |
+
q_trans, q_rot_grip, q_collision, voxel_grid = self._q(
|
| 481 |
+
obs,
|
| 482 |
+
proprio,
|
| 483 |
+
pcd,
|
| 484 |
+
lang_goal_emb,
|
| 485 |
+
lang_token_embs,
|
| 486 |
+
bounds,
|
| 487 |
+
prev_layer_bounds,
|
| 488 |
+
prev_layer_voxel_grid,
|
| 489 |
+
)
|
| 490 |
+
|
| 491 |
+
# argmax to choose best action
|
| 492 |
+
(
|
| 493 |
+
coords,
|
| 494 |
+
rot_and_grip_indicies,
|
| 495 |
+
ignore_collision_indicies,
|
| 496 |
+
) = self._q.choose_highest_action(q_trans, q_rot_grip, q_collision)
|
| 497 |
+
|
| 498 |
+
q_trans_loss, q_rot_loss, q_grip_loss, q_collision_loss = 0.0, 0.0, 0.0, 0.0
|
| 499 |
+
|
| 500 |
+
# translation one-hot
|
| 501 |
+
action_trans_one_hot = self._action_trans_one_hot_zeros.clone()
|
| 502 |
+
for b in range(bs):
|
| 503 |
+
gt_coord = action_trans[b, :].int()
|
| 504 |
+
action_trans_one_hot[b, :, gt_coord[0], gt_coord[1], gt_coord[2]] = 1
|
| 505 |
+
|
| 506 |
+
# translation loss
|
| 507 |
+
q_trans_flat = q_trans.view(bs, -1)
|
| 508 |
+
action_trans_one_hot_flat = action_trans_one_hot.view(bs, -1)
|
| 509 |
+
q_trans_loss = self._celoss(q_trans_flat, action_trans_one_hot_flat)
|
| 510 |
+
|
| 511 |
+
with_rot_and_grip = rot_and_grip_indicies is not None
|
| 512 |
+
if with_rot_and_grip:
|
| 513 |
+
# rotation, gripper, and collision one-hots
|
| 514 |
+
action_rot_x_one_hot = self._action_rot_x_one_hot_zeros.clone()
|
| 515 |
+
action_rot_y_one_hot = self._action_rot_y_one_hot_zeros.clone()
|
| 516 |
+
action_rot_z_one_hot = self._action_rot_z_one_hot_zeros.clone()
|
| 517 |
+
action_grip_one_hot = self._action_grip_one_hot_zeros.clone()
|
| 518 |
+
action_ignore_collisions_one_hot = (
|
| 519 |
+
self._action_ignore_collisions_one_hot_zeros.clone()
|
| 520 |
+
)
|
| 521 |
+
|
| 522 |
+
for b in range(bs):
|
| 523 |
+
gt_rot_grip = action_rot_grip[b, :].int()
|
| 524 |
+
action_rot_x_one_hot[b, gt_rot_grip[0]] = 1
|
| 525 |
+
action_rot_y_one_hot[b, gt_rot_grip[1]] = 1
|
| 526 |
+
action_rot_z_one_hot[b, gt_rot_grip[2]] = 1
|
| 527 |
+
action_grip_one_hot[b, gt_rot_grip[3]] = 1
|
| 528 |
+
|
| 529 |
+
gt_ignore_collisions = action_ignore_collisions[b, :].int()
|
| 530 |
+
action_ignore_collisions_one_hot[b, gt_ignore_collisions[0]] = 1
|
| 531 |
+
|
| 532 |
+
# flatten predictions
|
| 533 |
+
q_rot_x_flat = q_rot_grip[
|
| 534 |
+
:, 0 * self._num_rotation_classes : 1 * self._num_rotation_classes
|
| 535 |
+
]
|
| 536 |
+
q_rot_y_flat = q_rot_grip[
|
| 537 |
+
:, 1 * self._num_rotation_classes : 2 * self._num_rotation_classes
|
| 538 |
+
]
|
| 539 |
+
q_rot_z_flat = q_rot_grip[
|
| 540 |
+
:, 2 * self._num_rotation_classes : 3 * self._num_rotation_classes
|
| 541 |
+
]
|
| 542 |
+
q_grip_flat = q_rot_grip[:, 3 * self._num_rotation_classes :]
|
| 543 |
+
q_ignore_collisions_flat = q_collision
|
| 544 |
+
|
| 545 |
+
# rotation loss
|
| 546 |
+
q_rot_loss += self._celoss(q_rot_x_flat, action_rot_x_one_hot)
|
| 547 |
+
q_rot_loss += self._celoss(q_rot_y_flat, action_rot_y_one_hot)
|
| 548 |
+
q_rot_loss += self._celoss(q_rot_z_flat, action_rot_z_one_hot)
|
| 549 |
+
|
| 550 |
+
# gripper loss
|
| 551 |
+
q_grip_loss += self._celoss(q_grip_flat, action_grip_one_hot)
|
| 552 |
+
|
| 553 |
+
# collision loss
|
| 554 |
+
q_collision_loss += self._celoss(
|
| 555 |
+
q_ignore_collisions_flat, action_ignore_collisions_one_hot
|
| 556 |
+
)
|
| 557 |
+
|
| 558 |
+
combined_losses = (
|
| 559 |
+
(q_trans_loss * self._trans_loss_weight)
|
| 560 |
+
+ (q_rot_loss * self._rot_loss_weight)
|
| 561 |
+
+ (q_grip_loss * self._grip_loss_weight)
|
| 562 |
+
+ (q_collision_loss * self._collision_loss_weight)
|
| 563 |
+
)
|
| 564 |
+
total_loss = combined_losses.mean()
|
| 565 |
+
|
| 566 |
+
self._optimizer.zero_grad()
|
| 567 |
+
total_loss.backward()
|
| 568 |
+
self._optimizer.step()
|
| 569 |
+
|
| 570 |
+
self._summaries = {
|
| 571 |
+
"losses/total_loss": total_loss,
|
| 572 |
+
"losses/trans_loss": q_trans_loss.mean(),
|
| 573 |
+
"losses/rot_loss": q_rot_loss.mean() if with_rot_and_grip else 0.0,
|
| 574 |
+
"losses/grip_loss": q_grip_loss.mean() if with_rot_and_grip else 0.0,
|
| 575 |
+
"losses/collision_loss": q_collision_loss.mean()
|
| 576 |
+
if with_rot_and_grip
|
| 577 |
+
else 0.0,
|
| 578 |
+
}
|
| 579 |
+
|
| 580 |
+
if self._lr_scheduler:
|
| 581 |
+
self._scheduler.step()
|
| 582 |
+
self._summaries["learning_rate"] = self._scheduler.get_last_lr()[0]
|
| 583 |
+
|
| 584 |
+
self._vis_voxel_grid = voxel_grid[0]
|
| 585 |
+
self._vis_translation_qvalue = self._softmax_q_trans(q_trans[0])
|
| 586 |
+
self._vis_max_coordinate = coords[0]
|
| 587 |
+
self._vis_gt_coordinate = action_trans[0]
|
| 588 |
+
|
| 589 |
+
# Note: PerAct doesn't use multi-layer voxel grids like C2FARM
|
| 590 |
+
# stack prev_layer_voxel_grid(s) from previous layers into a list
|
| 591 |
+
if prev_layer_voxel_grid is None:
|
| 592 |
+
prev_layer_voxel_grid = [voxel_grid]
|
| 593 |
+
else:
|
| 594 |
+
prev_layer_voxel_grid = prev_layer_voxel_grid + [voxel_grid]
|
| 595 |
+
|
| 596 |
+
# stack prev_layer_bound(s) from previous layers into a list
|
| 597 |
+
if prev_layer_bounds is None:
|
| 598 |
+
prev_layer_bounds = [self._coordinate_bounds.repeat(bs, 1)]
|
| 599 |
+
else:
|
| 600 |
+
prev_layer_bounds = prev_layer_bounds + [bounds]
|
| 601 |
+
|
| 602 |
+
return {
|
| 603 |
+
"total_loss": total_loss,
|
| 604 |
+
"prev_layer_voxel_grid": prev_layer_voxel_grid,
|
| 605 |
+
"prev_layer_bounds": prev_layer_bounds,
|
| 606 |
+
}
|
| 607 |
+
|
| 608 |
+
def act(self, step: int, observation: dict, deterministic=False) -> ActResult:
|
| 609 |
+
deterministic = True
|
| 610 |
+
bounds = self._coordinate_bounds
|
| 611 |
+
prev_layer_voxel_grid = observation.get("prev_layer_voxel_grid", None)
|
| 612 |
+
prev_layer_bounds = observation.get("prev_layer_bounds", None)
|
| 613 |
+
lang_goal_tokens = observation.get("lang_goal_tokens", None).long()
|
| 614 |
+
|
| 615 |
+
# extract CLIP language embs
|
| 616 |
+
with torch.no_grad():
|
| 617 |
+
lang_goal_tokens = lang_goal_tokens.to(device=self._device)
|
| 618 |
+
(
|
| 619 |
+
lang_goal_emb,
|
| 620 |
+
lang_token_embs,
|
| 621 |
+
) = self._clip_rn50.encode_text_with_embeddings(lang_goal_tokens[0])
|
| 622 |
+
|
| 623 |
+
# voxelization resolution
|
| 624 |
+
res = (bounds[:, 3:] - bounds[:, :3]) / self._voxel_size
|
| 625 |
+
max_rot_index = int(360 // self._rotation_resolution)
|
| 626 |
+
proprio = None
|
| 627 |
+
|
| 628 |
+
if self._include_low_dim_state:
|
| 629 |
+
proprio = observation["low_dim_state"]
|
| 630 |
+
proprio = proprio[0].to(self._device)
|
| 631 |
+
|
| 632 |
+
obs, pcd = self._act_preprocess_inputs(observation)
|
| 633 |
+
|
| 634 |
+
# correct batch size and device
|
| 635 |
+
obs = [[o[0][0].to(self._device), o[1][0].to(self._device)] for o in obs]
|
| 636 |
+
pcd = [p[0].to(self._device) for p in pcd]
|
| 637 |
+
lang_goal_emb = lang_goal_emb.to(self._device)
|
| 638 |
+
lang_token_embs = lang_token_embs.to(self._device)
|
| 639 |
+
bounds = torch.as_tensor(bounds, device=self._device)
|
| 640 |
+
prev_layer_voxel_grid = (
|
| 641 |
+
prev_layer_voxel_grid.to(self._device)
|
| 642 |
+
if prev_layer_voxel_grid is not None
|
| 643 |
+
else None
|
| 644 |
+
)
|
| 645 |
+
prev_layer_bounds = (
|
| 646 |
+
prev_layer_bounds.to(self._device)
|
| 647 |
+
if prev_layer_bounds is not None
|
| 648 |
+
else None
|
| 649 |
+
)
|
| 650 |
+
|
| 651 |
+
# inference
|
| 652 |
+
q_trans, q_rot_grip, q_ignore_collisions, vox_grid = self._q(
|
| 653 |
+
obs,
|
| 654 |
+
proprio,
|
| 655 |
+
pcd,
|
| 656 |
+
lang_goal_emb,
|
| 657 |
+
lang_token_embs,
|
| 658 |
+
bounds,
|
| 659 |
+
prev_layer_bounds,
|
| 660 |
+
prev_layer_voxel_grid,
|
| 661 |
+
)
|
| 662 |
+
|
| 663 |
+
# softmax Q predictions
|
| 664 |
+
q_trans = self._softmax_q_trans(q_trans)
|
| 665 |
+
q_rot_grip = (
|
| 666 |
+
self._softmax_q_rot_grip(q_rot_grip)
|
| 667 |
+
if q_rot_grip is not None
|
| 668 |
+
else q_rot_grip
|
| 669 |
+
)
|
| 670 |
+
q_ignore_collisions = (
|
| 671 |
+
self._softmax_ignore_collision(q_ignore_collisions)
|
| 672 |
+
if q_ignore_collisions is not None
|
| 673 |
+
else q_ignore_collisions
|
| 674 |
+
)
|
| 675 |
+
|
| 676 |
+
# argmax Q predictions
|
| 677 |
+
(
|
| 678 |
+
coords,
|
| 679 |
+
rot_and_grip_indicies,
|
| 680 |
+
ignore_collisions,
|
| 681 |
+
) = self._q.choose_highest_action(q_trans, q_rot_grip, q_ignore_collisions)
|
| 682 |
+
|
| 683 |
+
rot_grip_action = rot_and_grip_indicies if q_rot_grip is not None else None
|
| 684 |
+
ignore_collisions_action = (
|
| 685 |
+
ignore_collisions.int() if ignore_collisions is not None else None
|
| 686 |
+
)
|
| 687 |
+
|
| 688 |
+
coords = coords.int()
|
| 689 |
+
attention_coordinate = bounds[:, :3] + res * coords + res / 2
|
| 690 |
+
|
| 691 |
+
# stack prev_layer_voxel_grid(s) into a list
|
| 692 |
+
# NOTE: PerAct doesn't used multi-layer voxel grids like C2FARM
|
| 693 |
+
if prev_layer_voxel_grid is None:
|
| 694 |
+
prev_layer_voxel_grid = [vox_grid]
|
| 695 |
+
else:
|
| 696 |
+
prev_layer_voxel_grid = prev_layer_voxel_grid + [vox_grid]
|
| 697 |
+
|
| 698 |
+
if prev_layer_bounds is None:
|
| 699 |
+
prev_layer_bounds = [bounds]
|
| 700 |
+
else:
|
| 701 |
+
prev_layer_bounds = prev_layer_bounds + [bounds]
|
| 702 |
+
|
| 703 |
+
observation_elements = {
|
| 704 |
+
"attention_coordinate": attention_coordinate,
|
| 705 |
+
"prev_layer_voxel_grid": prev_layer_voxel_grid,
|
| 706 |
+
"prev_layer_bounds": prev_layer_bounds,
|
| 707 |
+
}
|
| 708 |
+
info = {
|
| 709 |
+
"voxel_grid_depth%d" % self._layer: vox_grid,
|
| 710 |
+
"q_depth%d" % self._layer: q_trans,
|
| 711 |
+
"voxel_idx_depth%d" % self._layer: coords,
|
| 712 |
+
}
|
| 713 |
+
self._act_voxel_grid = vox_grid[0]
|
| 714 |
+
self._act_max_coordinate = coords[0]
|
| 715 |
+
self._act_qvalues = q_trans[0].detach()
|
| 716 |
+
return ActResult(
|
| 717 |
+
(coords, rot_grip_action, ignore_collisions_action),
|
| 718 |
+
observation_elements=observation_elements,
|
| 719 |
+
info=info,
|
| 720 |
+
)
|
| 721 |
+
|
| 722 |
+
def update_summaries(self) -> List[Summary]:
|
| 723 |
+
summaries = [
|
| 724 |
+
ImageSummary(
|
| 725 |
+
"%s/update_qattention" % self._name,
|
| 726 |
+
transforms.ToTensor()(
|
| 727 |
+
visualise_voxel(
|
| 728 |
+
self._vis_voxel_grid.detach().cpu().numpy(),
|
| 729 |
+
self._vis_translation_qvalue.detach().cpu().numpy(),
|
| 730 |
+
self._vis_max_coordinate.detach().cpu().numpy(),
|
| 731 |
+
self._vis_gt_coordinate.detach().cpu().numpy(),
|
| 732 |
+
)
|
| 733 |
+
),
|
| 734 |
+
)
|
| 735 |
+
]
|
| 736 |
+
|
| 737 |
+
for n, v in self._summaries.items():
|
| 738 |
+
summaries.append(ScalarSummary("%s/%s" % (self._name, n), v))
|
| 739 |
+
|
| 740 |
+
for name, crop in self._crop_summary:
|
| 741 |
+
crops = (torch.cat(torch.split(crop, 3, dim=1), dim=3) + 1.0) / 2.0
|
| 742 |
+
summaries.extend([ImageSummary("%s/crops/%s" % (self._name, name), crops)])
|
| 743 |
+
|
| 744 |
+
for tag, param in self._q.named_parameters():
|
| 745 |
+
# assert not torch.isnan(param.grad.abs() <= 1.0).all()
|
| 746 |
+
summaries.append(
|
| 747 |
+
HistogramSummary("%s/gradient/%s" % (self._name, tag), param.grad)
|
| 748 |
+
)
|
| 749 |
+
summaries.append(
|
| 750 |
+
HistogramSummary("%s/weight/%s" % (self._name, tag), param.data)
|
| 751 |
+
)
|
| 752 |
+
|
| 753 |
+
return summaries
|
| 754 |
+
|
| 755 |
+
def act_summaries(self) -> List[Summary]:
|
| 756 |
+
return [
|
| 757 |
+
ImageSummary(
|
| 758 |
+
"%s/act_Qattention" % self._name,
|
| 759 |
+
transforms.ToTensor()(
|
| 760 |
+
visualise_voxel(
|
| 761 |
+
self._act_voxel_grid.cpu().numpy(),
|
| 762 |
+
self._act_qvalues.cpu().numpy(),
|
| 763 |
+
self._act_max_coordinate.cpu().numpy(),
|
| 764 |
+
)
|
| 765 |
+
),
|
| 766 |
+
)
|
| 767 |
+
]
|
| 768 |
+
|
| 769 |
+
def load_weights(self, savedir: str):
|
| 770 |
+
device = (
|
| 771 |
+
self._device
|
| 772 |
+
if not self._training
|
| 773 |
+
else torch.device("cuda:%d" % self._device)
|
| 774 |
+
)
|
| 775 |
+
weight_file = os.path.join(savedir, "%s.pt" % self._name)
|
| 776 |
+
state_dict = torch.load(weight_file, map_location=device)
|
| 777 |
+
|
| 778 |
+
# load only keys that are in the current model
|
| 779 |
+
merged_state_dict = self._q.state_dict()
|
| 780 |
+
for k, v in state_dict.items():
|
| 781 |
+
if not self._training:
|
| 782 |
+
k = k.replace("_qnet.module", "_qnet")
|
| 783 |
+
if k in merged_state_dict:
|
| 784 |
+
merged_state_dict[k] = v
|
| 785 |
+
else:
|
| 786 |
+
if "_voxelizer" not in k:
|
| 787 |
+
logging.warning("key %s not found in checkpoint" % k)
|
| 788 |
+
if not self._training:
|
| 789 |
+
# reshape voxelizer weights
|
| 790 |
+
b = merged_state_dict["_voxelizer._ones_max_coords"].shape[0]
|
| 791 |
+
merged_state_dict["_voxelizer._ones_max_coords"] = merged_state_dict[
|
| 792 |
+
"_voxelizer._ones_max_coords"
|
| 793 |
+
][0:1]
|
| 794 |
+
flat_shape = merged_state_dict["_voxelizer._flat_output"].shape[0]
|
| 795 |
+
merged_state_dict["_voxelizer._flat_output"] = merged_state_dict[
|
| 796 |
+
"_voxelizer._flat_output"
|
| 797 |
+
][0 : flat_shape // b]
|
| 798 |
+
merged_state_dict["_voxelizer._tiled_batch_indices"] = merged_state_dict[
|
| 799 |
+
"_voxelizer._tiled_batch_indices"
|
| 800 |
+
][0:1]
|
| 801 |
+
merged_state_dict["_voxelizer._index_grid"] = merged_state_dict[
|
| 802 |
+
"_voxelizer._index_grid"
|
| 803 |
+
][0:1]
|
| 804 |
+
self._q.load_state_dict(merged_state_dict)
|
| 805 |
+
print("loaded weights from %s" % weight_file)
|
| 806 |
+
|
| 807 |
+
def save_weights(self, savedir: str):
|
| 808 |
+
torch.save(self._q.state_dict(), os.path.join(savedir, "%s.pt" % self._name))
|
external/peract_bimanual/agents/peract_bc/qattention_stack_agent.py
ADDED
|
@@ -0,0 +1,132 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import List
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
from yarr.agents.agent import Agent, ActResult, Summary
|
| 5 |
+
|
| 6 |
+
import numpy as np
|
| 7 |
+
|
| 8 |
+
from helpers import utils
|
| 9 |
+
from agents.peract_bc.qattention_peract_bc_agent import QAttentionPerActBCAgent
|
| 10 |
+
|
| 11 |
+
NAME = "QAttentionStackAgent"
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
class QAttentionStackAgent(Agent):
|
| 15 |
+
def __init__(
|
| 16 |
+
self,
|
| 17 |
+
qattention_agents: List[QAttentionPerActBCAgent],
|
| 18 |
+
rotation_resolution: float,
|
| 19 |
+
camera_names: List[str],
|
| 20 |
+
rotation_prediction_depth: int = 0,
|
| 21 |
+
):
|
| 22 |
+
super(QAttentionStackAgent, self).__init__()
|
| 23 |
+
self._qattention_agents = qattention_agents
|
| 24 |
+
self._rotation_resolution = rotation_resolution
|
| 25 |
+
self._camera_names = camera_names
|
| 26 |
+
self._rotation_prediction_depth = rotation_prediction_depth
|
| 27 |
+
|
| 28 |
+
def build(self, training: bool, device=None) -> None:
|
| 29 |
+
self._device = device
|
| 30 |
+
if self._device is None:
|
| 31 |
+
self._device = torch.device("cpu")
|
| 32 |
+
for qa in self._qattention_agents:
|
| 33 |
+
qa.build(training, device)
|
| 34 |
+
|
| 35 |
+
def update(self, step: int, replay_sample: dict) -> dict:
|
| 36 |
+
priorities = 0
|
| 37 |
+
total_losses = 0.0
|
| 38 |
+
for qa in self._qattention_agents:
|
| 39 |
+
update_dict = qa.update(step, replay_sample)
|
| 40 |
+
replay_sample.update(update_dict)
|
| 41 |
+
total_losses += update_dict["total_loss"]
|
| 42 |
+
return {
|
| 43 |
+
"total_losses": total_losses,
|
| 44 |
+
}
|
| 45 |
+
|
| 46 |
+
def act(self, step: int, observation: dict, deterministic=False) -> ActResult:
|
| 47 |
+
observation_elements = {}
|
| 48 |
+
translation_results, rot_grip_results, ignore_collisions_results = [], [], []
|
| 49 |
+
infos = {}
|
| 50 |
+
for depth, qagent in enumerate(self._qattention_agents):
|
| 51 |
+
act_results = qagent.act(step, observation, deterministic)
|
| 52 |
+
attention_coordinate = (
|
| 53 |
+
act_results.observation_elements["attention_coordinate"].cpu().numpy()
|
| 54 |
+
)
|
| 55 |
+
observation_elements[
|
| 56 |
+
"attention_coordinate_layer_%d" % depth
|
| 57 |
+
] = attention_coordinate[0]
|
| 58 |
+
|
| 59 |
+
translation_idxs, rot_grip_idxs, ignore_collisions_idxs = act_results.action
|
| 60 |
+
translation_results.append(translation_idxs)
|
| 61 |
+
if rot_grip_idxs is not None:
|
| 62 |
+
rot_grip_results.append(rot_grip_idxs)
|
| 63 |
+
if ignore_collisions_idxs is not None:
|
| 64 |
+
ignore_collisions_results.append(ignore_collisions_idxs)
|
| 65 |
+
|
| 66 |
+
observation["attention_coordinate"] = act_results.observation_elements[
|
| 67 |
+
"attention_coordinate"
|
| 68 |
+
]
|
| 69 |
+
observation["prev_layer_voxel_grid"] = act_results.observation_elements[
|
| 70 |
+
"prev_layer_voxel_grid"
|
| 71 |
+
]
|
| 72 |
+
observation["prev_layer_bounds"] = act_results.observation_elements[
|
| 73 |
+
"prev_layer_bounds"
|
| 74 |
+
]
|
| 75 |
+
|
| 76 |
+
for n in self._camera_names:
|
| 77 |
+
px, py = utils.point_to_pixel_index(
|
| 78 |
+
attention_coordinate[0],
|
| 79 |
+
observation["%s_camera_extrinsics" % n][0, 0].cpu().numpy(),
|
| 80 |
+
observation["%s_camera_intrinsics" % n][0, 0].cpu().numpy(),
|
| 81 |
+
)
|
| 82 |
+
pc_t = torch.tensor(
|
| 83 |
+
[[[py, px]]], dtype=torch.float32, device=self._device
|
| 84 |
+
)
|
| 85 |
+
observation["%s_pixel_coord" % n] = pc_t
|
| 86 |
+
observation_elements["%s_pixel_coord" % n] = [py, px]
|
| 87 |
+
|
| 88 |
+
infos.update(act_results.info)
|
| 89 |
+
|
| 90 |
+
rgai = torch.cat(rot_grip_results, 1)[0].cpu().numpy()
|
| 91 |
+
ignore_collisions = float(
|
| 92 |
+
torch.cat(ignore_collisions_results, 1)[0].cpu().numpy()
|
| 93 |
+
)
|
| 94 |
+
observation_elements["trans_action_indicies"] = (
|
| 95 |
+
torch.cat(translation_results, 1)[0].cpu().numpy()
|
| 96 |
+
)
|
| 97 |
+
observation_elements["rot_grip_action_indicies"] = rgai
|
| 98 |
+
continuous_action = np.concatenate(
|
| 99 |
+
[
|
| 100 |
+
act_results.observation_elements["attention_coordinate"]
|
| 101 |
+
.cpu()
|
| 102 |
+
.numpy()[0],
|
| 103 |
+
utils.discrete_euler_to_quaternion(
|
| 104 |
+
rgai[-4:-1], self._rotation_resolution
|
| 105 |
+
),
|
| 106 |
+
rgai[-1:],
|
| 107 |
+
[ignore_collisions],
|
| 108 |
+
]
|
| 109 |
+
)
|
| 110 |
+
return ActResult(
|
| 111 |
+
continuous_action, observation_elements=observation_elements, info=infos
|
| 112 |
+
)
|
| 113 |
+
|
| 114 |
+
def update_summaries(self) -> List[Summary]:
|
| 115 |
+
summaries = []
|
| 116 |
+
for qa in self._qattention_agents:
|
| 117 |
+
summaries.extend(qa.update_summaries())
|
| 118 |
+
return summaries
|
| 119 |
+
|
| 120 |
+
def act_summaries(self) -> List[Summary]:
|
| 121 |
+
s = []
|
| 122 |
+
for qa in self._qattention_agents:
|
| 123 |
+
s.extend(qa.act_summaries())
|
| 124 |
+
return s
|
| 125 |
+
|
| 126 |
+
def load_weights(self, savedir: str):
|
| 127 |
+
for qa in self._qattention_agents:
|
| 128 |
+
qa.load_weights(savedir)
|
| 129 |
+
|
| 130 |
+
def save_weights(self, savedir: str):
|
| 131 |
+
for qa in self._qattention_agents:
|
| 132 |
+
qa.save_weights(savedir)
|
external/peract_bimanual/agents/replay_utils.py
ADDED
|
@@ -0,0 +1,643 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import logging
|
| 2 |
+
from typing import List
|
| 3 |
+
|
| 4 |
+
import numpy as np
|
| 5 |
+
from rlbench.backend.observation import Observation
|
| 6 |
+
from rlbench.observation_config import ObservationConfig
|
| 7 |
+
import rlbench.utils as rlbench_utils
|
| 8 |
+
from rlbench.demo import Demo
|
| 9 |
+
from yarr.replay_buffer.replay_buffer import ReplayBuffer
|
| 10 |
+
|
| 11 |
+
from helpers import demo_loading_utils, utils
|
| 12 |
+
from helpers import observation_utils
|
| 13 |
+
from helpers.clip.core.clip import tokenize
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
from yarr.replay_buffer.prioritized_replay_buffer import ObservationElement
|
| 17 |
+
from yarr.replay_buffer.replay_buffer import ReplayElement
|
| 18 |
+
from yarr.replay_buffer.task_uniform_replay_buffer import TaskUniformReplayBuffer
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
import torch
|
| 22 |
+
from torch.multiprocessing import Process, Value, Manager
|
| 23 |
+
from helpers.clip.core.clip import build_model, load_clip
|
| 24 |
+
from omegaconf import DictConfig
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
REWARD_SCALE = 100.0
|
| 28 |
+
LOW_DIM_SIZE = 4
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
def create_replay(cfg, replay_path):
|
| 32 |
+
if cfg.method.robot_name == "bimanual":
|
| 33 |
+
return create_bimanual_replay(
|
| 34 |
+
cfg.replay.batch_size,
|
| 35 |
+
cfg.replay.timesteps,
|
| 36 |
+
cfg.replay.prioritisation,
|
| 37 |
+
cfg.replay.task_uniform,
|
| 38 |
+
replay_path if cfg.replay.use_disk else None,
|
| 39 |
+
cfg.rlbench.cameras,
|
| 40 |
+
cfg.method.voxel_sizes,
|
| 41 |
+
cfg.rlbench.camera_resolution,
|
| 42 |
+
)
|
| 43 |
+
else:
|
| 44 |
+
return create_unimanual_replay(
|
| 45 |
+
cfg.replay.batch_size,
|
| 46 |
+
cfg.replay.timesteps,
|
| 47 |
+
cfg.replay.prioritisation,
|
| 48 |
+
cfg.replay.task_uniform,
|
| 49 |
+
replay_path if cfg.replay.use_disk else None,
|
| 50 |
+
cfg.rlbench.cameras,
|
| 51 |
+
cfg.method.voxel_sizes,
|
| 52 |
+
cfg.rlbench.camera_resolution,
|
| 53 |
+
)
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
def create_bimanual_replay(
|
| 57 |
+
batch_size: int,
|
| 58 |
+
timesteps: int,
|
| 59 |
+
prioritisation: bool,
|
| 60 |
+
task_uniform: bool,
|
| 61 |
+
save_dir: str,
|
| 62 |
+
cameras: list,
|
| 63 |
+
voxel_sizes,
|
| 64 |
+
image_size=[128, 128],
|
| 65 |
+
replay_size=3e5,
|
| 66 |
+
):
|
| 67 |
+
trans_indicies_size = 3 * len(voxel_sizes)
|
| 68 |
+
rot_and_grip_indicies_size = 3 + 1
|
| 69 |
+
gripper_pose_size = 7
|
| 70 |
+
ignore_collisions_size = 1
|
| 71 |
+
max_token_seq_len = 77
|
| 72 |
+
lang_feat_dim = 1024
|
| 73 |
+
lang_emb_dim = 512
|
| 74 |
+
|
| 75 |
+
# low_dim_state
|
| 76 |
+
observation_elements = []
|
| 77 |
+
observation_elements.append(
|
| 78 |
+
ObservationElement("right_low_dim_state", (LOW_DIM_SIZE,), np.float32)
|
| 79 |
+
)
|
| 80 |
+
observation_elements.append(
|
| 81 |
+
ObservationElement("left_low_dim_state", (LOW_DIM_SIZE,), np.float32)
|
| 82 |
+
)
|
| 83 |
+
|
| 84 |
+
# rgb, depth, point cloud, intrinsics, extrinsics
|
| 85 |
+
for cname in cameras:
|
| 86 |
+
observation_elements.append(
|
| 87 |
+
# color, height, width
|
| 88 |
+
ObservationElement(
|
| 89 |
+
"%s_rgb" % cname,
|
| 90 |
+
(
|
| 91 |
+
3,
|
| 92 |
+
image_size[1],
|
| 93 |
+
image_size[0],
|
| 94 |
+
),
|
| 95 |
+
np.float32,
|
| 96 |
+
)
|
| 97 |
+
)
|
| 98 |
+
observation_elements.append(
|
| 99 |
+
ObservationElement(
|
| 100 |
+
"%s_point_cloud" % cname, (3, image_size[1], image_size[0]), np.float16
|
| 101 |
+
)
|
| 102 |
+
) # see pyrep/objects/vision_sensor.py on how pointclouds are extracted from depth frames
|
| 103 |
+
observation_elements.append(
|
| 104 |
+
ObservationElement(
|
| 105 |
+
"%s_camera_extrinsics" % cname,
|
| 106 |
+
(
|
| 107 |
+
4,
|
| 108 |
+
4,
|
| 109 |
+
),
|
| 110 |
+
np.float32,
|
| 111 |
+
)
|
| 112 |
+
)
|
| 113 |
+
observation_elements.append(
|
| 114 |
+
ObservationElement(
|
| 115 |
+
"%s_camera_intrinsics" % cname,
|
| 116 |
+
(
|
| 117 |
+
3,
|
| 118 |
+
3,
|
| 119 |
+
),
|
| 120 |
+
np.float32,
|
| 121 |
+
)
|
| 122 |
+
)
|
| 123 |
+
|
| 124 |
+
# discretized translation, discretized rotation, discrete ignore collision, 6-DoF gripper pose, and pre-trained language embeddings
|
| 125 |
+
for robot_name in ["right", "left"]:
|
| 126 |
+
observation_elements.extend(
|
| 127 |
+
[
|
| 128 |
+
ReplayElement(
|
| 129 |
+
f"{robot_name}_trans_action_indicies",
|
| 130 |
+
(trans_indicies_size,),
|
| 131 |
+
np.int32,
|
| 132 |
+
),
|
| 133 |
+
ReplayElement(
|
| 134 |
+
f"{robot_name}_rot_grip_action_indicies",
|
| 135 |
+
(rot_and_grip_indicies_size,),
|
| 136 |
+
np.int32,
|
| 137 |
+
),
|
| 138 |
+
ReplayElement(
|
| 139 |
+
f"{robot_name}_ignore_collisions",
|
| 140 |
+
(ignore_collisions_size,),
|
| 141 |
+
np.int32,
|
| 142 |
+
),
|
| 143 |
+
ReplayElement(
|
| 144 |
+
f"{robot_name}_gripper_pose", (gripper_pose_size,), np.float32
|
| 145 |
+
),
|
| 146 |
+
]
|
| 147 |
+
)
|
| 148 |
+
|
| 149 |
+
observation_elements.extend(
|
| 150 |
+
[
|
| 151 |
+
ReplayElement("lang_goal_emb", (lang_feat_dim,), np.float32),
|
| 152 |
+
ReplayElement(
|
| 153 |
+
"lang_token_embs",
|
| 154 |
+
(
|
| 155 |
+
max_token_seq_len,
|
| 156 |
+
lang_emb_dim,
|
| 157 |
+
),
|
| 158 |
+
np.float32,
|
| 159 |
+
), # extracted from CLIP's language encoder
|
| 160 |
+
ReplayElement("task", (), str),
|
| 161 |
+
ReplayElement(
|
| 162 |
+
"lang_goal", (1,), object
|
| 163 |
+
), # language goal string for debugging and visualization
|
| 164 |
+
]
|
| 165 |
+
)
|
| 166 |
+
|
| 167 |
+
extra_replay_elements = [
|
| 168 |
+
ReplayElement("demo", (), bool),
|
| 169 |
+
]
|
| 170 |
+
|
| 171 |
+
replay_buffer = TaskUniformReplayBuffer(
|
| 172 |
+
save_dir=save_dir,
|
| 173 |
+
batch_size=batch_size,
|
| 174 |
+
timesteps=timesteps,
|
| 175 |
+
replay_capacity=int(replay_size),
|
| 176 |
+
action_shape=(8 * 2,),
|
| 177 |
+
action_dtype=np.float32,
|
| 178 |
+
reward_shape=(),
|
| 179 |
+
reward_dtype=np.float32,
|
| 180 |
+
update_horizon=1,
|
| 181 |
+
observation_elements=observation_elements,
|
| 182 |
+
extra_replay_elements=extra_replay_elements,
|
| 183 |
+
)
|
| 184 |
+
return replay_buffer
|
| 185 |
+
|
| 186 |
+
|
| 187 |
+
def create_unimanual_replay(
|
| 188 |
+
batch_size: int,
|
| 189 |
+
timesteps: int,
|
| 190 |
+
prioritisation: bool,
|
| 191 |
+
task_uniform: bool,
|
| 192 |
+
save_dir: str,
|
| 193 |
+
cameras: list,
|
| 194 |
+
voxel_sizes,
|
| 195 |
+
image_size=[128, 128],
|
| 196 |
+
replay_size=3e5,
|
| 197 |
+
):
|
| 198 |
+
trans_indicies_size = 3 * len(voxel_sizes)
|
| 199 |
+
rot_and_grip_indicies_size = 3 + 1
|
| 200 |
+
gripper_pose_size = 7
|
| 201 |
+
ignore_collisions_size = 1
|
| 202 |
+
max_token_seq_len = 77
|
| 203 |
+
lang_feat_dim = 1024
|
| 204 |
+
lang_emb_dim = 512
|
| 205 |
+
|
| 206 |
+
# low_dim_state
|
| 207 |
+
observation_elements = []
|
| 208 |
+
observation_elements.append(
|
| 209 |
+
ObservationElement("low_dim_state", (LOW_DIM_SIZE,), np.float32)
|
| 210 |
+
)
|
| 211 |
+
|
| 212 |
+
# rgb, depth, point cloud, intrinsics, extrinsics
|
| 213 |
+
for cname in cameras:
|
| 214 |
+
observation_elements.append(
|
| 215 |
+
ObservationElement(
|
| 216 |
+
"%s_rgb" % cname,
|
| 217 |
+
(
|
| 218 |
+
3,
|
| 219 |
+
*image_size,
|
| 220 |
+
),
|
| 221 |
+
np.float32,
|
| 222 |
+
)
|
| 223 |
+
)
|
| 224 |
+
observation_elements.append(
|
| 225 |
+
ObservationElement("%s_point_cloud" % cname, (3, *image_size), np.float32)
|
| 226 |
+
) # see pyrep/objects/vision_sensor.py on how pointclouds are extracted from depth frames
|
| 227 |
+
observation_elements.append(
|
| 228 |
+
ObservationElement(
|
| 229 |
+
"%s_camera_extrinsics" % cname,
|
| 230 |
+
(
|
| 231 |
+
4,
|
| 232 |
+
4,
|
| 233 |
+
),
|
| 234 |
+
np.float32,
|
| 235 |
+
)
|
| 236 |
+
)
|
| 237 |
+
observation_elements.append(
|
| 238 |
+
ObservationElement(
|
| 239 |
+
"%s_camera_intrinsics" % cname,
|
| 240 |
+
(
|
| 241 |
+
3,
|
| 242 |
+
3,
|
| 243 |
+
),
|
| 244 |
+
np.float32,
|
| 245 |
+
)
|
| 246 |
+
)
|
| 247 |
+
|
| 248 |
+
# discretized translation, discretized rotation, discrete ignore collision, 6-DoF gripper pose, and pre-trained language embeddings
|
| 249 |
+
observation_elements.extend(
|
| 250 |
+
[
|
| 251 |
+
ReplayElement("trans_action_indicies", (trans_indicies_size,), np.int32),
|
| 252 |
+
ReplayElement(
|
| 253 |
+
"rot_grip_action_indicies", (rot_and_grip_indicies_size,), np.int32
|
| 254 |
+
),
|
| 255 |
+
ReplayElement("ignore_collisions", (ignore_collisions_size,), np.int32),
|
| 256 |
+
ReplayElement("gripper_pose", (gripper_pose_size,), np.float32),
|
| 257 |
+
ReplayElement("lang_goal_emb", (lang_feat_dim,), np.float32),
|
| 258 |
+
ReplayElement(
|
| 259 |
+
"lang_token_embs",
|
| 260 |
+
(
|
| 261 |
+
max_token_seq_len,
|
| 262 |
+
lang_emb_dim,
|
| 263 |
+
),
|
| 264 |
+
np.float32,
|
| 265 |
+
), # extracted from CLIP's language encoder
|
| 266 |
+
ReplayElement("task", (), str),
|
| 267 |
+
ReplayElement(
|
| 268 |
+
"lang_goal", (1,), object
|
| 269 |
+
), # language goal string for debugging and visualization
|
| 270 |
+
]
|
| 271 |
+
)
|
| 272 |
+
|
| 273 |
+
extra_replay_elements = [
|
| 274 |
+
ReplayElement("demo", (), bool),
|
| 275 |
+
]
|
| 276 |
+
|
| 277 |
+
replay_buffer = TaskUniformReplayBuffer(
|
| 278 |
+
save_dir=save_dir,
|
| 279 |
+
batch_size=batch_size,
|
| 280 |
+
timesteps=timesteps,
|
| 281 |
+
replay_capacity=int(replay_size),
|
| 282 |
+
action_shape=(8,),
|
| 283 |
+
action_dtype=np.float32,
|
| 284 |
+
reward_shape=(),
|
| 285 |
+
reward_dtype=np.float32,
|
| 286 |
+
update_horizon=1,
|
| 287 |
+
observation_elements=observation_elements,
|
| 288 |
+
extra_replay_elements=extra_replay_elements,
|
| 289 |
+
)
|
| 290 |
+
return replay_buffer
|
| 291 |
+
|
| 292 |
+
|
| 293 |
+
def _get_action(
|
| 294 |
+
obs_tp1: Observation,
|
| 295 |
+
obs_tm1: Observation,
|
| 296 |
+
rlbench_scene_bounds: List[float], # metric 3D bounds of the scene
|
| 297 |
+
voxel_sizes: List[int],
|
| 298 |
+
bounds_offset: List[float],
|
| 299 |
+
rotation_resolution: int,
|
| 300 |
+
crop_augmentation: bool,
|
| 301 |
+
):
|
| 302 |
+
quat = utils.normalize_quaternion(obs_tp1.gripper_pose[3:])
|
| 303 |
+
if quat[-1] < 0:
|
| 304 |
+
quat = -quat
|
| 305 |
+
disc_rot = utils.quaternion_to_discrete_euler(quat, rotation_resolution)
|
| 306 |
+
disc_rot = utils.correct_rotation_instability(disc_rot, rotation_resolution)
|
| 307 |
+
|
| 308 |
+
attention_coordinate = obs_tp1.gripper_pose[:3]
|
| 309 |
+
trans_indicies, attention_coordinates = [], []
|
| 310 |
+
bounds = np.array(rlbench_scene_bounds)
|
| 311 |
+
ignore_collisions = int(obs_tm1.ignore_collisions)
|
| 312 |
+
for depth, vox_size in enumerate(
|
| 313 |
+
voxel_sizes
|
| 314 |
+
): # only single voxelization-level is used in PerAct
|
| 315 |
+
if depth > 0:
|
| 316 |
+
if crop_augmentation:
|
| 317 |
+
shift = bounds_offset[depth - 1] * 0.75
|
| 318 |
+
attention_coordinate += np.random.uniform(-shift, shift, size=(3,))
|
| 319 |
+
bounds = np.concatenate(
|
| 320 |
+
[
|
| 321 |
+
attention_coordinate - bounds_offset[depth - 1],
|
| 322 |
+
attention_coordinate + bounds_offset[depth - 1],
|
| 323 |
+
]
|
| 324 |
+
)
|
| 325 |
+
index = utils.point_to_voxel_index(obs_tp1.gripper_pose[:3], vox_size, bounds)
|
| 326 |
+
trans_indicies.extend(index.tolist())
|
| 327 |
+
res = (bounds[3:] - bounds[:3]) / vox_size
|
| 328 |
+
attention_coordinate = bounds[:3] + res * index
|
| 329 |
+
attention_coordinates.append(attention_coordinate)
|
| 330 |
+
|
| 331 |
+
rot_and_grip_indicies = disc_rot.tolist()
|
| 332 |
+
grip = float(obs_tp1.gripper_open)
|
| 333 |
+
rot_and_grip_indicies.extend([int(obs_tp1.gripper_open)])
|
| 334 |
+
return (
|
| 335 |
+
trans_indicies,
|
| 336 |
+
rot_and_grip_indicies,
|
| 337 |
+
ignore_collisions,
|
| 338 |
+
np.concatenate([obs_tp1.gripper_pose, np.array([grip])]),
|
| 339 |
+
attention_coordinates,
|
| 340 |
+
)
|
| 341 |
+
|
| 342 |
+
|
| 343 |
+
def _add_keypoints_to_replay(
|
| 344 |
+
cfg: DictConfig,
|
| 345 |
+
task: str,
|
| 346 |
+
replay: ReplayBuffer,
|
| 347 |
+
inital_obs: Observation,
|
| 348 |
+
demo: Demo,
|
| 349 |
+
episode_keypoints: List[int],
|
| 350 |
+
description: str = "",
|
| 351 |
+
clip_model=None,
|
| 352 |
+
device="cpu",
|
| 353 |
+
):
|
| 354 |
+
cameras = cfg.rlbench.cameras
|
| 355 |
+
rlbench_scene_bounds = cfg.rlbench.scene_bounds
|
| 356 |
+
voxel_sizes = cfg.method.voxel_sizes
|
| 357 |
+
bounds_offset = cfg.method.bounds_offset
|
| 358 |
+
rotation_resolution = cfg.method.rotation_resolution
|
| 359 |
+
crop_augmentation = cfg.method.crop_augmentation
|
| 360 |
+
robot_name = cfg.method.robot_name
|
| 361 |
+
|
| 362 |
+
prev_action = None
|
| 363 |
+
obs = inital_obs
|
| 364 |
+
|
| 365 |
+
for k, keypoint in enumerate(episode_keypoints):
|
| 366 |
+
obs_tp1 = demo[keypoint]
|
| 367 |
+
obs_tm1 = demo[max(0, keypoint - 1)]
|
| 368 |
+
|
| 369 |
+
if obs_tp1.is_bimanual and robot_name == "bimanual":
|
| 370 |
+
# assert isinstance(obs_tp1, BimanualObservation)
|
| 371 |
+
(
|
| 372 |
+
right_trans_indicies,
|
| 373 |
+
right_rot_grip_indicies,
|
| 374 |
+
right_ignore_collisions,
|
| 375 |
+
right_action,
|
| 376 |
+
right_attention_coordinates,
|
| 377 |
+
) = _get_action(
|
| 378 |
+
obs_tp1.right,
|
| 379 |
+
obs_tm1.right,
|
| 380 |
+
rlbench_scene_bounds,
|
| 381 |
+
voxel_sizes,
|
| 382 |
+
bounds_offset,
|
| 383 |
+
rotation_resolution,
|
| 384 |
+
crop_augmentation,
|
| 385 |
+
)
|
| 386 |
+
|
| 387 |
+
(
|
| 388 |
+
left_trans_indicies,
|
| 389 |
+
left_rot_grip_indicies,
|
| 390 |
+
left_ignore_collisions,
|
| 391 |
+
left_action,
|
| 392 |
+
left_attention_coordinates,
|
| 393 |
+
) = _get_action(
|
| 394 |
+
obs_tp1.left,
|
| 395 |
+
obs_tm1.left,
|
| 396 |
+
rlbench_scene_bounds,
|
| 397 |
+
voxel_sizes,
|
| 398 |
+
bounds_offset,
|
| 399 |
+
rotation_resolution,
|
| 400 |
+
crop_augmentation,
|
| 401 |
+
)
|
| 402 |
+
|
| 403 |
+
action = np.append(right_action, left_action)
|
| 404 |
+
|
| 405 |
+
right_ignore_collisions = np.array([right_ignore_collisions])
|
| 406 |
+
left_ignore_collisions = np.array([left_ignore_collisions])
|
| 407 |
+
|
| 408 |
+
elif robot_name == "unimanual":
|
| 409 |
+
(
|
| 410 |
+
trans_indicies,
|
| 411 |
+
rot_grip_indicies,
|
| 412 |
+
ignore_collisions,
|
| 413 |
+
action,
|
| 414 |
+
attention_coordinates,
|
| 415 |
+
) = _get_action(
|
| 416 |
+
obs_tp1,
|
| 417 |
+
obs_tm1,
|
| 418 |
+
rlbench_scene_bounds,
|
| 419 |
+
voxel_sizes,
|
| 420 |
+
bounds_offset,
|
| 421 |
+
rotation_resolution,
|
| 422 |
+
crop_augmentation,
|
| 423 |
+
)
|
| 424 |
+
gripper_pose = obs_tp1.gripper_pose
|
| 425 |
+
elif obs_tp1.is_bimanual and robot_name == "right":
|
| 426 |
+
(
|
| 427 |
+
trans_indicies,
|
| 428 |
+
rot_grip_indicies,
|
| 429 |
+
ignore_collisions,
|
| 430 |
+
action,
|
| 431 |
+
attention_coordinates,
|
| 432 |
+
) = _get_action(
|
| 433 |
+
obs_tp1.right,
|
| 434 |
+
obs_tm1.right,
|
| 435 |
+
rlbench_scene_bounds,
|
| 436 |
+
voxel_sizes,
|
| 437 |
+
bounds_offset,
|
| 438 |
+
rotation_resolution,
|
| 439 |
+
crop_augmentation,
|
| 440 |
+
)
|
| 441 |
+
gripper_pose = obs_tp1.right.gripper_pose
|
| 442 |
+
elif obs_tp1.is_bimanual and robot_name == "left":
|
| 443 |
+
(
|
| 444 |
+
trans_indicies,
|
| 445 |
+
rot_grip_indicies,
|
| 446 |
+
ignore_collisions,
|
| 447 |
+
action,
|
| 448 |
+
attention_coordinates,
|
| 449 |
+
) = _get_action(
|
| 450 |
+
obs_tp1.left,
|
| 451 |
+
obs_tm1.left,
|
| 452 |
+
rlbench_scene_bounds,
|
| 453 |
+
voxel_sizes,
|
| 454 |
+
bounds_offset,
|
| 455 |
+
rotation_resolution,
|
| 456 |
+
crop_augmentation,
|
| 457 |
+
)
|
| 458 |
+
gripper_pose = obs_tp1.left.gripper_pose
|
| 459 |
+
else:
|
| 460 |
+
logging.error("Invalid robot name %s", cfg.method.robot_name)
|
| 461 |
+
raise Exception("Invalid robot name.")
|
| 462 |
+
|
| 463 |
+
terminal = k == len(episode_keypoints) - 1
|
| 464 |
+
reward = float(terminal) * REWARD_SCALE if terminal else 0
|
| 465 |
+
|
| 466 |
+
obs_dict = observation_utils.extract_obs(
|
| 467 |
+
obs,
|
| 468 |
+
t=k,
|
| 469 |
+
prev_action=prev_action,
|
| 470 |
+
cameras=cameras,
|
| 471 |
+
episode_length=cfg.rlbench.episode_length,
|
| 472 |
+
robot_name=robot_name,
|
| 473 |
+
)
|
| 474 |
+
tokens = tokenize([description]).numpy()
|
| 475 |
+
token_tensor = torch.from_numpy(tokens).to(device)
|
| 476 |
+
sentence_emb, token_embs = clip_model.encode_text_with_embeddings(token_tensor)
|
| 477 |
+
obs_dict["lang_goal_emb"] = sentence_emb[0].float().detach().cpu().numpy()
|
| 478 |
+
obs_dict["lang_token_embs"] = token_embs[0].float().detach().cpu().numpy()
|
| 479 |
+
|
| 480 |
+
prev_action = np.copy(action)
|
| 481 |
+
|
| 482 |
+
others = {"demo": True}
|
| 483 |
+
if robot_name == "bimanual":
|
| 484 |
+
final_obs = {
|
| 485 |
+
"right_trans_action_indicies": right_trans_indicies,
|
| 486 |
+
"right_rot_grip_action_indicies": right_rot_grip_indicies,
|
| 487 |
+
"right_gripper_pose": obs_tp1.right.gripper_pose,
|
| 488 |
+
"left_trans_action_indicies": left_trans_indicies,
|
| 489 |
+
"left_rot_grip_action_indicies": left_rot_grip_indicies,
|
| 490 |
+
"left_gripper_pose": obs_tp1.left.gripper_pose,
|
| 491 |
+
"task": task,
|
| 492 |
+
"lang_goal": np.array([description], dtype=object),
|
| 493 |
+
}
|
| 494 |
+
else:
|
| 495 |
+
final_obs = {
|
| 496 |
+
"trans_action_indicies": trans_indicies,
|
| 497 |
+
"rot_grip_action_indicies": rot_grip_indicies,
|
| 498 |
+
"gripper_pose": gripper_pose,
|
| 499 |
+
"task": task,
|
| 500 |
+
"lang_goal": np.array([description], dtype=object),
|
| 501 |
+
}
|
| 502 |
+
|
| 503 |
+
others.update(final_obs)
|
| 504 |
+
others.update(obs_dict)
|
| 505 |
+
|
| 506 |
+
timeout = False
|
| 507 |
+
replay.add(action, reward, terminal, timeout, **others)
|
| 508 |
+
obs = obs_tp1
|
| 509 |
+
|
| 510 |
+
# final step
|
| 511 |
+
obs_dict_tp1 = observation_utils.extract_obs(
|
| 512 |
+
obs_tp1,
|
| 513 |
+
t=k + 1,
|
| 514 |
+
prev_action=prev_action,
|
| 515 |
+
cameras=cameras,
|
| 516 |
+
episode_length=cfg.rlbench.episode_length,
|
| 517 |
+
robot_name=cfg.method.robot_name,
|
| 518 |
+
)
|
| 519 |
+
obs_dict_tp1["lang_goal_emb"] = sentence_emb[0].float().detach().cpu().numpy()
|
| 520 |
+
obs_dict_tp1["lang_token_embs"] = token_embs[0].float().detach().cpu().numpy()
|
| 521 |
+
|
| 522 |
+
obs_dict_tp1.pop("wrist_world_to_cam", None)
|
| 523 |
+
obs_dict_tp1.update(final_obs)
|
| 524 |
+
replay.add_final(**obs_dict_tp1)
|
| 525 |
+
|
| 526 |
+
|
| 527 |
+
def fill_replay(
|
| 528 |
+
cfg: DictConfig,
|
| 529 |
+
obs_config: ObservationConfig,
|
| 530 |
+
rank: int,
|
| 531 |
+
replay: ReplayBuffer,
|
| 532 |
+
task: str,
|
| 533 |
+
clip_model=None,
|
| 534 |
+
device="cpu",
|
| 535 |
+
):
|
| 536 |
+
num_demos = cfg.rlbench.demos
|
| 537 |
+
demo_augmentation = cfg.method.demo_augmentation
|
| 538 |
+
demo_augmentation_every_n = cfg.method.demo_augmentation_every_n
|
| 539 |
+
keypoint_method = cfg.method.keypoint_method
|
| 540 |
+
|
| 541 |
+
if clip_model is None:
|
| 542 |
+
model, _ = load_clip("RN50", jit=False, device=device)
|
| 543 |
+
clip_model = build_model(model.state_dict())
|
| 544 |
+
clip_model.to(device)
|
| 545 |
+
del model
|
| 546 |
+
|
| 547 |
+
logging.debug("Filling %s replay ..." % task)
|
| 548 |
+
for d_idx in range(num_demos):
|
| 549 |
+
# load demo from disk
|
| 550 |
+
demo = rlbench_utils.get_stored_demos(
|
| 551 |
+
amount=1,
|
| 552 |
+
image_paths=False,
|
| 553 |
+
dataset_root=cfg.rlbench.demo_path,
|
| 554 |
+
variation_number=-1,
|
| 555 |
+
task_name=task,
|
| 556 |
+
obs_config=obs_config,
|
| 557 |
+
random_selection=False,
|
| 558 |
+
from_episode_number=d_idx,
|
| 559 |
+
)[0]
|
| 560 |
+
|
| 561 |
+
descs = demo._observations[0].misc["descriptions"]
|
| 562 |
+
|
| 563 |
+
# extract keypoints (a.k.a keyframes)
|
| 564 |
+
episode_keypoints = demo_loading_utils.keypoint_discovery(
|
| 565 |
+
demo, method=keypoint_method
|
| 566 |
+
)
|
| 567 |
+
|
| 568 |
+
if rank == 0:
|
| 569 |
+
logging.info(
|
| 570 |
+
f"Loading Demo({d_idx}) - found {len(episode_keypoints)} keypoints - {task}"
|
| 571 |
+
)
|
| 572 |
+
|
| 573 |
+
for i in range(len(demo) - 1):
|
| 574 |
+
if not demo_augmentation and i > 0:
|
| 575 |
+
break
|
| 576 |
+
if i % demo_augmentation_every_n != 0:
|
| 577 |
+
continue
|
| 578 |
+
|
| 579 |
+
obs = demo[i]
|
| 580 |
+
desc = descs[0]
|
| 581 |
+
# if our starting point is past one of the keypoints, then remove it
|
| 582 |
+
while len(episode_keypoints) > 0 and i >= episode_keypoints[0]:
|
| 583 |
+
episode_keypoints = episode_keypoints[1:]
|
| 584 |
+
if len(episode_keypoints) == 0:
|
| 585 |
+
break
|
| 586 |
+
_add_keypoints_to_replay(
|
| 587 |
+
cfg,
|
| 588 |
+
task,
|
| 589 |
+
replay,
|
| 590 |
+
obs,
|
| 591 |
+
demo,
|
| 592 |
+
episode_keypoints,
|
| 593 |
+
description=desc,
|
| 594 |
+
clip_model=clip_model,
|
| 595 |
+
device=device,
|
| 596 |
+
)
|
| 597 |
+
logging.debug("Replay %s filled with demos." % task)
|
| 598 |
+
|
| 599 |
+
|
| 600 |
+
def fill_multi_task_replay(
|
| 601 |
+
cfg: DictConfig,
|
| 602 |
+
obs_config: ObservationConfig,
|
| 603 |
+
rank: int,
|
| 604 |
+
replay: ReplayBuffer,
|
| 605 |
+
tasks: List[str],
|
| 606 |
+
clip_model=None,
|
| 607 |
+
):
|
| 608 |
+
tasks = cfg.rlbench.tasks
|
| 609 |
+
|
| 610 |
+
manager = Manager()
|
| 611 |
+
store = manager.dict()
|
| 612 |
+
|
| 613 |
+
# create a MP dict for storing indicies
|
| 614 |
+
# TODO(mohit): this shouldn't be initialized here
|
| 615 |
+
del replay._task_idxs
|
| 616 |
+
task_idxs = manager.dict()
|
| 617 |
+
replay._task_idxs = task_idxs
|
| 618 |
+
replay._create_storage(store)
|
| 619 |
+
replay.add_count = Value("i", 0)
|
| 620 |
+
|
| 621 |
+
# fill replay buffer in parallel across tasks
|
| 622 |
+
max_parallel_processes = cfg.replay.max_parallel_processes
|
| 623 |
+
processes = []
|
| 624 |
+
n = np.arange(len(tasks))
|
| 625 |
+
split_n = utils.split_list(n, max_parallel_processes)
|
| 626 |
+
for split in split_n:
|
| 627 |
+
for e_idx, task_idx in enumerate(split):
|
| 628 |
+
task = tasks[int(task_idx)]
|
| 629 |
+
model_device = torch.device(
|
| 630 |
+
"cuda:%s" % (e_idx % torch.cuda.device_count())
|
| 631 |
+
if torch.cuda.is_available()
|
| 632 |
+
else "cpu"
|
| 633 |
+
)
|
| 634 |
+
p = Process(
|
| 635 |
+
target=fill_replay,
|
| 636 |
+
args=(cfg, obs_config, rank, replay, task, clip_model, model_device),
|
| 637 |
+
)
|
| 638 |
+
|
| 639 |
+
p.start()
|
| 640 |
+
processes.append(p)
|
| 641 |
+
|
| 642 |
+
for p in processes:
|
| 643 |
+
p.join()
|
external/peract_bimanual/agents/rvt/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
import agents.rvt.launch_utils
|
external/peract_bimanual/agents/rvt/launch_utils.py
ADDED
|
@@ -0,0 +1,168 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
from typing import List
|
| 3 |
+
import torch
|
| 4 |
+
import numpy as np
|
| 5 |
+
|
| 6 |
+
from omegaconf import DictConfig
|
| 7 |
+
|
| 8 |
+
from yarr.agents.agent import Agent
|
| 9 |
+
from yarr.agents.agent import ActResult
|
| 10 |
+
from yarr.agents.agent import Summary
|
| 11 |
+
from yarr.agents.agent import ScalarSummary
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
from torch.nn.parallel import DistributedDataParallel as DDP
|
| 15 |
+
|
| 16 |
+
from helpers.preprocess_agent import PreprocessAgent
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
from rvt.mvt.mvt import MVT
|
| 20 |
+
from rvt.models import rvt_agent
|
| 21 |
+
from rvt.utils.peract_utils import (
|
| 22 |
+
CAMERAS,
|
| 23 |
+
SCENE_BOUNDS,
|
| 24 |
+
IMAGE_SIZE,
|
| 25 |
+
DATA_FOLDER,
|
| 26 |
+
)
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
import rvt.config as exp_cfg_mod
|
| 30 |
+
import rvt.models.rvt_agent as rvt_agent
|
| 31 |
+
import rvt.mvt.config as mvt_cfg_mod
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
def create_agent(cfg: DictConfig):
|
| 35 |
+
exp_cfg = exp_cfg_mod.get_cfg_defaults()
|
| 36 |
+
exp_cfg.bs = cfg.replay.batch_size
|
| 37 |
+
exp_cfg.tasks = ",".join(cfg.rlbench.tasks)
|
| 38 |
+
|
| 39 |
+
exp_cfg.freeze()
|
| 40 |
+
|
| 41 |
+
mvt_cfg = mvt_cfg_mod.get_cfg_defaults()
|
| 42 |
+
mvt_cfg.proprio_dim = cfg.method.low_dim_size
|
| 43 |
+
mvt_cfg.freeze()
|
| 44 |
+
|
| 45 |
+
agent = RVTAgentWrapper(
|
| 46 |
+
cfg.framework.checkpoint_name_prefix, cfg.rlbench, mvt_cfg, exp_cfg
|
| 47 |
+
)
|
| 48 |
+
|
| 49 |
+
preprocess_agent = PreprocessAgent(pose_agent=agent)
|
| 50 |
+
return preprocess_agent
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
class RVTAgentWrapper(Agent):
|
| 54 |
+
def __init__(self, checkpoint_name_prefix, rlbench_cfg, mvt_cfg, exp_cfg):
|
| 55 |
+
self._checkpoint_filename = f"{checkpoint_name_prefix}.pt"
|
| 56 |
+
self.rvt_agent = None
|
| 57 |
+
self.rlbench_cfg = rlbench_cfg
|
| 58 |
+
self.mvt_cfg = mvt_cfg
|
| 59 |
+
self.exp_cfg = exp_cfg
|
| 60 |
+
self._summaries = {}
|
| 61 |
+
|
| 62 |
+
def build(self, training: bool, device=None) -> None:
|
| 63 |
+
import torch
|
| 64 |
+
|
| 65 |
+
torch.cuda.set_device(device)
|
| 66 |
+
torch.cuda.empty_cache()
|
| 67 |
+
|
| 68 |
+
if isinstance(device, int):
|
| 69 |
+
device = f"cuda:{device}"
|
| 70 |
+
|
| 71 |
+
rvt = MVT(
|
| 72 |
+
renderer_device=device,
|
| 73 |
+
**self.mvt_cfg,
|
| 74 |
+
)
|
| 75 |
+
rvt = rvt.to(device)
|
| 76 |
+
|
| 77 |
+
if training:
|
| 78 |
+
rvt = DDP(rvt, device_ids=[device])
|
| 79 |
+
|
| 80 |
+
self.rvt_agent = rvt_agent.RVTAgent(
|
| 81 |
+
network=rvt,
|
| 82 |
+
# image_resolution=self.rlbench_cfg.camera_resolution,
|
| 83 |
+
add_lang=self.mvt_cfg.add_lang,
|
| 84 |
+
scene_bounds=self.rlbench_cfg.scene_bounds,
|
| 85 |
+
cameras=self.rlbench_cfg.cameras,
|
| 86 |
+
log_dir="/tmp/eval_run",
|
| 87 |
+
**self.exp_cfg.peract,
|
| 88 |
+
**self.exp_cfg.rvt,
|
| 89 |
+
)
|
| 90 |
+
|
| 91 |
+
self.rvt_agent.build(training, device)
|
| 92 |
+
|
| 93 |
+
def update(self, step: int, replay_sample: dict) -> dict:
|
| 94 |
+
for k, v in replay_sample.items():
|
| 95 |
+
replay_sample[k] = v.unsqueeze(1)
|
| 96 |
+
# RVT is based on the PerAct's Colab version.
|
| 97 |
+
replay_sample["lang_goal_embs"] = replay_sample["lang_token_embs"]
|
| 98 |
+
replay_sample["tasks"] = self.exp_cfg.tasks.split(",")
|
| 99 |
+
|
| 100 |
+
update_dict = self.rvt_agent.update(step, replay_sample)
|
| 101 |
+
|
| 102 |
+
for key, val in self.rvt_agent.loss_log.items():
|
| 103 |
+
self._summaries[key] = np.mean(np.array(val))
|
| 104 |
+
|
| 105 |
+
return {
|
| 106 |
+
"total_losses": update_dict["total_loss"],
|
| 107 |
+
}
|
| 108 |
+
|
| 109 |
+
return result
|
| 110 |
+
|
| 111 |
+
def act(self, step: int, observation: dict, deterministic: bool) -> ActResult:
|
| 112 |
+
return self.rvt_agent.act(step, observation, deterministic)
|
| 113 |
+
|
| 114 |
+
def reset(self) -> None:
|
| 115 |
+
self.rvt_agent.reset()
|
| 116 |
+
|
| 117 |
+
def update_summaries(self) -> List[Summary]:
|
| 118 |
+
summaries = []
|
| 119 |
+
for k, v in self._summaries.items():
|
| 120 |
+
summaries.append(ScalarSummary(f"RVT/{k}", v))
|
| 121 |
+
return summaries
|
| 122 |
+
|
| 123 |
+
def act_summaries(self) -> List[Summary]:
|
| 124 |
+
return []
|
| 125 |
+
|
| 126 |
+
def load_weights(self, savedir: str) -> None:
|
| 127 |
+
"""
|
| 128 |
+
copied from RVT
|
| 129 |
+
"""
|
| 130 |
+
device = torch.device("cuda:0")
|
| 131 |
+
weight_file = os.path.join(savedir, self._checkpoint_filename)
|
| 132 |
+
state_dict = torch.load(weight_file, map_location=device)
|
| 133 |
+
|
| 134 |
+
model = self.rvt_agent._network
|
| 135 |
+
optimizer = self.rvt_agent._optimizer
|
| 136 |
+
lr_sched = self.rvt_agent._lr_sched
|
| 137 |
+
|
| 138 |
+
if isinstance(model, DDP):
|
| 139 |
+
model = model.module
|
| 140 |
+
|
| 141 |
+
model.load_state_dict(state_dict["model_state"])
|
| 142 |
+
optimizer.load_state_dict(state_dict["optimizer_state"])
|
| 143 |
+
lr_sched.load_state_dict(state_dict["lr_sched_state"])
|
| 144 |
+
|
| 145 |
+
return self.rvt_agent.load_clip()
|
| 146 |
+
|
| 147 |
+
def save_weights(self, savedir: str) -> None:
|
| 148 |
+
os.makedirs(savedir, exist_ok=True)
|
| 149 |
+
|
| 150 |
+
weight_file = os.path.join(savedir, self._checkpoint_filename)
|
| 151 |
+
|
| 152 |
+
model = self.rvt_agent._network
|
| 153 |
+
optimizer = self.rvt_agent._optimizer
|
| 154 |
+
lr_sched = self.rvt_agent._lr_sched
|
| 155 |
+
|
| 156 |
+
if isinstance(model, DDP):
|
| 157 |
+
model = model.module
|
| 158 |
+
|
| 159 |
+
model_state = model.state_dict()
|
| 160 |
+
|
| 161 |
+
torch.save(
|
| 162 |
+
{
|
| 163 |
+
"model_state": model_state,
|
| 164 |
+
"optimizer_state": optimizer.state_dict(),
|
| 165 |
+
"lr_sched_state": lr_sched.state_dict(),
|
| 166 |
+
},
|
| 167 |
+
weight_file,
|
| 168 |
+
)
|
external/peract_bimanual/conf/config.yaml
ADDED
|
@@ -0,0 +1,52 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
ddp:
|
| 2 |
+
master_addr: "localhost"
|
| 3 |
+
master_port: "0"
|
| 4 |
+
num_devices: 1
|
| 5 |
+
|
| 6 |
+
rlbench:
|
| 7 |
+
task_name: "multi"
|
| 8 |
+
tasks: [open_drawer,slide_block_to_color_target]
|
| 9 |
+
demos: 100
|
| 10 |
+
demo_path: /my/demo/path
|
| 11 |
+
episode_length: 25
|
| 12 |
+
cameras: ["over_shoulder_left", "over_shoulder_right", "overhead", "wrist_right", "wrist_left", "front"]
|
| 13 |
+
camera_resolution: [128, 128]
|
| 14 |
+
scene_bounds: [-0.3, -0.5, 0.6, 0.7, 0.5, 1.6]
|
| 15 |
+
include_lang_goal_in_obs: True
|
| 16 |
+
|
| 17 |
+
replay:
|
| 18 |
+
batch_size: 8
|
| 19 |
+
timesteps: 1
|
| 20 |
+
prioritisation: False
|
| 21 |
+
task_uniform: True # uniform sampling of tasks for multi-task buffers
|
| 22 |
+
use_disk: True
|
| 23 |
+
path: '/tmp/arm/replay' # only used when use_disk is True.
|
| 24 |
+
max_parallel_processes: 32
|
| 25 |
+
|
| 26 |
+
framework:
|
| 27 |
+
log_freq: 100
|
| 28 |
+
save_freq: 100
|
| 29 |
+
train_envs: 1
|
| 30 |
+
replay_ratio: ${replay.batch_size}
|
| 31 |
+
transitions_before_train: 200
|
| 32 |
+
tensorboard_logging: True
|
| 33 |
+
csv_logging: True
|
| 34 |
+
training_iterations: 40000
|
| 35 |
+
gpu: 0
|
| 36 |
+
env_gpu: 0
|
| 37 |
+
logdir: '/tmp/arm_test/'
|
| 38 |
+
logging_level: 20 # https://docs.python.org/3/library/logging.html#levels
|
| 39 |
+
seeds: 1
|
| 40 |
+
start_seed: 0
|
| 41 |
+
load_existing_weights: True
|
| 42 |
+
num_weights_to_keep: 60 # older checkpoints will be deleted chronologically
|
| 43 |
+
num_workers: 0
|
| 44 |
+
record_every_n: 5
|
| 45 |
+
checkpoint_name_prefix: "checkpoint"
|
| 46 |
+
|
| 47 |
+
defaults:
|
| 48 |
+
- method: PERACT_BC
|
| 49 |
+
|
| 50 |
+
hydra:
|
| 51 |
+
run:
|
| 52 |
+
dir: ${framework.logdir}/${rlbench.task_name}/${method.name}
|
external/peract_bimanual/conf/eval.yaml
ADDED
|
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
defaults:
|
| 2 |
+
- method: PERACT_BC
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
rlbench:
|
| 6 |
+
task_name: "multi"
|
| 7 |
+
tasks: [open_drawer,slide_block_to_color_target]
|
| 8 |
+
demo_path: /my/demo/path
|
| 9 |
+
episode_length: 25
|
| 10 |
+
cameras: ["over_shoulder_left", "over_shoulder_right", "overhead", "wrist_right", "wrist_left", "front"]
|
| 11 |
+
camera_resolution: [128, 128]
|
| 12 |
+
scene_bounds: [-0.3, -0.5, 0.6, 0.7, 0.5, 1.6]
|
| 13 |
+
include_lang_goal_in_obs: True
|
| 14 |
+
time_in_state: True
|
| 15 |
+
headless: True
|
| 16 |
+
gripper_mode: 'Discrete'
|
| 17 |
+
arm_action_mode: 'EndEffectorPoseViaPlanning'
|
| 18 |
+
action_mode: 'MoveArmThenGripper'
|
| 19 |
+
|
| 20 |
+
framework:
|
| 21 |
+
tensorboard_logging: True
|
| 22 |
+
csv_logging: True
|
| 23 |
+
gpu: 0
|
| 24 |
+
logdir: '/tmp/arm_test/'
|
| 25 |
+
start_seed: 0
|
| 26 |
+
record_every_n: 5
|
| 27 |
+
|
| 28 |
+
eval_envs: 1
|
| 29 |
+
eval_from_eps_number: 0
|
| 30 |
+
eval_episodes: 5
|
| 31 |
+
eval_type: 'last' # or 'best', 'missing', or 'last'
|
| 32 |
+
eval_save_metrics: True
|
| 33 |
+
|
| 34 |
+
cinematic_recorder:
|
| 35 |
+
enabled: False
|
| 36 |
+
camera_resolution: [1280, 720]
|
| 37 |
+
fps: 30
|
| 38 |
+
rotate_speed: 0.005
|
| 39 |
+
save_path: '/tmp/videos/'
|
external/peract_bimanual/conf/hydra/job_logging/custom.yaml
ADDED
|
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version: 1
|
| 2 |
+
formatters:
|
| 3 |
+
simple:
|
| 4 |
+
format: '[%(levelname)s] - %(message)s'
|
| 5 |
+
handlers:
|
| 6 |
+
rich_console:
|
| 7 |
+
class: rich.logging.RichHandler
|
| 8 |
+
root:
|
| 9 |
+
handlers: [rich_console]
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
disable_existing_loggers: false
|
external/peract_bimanual/conf/method/ACT_BC_LANG.yaml
ADDED
|
@@ -0,0 +1,51 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# @package _group_
|
| 2 |
+
|
| 3 |
+
name: 'ACT_BC_LANG'
|
| 4 |
+
|
| 5 |
+
# Agent
|
| 6 |
+
robot_name: 'bimanual'
|
| 7 |
+
agent_type: 'bimanual'
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
train_demo_path: "/home/markus/rlbench_data_v2_128/train/"
|
| 11 |
+
|
| 12 |
+
activation: lrelu
|
| 13 |
+
lr: 1e-4
|
| 14 |
+
weight_decay: 0.000001
|
| 15 |
+
grad_clip: 0.1
|
| 16 |
+
demo_augmentation: True
|
| 17 |
+
demo_augmentation_every_n: 10
|
| 18 |
+
|
| 19 |
+
prev_action_horizon: 1
|
| 20 |
+
next_action_horizon: 10
|
| 21 |
+
|
| 22 |
+
# hyperparameters
|
| 23 |
+
lr_backbone: 1e-5
|
| 24 |
+
backbone: resnet18
|
| 25 |
+
dilation: False
|
| 26 |
+
position_embedding: sine
|
| 27 |
+
kl_weight: 100
|
| 28 |
+
chunk_size: ${method.next_action_horizon}
|
| 29 |
+
|
| 30 |
+
# transformer
|
| 31 |
+
input_dim: 16 # 7 revolute joints + 1 gripper joints
|
| 32 |
+
enc_layers: 4
|
| 33 |
+
dec_layers: 7
|
| 34 |
+
dim_feedforward: 3200
|
| 35 |
+
hidden_dim: 512
|
| 36 |
+
dropout: 0.1
|
| 37 |
+
nheads: 8
|
| 38 |
+
num_queries: ${method.next_action_horizon}
|
| 39 |
+
pre_norm: False
|
| 40 |
+
|
| 41 |
+
# unused
|
| 42 |
+
masks: False
|
| 43 |
+
|
| 44 |
+
# legacy
|
| 45 |
+
camera_names: ${rlbench.cameras}
|
| 46 |
+
|
| 47 |
+
# ..todo:: also set the following
|
| 48 |
+
|
| 49 |
+
+rlbench.episode_length: 400
|
| 50 |
+
+rlbench.arm_action_mode: JointPosition
|
| 51 |
+
+rlbench.action_mode: JointPositionActionMode
|
external/peract_bimanual/conf/method/ARM.yaml
ADDED
|
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# @package _group_
|
| 2 |
+
|
| 3 |
+
name: 'ARM'
|
| 4 |
+
activation: lrelu
|
| 5 |
+
q_conf: True
|
| 6 |
+
alpha: 0.05
|
| 7 |
+
alpha_lr: 0.0001
|
| 8 |
+
alpha_auto_tune: False
|
| 9 |
+
next_best_pose_critic_lr: 0.0025
|
| 10 |
+
next_best_pose_actor_lr: 0.001
|
| 11 |
+
next_best_pose_critic_weight_decay: 0.00001
|
| 12 |
+
next_best_pose_actor_weight_decay: 0.00001
|
| 13 |
+
crop_shape: [16, 16]
|
| 14 |
+
next_best_pose_tau: 0.005
|
| 15 |
+
next_best_pose_critic_grad_clip: 5
|
| 16 |
+
next_best_pose_actor_grad_clip: 5
|
| 17 |
+
qattention_grad_clip: 5
|
| 18 |
+
qattention_tau: 0.005
|
| 19 |
+
qattention_lr: 0.0005
|
| 20 |
+
qattention_weight_decay: 0.00001
|
| 21 |
+
qattention_lambda_qreg: 0.0000001
|
| 22 |
+
|
| 23 |
+
demo_augmentation: True
|
| 24 |
+
demo_augmentation_every_n: 10
|
external/peract_bimanual/conf/method/BC_LANG.yaml
ADDED
|
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# @package _group_
|
| 2 |
+
|
| 3 |
+
name: 'BC_LANG'
|
| 4 |
+
activation: lrelu
|
| 5 |
+
lr: 0.0005
|
| 6 |
+
weight_decay: 0.000001
|
| 7 |
+
grad_clip: 0.1
|
| 8 |
+
demo_augmentation: True
|
| 9 |
+
demo_augmentation_every_n: 10
|
external/peract_bimanual/conf/method/BIMANUAL_PERACT.yaml
ADDED
|
@@ -0,0 +1,70 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# @package _group_
|
| 2 |
+
|
| 3 |
+
name: 'BIMANUAL_PERACT'
|
| 4 |
+
|
| 5 |
+
# Agent
|
| 6 |
+
robot_name: 'bimanual'
|
| 7 |
+
agent_type: 'bimanual'
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
# Voxelization
|
| 11 |
+
image_crop_size: 64
|
| 12 |
+
bounds_offset: [0.15]
|
| 13 |
+
voxel_sizes: [100]
|
| 14 |
+
include_prev_layer: False
|
| 15 |
+
|
| 16 |
+
# Perceiver
|
| 17 |
+
num_latents: 2048
|
| 18 |
+
latent_dim: 512
|
| 19 |
+
transformer_depth: 6
|
| 20 |
+
transformer_iterations: 1
|
| 21 |
+
cross_heads: 1
|
| 22 |
+
cross_dim_head: 64
|
| 23 |
+
latent_heads: 8
|
| 24 |
+
latent_dim_head: 64
|
| 25 |
+
pos_encoding_with_lang: True
|
| 26 |
+
conv_downsample: True
|
| 27 |
+
lang_fusion_type: 'seq' # or 'concat'
|
| 28 |
+
voxel_patch_size: 5
|
| 29 |
+
voxel_patch_stride: 5
|
| 30 |
+
final_dim: 64
|
| 31 |
+
low_dim_size: 8
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
# Training
|
| 35 |
+
input_dropout: 0.1
|
| 36 |
+
attn_dropout: 0.1
|
| 37 |
+
decoder_dropout: 0.0
|
| 38 |
+
|
| 39 |
+
lr: 0.0005
|
| 40 |
+
lr_scheduler: False
|
| 41 |
+
num_warmup_steps: 3000
|
| 42 |
+
optimizer: 'lamb' # or 'adam'
|
| 43 |
+
|
| 44 |
+
lambda_weight_l2: 0.000001
|
| 45 |
+
trans_loss_weight: 1.0
|
| 46 |
+
rot_loss_weight: 1.0
|
| 47 |
+
grip_loss_weight: 1.0
|
| 48 |
+
collision_loss_weight: 1.0
|
| 49 |
+
rotation_resolution: 5
|
| 50 |
+
|
| 51 |
+
# Network
|
| 52 |
+
activation: lrelu
|
| 53 |
+
norm: None
|
| 54 |
+
|
| 55 |
+
# Augmentation
|
| 56 |
+
crop_augmentation: True
|
| 57 |
+
transform_augmentation:
|
| 58 |
+
apply_se3: True
|
| 59 |
+
aug_xyz: [0.125, 0.125, 0.125]
|
| 60 |
+
aug_rpy: [0.0, 0.0, 45.0]
|
| 61 |
+
aug_rot_resolution: ${method.rotation_resolution}
|
| 62 |
+
|
| 63 |
+
demo_augmentation: True
|
| 64 |
+
demo_augmentation_every_n: 10
|
| 65 |
+
|
| 66 |
+
# Ablations
|
| 67 |
+
no_skip_connection: False
|
| 68 |
+
no_perceiver: False
|
| 69 |
+
no_language: False
|
| 70 |
+
keypoint_method: 'heuristic'
|
external/peract_bimanual/conf/method/C2FARM_LINGUNET_BC.yaml
ADDED
|
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# @package _group_
|
| 2 |
+
|
| 3 |
+
name: 'C2FARM_LINGUNET_BC'
|
| 4 |
+
|
| 5 |
+
# Voxelization
|
| 6 |
+
image_crop_size: 64
|
| 7 |
+
bounds_offset: [0.15]
|
| 8 |
+
voxel_sizes: [32, 32]
|
| 9 |
+
include_prev_layer: False
|
| 10 |
+
|
| 11 |
+
# Training
|
| 12 |
+
lr: 0.0005
|
| 13 |
+
lr_scheduler: False
|
| 14 |
+
num_warmup_steps: 10000
|
| 15 |
+
|
| 16 |
+
lambda_weight_l2: 0.000001
|
| 17 |
+
trans_loss_weight: 1.0
|
| 18 |
+
rot_loss_weight: 1.0
|
| 19 |
+
grip_loss_weight: 1.0
|
| 20 |
+
collision_loss_weight: 1.0
|
| 21 |
+
rotation_resolution: 5
|
| 22 |
+
|
| 23 |
+
# Network
|
| 24 |
+
activation: lrelu
|
| 25 |
+
norm: None
|
| 26 |
+
|
| 27 |
+
# Augmentation
|
| 28 |
+
crop_augmentation: True
|
| 29 |
+
transform_augmentation:
|
| 30 |
+
apply_se3: True
|
| 31 |
+
aug_xyz: [0.125, 0.125, 0.125]
|
| 32 |
+
aug_rpy: [0.0, 0.0, 45.0]
|
| 33 |
+
aug_rot_resolution: ${method.rotation_resolution}
|
| 34 |
+
|
| 35 |
+
demo_augmentation: True
|
| 36 |
+
demo_augmentation_every_n: 10
|
| 37 |
+
exploration_strategy: gaussian
|
| 38 |
+
|
| 39 |
+
# Ablations
|
| 40 |
+
keypoint_method: 'heuristic'
|
external/peract_bimanual/conf/method/PERACT_BC.yaml
ADDED
|
@@ -0,0 +1,68 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# @package _group_
|
| 2 |
+
|
| 3 |
+
name: 'PERACT_BC'
|
| 4 |
+
|
| 5 |
+
# Agent
|
| 6 |
+
agent_type: 'leader_follower'
|
| 7 |
+
robot_name: 'bimanual'
|
| 8 |
+
|
| 9 |
+
# Voxelization
|
| 10 |
+
image_crop_size: 64
|
| 11 |
+
bounds_offset: [0.15]
|
| 12 |
+
voxel_sizes: [100]
|
| 13 |
+
include_prev_layer: False
|
| 14 |
+
|
| 15 |
+
# Perceiver
|
| 16 |
+
num_latents: 2048
|
| 17 |
+
latent_dim: 512
|
| 18 |
+
transformer_depth: 6
|
| 19 |
+
transformer_iterations: 1
|
| 20 |
+
cross_heads: 1
|
| 21 |
+
cross_dim_head: 64
|
| 22 |
+
latent_heads: 8
|
| 23 |
+
latent_dim_head: 64
|
| 24 |
+
pos_encoding_with_lang: True
|
| 25 |
+
conv_downsample: True
|
| 26 |
+
lang_fusion_type: 'seq' # or 'concat'
|
| 27 |
+
voxel_patch_size: 5
|
| 28 |
+
voxel_patch_stride: 5
|
| 29 |
+
final_dim: 64
|
| 30 |
+
low_dim_size: 4
|
| 31 |
+
|
| 32 |
+
# Training
|
| 33 |
+
input_dropout: 0.1
|
| 34 |
+
attn_dropout: 0.1
|
| 35 |
+
decoder_dropout: 0.0
|
| 36 |
+
|
| 37 |
+
lr: 0.0005
|
| 38 |
+
lr_scheduler: False
|
| 39 |
+
num_warmup_steps: 3000
|
| 40 |
+
optimizer: 'lamb' # or 'adam'
|
| 41 |
+
|
| 42 |
+
lambda_weight_l2: 0.000001
|
| 43 |
+
trans_loss_weight: 1.0
|
| 44 |
+
rot_loss_weight: 1.0
|
| 45 |
+
grip_loss_weight: 1.0
|
| 46 |
+
collision_loss_weight: 1.0
|
| 47 |
+
rotation_resolution: 5
|
| 48 |
+
|
| 49 |
+
# Network
|
| 50 |
+
activation: lrelu
|
| 51 |
+
norm: None
|
| 52 |
+
|
| 53 |
+
# Augmentation
|
| 54 |
+
crop_augmentation: True
|
| 55 |
+
transform_augmentation:
|
| 56 |
+
apply_se3: True
|
| 57 |
+
aug_xyz: [0.125, 0.125, 0.125]
|
| 58 |
+
aug_rpy: [0.0, 0.0, 45.0]
|
| 59 |
+
aug_rot_resolution: ${method.rotation_resolution}
|
| 60 |
+
|
| 61 |
+
demo_augmentation: True
|
| 62 |
+
demo_augmentation_every_n: 10
|
| 63 |
+
|
| 64 |
+
# Ablations
|
| 65 |
+
no_skip_connection: False
|
| 66 |
+
no_perceiver: False
|
| 67 |
+
no_language: False
|
| 68 |
+
keypoint_method: 'heuristic'
|