lsnu commited on
Commit
0d89eb9
·
verified ·
1 Parent(s): d72206d

Add files using upload-large-folder tool

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. external/peract_bimanual/.gitignore +160 -0
  2. external/peract_bimanual/ARM_LICENSE +196 -0
  3. external/peract_bimanual/Dockerfile +68 -0
  4. external/peract_bimanual/INSTALLATION.md +87 -0
  5. external/peract_bimanual/agents/__init__.py +0 -0
  6. external/peract_bimanual/agents/act_bc_lang/__init__.py +1 -0
  7. external/peract_bimanual/agents/act_bc_lang/act_bc_lang_agent.py +381 -0
  8. external/peract_bimanual/agents/act_bc_lang/act_policy.py +135 -0
  9. external/peract_bimanual/agents/act_bc_lang/detr/__init__.py +0 -0
  10. external/peract_bimanual/agents/act_bc_lang/detr/build.py +41 -0
  11. external/peract_bimanual/agents/act_bc_lang/detr/util/__init__.py +1 -0
  12. external/peract_bimanual/agents/act_bc_lang/launch_utils.py +456 -0
  13. external/peract_bimanual/agents/agent_factory.py +111 -0
  14. external/peract_bimanual/agents/arm/launch_utils.py +441 -0
  15. external/peract_bimanual/agents/arm/next_best_pose_agent.py +526 -0
  16. external/peract_bimanual/agents/arm/qattention_agent.py +247 -0
  17. external/peract_bimanual/agents/baselines/__init__.py +0 -0
  18. external/peract_bimanual/agents/baselines/bc_lang/__init__.py +1 -0
  19. external/peract_bimanual/agents/baselines/bc_lang/bc_lang_agent.py +148 -0
  20. external/peract_bimanual/agents/baselines/bc_lang/launch_utils.py +368 -0
  21. external/peract_bimanual/agents/baselines/vit_bc_lang/__init__.py +1 -0
  22. external/peract_bimanual/agents/baselines/vit_bc_lang/launch_utils.py +372 -0
  23. external/peract_bimanual/agents/baselines/vit_bc_lang/vit_bc_lang_agent.py +148 -0
  24. external/peract_bimanual/agents/bimanual_peract/__init__.py +1 -0
  25. external/peract_bimanual/agents/bimanual_peract/launch_utils.py +93 -0
  26. external/peract_bimanual/agents/bimanual_peract/perceiver_lang_io.py +549 -0
  27. external/peract_bimanual/agents/bimanual_peract/qattention_peract_bc_agent.py +1063 -0
  28. external/peract_bimanual/agents/bimanual_peract/qattention_stack_agent.py +202 -0
  29. external/peract_bimanual/agents/c2farm_lingunet_bc/__init__.py +1 -0
  30. external/peract_bimanual/agents/c2farm_lingunet_bc/launch_utils.py +519 -0
  31. external/peract_bimanual/agents/c2farm_lingunet_bc/networks.py +301 -0
  32. external/peract_bimanual/agents/c2farm_lingunet_bc/qattention_lingunet_bc_agent.py +790 -0
  33. external/peract_bimanual/agents/c2farm_lingunet_bc/qattention_stack_agent.py +136 -0
  34. external/peract_bimanual/agents/peract_bc/__init__.py +1 -0
  35. external/peract_bimanual/agents/peract_bc/launch_utils.py +94 -0
  36. external/peract_bimanual/agents/peract_bc/perceiver_lang_io.py +426 -0
  37. external/peract_bimanual/agents/peract_bc/qattention_peract_bc_agent.py +808 -0
  38. external/peract_bimanual/agents/peract_bc/qattention_stack_agent.py +132 -0
  39. external/peract_bimanual/agents/replay_utils.py +643 -0
  40. external/peract_bimanual/agents/rvt/__init__.py +1 -0
  41. external/peract_bimanual/agents/rvt/launch_utils.py +168 -0
  42. external/peract_bimanual/conf/config.yaml +52 -0
  43. external/peract_bimanual/conf/eval.yaml +39 -0
  44. external/peract_bimanual/conf/hydra/job_logging/custom.yaml +12 -0
  45. external/peract_bimanual/conf/method/ACT_BC_LANG.yaml +51 -0
  46. external/peract_bimanual/conf/method/ARM.yaml +24 -0
  47. external/peract_bimanual/conf/method/BC_LANG.yaml +9 -0
  48. external/peract_bimanual/conf/method/BIMANUAL_PERACT.yaml +70 -0
  49. external/peract_bimanual/conf/method/C2FARM_LINGUNET_BC.yaml +40 -0
  50. external/peract_bimanual/conf/method/PERACT_BC.yaml +68 -0
external/peract_bimanual/.gitignore ADDED
@@ -0,0 +1,160 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Byte-compiled / optimized / DLL files
2
+ __pycache__/
3
+ *.py[cod]
4
+ *$py.class
5
+
6
+ # C extensions
7
+ *.so
8
+
9
+ # Distribution / packaging
10
+ .Python
11
+ build/
12
+ develop-eggs/
13
+ dist/
14
+ downloads/
15
+ eggs/
16
+ .eggs/
17
+ lib/
18
+ lib64/
19
+ parts/
20
+ sdist/
21
+ var/
22
+ wheels/
23
+ share/python-wheels/
24
+ *.egg-info/
25
+ .installed.cfg
26
+ *.egg
27
+ MANIFEST
28
+
29
+ # PyInstaller
30
+ # Usually these files are written by a python script from a template
31
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
32
+ *.manifest
33
+ *.spec
34
+
35
+ # Installer logs
36
+ pip-log.txt
37
+ pip-delete-this-directory.txt
38
+
39
+ # Unit test / coverage reports
40
+ htmlcov/
41
+ .tox/
42
+ .nox/
43
+ .coverage
44
+ .coverage.*
45
+ .cache
46
+ nosetests.xml
47
+ coverage.xml
48
+ *.cover
49
+ *.py,cover
50
+ .hypothesis/
51
+ .pytest_cache/
52
+ cover/
53
+
54
+ # Translations
55
+ *.mo
56
+ *.pot
57
+
58
+ # Django stuff:
59
+ *.log
60
+ local_settings.py
61
+ db.sqlite3
62
+ db.sqlite3-journal
63
+
64
+ # Flask stuff:
65
+ instance/
66
+ .webassets-cache
67
+
68
+ # Scrapy stuff:
69
+ .scrapy
70
+
71
+ # Sphinx documentation
72
+ docs/_build/
73
+
74
+ # PyBuilder
75
+ .pybuilder/
76
+ target/
77
+
78
+ # Jupyter Notebook
79
+ .ipynb_checkpoints
80
+
81
+ # IPython
82
+ profile_default/
83
+ ipython_config.py
84
+
85
+ # pyenv
86
+ # For a library or package, you might want to ignore these files since the code is
87
+ # intended to run in multiple environments; otherwise, check them in:
88
+ # .python-version
89
+
90
+ # pipenv
91
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
92
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
93
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
94
+ # install all needed dependencies.
95
+ #Pipfile.lock
96
+
97
+ # poetry
98
+ # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
99
+ # This is especially recommended for binary packages to ensure reproducibility, and is more
100
+ # commonly ignored for libraries.
101
+ # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
102
+ #poetry.lock
103
+
104
+ # pdm
105
+ # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
106
+ #pdm.lock
107
+ # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
108
+ # in version control.
109
+ # https://pdm.fming.dev/#use-with-ide
110
+ .pdm.toml
111
+
112
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
113
+ __pypackages__/
114
+
115
+ # Celery stuff
116
+ celerybeat-schedule
117
+ celerybeat.pid
118
+
119
+ # SageMath parsed files
120
+ *.sage.py
121
+
122
+ # Environments
123
+ .env
124
+ .venv
125
+ env/
126
+ venv/
127
+ ENV/
128
+ env.bak/
129
+ venv.bak/
130
+
131
+ # Spyder project settings
132
+ .spyderproject
133
+ .spyproject
134
+
135
+ # Rope project settings
136
+ .ropeproject
137
+
138
+ # mkdocs documentation
139
+ /site
140
+
141
+ # mypy
142
+ .mypy_cache/
143
+ .dmypy.json
144
+ dmypy.json
145
+
146
+ # Pyre type checker
147
+ .pyre/
148
+
149
+ # pytype static type analyzer
150
+ .pytype/
151
+
152
+ # Cython debug symbols
153
+ cython_debug/
154
+
155
+ # PyCharm
156
+ # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
157
+ # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
158
+ # and can be added to the global gitignore or merged into this file. For a more nuclear
159
+ # option (not recommended) you can uncomment the following to ignore the entire idea folder.
160
+ #.idea/
external/peract_bimanual/ARM_LICENSE ADDED
@@ -0,0 +1,196 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Q-attention: Enabling Efficient Learning for Vision-based Robotic Manipulation
2
+
3
+ LICENCE AGREEMENT
4
+
5
+ WE (Imperial College of Science, Technology and Medicine, (“Imperial College London”))
6
+ ARE WILLING TO LICENSE THIS SOFTWARE TO YOU (a licensee “You”) ONLY ON THE
7
+ CONDITION THAT YOU ACCEPT ALL OF THE TERMS CONTAINED IN THE
8
+ FOLLOWING AGREEMENT. PLEASE READ THE AGREEMENT CAREFULLY BEFORE
9
+ DOWNLOADING THE SOFTWARE. BY EXERCISING THE OPTION TO DOWNLOAD
10
+ THE SOFTWARE YOU AGREE TO BE BOUND BY THE TERMS OF THE AGREEMENT.
11
+ SOFTWARE LICENCE AGREEMENT (EXCLUDING BSD COMPONENTS)
12
+
13
+ 1.This Agreement pertains to a worldwide, non-exclusive, temporary, fully paid-up, royalty
14
+ free, non-transferable, non-sub- licensable licence (the “Licence”) to use the Q-attention
15
+ source code, including any modification, part or derivative (the “Software”).
16
+ Ownership and Licence. Your rights to use and download the Software onto your computer,
17
+ and all other copies that You are authorised to make, are specified in this Agreement.
18
+ However, we (or our licensors) retain all rights, including but not limited to all copyright and
19
+ other intellectual property rights anywhere in the world, in the Software not expressly
20
+ granted to You in this Agreement.
21
+
22
+ 2. Permitted use of the Licence:
23
+
24
+ (a) You may download and install the Software onto one computer or server for use in
25
+ accordance with Clause 2(b) of this Agreement provided that You ensure that the Software is
26
+ not accessible by other users unless they have themselves accepted the terms of this licence
27
+ agreement.
28
+
29
+ (b) You may use the Software solely for non-commercial, internal or academic research
30
+ purposes and only in accordance with the terms of this Agreement. You may not use the
31
+ Software for commercial purposes, including but not limited to (1) integration of all or part of
32
+ the source code or the Software into a product for sale or licence by or on behalf of You to
33
+ third parties or (2) use of the Software or any derivative of it for research to develop software
34
+ products for sale or licence to a third party or (3) use of the Software or any derivative of it
35
+ for research to develop non-software products for sale or licence to a third party, or (4) use of
36
+ the Software to provide any service to an external organisation for which payment is
37
+ received.
38
+
39
+ Should You wish to use the Software for commercial purposes, You shall
40
+ email researchcontracts.engineering@imperial.ac.uk .
41
+
42
+ (c) Right to Copy. You may copy the Software for back-up and archival purposes, provided
43
+ that each copy is kept in your possession and provided You reproduce our copyright notice
44
+ (set out in Schedule 1) on each copy.
45
+
46
+ (d) Transfer and sub-licensing. You may not rent, lend, or lease the Software and You may
47
+ not transmit, transfer or sub-license this licence to use the Software or any of your rights or
48
+ obligations under this Agreement to another party.
49
+
50
+ (e) Identity of Licensee. The licence granted herein is personal to You. You shall not permit
51
+ any third party to access, modify or otherwise use the Software nor shall You access modify
52
+ or otherwise use the Software on behalf of any third party. If You wish to obtain a licence for
53
+ mutiple users or a site licence for the Software please contact us
54
+ at researchcontracts.engineering@imperial.ac.uk .
55
+
56
+ (f) Publications and presentations. You may make public, results or data obtained from,
57
+ dependent on or arising from research carried out using the Software, provided that any such
58
+ presentation or publication identifies the Software as the source of the results or the data,
59
+ including the Copyright Notice given in each element of the Software, and stating that the
60
+ Software has been made available for use by You under licence from Imperial College London
61
+ and You provide a copy of any such publication to Imperial College London.
62
+
63
+ 3. Prohibited Uses. You may not, without written permission from us
64
+ at researchcontracts.engineering@imperial.ac.uk :
65
+
66
+ (a) Use, copy, modify, merge, or transfer copies of the Software or any documentation
67
+ provided by us which relates to the Software except as provided in this Agreement;
68
+
69
+ (b) Use any back-up or archival copies of the Software (or allow anyone else to use such
70
+ copies) for any purpose other than to replace the original copy in the event it is destroyed or
71
+ becomes defective; or
72
+
73
+ (c) Disassemble, decompile or "unlock", reverse translate, or in any manner decode the
74
+ Software for any reason.
75
+
76
+ 4. Warranty Disclaimer
77
+
78
+ (a) Disclaimer. The Software has been developed for research purposes only. You
79
+ acknowledge that we are providing the Software to You under this licence agreement free of
80
+ charge and on condition that the disclaimer set out below shall apply. We do not represent or
81
+ warrant that the Software as to: (i) the quality, accuracy or reliability of the Software; (ii) the
82
+ suitability of the Software for any particular use or for use under any specific conditions; and
83
+ (iii) whether use of the Software will infringe third-party rights.
84
+ You acknowledge that You have reviewed and evaluated the Software to determine that it
85
+ meets your needs and that You assume all responsibility and liability for determining the
86
+ suitability of the Software as fit for your particular purposes and requirements. Subject to
87
+ Clause 4(b), we exclude and expressly disclaim all express and implied representations,
88
+ warranties, conditions and terms not stated herein (including the implied conditions or
89
+ warranties of satisfactory quality, merchantable quality, merchantability and fitness for
90
+ purpose).
91
+
92
+ (b) Savings. Some jurisdictions may imply warranties, conditions or terms or impose
93
+ obligations upon us which cannot, in whole or in part, be excluded, restricted or modified or
94
+ otherwise do not allow the exclusion of implied warranties, conditions or terms, in which
95
+ case the above warranty disclaimer and exclusion will only apply to You to the extent
96
+ permitted in the relevant jurisdiction and does not in any event exclude any implied
97
+ warranties, conditions or terms which may not under applicable law be excluded.
98
+
99
+ (c) Imperial College London disclaims all responsibility for the use which is made of the
100
+ Software and any liability for the outcomes arising from using the Software.
101
+
102
+ 5. Limitation of Liability
103
+
104
+ (a) You acknowledge that we are providing the Software to You under this licence agreement
105
+ free of charge and on condition that the limitation of liability set out below shall apply.
106
+ Accordingly, subject to Clause 5(b), we exclude all liability whether in contract, tort,
107
+ negligence or otherwise, in respect of the Software and/or any related documentation
108
+ provided to You by us including, but not limited to, liability for loss or corruption of data,
109
+ loss of contracts, loss of income, loss of profits, loss of cover and any consequential or indirect
110
+ loss or damage of any kind arising out of or in connection with this licence agreement,
111
+ however caused. This exclusion shall apply even if we have been advised of the possibility of
112
+ such loss or damage.
113
+
114
+ (b) You agree to indemnify Imperial College London and hold it harmless from and against
115
+ any and all claims, damages and liabilities asserted by third parties (including claims for
116
+ negligence) which arise directly or indirectly from the use of the Software or any derivative
117
+ of it or the sale of any products based on the Software. You undertake to make no liability
118
+ claim against any employee, student, agent or appointee of Imperial College London, in
119
+ connection with this Licence or the Software.
120
+
121
+ (c) Nothing in this Agreement shall have the effect of excluding or limiting our statutory
122
+ liability.
123
+
124
+ (d) Some jurisdictions do not allow these limitations or exclusions either wholly or in part,
125
+ and, to that extent, they may not apply to you. Nothing in this licence agreement will affect
126
+ your statutory rights or other relevant statutory provisions which cannot be excluded,
127
+ restricted or modified, and its terms and conditions must be read and construed subject to any
128
+ such statutory rights and/or provisions.
129
+
130
+ 6. Confidentiality. You agree not to disclose any confidential information provided to You by
131
+ us pursuant to this Agreement to any third party without our prior written consent. The
132
+ obligations in this Clause 6 shall survive the termination of this Agreement for any reason.
133
+
134
+ 7. Termination.
135
+
136
+ (a) We may terminate this licence agreement and your right to use the Software at any time
137
+ with immediate effect upon written notice to You.
138
+
139
+ (b) This licence agreement and your right to use the Software automatically terminate if You:
140
+ (i) fail to comply with any provisions of this Agreement; or
141
+ (ii) destroy the copies of the Software in your possession, or voluntarily return the Software
142
+ to us.
143
+
144
+ (c) Upon termination You will destroy all copies of the Software.
145
+
146
+ (d) Otherwise, the restrictions on your rights to use the Software will expire 10 (ten) years
147
+ after first use of the Software under this licence agreement.
148
+
149
+ 8. Miscellaneous Provisions.
150
+
151
+ (a) This Agreement will be governed by and construed in accordance with the substantive
152
+ laws of England and Wales whose courts shall have exclusive jurisdiction over all disputes
153
+ which may arise between us.
154
+
155
+ (b) This is the entire agreement between us relating to the Software, and supersedes any prior
156
+ purchase order, communications, advertising or representations concerning the Software.
157
+
158
+ (c) No change or modification of this Agreement will be valid unless it is in writing, and is
159
+ signed by us.
160
+
161
+ (d) The unenforceability or invalidity of any part of this Agreement will not affect the
162
+ enforceability or validity of the remaining parts.
163
+
164
+ BSD Elements of the Software
165
+
166
+ For BSD elements of the Software, the following terms shall apply:
167
+
168
+ Copyright as indicated in the header of the individual element of the Software.
169
+
170
+ All rights reserved.
171
+
172
+ Redistribution and use in source and binary forms, with or without modification, are
173
+ permitted provided that the following conditions are met:
174
+
175
+ 1. Redistributions of source code must retain the above copyright notice, this list of
176
+ conditions and the following disclaimer.
177
+
178
+ 2. Redistributions in binary form must reproduce the above copyright notice, this list of
179
+ conditions and the following disclaimer in the documentation and/or other materials
180
+ provided with the distribution.
181
+
182
+ 3. Neither the name of the copyright holder nor the names of its contributors may be used to
183
+ endorse or promote products derived from this software without specific prior written
184
+ permission.
185
+
186
+ THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
187
+ "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
188
+ LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
189
+ A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
190
+ HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
191
+ SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
192
+ LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
193
+ DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON
194
+ ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
195
+ (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
196
+ OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
external/peract_bimanual/Dockerfile ADDED
@@ -0,0 +1,68 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Use the NVIDIA base image for CUDA
2
+ FROM nvcr.io/nvidia/cuda:12.3.2-cudnn9-devel-ubuntu20.04
3
+
4
+ # Set environment variables
5
+ ENV COPPELIASIM_ROOT=${HOME}/code/coppelia_sim
6
+ ENV LD_LIBRARY_PATH=$LD_LIBRARY_PATH:$COPPELIASIM_ROOT
7
+ ENV QT_QPA_PLATFORM_PLUGIN_PATH=$COPPELIASIM_ROOT
8
+ ENV DEBIAN_FRONTEND=noninteractive
9
+ ENV TZ=America/Los_Angeles
10
+ ENV CONDA_ALWAYS_YES=true
11
+ ENV FORCE_CUDA=1
12
+ ENV TORCH_CUDA_ARCH_LIST="5.0;5.2;5.3;6.0;6.1;6.2;7.0;7.2;7.5;8.0;8.6;8.7;8.9;9.0+PTX"
13
+
14
+ # Create necessary directories
15
+ RUN mkdir -p ${HOME}/code
16
+
17
+ # Install dependencies and essential tools
18
+ RUN apt-get update && apt-get install -y \
19
+ tzdata sudo curl git vim htop tar bzip2 pigz rsync less mlocate \
20
+ build-essential gdb ca-certificates stress sysstat itop \
21
+ xauth xvfb mesa-utils mesa-utils-extra x11-apps \
22
+ xorg xserver-xorg-core libxv1 x11-xserver-utils libxcb-randr0-dev \
23
+ libxrender-dev libxkbcommon-dev libxkbcommon-x11-0 libavcodec-dev \
24
+ libavformat-dev libswscale-dev '^libxcb.*-dev' libx11-xcb-dev \
25
+ libglu1-mesa-dev libxrender-dev libxi-dev libxkbcommon-dev \
26
+ libxkbcommon-x11-dev libegl1-mesa libarchive-dev libarchive13 \
27
+ && rm -rf /var/lib/apt/lists/*
28
+
29
+ # Install VirtualGL
30
+ RUN TEMP_DIR=$(mktemp -d -p /) && cd $TEMP_DIR && \
31
+ curl -L -o virtualgl.deb https://sourceforge.net/projects/virtualgl/files/3.1/virtualgl_3.1_amd64.deb/download && \
32
+ dpkg -i virtualgl.deb && \
33
+ /opt/VirtualGL/bin/vglserver_config +glx +egl +s +f +t && \
34
+ rm -rf $TEMP_DIR
35
+
36
+ RUN mkdir ${HOME}/.ssh && chmod -R 700 ${HOME}/.ssh
37
+
38
+ RUN ssh-keyscan github.com >> ${HOME}/.ssh/known_hosts
39
+
40
+ RUN curl -L -O https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh
41
+ RUN bash Miniconda3-latest-Linux-x86_64.sh -b -p /opt/conda
42
+ RUN export PATH=/opt/conda/bin:${PATH}
43
+
44
+ # Install code and dependencies
45
+
46
+ WORKDIR ${HOME}/code
47
+
48
+ RUN eval "$(/opt/conda/bin/conda shell.bash hook)" && conda init bash
49
+ RUN eval "$(/opt/conda/bin/conda shell.bash hook)" && conda install mamba -c conda-forge
50
+ #RUN conda config --set auto_activate_base false
51
+
52
+
53
+ RUN git clone https://github.com/markusgrotz/peract_bimanual.git ${HOME}/code/peract_bimanual
54
+
55
+
56
+ RUN eval "$(/opt/conda/bin/conda shell.bash hook)" && ${HOME}/code/peract_bimanual/scripts/install_dependencies.sh
57
+
58
+
59
+ # Activate the environment by default
60
+ RUN echo ". /opt/conda/etc/profile.d/conda.sh" >> ~/.bashrc && \
61
+ echo "conda activate rlbench" >> ~/.bashrc
62
+
63
+
64
+ WORKDIR /root/code/peract_bimanual
65
+
66
+ # Default command
67
+ CMD ["/bin/bash"]
68
+
external/peract_bimanual/INSTALLATION.md ADDED
@@ -0,0 +1,87 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # INSTALLATION
2
+
3
+ To install the dependencies execute the `scripts/install_dependencies.sh`
4
+
5
+ ```bash
6
+ scripts/install_conda.sh # Skip this step if you already have conda installed.
7
+ scripts/install_dependencies.sh
8
+ ```
9
+
10
+ Please see the [README](README.md) for a quick start instruction.
11
+
12
+
13
+ Alternatively, you can follow the detailed instructions to setup the software from scratch
14
+
15
+ #### 2. PyRep and Coppelia Simulator
16
+
17
+ Follow instructions from my [PyRep fork](https://github.com/markusgrotz/PyRep); reproduced here for convenience:
18
+
19
+ PyRep requires version **4.1** of CoppeliaSim. Download:
20
+ - [Ubuntu 20.04](https://www.coppeliarobotics.com/files/V4_1_0/CoppeliaSim_Edu_V4_1_0_Ubuntu20_04.tar.xz)
21
+
22
+ Once you have downloaded CoppeliaSim, you can pull PyRep from git:
23
+
24
+ ```bash
25
+ cd <install_dir>
26
+ git clone https://github.com/markusgrotz/PyRep.git
27
+ cd PyRep
28
+ ```
29
+
30
+ Add the following to your *~/.bashrc* file: (__NOTE__: the 'EDIT ME' in the first line)
31
+
32
+ ```bash
33
+ export COPPELIASIM_ROOT=<EDIT ME>/PATH/TO/COPPELIASIM/INSTALL/DIR
34
+ export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:$COPPELIASIM_ROOT
35
+ export QT_QPA_PLATFORM_PLUGIN_PATH=$COPPELIASIM_ROOT
36
+ ```
37
+
38
+ Remember to source your bashrc (`source ~/.bashrc`) or
39
+ zshrc (`source ~/.zshrc`) after this.
40
+
41
+ **Warning**: CoppeliaSim might cause conflicts with ROS workspaces.
42
+
43
+ Finally install the python library:
44
+
45
+ ```bash
46
+ pip install -e .
47
+ ```
48
+
49
+ You should be good to go!
50
+ You could try running one of the examples in the *examples/* folder.
51
+
52
+ #### 3. RLBench
53
+
54
+ PerAct uses my [RLBench fork](https://github.com/markusgrotz/RLBench/tree/peract).
55
+
56
+ ```bash
57
+ cd <install_dir>
58
+ git clone https://github.com/markusgrotz/RLBench.git
59
+
60
+ cd RLBench
61
+ pip install -e .
62
+ ```
63
+
64
+ For [running in headless mode](https://github.com/MohitShridhar/RLBench/tree/peract#running-headless), tasks setups, and other issues, please refer to the [official repo](https://github.com/stepjam/RLBench).
65
+
66
+ #### 4. YARR
67
+
68
+ PerAct uses my [YARR fork](https://github.com/markusgrotz/YARR/).
69
+
70
+ ```bash
71
+ cd <install_dir>
72
+ git clone https://github.com/markusgrotz/YARR.git
73
+
74
+ cd YARR
75
+ pip install -e .
76
+ ```
77
+
78
+
79
+
80
+ #### RVT baseline
81
+
82
+ pip install git+https://github.com/NVlabs/RVT.git
83
+ pip install -e .
84
+
85
+
86
+
87
+
external/peract_bimanual/agents/__init__.py ADDED
File without changes
external/peract_bimanual/agents/act_bc_lang/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ import agents.act_bc_lang.launch_utils
external/peract_bimanual/agents/act_bc_lang/act_bc_lang_agent.py ADDED
@@ -0,0 +1,381 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import copy
2
+ import logging
3
+ from functools import lru_cache
4
+ import pickle
5
+ import os
6
+ from typing import List
7
+ import numpy as np
8
+
9
+ import torch
10
+ import torch.nn as nn
11
+ import torch.nn.functional as F
12
+ from yarr.agents.agent import Agent, Summary, ActResult, ScalarSummary, HistogramSummary
13
+
14
+ from helpers import utils
15
+ from helpers.utils import stack_on_channel
16
+
17
+ from helpers.clip.core.clip import build_model, load_clip
18
+
19
+ NAME = "ActBCLangAgent"
20
+
21
+
22
+ class ActBCLangAgent(Agent):
23
+ def __init__(
24
+ self,
25
+ actor_network: nn.Module,
26
+ camera_names: List[str],
27
+ lr: float = 0.01,
28
+ weight_decay: float = 1e-5,
29
+ grad_clip: float = 20.0,
30
+ episode_length: int = 400,
31
+ train_demo_path=None,
32
+ task_name=None,
33
+ ):
34
+ self._camera_names = camera_names
35
+ self._actor = actor_network
36
+ self._lr = lr
37
+ self._weight_decay = weight_decay
38
+ self._grad_clip = grad_clip
39
+ self._episode_length = episode_length
40
+ self.train_demo_path = train_demo_path
41
+ self.task_name = task_name
42
+
43
+ def build(self, training: bool, device: torch.device = None):
44
+ if device is None:
45
+ device = torch.device("cpu")
46
+ self._actor = self._actor.to(device).train(training)
47
+ self._actor_optimizer = self._actor.configure_optimizers()
48
+
49
+ self._device = device
50
+
51
+ def reset(self):
52
+ super(ActBCLangAgent, self).reset()
53
+
54
+ self._timestep = 0
55
+ # .. input_dim = input_dim * 2 for bimanual
56
+ self._all_time_actions = torch.zeros(
57
+ [
58
+ self._episode_length,
59
+ self._episode_length + self._actor.model.num_queries,
60
+ self._actor.model.input_dim,
61
+ ]
62
+ ).to(self._device)
63
+ self._all_actions = None
64
+
65
+ def _grad_step(self, loss, opt, model_params=None, clip=None):
66
+ opt.zero_grad()
67
+ loss.backward()
68
+ if clip is not None and model_params is not None:
69
+ nn.utils.clip_grad_value_(model_params, clip)
70
+ opt.step()
71
+
72
+ @lru_cache()
73
+ def train_stats(self):
74
+ right_joint_positions = []
75
+ left_joint_positions = []
76
+
77
+ right_gripper_positions = []
78
+ left_gripper_positions = []
79
+
80
+ episodes_dir = (
81
+ f"{self.train_demo_path}/{self.task_name}/all_variations/episodes/"
82
+ )
83
+
84
+ for episode in os.listdir(episodes_dir):
85
+ with open(
86
+ os.path.join(episodes_dir, episode, "low_dim_obs.pkl"), "br"
87
+ ) as f:
88
+ d = pickle.load(f)
89
+
90
+ for o in d:
91
+ right_joint_positions.append(o.right.joint_positions)
92
+ left_joint_positions.append(o.left.joint_positions)
93
+
94
+ right_gripper_positions.append([o.right.gripper_joint_positions[0]])
95
+ left_gripper_positions.append([o.left.gripper_joint_positions[0]])
96
+
97
+ right_joint_positions = np.asarray(right_joint_positions, dtype=np.float32)
98
+ left_joint_positions = np.asarray(left_joint_positions, dtype=np.float32)
99
+
100
+ right_gripper_positions = np.asarray(right_gripper_positions, dtype=np.float32)
101
+ left_gripper_positions = np.asarray(left_gripper_positions, dtype=np.float32)
102
+
103
+ stats = {
104
+ "right_joints_mean": right_joint_positions.mean(axis=0),
105
+ "right_joints_std": right_joint_positions.std(axis=0),
106
+ "left_joints_mean": left_joint_positions.mean(axis=0),
107
+ "left_joints_std": left_joint_positions.std(axis=0),
108
+ "right_gripper_mean": right_gripper_positions.mean(axis=0),
109
+ "right_gripper_std": right_gripper_positions.std(axis=0),
110
+ "left_gripper_mean": left_gripper_positions.mean(axis=0),
111
+ "left_gripper_std": left_gripper_positions.std(axis=0),
112
+ }
113
+
114
+ return {k: torch.from_numpy(v).to(self._device) for k, v in stats.items()}
115
+
116
+ def normalize_z(self, data, mean, std):
117
+ return (data - mean) / std
118
+
119
+ def unnormalize_z(self, data, mean, std):
120
+ return data * std + mean
121
+
122
+ def preprocess_qpos(self, observation: dict):
123
+ stats = self.train_stats()
124
+
125
+ right_qrev = self.normalize_z(
126
+ observation["right_joint_positions"][:, 0],
127
+ stats["right_joints_mean"],
128
+ stats["right_joints_std"],
129
+ )
130
+ right_qgripper = self.normalize_z(
131
+ observation["right_gripper_joint_positions"][:, 0],
132
+ stats["right_gripper_mean"],
133
+ stats["right_gripper_std"],
134
+ )
135
+ left_qrev = self.normalize_z(
136
+ observation["left_joint_positions"][:, 0],
137
+ stats["left_joints_mean"],
138
+ stats["left_joints_std"],
139
+ )
140
+ left_qgripper = self.normalize_z(
141
+ observation["left_gripper_joint_positions"][:, 0],
142
+ stats["left_gripper_mean"],
143
+ stats["left_gripper_std"],
144
+ )
145
+ qpos = torch.cat(
146
+ [
147
+ right_qrev,
148
+ right_qgripper[:, 0].unsqueeze(-1),
149
+ left_qrev,
150
+ left_qgripper[:, 0].unsqueeze(-1),
151
+ ],
152
+ dim=-1,
153
+ )
154
+
155
+ return qpos
156
+
157
+ def preprocess_action(self, replay_sample: dict):
158
+ stats = self.train_stats()
159
+
160
+ right_qrev = self.normalize_z(
161
+ replay_sample["right_prev_joint_positions"][:, 0],
162
+ stats["right_joints_mean"],
163
+ stats["right_joints_std"],
164
+ )
165
+ right_qgripper = self.normalize_z(
166
+ replay_sample["right_prev_gripper_joint_positions"][:, 0],
167
+ stats["right_gripper_mean"],
168
+ stats["right_gripper_std"],
169
+ )
170
+ left_qrev = self.normalize_z(
171
+ replay_sample["left_prev_joint_positions"][:, 0],
172
+ stats["left_joints_mean"],
173
+ stats["left_joints_std"],
174
+ )
175
+ left_qgripper = self.normalize_z(
176
+ replay_sample["left_prev_gripper_joint_positions"][:, 0],
177
+ stats["left_gripper_mean"],
178
+ stats["left_gripper_std"],
179
+ )
180
+ qpos = torch.cat(
181
+ [
182
+ right_qrev,
183
+ right_qgripper[:, 0].unsqueeze(-1),
184
+ left_qrev,
185
+ left_qgripper[:, 0].unsqueeze(-1),
186
+ ],
187
+ dim=-1,
188
+ )
189
+
190
+ right_action_rev = self.normalize_z(
191
+ replay_sample["right_next_joint_positions"],
192
+ stats["right_joints_mean"],
193
+ stats["right_joints_std"],
194
+ )
195
+ right_action_gripper = self.normalize_z(
196
+ replay_sample["right_next_gripper_joint_positions"],
197
+ stats["right_gripper_mean"],
198
+ stats["right_gripper_std"],
199
+ )
200
+ left_action_rev = self.normalize_z(
201
+ replay_sample["left_next_joint_positions"],
202
+ stats["left_joints_mean"],
203
+ stats["left_joints_std"],
204
+ )
205
+ left_action_gripper = self.normalize_z(
206
+ replay_sample["left_next_gripper_joint_positions"],
207
+ stats["left_gripper_mean"],
208
+ stats["left_gripper_std"],
209
+ )
210
+ action_seq = torch.cat(
211
+ [
212
+ right_action_rev,
213
+ right_action_gripper[:, :, 0].unsqueeze(-1),
214
+ left_action_rev,
215
+ left_action_gripper[:, :, 0].unsqueeze(-1),
216
+ ],
217
+ dim=-1,
218
+ )
219
+
220
+ return qpos, action_seq
221
+
222
+ def preprocess_images(self, replay_sample: dict):
223
+ stacked_rgb = []
224
+ stacked_point_cloud = []
225
+
226
+ for camera in self._camera_names:
227
+ rgb = replay_sample["%s_rgb" % camera]
228
+ rgb = rgb if rgb.dim() == 4 else rgb[:, 0]
229
+ stacked_rgb.append(rgb)
230
+
231
+ point_cloud = replay_sample["%s_point_cloud" % camera]
232
+ point_cloud = point_cloud if point_cloud.dim() == 4 else point_cloud[:, 0]
233
+ stacked_point_cloud.append(point_cloud)
234
+
235
+ stacked_rgb = torch.stack(stacked_rgb, dim=1)
236
+ stacked_point_cloud = torch.stack(stacked_point_cloud, dim=1)
237
+
238
+ return stacked_rgb, stacked_point_cloud
239
+
240
+ def update(self, step: int, replay_sample: dict) -> dict:
241
+ lang_goal_emb = replay_sample["lang_goal_emb"] # TODO use language
242
+ robot_state = replay_sample["low_dim_state"]
243
+
244
+ # preprocess input
245
+ qpos, action_seq = self.preprocess_action(replay_sample)
246
+ stacked_rgb, stacked_point_cloud = self.preprocess_images(replay_sample)
247
+ is_pad = replay_sample["is_pad"].bool()
248
+
249
+ # forward pass
250
+ loss_dict = self._actor(qpos, stacked_rgb, action_seq, is_pad)
251
+
252
+ # gradient step
253
+ loss = loss_dict["total_losses"]
254
+ loss.backward()
255
+ self._actor_optimizer.step()
256
+ self._actor_optimizer.zero_grad()
257
+
258
+ self._summaries = {
259
+ "loss": loss_dict["total_losses"],
260
+ "l1": loss_dict["l1"],
261
+ "right_l1": loss_dict["right_l1"],
262
+ "left_l1": loss_dict["left_l1"],
263
+ "kl": loss_dict["kl"],
264
+ }
265
+
266
+ return loss_dict
267
+
268
+ def _normalize_quat(self, x):
269
+ return x / x.square().sum(dim=1).sqrt().unsqueeze(-1)
270
+
271
+ def _normalize_revolute_joints(self, x):
272
+ # normalize joint angles
273
+ # input ranges from -pi to pi
274
+ # out ranges from 0 to 1
275
+ return (x + np.pi) / (2 * np.pi)
276
+
277
+ def _unnormalize_revolute_joints(self, x):
278
+ # map input with range 0 to 1 to -pi to pi
279
+ x = (x - 0.5) * 2.0 * np.pi
280
+ x = torch.clamp(x, -np.pi, np.pi)
281
+ return x
282
+
283
+ def _normalize_gripper_joints(self, x):
284
+ gripper_min = 0
285
+ gripper_max = 0.04
286
+ # normalize gripper joint angles between 0 and 1, the input ranges from 0 to 0.04
287
+ return (x - gripper_min) / (gripper_max - gripper_min)
288
+
289
+ def _unnormalize_gripper_joints(self, x):
290
+ gripper_min = 0
291
+ gripper_max = 0.04
292
+
293
+ x = x * (gripper_max - gripper_min) + gripper_min
294
+ x = torch.clamp(x, gripper_min, gripper_max)
295
+ return torch.unsqueeze(x, dim=0)
296
+
297
+ def act(self, step: int, observation: dict, deterministic=False) -> ActResult:
298
+ # lang_goal_tokens = observation.get('lang_goal_tokens', None).long()
299
+ # with torch.no_grad():
300
+ # lang_goal_tokens = lang_goal_tokens.to(device=self._device)
301
+ # lang_goal_emb, _ = self._clip_rn50.encode_text_with_embeddings(lang_goal_tokens[0])
302
+ # lang_goal_emb = lang_goal_emb.to(device=self._device)
303
+
304
+ action_horizon = self._actor.model.num_queries
305
+ query_freq = 1
306
+
307
+ stats = self.train_stats()
308
+
309
+ if self._timestep % query_freq == 0:
310
+ with torch.no_grad():
311
+ # preprocess input
312
+ qpos = self.preprocess_qpos(observation)
313
+ stacked_rgb, stacked_point_cloud = self.preprocess_images(observation)
314
+
315
+ # forward pass
316
+ self._all_actions = self._actor(
317
+ qpos, stacked_rgb, actions=None, is_pad=None
318
+ )
319
+
320
+ # temporal aggregation
321
+ t = self._timestep
322
+
323
+ self._all_time_actions[[t], t : t + action_horizon] = self._all_actions
324
+ actions_for_curr_step = self._all_time_actions[:, t]
325
+ actions_populated = torch.all(actions_for_curr_step != 0, axis=1)
326
+ actions_for_curr_step = actions_for_curr_step[actions_populated]
327
+ k = 0.01
328
+ exp_weights = np.exp(-k * np.arange(len(actions_for_curr_step)))
329
+ exp_weights = exp_weights / exp_weights.sum()
330
+ exp_weights = torch.from_numpy(exp_weights).to(self._device).unsqueeze(dim=1)
331
+ raw_action = (actions_for_curr_step * exp_weights).sum(dim=0, keepdim=True)
332
+ raw_action = raw_action[0]
333
+
334
+ right_a_rev = self.unnormalize_z(
335
+ raw_action[0:7], stats["right_joints_mean"], stats["right_joints_std"]
336
+ )
337
+ right_a_gripper = self.unnormalize_z(
338
+ raw_action[7], stats["right_gripper_mean"], stats["right_gripper_std"]
339
+ )
340
+
341
+ left_a_rev = self.unnormalize_z(
342
+ raw_action[8:15], stats["left_joints_mean"], stats["left_joints_std"]
343
+ )
344
+ left_a_gripper = self.unnormalize_z(
345
+ raw_action[15], stats["left_gripper_mean"], stats["left_gripper_std"]
346
+ )
347
+
348
+ raw_action = torch.cat(
349
+ [right_a_rev, right_a_gripper, left_a_rev, left_a_gripper], dim=-1
350
+ )
351
+
352
+ self._timestep += 1
353
+
354
+ return ActResult(raw_action.detach().cpu().numpy())
355
+
356
+ def update_summaries(self) -> List[Summary]:
357
+ summaries = []
358
+ for n, v in self._summaries.items():
359
+ summaries.append(ScalarSummary("%s/%s" % (NAME, n), v))
360
+
361
+ # for tag, param in self._actor.named_parameters():
362
+ # summaries.append(
363
+ #
364
+ # summaries.append(
365
+ # HistogramSummary('%s/weight/%s' % (NAME, tag), param.data))
366
+
367
+ return summaries
368
+
369
+ def act_summaries(self) -> List[Summary]:
370
+ return []
371
+
372
+ def load_weights(self, savedir: str):
373
+ self._actor.load_state_dict(
374
+ torch.load(
375
+ os.path.join(savedir, "bc_actor.pt"), map_location=torch.device("cpu")
376
+ )
377
+ )
378
+ print("Loaded weights from %s" % savedir)
379
+
380
+ def save_weights(self, savedir: str):
381
+ torch.save(self._actor.state_dict(), os.path.join(savedir, "bc_actor.pt"))
external/peract_bimanual/agents/act_bc_lang/act_policy.py ADDED
@@ -0,0 +1,135 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from torch.nn import functional as F
4
+ import torchvision.transforms as transforms
5
+
6
+ from agents.act_bc_lang.detr.build import (
7
+ build_ACT_model_and_optimizer,
8
+ build_CNNMLP_model_and_optimizer,
9
+ )
10
+
11
+
12
+ class ACTPolicy(nn.Module):
13
+ def __init__(self, args):
14
+ super().__init__()
15
+ model, optimizer = build_ACT_model_and_optimizer(args)
16
+ self.model = model # CVAE decoder
17
+ self.optimizer = optimizer
18
+ self.kl_weight = args.kl_weight
19
+ print(f"KL Weight {self.kl_weight}")
20
+
21
+ def forward(self, qpos, image, actions=None, is_pad=None):
22
+ env_state = None
23
+
24
+ if actions is not None: # training time
25
+ actions = actions[:, : self.model.num_queries]
26
+ is_pad = is_pad[:, : self.model.num_queries]
27
+
28
+ a_hat, is_pad_hat, (mu, logvar) = self.model(
29
+ qpos, image, env_state, actions, is_pad
30
+ )
31
+ total_kld, dim_wise_kld, mean_kld = kl_divergence(mu, logvar)
32
+ loss_dict = dict()
33
+
34
+ right_actions_joints, right_a_hat_joints = (
35
+ actions[:, :, 0:8],
36
+ a_hat[:, :, 0:8],
37
+ )
38
+ right_actions_gripper, right_a_hat_gripper = (
39
+ actions[:, :, 7],
40
+ a_hat[:, :, 7],
41
+ )
42
+ left_actions_joints, left_a_hat_joints = (
43
+ actions[:, :, 8:16],
44
+ a_hat[:, :, 8:16],
45
+ )
46
+ left_actions_gripper, left_a_hat_gripper = (
47
+ actions[:, :, 15],
48
+ a_hat[:, :, 15],
49
+ )
50
+
51
+ # use L1 loss for joints
52
+ right_l1_loss = F.l1_loss(
53
+ right_a_hat_joints, right_actions_joints, reduction="none"
54
+ )
55
+ right_l1 = (right_l1_loss * ~is_pad.unsqueeze(-1)).mean()
56
+
57
+ left_l1_loss = F.l1_loss(
58
+ left_a_hat_joints, left_actions_joints, reduction="none"
59
+ )
60
+ left_l1 = (left_l1_loss * ~is_pad.unsqueeze(-1)).mean()
61
+
62
+ l1 = right_l1 + left_l1
63
+
64
+ right_gripper_l1_loss = F.l1_loss(
65
+ right_a_hat_gripper, right_actions_gripper, reduction="none"
66
+ )
67
+ right_gripper_l1_loss = (right_gripper_l1_loss * ~is_pad).mean()
68
+
69
+ left_gripper_l1_loss = F.l1_loss(
70
+ left_a_hat_gripper, left_actions_gripper, reduction="none"
71
+ )
72
+ left_gripper_l1_loss = (left_gripper_l1_loss * ~is_pad).mean()
73
+
74
+ gripper_l1 = right_gripper_l1_loss + left_gripper_l1_loss
75
+ loss_dict["right_l1"] = right_l1
76
+ loss_dict["left_l1"] = left_l1
77
+
78
+ loss_dict["l1"] = l1
79
+ loss_dict["gripper_l1"] = gripper_l1
80
+
81
+ loss_dict["kl"] = total_kld[0]
82
+ loss_dict["total_losses"] = (
83
+ loss_dict["l1"] + loss_dict["kl"] * self.kl_weight
84
+ )
85
+ return loss_dict
86
+ else: # inference time
87
+ a_hat, _, (_, _) = self.model(
88
+ qpos, image, env_state
89
+ ) # no action, sample from prior
90
+ return a_hat
91
+
92
+ def configure_optimizers(self):
93
+ return self.optimizer
94
+
95
+
96
+ class CNNMLPPolicy(nn.Module):
97
+ def __init__(self, args):
98
+ super().__init__()
99
+ model, optimizer = build_CNNMLP_model_and_optimizer(args)
100
+ self.model = model # decoder
101
+ self.optimizer = optimizer
102
+
103
+ def forward(self, qpos, image, actions=None, is_pad=None):
104
+ env_state = None # TODO
105
+
106
+ if actions is not None: # training time
107
+ actions = actions[:, 0]
108
+ a_hat = self.model(qpos, image, env_state, actions)
109
+ mse = F.mse_loss(actions, a_hat)
110
+ loss_dict = dict()
111
+ loss_dict["mse"] = mse
112
+ loss_dict["loss"] = loss_dict["mse"]
113
+ return loss_dict
114
+ else: # inference time
115
+ a_hat = self.model(qpos, image, env_state) # no action, sample from prior
116
+ return a_hat
117
+
118
+ def configure_optimizers(self):
119
+ return self.optimizer
120
+
121
+
122
+ def kl_divergence(mu, logvar):
123
+ batch_size = mu.size(0)
124
+ assert batch_size != 0
125
+ if mu.data.ndimension() == 4:
126
+ mu = mu.view(mu.size(0), mu.size(1))
127
+ if logvar.data.ndimension() == 4:
128
+ logvar = logvar.view(logvar.size(0), logvar.size(1))
129
+
130
+ klds = -0.5 * (1 + logvar - mu.pow(2) - logvar.exp())
131
+ total_kld = klds.sum(1).mean(0, True)
132
+ dimension_wise_kld = klds.mean(0)
133
+ mean_kld = klds.mean(1).mean(0, True)
134
+
135
+ return total_kld, dimension_wise_kld, mean_kld
external/peract_bimanual/agents/act_bc_lang/detr/__init__.py ADDED
File without changes
external/peract_bimanual/agents/act_bc_lang/detr/build.py ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
2
+ import argparse
3
+ from pathlib import Path
4
+
5
+ import numpy as np
6
+ import torch
7
+ from .models import build_ACT_model, build_CNNMLP_model
8
+
9
+
10
+
11
+ def build_ACT_model_and_optimizer(args):
12
+ model = build_ACT_model(args)
13
+
14
+ param_dicts = [
15
+ {"params": [p for n, p in model.named_parameters() if "backbone" not in n and p.requires_grad]},
16
+ {
17
+ "params": [p for n, p in model.named_parameters() if "backbone" in n and p.requires_grad],
18
+ "lr": args.lr_backbone,
19
+ },
20
+ ]
21
+ optimizer = torch.optim.AdamW(param_dicts, lr=args.lr,
22
+ weight_decay=args.weight_decay)
23
+
24
+ return model, optimizer
25
+
26
+
27
+ def build_CNNMLP_model_and_optimizer(args):
28
+ model = build_CNNMLP_model(args)
29
+
30
+ param_dicts = [
31
+ {"params": [p for n, p in model.named_parameters() if "backbone" not in n and p.requires_grad]},
32
+ {
33
+ "params": [p for n, p in model.named_parameters() if "backbone" in n and p.requires_grad],
34
+ "lr": args.lr_backbone,
35
+ },
36
+ ]
37
+ optimizer = torch.optim.AdamW(param_dicts, lr=args.lr,
38
+ weight_decay=args.weight_decay)
39
+
40
+ return model, optimizer
41
+
external/peract_bimanual/agents/act_bc_lang/detr/util/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
external/peract_bimanual/agents/act_bc_lang/launch_utils.py ADDED
@@ -0,0 +1,456 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Adapted from ARM
2
+ # Source: https://github.com/stepjam/ARM
3
+ # License: https://github.com/stepjam/ARM/LICENSE
4
+
5
+ import logging
6
+ from typing import List
7
+
8
+ import numpy as np
9
+ from omegaconf import DictConfig
10
+ from rlbench.backend.observation import Observation
11
+ from rlbench.observation_config import ObservationConfig
12
+ import rlbench.utils as rlbench_utils
13
+ from rlbench.demo import Demo
14
+ from yarr.replay_buffer.prioritized_replay_buffer import (
15
+ PrioritizedReplayBuffer,
16
+ ObservationElement,
17
+ )
18
+ from yarr.replay_buffer.replay_buffer import ReplayElement, ReplayBuffer
19
+ from yarr.replay_buffer.uniform_replay_buffer import UniformReplayBuffer
20
+ from yarr.replay_buffer.task_uniform_replay_buffer import TaskUniformReplayBuffer
21
+
22
+ from helpers import utils
23
+ from helpers import observation_utils
24
+ from agents.act_bc_lang.act_bc_lang_agent import ActBCLangAgent
25
+ from helpers.custom_rlbench_env import CustomRLBenchEnv
26
+ from helpers.preprocess_agent import PreprocessAgent
27
+ from agents.act_bc_lang.act_policy import ACTPolicy, CNNMLPPolicy
28
+
29
+ import torch
30
+ from torch.multiprocessing import Process, Value, Manager
31
+ from helpers.clip.core.clip import build_model, load_clip, tokenize
32
+
33
+ LOW_DIM_SIZE = 8
34
+
35
+
36
+ def create_replay(
37
+ batch_size: int,
38
+ timesteps: int,
39
+ prioritisation: bool,
40
+ task_uniform: bool,
41
+ save_dir: str,
42
+ cameras: list,
43
+ image_size=[128, 128],
44
+ replay_size=3e5,
45
+ prev_action_horizon: int = 1,
46
+ next_action_horizon: int = 1,
47
+ ):
48
+ lang_feat_dim = 1024
49
+
50
+ # low_dim_state
51
+ observation_elements = []
52
+ observation_elements.append(
53
+ ObservationElement("low_dim_state", (LOW_DIM_SIZE,), np.float32)
54
+ )
55
+
56
+ # action sequences
57
+ action_seq_sizes = {
58
+ "right_prev_joint_positions": 7,
59
+ "right_prev_gripper_joint_positions": 2,
60
+ "right_prev_gripper_poses": 7,
61
+ "right_next_joint_positions": 7,
62
+ "right_next_gripper_joint_positions": 2,
63
+ "right_next_gripper_poses": 7,
64
+ "left_prev_joint_positions": 7,
65
+ "left_prev_gripper_joint_positions": 2,
66
+ "left_prev_gripper_poses": 7,
67
+ "left_next_joint_positions": 7,
68
+ "left_next_gripper_joint_positions": 2,
69
+ "left_next_gripper_poses": 7,
70
+ }
71
+
72
+ for seq_name, seq_size in action_seq_sizes.items():
73
+ horizon = prev_action_horizon if "prev" in seq_name else next_action_horizon
74
+ observation_elements.append(
75
+ ObservationElement(
76
+ seq_name,
77
+ (
78
+ horizon,
79
+ seq_size,
80
+ ),
81
+ np.float32,
82
+ )
83
+ )
84
+
85
+ # action is_pad
86
+ observation_elements.append(
87
+ ObservationElement("is_pad", (next_action_horizon,), np.int32)
88
+ )
89
+
90
+ # rgb, depth, point cloud, intrinsics, extrinsics
91
+ for cname in cameras:
92
+ observation_elements.append(
93
+ ObservationElement(
94
+ "%s_rgb" % cname,
95
+ (
96
+ 3,
97
+ *image_size,
98
+ ),
99
+ np.float32,
100
+ )
101
+ )
102
+ observation_elements.append(
103
+ ObservationElement("%s_point_cloud" % cname, (3, *image_size), np.float32)
104
+ ) # see pyrep/objects/vision_sensor.py on how pointclouds are extracted from depth frames
105
+ observation_elements.append(
106
+ ObservationElement(
107
+ "%s_camera_extrinsics" % cname,
108
+ (
109
+ 4,
110
+ 4,
111
+ ),
112
+ np.float32,
113
+ )
114
+ )
115
+ observation_elements.append(
116
+ ObservationElement(
117
+ "%s_camera_intrinsics" % cname,
118
+ (
119
+ 3,
120
+ 3,
121
+ ),
122
+ np.float32,
123
+ )
124
+ )
125
+
126
+ observation_elements.extend(
127
+ [
128
+ ReplayElement("lang_goal_emb", (lang_feat_dim,), np.float32),
129
+ ReplayElement("task", (), str),
130
+ ReplayElement(
131
+ "lang_goal", (1,), object
132
+ ), # language goal string for debugging and visualization
133
+ ]
134
+ )
135
+
136
+ extra_replay_elements = [
137
+ ReplayElement("demo", (), bool),
138
+ ]
139
+
140
+ replay_buffer = TaskUniformReplayBuffer(
141
+ save_dir=save_dir,
142
+ batch_size=batch_size,
143
+ timesteps=timesteps,
144
+ replay_capacity=int(replay_size),
145
+ action_shape=(8 * 2,),
146
+ action_dtype=np.float32,
147
+ reward_shape=(),
148
+ reward_dtype=np.float32,
149
+ update_horizon=1,
150
+ observation_elements=observation_elements,
151
+ extra_replay_elements=extra_replay_elements,
152
+ )
153
+ return replay_buffer
154
+
155
+
156
+ def _get_action(obs_tp1: Observation):
157
+ quat = utils.normalize_quaternion(obs_tp1.gripper_pose[3:])
158
+ if quat[-1] < 0:
159
+ quat = -quat
160
+ return np.concatenate(
161
+ [obs_tp1.gripper_pose[:3], quat, [float(obs_tp1.gripper_open)]]
162
+ )
163
+
164
+
165
+ def _get_action_seq(
166
+ demo: Demo,
167
+ timestep: int,
168
+ prev_action_horizon: int,
169
+ next_action_horizon: int,
170
+ robot_name: str,
171
+ ):
172
+ action_seq = {
173
+ "right_prev_joint_positions": [],
174
+ "right_prev_gripper_joint_positions": [],
175
+ "right_prev_gripper_poses": [],
176
+ "left_prev_joint_positions": [],
177
+ "left_prev_gripper_joint_positions": [],
178
+ "left_prev_gripper_poses": [],
179
+ "right_next_joint_positions": [],
180
+ "right_next_gripper_joint_positions": [],
181
+ "right_next_gripper_poses": [],
182
+ "left_next_joint_positions": [],
183
+ "left_next_gripper_joint_positions": [],
184
+ "left_next_gripper_poses": [],
185
+ "is_pad": [],
186
+ }
187
+
188
+ for prev_t in list(reversed(range(prev_action_horizon))):
189
+ t = timestep - prev_t
190
+ t = max(0, t)
191
+ obs = demo[t]
192
+
193
+ action_seq["right_prev_joint_positions"].append(obs.right.joint_positions)
194
+ action_seq["right_prev_gripper_joint_positions"].append(
195
+ obs.right.gripper_joint_positions
196
+ )
197
+ action_seq["right_prev_gripper_poses"].append(obs.right.gripper_pose)
198
+ action_seq["left_prev_joint_positions"].append(obs.left.joint_positions)
199
+ action_seq["left_prev_gripper_joint_positions"].append(
200
+ obs.left.gripper_joint_positions
201
+ )
202
+ action_seq["left_prev_gripper_poses"].append(obs.left.gripper_pose)
203
+
204
+ action_seq["is_pad"] = np.zeros(next_action_horizon)
205
+ for idx, next_t in enumerate(range(0, next_action_horizon)):
206
+ t = timestep + next_t
207
+ t = min(t, len(demo) - 1)
208
+ obs = demo[t]
209
+
210
+ if timestep + next_t > len(demo) - 1:
211
+ action_seq["is_pad"][idx] = 1
212
+
213
+ action_seq["right_next_joint_positions"].append(obs.right.joint_positions)
214
+ action_seq["right_next_gripper_joint_positions"].append(
215
+ obs.right.gripper_joint_positions
216
+ )
217
+ action_seq["right_next_gripper_poses"].append(obs.right.gripper_pose)
218
+ action_seq["left_next_joint_positions"].append(obs.left.joint_positions)
219
+ action_seq["left_next_gripper_joint_positions"].append(
220
+ obs.left.gripper_joint_positions
221
+ )
222
+ action_seq["left_next_gripper_poses"].append(obs.left.gripper_pose)
223
+
224
+ # convert to numpy arrays
225
+ return {k: np.array(v) for k, v in action_seq.items()}
226
+
227
+
228
+ def _add_keypoints_to_replay(
229
+ step: int,
230
+ cfg: DictConfig,
231
+ task: str,
232
+ replay: ReplayBuffer,
233
+ inital_obs: Observation,
234
+ demo: Demo,
235
+ description: str = "",
236
+ clip_model=None,
237
+ device="cpu",
238
+ ):
239
+ cameras = cfg.rlbench.cameras
240
+ robot_name = cfg.method.robot_name
241
+
242
+ prev_action = None
243
+ obs = inital_obs
244
+ all_actions = []
245
+ k = step
246
+ k_tp1 = min(k + 1, len(demo) - 1)
247
+ obs_tp1 = demo[k_tp1]
248
+
249
+ if obs_tp1.is_bimanual and robot_name == "bimanual":
250
+ right_action = _get_action(obs_tp1.right)
251
+ left_action = _get_action(obs_tp1.left)
252
+ action = np.append(right_action, left_action)
253
+ elif robot_name == "unimanual":
254
+ action = _get_action(obs_tp1)
255
+ elif obs_tp1.is_bimanual and robot_name == "right":
256
+ action = _get_action(obs_tp1.right)
257
+ elif obs_tp1.is_bimanual and robot_name == "left":
258
+ action = _get_action(obs_tp1.left)
259
+ else:
260
+ logging.error("Invalid robot name %s", cfg.method.robot_name)
261
+ raise Exception("Invalid robot name.")
262
+
263
+ all_actions.append(action)
264
+
265
+ terminal = k == len(demo) - 1
266
+ reward = float(terminal) if terminal else 0
267
+
268
+ obs_dict = observation_utils.extract_obs(
269
+ obs,
270
+ t=k,
271
+ prev_action=prev_action,
272
+ cameras=cameras,
273
+ episode_length=cfg.rlbench.episode_length,
274
+ robot_name=robot_name,
275
+ )
276
+
277
+ if obs_tp1.is_bimanual and robot_name == "bimanual":
278
+ obs_dict["low_dim_state"] = np.concatenate(
279
+ [obs_dict["right_low_dim_state"], obs_dict["left_low_dim_state"]]
280
+ )
281
+ del obs_dict["right_low_dim_state"]
282
+ del obs_dict["left_low_dim_state"]
283
+ del obs_dict["right_ignore_collisions"]
284
+ del obs_dict["left_ignore_collisions"]
285
+ else:
286
+ del obs_dict["ignore_collisions"]
287
+
288
+ tokens = tokenize([description]).numpy()
289
+ token_tensor = torch.from_numpy(tokens).to(device)
290
+ lang_feats, lang_embs = clip_model.encode_text_with_embeddings(token_tensor)
291
+ obs_dict["lang_goal_emb"] = lang_feats[0].float().detach().cpu().numpy()
292
+
293
+ final_obs = {
294
+ "task": task,
295
+ "lang_goal": np.array([description], dtype=object),
296
+ }
297
+
298
+ action_seq = _get_action_seq(
299
+ demo,
300
+ step,
301
+ cfg.method.prev_action_horizon,
302
+ cfg.method.next_action_horizon,
303
+ robot_name,
304
+ )
305
+ obs_dict.update(action_seq)
306
+
307
+ prev_action = np.copy(action)
308
+ others = {"demo": True}
309
+ others.update(final_obs)
310
+ others.update(obs_dict)
311
+ timeout = False
312
+ replay.add(action, reward, terminal, timeout, **others)
313
+
314
+ return all_actions
315
+
316
+
317
+ def fill_replay(
318
+ cfg: DictConfig,
319
+ obs_config: ObservationConfig,
320
+ rank: int,
321
+ replay: ReplayBuffer,
322
+ task: str,
323
+ num_demos: int,
324
+ demo_augmentation: bool,
325
+ demo_augmentation_every_n: int,
326
+ cameras: List[str],
327
+ clip_model=None,
328
+ device="cpu",
329
+ ):
330
+ if clip_model is None:
331
+ model, _ = load_clip("RN50", jit=False, device=device)
332
+ clip_model = build_model(model.state_dict())
333
+ clip_model.to(device)
334
+ del model
335
+
336
+ logging.debug("Filling %s replay ..." % task)
337
+ all_actions = []
338
+ for d_idx in range(num_demos):
339
+ # load demo from disk
340
+ demo = rlbench_utils.get_stored_demos(
341
+ amount=1,
342
+ image_paths=False,
343
+ dataset_root=cfg.rlbench.demo_path,
344
+ variation_number=-1,
345
+ task_name=task,
346
+ obs_config=obs_config,
347
+ random_selection=False,
348
+ from_episode_number=d_idx,
349
+ )[0]
350
+
351
+ descs = demo._observations[0].misc["descriptions"]
352
+
353
+ if rank == 0:
354
+ logging.info(f"Loading Demo({d_idx})")
355
+
356
+ for i in range(len(demo) - 1):
357
+ obs = demo[i]
358
+ desc = descs[0]
359
+
360
+ # stopped = np.allclose(obs.joint_velocities, 0, atol=0.1)
361
+ # if stopped:
362
+ # continue
363
+
364
+ all_actions.extend(
365
+ _add_keypoints_to_replay(
366
+ i,
367
+ cfg,
368
+ task,
369
+ replay,
370
+ obs,
371
+ demo,
372
+ description=desc,
373
+ clip_model=clip_model,
374
+ device=device,
375
+ )
376
+ )
377
+ logging.debug("Replay filled with demos.")
378
+ return all_actions
379
+
380
+
381
+ def fill_multi_task_replay(
382
+ cfg: DictConfig,
383
+ obs_config: ObservationConfig,
384
+ rank: int,
385
+ replay: ReplayBuffer,
386
+ tasks: List[str],
387
+ num_demos: int,
388
+ demo_augmentation: bool,
389
+ demo_augmentation_every_n: int,
390
+ cameras: List[str],
391
+ clip_model=None,
392
+ ):
393
+ manager = Manager()
394
+ store = manager.dict()
395
+
396
+ # create a MP dict for storing indicies
397
+ # TODO(mohit): this shouldn't be initialized here
398
+ del replay._task_idxs
399
+ task_idxs = manager.dict()
400
+ replay._task_idxs = task_idxs
401
+ replay._create_storage(store)
402
+ replay.add_count = Value("i", 0)
403
+
404
+ # fill replay buffer in parallel across tasks
405
+ max_parallel_processes = cfg.replay.max_parallel_processes
406
+ processes = []
407
+ n = np.arange(len(tasks))
408
+ split_n = utils.split_list(n, max_parallel_processes)
409
+ for split in split_n:
410
+ for e_idx, task_idx in enumerate(split):
411
+ task = tasks[int(task_idx)]
412
+ model_device = torch.device(
413
+ "cuda:%s" % (e_idx % torch.cuda.device_count())
414
+ if torch.cuda.is_available()
415
+ else "cpu"
416
+ )
417
+ p = Process(
418
+ target=fill_replay,
419
+ args=(
420
+ cfg,
421
+ obs_config,
422
+ rank,
423
+ replay,
424
+ task,
425
+ num_demos,
426
+ demo_augmentation,
427
+ demo_augmentation_every_n,
428
+ cameras,
429
+ clip_model,
430
+ model_device,
431
+ ),
432
+ )
433
+ p.start()
434
+ processes.append(p)
435
+
436
+ for p in processes:
437
+ p.join()
438
+
439
+ logging.debug("Replay filled with multi demos.")
440
+
441
+
442
+ def create_agent(cfg: DictConfig):
443
+ actor_net = ACTPolicy(cfg.method)
444
+
445
+ bc_agent = ActBCLangAgent(
446
+ actor_network=actor_net,
447
+ camera_names=cfg.rlbench.cameras,
448
+ lr=cfg.method.lr,
449
+ weight_decay=cfg.method.weight_decay,
450
+ grad_clip=cfg.method.grad_clip,
451
+ episode_length=cfg.rlbench.episode_length,
452
+ train_demo_path=cfg.method.train_demo_path,
453
+ task_name=cfg.rlbench.tasks[0],
454
+ )
455
+
456
+ return PreprocessAgent(pose_agent=bc_agent, norm_type="imagenet")
external/peract_bimanual/agents/agent_factory.py ADDED
@@ -0,0 +1,111 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import logging
3
+
4
+ from omegaconf import DictConfig
5
+
6
+
7
+ from yarr.agents.agent import BimanualAgent
8
+ from yarr.agents.agent import LeaderFollowerAgent
9
+ from yarr.agents.agent import Agent
10
+
11
+
12
+ supported_agents = {
13
+ "leader_follower": ("PERACT_BC", "RVT"),
14
+ "independent": ("PERACT_BC", "RVT"),
15
+ "bimanual": ("BIMANUAL_PERACT", "ACT_BC_LANG"),
16
+ "unimanual": (),
17
+ }
18
+
19
+
20
+ def create_agent(cfg: DictConfig) -> Agent:
21
+ method_name = cfg.method.name
22
+ agent_type = cfg.method.agent_type
23
+
24
+ logging.info("Using method %s with type %s", method_name, agent_type)
25
+
26
+ assert method_name in supported_agents[agent_type]
27
+
28
+ agent_fn = agent_fn_by_name(method_name)
29
+
30
+ if agent_type == "leader_follower":
31
+ checkpoint_name_prefix = cfg.framework.checkpoint_name_prefix
32
+ cfg.method.robot_name = "right"
33
+ cfg.framework.checkpoint_name_prefix = (
34
+ f"{checkpoint_name_prefix}_{method_name.lower()}_leader"
35
+ )
36
+ leader_agent = agent_fn(cfg)
37
+
38
+ cfg.method.robot_name = "left"
39
+ cfg.framework.checkpoint_name_prefix = (
40
+ f"{checkpoint_name_prefix}_{method_name.lower()}_follower"
41
+ )
42
+ cfg.method.low_dim_size = (
43
+ cfg.method.low_dim_size + 8
44
+ ) # also add the action size
45
+ follower_agent = agent_fn(cfg)
46
+
47
+ cfg.method.robot_name = "bimanual"
48
+
49
+ return LeaderFollowerAgent(leader_agent, follower_agent)
50
+
51
+ elif agent_type == "independent":
52
+ checkpoint_name_prefix = cfg.framework.checkpoint_name_prefix
53
+ cfg.method.robot_name = "right"
54
+ cfg.framework.checkpoint_name_prefix = (
55
+ f"{checkpoint_name_prefix}_{method_name.lower()}_right"
56
+ )
57
+ right_agent = agent_fn(cfg)
58
+
59
+ cfg.method.robot_name = "left"
60
+ cfg.framework.checkpoint_name_prefix = (
61
+ f"{checkpoint_name_prefix}_{method_name.lower()}_left"
62
+ )
63
+ left_agent = agent_fn(cfg)
64
+
65
+ cfg.method.robot_name = "bimanual"
66
+
67
+ return BimanualAgent(right_agent, left_agent)
68
+ elif agent_type == "bimanual" or agent_type == "unimanual":
69
+ return agent_fn(cfg)
70
+ else:
71
+ raise Exception("invalid agent type")
72
+
73
+
74
+ def agent_fn_by_name(method_name: str) -> Agent:
75
+ if method_name == "ARM":
76
+ from agents import arm
77
+
78
+ raise NotImplementedError("ARM not yet supported for eval.py")
79
+ elif method_name == "BC_LANG":
80
+ from agents.baselines import bc_lang
81
+
82
+ return bc_lang.launch_utils.create_agent
83
+ elif method_name == "VIT_BC_LANG":
84
+ from agents.baselines import vit_bc_lang
85
+
86
+ return vit_bc_lang.launch_utils.create_agent
87
+ elif method_name == "C2FARM_LINGUNET_BC":
88
+ from agents import c2farm_lingunet_bc
89
+
90
+ return c2farm_lingunet_bc.launch_utils.create_agent
91
+ elif method_name.startswith("PERACT_BC"):
92
+ from agents import peract_bc
93
+
94
+ return peract_bc.launch_utils.create_agent
95
+ elif method_name.startswith("BIMANUAL_PERACT"):
96
+ from agents import bimanual_peract
97
+
98
+ return bimanual_peract.launch_utils.create_agent
99
+ elif method_name.startswith("RVT"):
100
+ from agents import rvt
101
+
102
+ return rvt.launch_utils.create_agent
103
+ elif method_name.startswith("ACT_BC_LANG"):
104
+ from agents import act_bc_lang
105
+
106
+ return act_bc_lang.launch_utils.create_agent
107
+ elif method_name == "PERACT_RL":
108
+ raise NotImplementedError("PERACT_RL not yet supported for eval.py")
109
+
110
+ else:
111
+ raise ValueError("Method %s does not exists." % method_name)
external/peract_bimanual/agents/arm/launch_utils.py ADDED
@@ -0,0 +1,441 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import copy
2
+ import logging
3
+ from typing import List
4
+
5
+ import numpy as np
6
+ import torch
7
+ import torch.nn as nn
8
+ from rlbench.backend.observation import Observation
9
+ from rlbench.demo import Demo
10
+ from yarr.replay_buffer.prioritized_replay_buffer import (
11
+ PrioritizedReplayBuffer,
12
+ ObservationElement,
13
+ )
14
+ from yarr.replay_buffer.replay_buffer import ReplayElement, ReplayBuffer
15
+ from yarr.replay_buffer.uniform_replay_buffer import UniformReplayBuffer
16
+
17
+ from helpers import demo_loading_utils, utils
18
+ from helpers.custom_rlbench_env import CustomRLBenchEnv
19
+ from helpers.network_utils import (
20
+ SiameseNet,
21
+ DenseBlock,
22
+ Conv2DBlock,
23
+ Conv2DUpsampleBlock,
24
+ )
25
+ from helpers.preprocess_agent import PreprocessAgent
26
+ from agents.arm.next_best_pose_agent import NextBestPoseAgent
27
+ from agents.arm.qattention_agent import QAttentionAgent
28
+
29
+ REWARD_SCALE = 100.0
30
+
31
+
32
+ def create_replay(
33
+ batch_size: int,
34
+ timesteps: int,
35
+ prioritisation: bool,
36
+ save_dir: str,
37
+ cameras: list,
38
+ env: CustomRLBenchEnv,
39
+ ):
40
+ observation_elements = env.observation_elements
41
+ for cname in cameras:
42
+ observation_elements.extend(
43
+ [
44
+ ObservationElement("%s_pixel_coord" % cname, (2,), np.int32),
45
+ ]
46
+ )
47
+
48
+ replay_class = UniformReplayBuffer
49
+ if prioritisation:
50
+ replay_class = PrioritizedReplayBuffer
51
+ replay_buffer = replay_class(
52
+ save_dir=save_dir,
53
+ batch_size=batch_size,
54
+ timesteps=timesteps,
55
+ replay_capacity=int(1e5),
56
+ action_shape=(8,),
57
+ action_dtype=np.float32,
58
+ reward_shape=(),
59
+ reward_dtype=np.float32,
60
+ update_horizon=1,
61
+ observation_elements=observation_elements,
62
+ extra_replay_elements=[ReplayElement("demo", (), np.bool)],
63
+ )
64
+ return replay_buffer
65
+
66
+
67
+ def _point_to_pixel_index(
68
+ point: np.ndarray, extrinsics: np.ndarray, intrinsics: np.ndarray
69
+ ):
70
+ point = np.array([point[0], point[1], point[2], 1])
71
+ world_to_cam = np.linalg.inv(extrinsics)
72
+ point_in_cam_frame = world_to_cam.dot(point)
73
+ px, py, pz = point_in_cam_frame[:3]
74
+ px = 2 * intrinsics[0, 2] - int(-intrinsics[0, 0] * (px / pz) + intrinsics[0, 2])
75
+ py = 2 * intrinsics[1, 2] - int(-intrinsics[1, 1] * (py / pz) + intrinsics[1, 2])
76
+ return px, py
77
+
78
+
79
+ def _get_action(obs_tp1: Observation):
80
+ quat = utils.normalize_quaternion(obs_tp1.gripper_pose[3:])
81
+ if quat[-1] < 0:
82
+ quat = -quat
83
+ return np.concatenate(
84
+ [obs_tp1.gripper_pose[:3], quat, [float(obs_tp1.gripper_open)]]
85
+ )
86
+
87
+
88
+ def _add_keypoints_to_replay(
89
+ replay: ReplayBuffer,
90
+ inital_obs: Observation,
91
+ demo: Demo,
92
+ env: CustomRLBenchEnv,
93
+ episode_keypoints: List[int],
94
+ cameras: List[str],
95
+ ):
96
+ prev_action = None
97
+ obs = inital_obs
98
+ all_actions = []
99
+ for k, keypoint in enumerate(episode_keypoints):
100
+ obs_tp1 = demo[keypoint]
101
+ action = _get_action(obs_tp1)
102
+ all_actions.append(action)
103
+ terminal = k == len(episode_keypoints) - 1
104
+ reward = float(terminal) * REWARD_SCALE if terminal else 0
105
+ obs_dict = env.extract_obs(obs, t=k, prev_action=prev_action)
106
+ prev_action = np.copy(action)
107
+ others = {"demo": True}
108
+ final_obs = {}
109
+ for name in cameras:
110
+ px, py = _point_to_pixel_index(
111
+ obs_tp1.gripper_pose[:3],
112
+ obs_tp1.misc["%s_camera_extrinsics" % name],
113
+ obs_tp1.misc["%s_camera_intrinsics" % name],
114
+ )
115
+ final_obs["%s_pixel_coord" % name] = [py, px]
116
+ others.update(final_obs)
117
+ others.update(obs_dict)
118
+ timeout = False
119
+ replay.add(action, reward, terminal, timeout, **others)
120
+ obs = obs_tp1 # Set the next obs
121
+ # Final step
122
+ obs_dict_tp1 = env.extract_obs(obs_tp1, t=k + 1, prev_action=prev_action)
123
+ obs_dict_tp1.update(final_obs)
124
+ replay.add_final(**obs_dict_tp1)
125
+ return all_actions
126
+
127
+
128
+ def fill_replay(
129
+ replay: ReplayBuffer,
130
+ task: str,
131
+ env: CustomRLBenchEnv,
132
+ num_demos: int,
133
+ demo_augmentation: bool,
134
+ demo_augmentation_every_n: int,
135
+ cameras: List[str],
136
+ ):
137
+ logging.info("Filling replay with demos...")
138
+ all_actions = []
139
+ for d_idx in range(num_demos):
140
+ demo = env.env.get_demos(
141
+ task,
142
+ 1,
143
+ variation_number=0,
144
+ random_selection=False,
145
+ from_episode_number=d_idx,
146
+ )[0]
147
+ episode_keypoints = demo_loading_utils.keypoint_discovery(demo)
148
+
149
+ for i in range(len(demo) - 1):
150
+ if not demo_augmentation and i > 0:
151
+ break
152
+ if i % demo_augmentation_every_n != 0:
153
+ continue
154
+ obs = demo[i]
155
+ # If our starting point is past one of the keypoints, then remove it
156
+ while len(episode_keypoints) > 0 and i >= episode_keypoints[0]:
157
+ episode_keypoints = episode_keypoints[1:]
158
+ if len(episode_keypoints) == 0:
159
+ break
160
+ all_actions.extend(
161
+ _add_keypoints_to_replay(
162
+ replay, obs, demo, env, episode_keypoints, cameras
163
+ )
164
+ )
165
+ logging.info("Replay filled with demos.")
166
+ return all_actions
167
+
168
+
169
+ class SharedNet(nn.Module):
170
+ def __init__(self, activation: str, norm: str = None):
171
+ super(SharedNet, self).__init__()
172
+ self._activation = activation
173
+ self._norm = norm
174
+
175
+ def build(self):
176
+ self._rgb_pre = nn.Sequential(
177
+ Conv2DBlock(3, 32, 3, 1, activation=self._activation, norm=self._norm),
178
+ )
179
+ self._pcd_pre = nn.Sequential(
180
+ Conv2DBlock(3, 32, 3, 1, activation=self._activation, norm=self._norm),
181
+ )
182
+
183
+ def forward(self, observations):
184
+ x_rgb, x_pcd = self._rgb_pre(observations[0]), self._pcd_pre(observations[1])
185
+ x = torch.cat([x_rgb, x_pcd], dim=1)
186
+ return x
187
+
188
+
189
+ class ActorNet(nn.Module):
190
+ def __init__(self, activation: str, low_dim_size: int, norm: str = None):
191
+ super(ActorNet, self).__init__()
192
+ self._activation = activation
193
+ self._low_dim_size = low_dim_size
194
+ self._norm = norm
195
+
196
+ def build(self):
197
+ self._convs = nn.Sequential(
198
+ Conv2DBlock(
199
+ 64 + self._low_dim_size,
200
+ 64,
201
+ 1,
202
+ 1,
203
+ activation=self._activation,
204
+ norm=self._norm,
205
+ ),
206
+ Conv2DBlock(64, 64, 3, 1, activation=self._activation, norm=self._norm),
207
+ )
208
+ self._fcs = nn.Sequential(
209
+ DenseBlock(64, 64, activation=self._activation),
210
+ DenseBlock(64, 64, activation=self._activation),
211
+ DenseBlock(64, 8 * 2),
212
+ )
213
+ self._maxp = nn.AdaptiveMaxPool2d(1)
214
+
215
+ def forward(self, observation_feats, low_dim_ins):
216
+ low_dim_feats = low_dim_ins
217
+ _, _, h, w = observation_feats.shape
218
+ low_dim_feats = low_dim_feats.unsqueeze(-1).unsqueeze(-1).repeat(1, 1, h, w)
219
+ x = torch.cat([observation_feats, low_dim_feats], dim=1)
220
+ x = self._convs(x)
221
+ x = self._maxp(x).squeeze(-1).squeeze(-1)
222
+ x = self._fcs(x)
223
+ return x
224
+
225
+
226
+ class CriticNet(nn.Module):
227
+ def __init__(
228
+ self, activation: str, low_dim_size: int, norm: str = None, q_conf: bool = True
229
+ ):
230
+ super(CriticNet, self).__init__()
231
+ self._activation = activation
232
+ self._low_dim_size = low_dim_size
233
+ self._norm = norm
234
+ self._q_conf = q_conf
235
+
236
+ def build(self):
237
+ self._convs = nn.Sequential(
238
+ Conv2DBlock(
239
+ 64 + self._low_dim_size, 128, 3, 1, self._norm, self._activation
240
+ ),
241
+ Conv2DBlock(128, 128, 3, 1, self._norm, self._activation),
242
+ Conv2DBlock(128, 128, 3, 1, self._norm, self._activation),
243
+ Conv2DBlock(128, 128, 3, 1, self._norm, self._activation),
244
+ )
245
+ if self._q_conf:
246
+ self._final_conv = Conv2DBlock(128, 2, 3, 1)
247
+ else:
248
+ self._maxp = nn.AdaptiveMaxPool2d(1)
249
+ self._fcs = nn.Sequential(
250
+ DenseBlock(128, 64, activation=self._activation),
251
+ DenseBlock(64, 1),
252
+ )
253
+
254
+ def forward(self, observation_feats, low_dim_ins):
255
+ low_dim_feats = low_dim_ins
256
+ _, _, h, w = observation_feats.shape
257
+ low_dim_feats = low_dim_feats.unsqueeze(-1).unsqueeze(-1).repeat(1, 1, h, w)
258
+ x = torch.cat([observation_feats, low_dim_feats], dim=1)
259
+ x = self._convs(x)
260
+ if self._q_conf:
261
+ x = self._final_conv(x)
262
+ x[:, 1] = torch.sigmoid(x[:, 1])
263
+ else:
264
+ x = self._maxp(x).squeeze(-1).squeeze(-1)
265
+ x = self._fcs(x)
266
+ return x
267
+
268
+
269
+ class Qattention2DNet(nn.Module):
270
+ def __init__(
271
+ self,
272
+ siamese_net: SiameseNet,
273
+ filters: List[int],
274
+ kernel_sizes: List[int],
275
+ strides: List[int],
276
+ low_dim_state_len: int,
277
+ norm: str = None,
278
+ activation: str = "relu",
279
+ output_channels: int = 1,
280
+ skip_connections: bool = True,
281
+ ):
282
+ super(Qattention2DNet, self).__init__()
283
+ self._siamese_net = copy.deepcopy(siamese_net)
284
+ self._input_channels = self._siamese_net.output_channels + low_dim_state_len
285
+ self._filters = filters
286
+ self._kernel_sizes = kernel_sizes
287
+ self._strides = strides
288
+ self._norm = norm
289
+ self._activation = activation
290
+ self._output_channels = output_channels
291
+ self._skip_connections = skip_connections
292
+ self._build_calls = 0
293
+
294
+ def build(self):
295
+ self._build_calls += 1
296
+ if self._build_calls != 1:
297
+ raise RuntimeError("Build needs to be called once.")
298
+ self._siamese_net.build()
299
+ self._down = []
300
+ ch = self._input_channels
301
+ for filt, ksize, stride in zip(
302
+ self._filters, self._kernel_sizes, self._strides
303
+ ):
304
+ conv_block = Conv2DBlock(
305
+ ch,
306
+ filt,
307
+ ksize,
308
+ stride,
309
+ self._norm,
310
+ self._activation,
311
+ padding_mode="replicate",
312
+ )
313
+ ch = filt
314
+ self._down.append(conv_block)
315
+ self._down = nn.ModuleList(self._down)
316
+
317
+ reverse_conv_data = list(zip(self._filters, self._kernel_sizes, self._strides))
318
+ reverse_conv_data.reverse()
319
+
320
+ self._up = []
321
+ for i, (filt, ksize, stride) in enumerate(reverse_conv_data):
322
+ if i > 0 and self._skip_connections:
323
+ ch += reverse_conv_data[-i - 1][0]
324
+ convt_block = Conv2DUpsampleBlock(
325
+ ch, filt, ksize, stride, self._norm, self._activation
326
+ )
327
+ ch = filt
328
+ self._up.append(convt_block)
329
+ self._up = nn.ModuleList(self._up)
330
+
331
+ self._final_conv = Conv2DBlock(
332
+ ch, self._output_channels, 3, 1, padding_mode="replicate"
333
+ )
334
+
335
+ def forward(self, observations, low_dim_ins):
336
+ x = self._siamese_net(observations)
337
+ _, _, h, w = x.shape
338
+ if low_dim_ins is not None:
339
+ low_dim_latents = low_dim_ins.unsqueeze(-1).unsqueeze(-1).repeat(1, 1, h, w)
340
+ x = torch.cat([x, low_dim_latents], dim=1)
341
+ self.ups = []
342
+ self.downs = []
343
+ layers_for_skip = []
344
+ for l in self._down:
345
+ x = l(x)
346
+ layers_for_skip.append(x)
347
+ self.downs.append(x)
348
+ self.latent = x
349
+ layers_for_skip.reverse()
350
+ for i, l in enumerate(self._up):
351
+ if i > 0 and self._skip_connections:
352
+ # Skip connections. Skip the first up layer.
353
+ x = torch.cat([layers_for_skip[i], x], 1)
354
+ x = l(x)
355
+ self.ups.append(x)
356
+ x = self._final_conv(x)
357
+ return x
358
+
359
+
360
+ def create_agent(
361
+ camera_name: str,
362
+ activation: str,
363
+ q_conf: bool,
364
+ action_min_max,
365
+ alpha,
366
+ alpha_lr,
367
+ alpha_auto_tune,
368
+ critic_lr,
369
+ actor_lr,
370
+ next_best_pose_critic_weight_decay,
371
+ next_best_pose_actor_weight_decay,
372
+ crop_shape,
373
+ next_best_pose_tau,
374
+ next_best_pose_critic_grad_clip,
375
+ next_best_pose_actor_grad_clip,
376
+ qattention_tau,
377
+ qattention_lr,
378
+ qattention_weight_decay,
379
+ qattention_lambda_qreg,
380
+ low_dim_state_len,
381
+ qattention_grad_clip,
382
+ ):
383
+ siamese_net = SiameseNet(
384
+ input_channels=[3, 3],
385
+ filters=[8],
386
+ kernel_sizes=[5],
387
+ strides=[1],
388
+ activation=activation,
389
+ norm=None,
390
+ )
391
+ qattention_net = Qattention2DNet(
392
+ siamese_net=siamese_net,
393
+ filters=[16, 16],
394
+ kernel_sizes=[5, 5],
395
+ strides=[2, 2],
396
+ output_channels=1,
397
+ norm=None,
398
+ activation=activation,
399
+ skip_connections=True,
400
+ low_dim_state_len=0,
401
+ )
402
+
403
+ qattention_agent = QAttentionAgent(
404
+ pixel_unet=qattention_net,
405
+ tau=qattention_tau,
406
+ camera_name=camera_name,
407
+ lr=qattention_lr,
408
+ weight_decay=qattention_weight_decay,
409
+ lambda_qreg=qattention_lambda_qreg,
410
+ include_low_dim_state=False,
411
+ grad_clip=qattention_grad_clip,
412
+ )
413
+
414
+ shared_net = SharedNet(activation, norm="layer")
415
+ critic_net = CriticNet(
416
+ activation, low_dim_state_len + 8, norm="layer", q_conf=q_conf
417
+ )
418
+ actor_net = ActorNet(activation, low_dim_state_len)
419
+
420
+ next_best_pose_agent = NextBestPoseAgent(
421
+ qattention_agent=qattention_agent,
422
+ shared_network=shared_net,
423
+ critic_network=critic_net,
424
+ actor_network=actor_net,
425
+ action_min_max=action_min_max,
426
+ camera_name=camera_name,
427
+ alpha=alpha,
428
+ alpha_lr=alpha_lr,
429
+ alpha_auto_tune=alpha_auto_tune,
430
+ critic_lr=critic_lr,
431
+ actor_lr=actor_lr,
432
+ critic_weight_decay=next_best_pose_critic_weight_decay,
433
+ actor_weight_decay=next_best_pose_actor_weight_decay,
434
+ crop_shape=crop_shape,
435
+ critic_tau=next_best_pose_tau,
436
+ critic_grad_clip=next_best_pose_critic_grad_clip,
437
+ actor_grad_clip=next_best_pose_actor_grad_clip,
438
+ q_conf=q_conf,
439
+ )
440
+
441
+ return PreprocessAgent(pose_agent=next_best_pose_agent)
external/peract_bimanual/agents/arm/next_best_pose_agent.py ADDED
@@ -0,0 +1,526 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import copy
2
+ import logging
3
+ import os
4
+ from typing import List
5
+
6
+ import numpy as np
7
+ import torch
8
+ import torch.nn as nn
9
+ import torch.nn.functional as F
10
+ from yarr.agents.agent import (
11
+ Agent,
12
+ Summary,
13
+ ActResult,
14
+ ScalarSummary,
15
+ ImageSummary,
16
+ HistogramSummary,
17
+ )
18
+
19
+ from helpers import utils
20
+ from helpers.utils import stack_on_channel
21
+ from agents.arm.qattention_agent import QAttentionAgent
22
+
23
+ NAME = "NextBestPoseAgent"
24
+ LOG_STD_MAX = 4
25
+ LOG_STD_MIN = -40
26
+ REPLAY_ALPHA = 0.7
27
+ REPLAY_BETA = 0.5
28
+
29
+
30
+ class QFunction(nn.Module):
31
+ def __init__(self, critic: nn.Module, shared: nn.Module, q_conf: bool):
32
+ super(QFunction, self).__init__()
33
+ self._q_conf = q_conf
34
+ self._q1 = copy.deepcopy(critic)
35
+ self._q2 = copy.deepcopy(critic)
36
+ self.shared = copy.deepcopy(shared)
37
+ self._q1.build()
38
+ self._q2.build()
39
+ self.shared.build()
40
+
41
+ def forward(self, observations, robot_state, action):
42
+ obs_feats = self.shared(observations)
43
+ combined = torch.cat([robot_state, action.float()], dim=1)
44
+ q1 = self._q1(obs_feats, combined)
45
+ q2 = self._q2(obs_feats, combined)
46
+ if self._q_conf:
47
+ b = q1.shape[0]
48
+ q1 = q1.view(b, 2, -1)
49
+ q2 = q2.view(b, 2, -1)
50
+ q1v, q1c = q1[:, 0], q1[:, 1]
51
+ q1_best = q1v.gather(1, q1c.argmax(dim=1).unsqueeze(-1))
52
+ q2v, q2c = q2[:, 0], q2[:, 1]
53
+ q2_best = q2v.gather(1, q2c.argmax(dim=1).unsqueeze(-1))
54
+ return q1, q2, q1_best, q2_best
55
+ else:
56
+ q1, q2 = q1.unsqueeze(1), q2.unsqueeze(1)
57
+ return q1, q2, q1, q2
58
+
59
+
60
+ class Actor(nn.Module):
61
+ def __init__(self, actor_network: nn.Module, action_min_max: torch.tensor):
62
+ super(Actor, self).__init__()
63
+ self._action_min_max = action_min_max
64
+ self._actor_network = copy.deepcopy(actor_network)
65
+ self._actor_network.build()
66
+
67
+ def _rescale_actions(self, x):
68
+ return (
69
+ 0.5 * (x + 1.0) * (self._action_min_max[1] - self._action_min_max[0])
70
+ + self._action_min_max[0]
71
+ )
72
+
73
+ def _normalize(self, x):
74
+ return x / x.square().sum(dim=1).sqrt().unsqueeze(-1)
75
+
76
+ def _gaussian_logprob(self, noise, log_std):
77
+ residual = (-0.5 * noise.pow(2) - log_std).sum(-1, keepdim=True)
78
+ return residual - 0.5 * np.log(2 * np.pi) * noise.size(-1)
79
+
80
+ def forward(self, observations, robot_state):
81
+ mu_and_logstd = self._actor_network(observations, robot_state)
82
+ mu, log_std = torch.split(mu_and_logstd, 8, dim=1)
83
+ log_std = torch.clamp(log_std, LOG_STD_MIN, LOG_STD_MAX)
84
+
85
+ std = log_std.exp()
86
+ noise = torch.randn_like(mu)
87
+ pi = mu + noise * std
88
+ log_pi = self._gaussian_logprob(noise, log_std)
89
+ mu = torch.tanh(mu)
90
+ pi = torch.tanh(pi)
91
+ log_pi -= torch.log(F.relu(1 - pi.pow(2)) + 1e-6).sum(-1, keepdim=True)
92
+
93
+ pi = self._rescale_actions(pi)
94
+ mu = self._rescale_actions(mu)
95
+
96
+ pi = torch.cat([pi[:, :3], self._normalize(pi[:, 3:7]), pi[:, 7:]], dim=-1)
97
+ mu = torch.cat([mu[:, :3], self._normalize(mu[:, 3:7]), mu[:, 7:]], dim=-1)
98
+ return mu, pi, log_pi, log_std
99
+
100
+
101
+ class NextBestPoseAgent(Agent):
102
+ def __init__(
103
+ self,
104
+ qattention_agent: QAttentionAgent,
105
+ shared_network: nn.Module,
106
+ critic_network: nn.Module,
107
+ actor_network: nn.Module,
108
+ action_min_max: tuple,
109
+ camera_name: str,
110
+ alpha: float = 0.2,
111
+ alpha_auto_tune: bool = True,
112
+ alpha_lr: float = 0.001,
113
+ critic_lr: float = 0.01,
114
+ actor_lr: float = 0.01,
115
+ critic_weight_decay: float = 1e-5,
116
+ actor_weight_decay: float = 1e-5,
117
+ crop_shape: tuple = (16, 16),
118
+ critic_tau: float = 0.005,
119
+ critic_grad_clip: float = 20.0,
120
+ actor_grad_clip: float = 20.0,
121
+ gamma: float = 0.99,
122
+ nstep: int = 1,
123
+ q_conf: bool = True,
124
+ ):
125
+ self._qattention_agent = qattention_agent
126
+ self._alpha = alpha
127
+ self._alpha_auto_tune = alpha_auto_tune
128
+ self._crop_shape = crop_shape
129
+ self._critic_tau = critic_tau
130
+ self._critic_grad_clip = critic_grad_clip
131
+ self._actor_grad_clip = actor_grad_clip
132
+ self._camera_name = camera_name
133
+ self._gamma = gamma
134
+ self._nstep = nstep
135
+ self._target_entropy = -8
136
+ self._shared_network = shared_network
137
+ self._critic_network = critic_network
138
+ self._actor_network = actor_network
139
+ self._action_min_max = action_min_max
140
+ self._critic_lr = critic_lr
141
+ self._actor_lr = actor_lr
142
+ self._alpha_lr = alpha_lr
143
+ self._critic_weight_decay = critic_weight_decay
144
+ self._actor_weight_decay = actor_weight_decay
145
+ self._q_conf = q_conf
146
+ self._crop_augmentation = False
147
+
148
+ def build(self, training: bool, device: torch.device = None):
149
+ if device is None:
150
+ device = torch.device("cpu")
151
+ self._qattention_agent.build(training, device)
152
+ action_min_max = torch.tensor(self._action_min_max).to(device)
153
+ self._actor = (
154
+ Actor(self._actor_network, action_min_max).to(device).train(training)
155
+ )
156
+
157
+ self._action_min_max_t = torch.tensor(self._action_min_max).to(device)
158
+
159
+ grid_for_crop = (
160
+ torch.arange(0, self._crop_shape[0], device=device)
161
+ .unsqueeze(0)
162
+ .repeat(self._crop_shape[0], 1)
163
+ .unsqueeze(-1)
164
+ )
165
+ self._grid_for_crop = torch.cat(
166
+ [grid_for_crop.transpose(1, 0), grid_for_crop], dim=2
167
+ ).unsqueeze(0)
168
+ self._q = (
169
+ QFunction(self._critic_network, self._shared_network, self._q_conf)
170
+ .to(device)
171
+ .train(training)
172
+ )
173
+ if training:
174
+ self._q_target = (
175
+ QFunction(self._critic_network, self._shared_network, self._q_conf)
176
+ .to(device)
177
+ .train(False)
178
+ )
179
+ utils.soft_updates(self._q, self._q_target, 1.0)
180
+
181
+ self._crop_shape_t = torch.tensor(
182
+ [list(self._crop_shape)], dtype=torch.int32, device=device
183
+ )
184
+
185
+ # Freeze target critic.
186
+ for p in self._q_target.parameters():
187
+ p.requires_grad = False
188
+
189
+ self._log_alpha = 0
190
+ if self._alpha_auto_tune:
191
+ self._log_alpha = torch.tensor(
192
+ (np.log(self._alpha)),
193
+ dtype=torch.float,
194
+ requires_grad=True,
195
+ device=device,
196
+ )
197
+ if training:
198
+ self._alpha_optimizer = torch.optim.Adam(
199
+ [self._log_alpha], lr=self._alpha_lr
200
+ )
201
+
202
+ self._critic_optimizer = torch.optim.Adam(
203
+ self._q.parameters(),
204
+ lr=self._critic_lr,
205
+ weight_decay=self._critic_weight_decay,
206
+ )
207
+ self._actor_optimizer = torch.optim.Adam(
208
+ self._actor.parameters(),
209
+ lr=self._actor_lr,
210
+ weight_decay=self._actor_weight_decay,
211
+ )
212
+
213
+ logging.info(
214
+ "# NBP Critic Params: %d"
215
+ % sum(p.numel() for p in self._q.parameters() if p.requires_grad)
216
+ )
217
+ logging.info(
218
+ "# NBP Actor Params: %d"
219
+ % sum(p.numel() for p in self._actor.parameters() if p.requires_grad)
220
+ )
221
+ else:
222
+ for p in self._actor.parameters():
223
+ p.requires_grad = False
224
+
225
+ self._device = device
226
+
227
+ @property
228
+ def alpha(self):
229
+ return self._log_alpha.exp() if self._alpha_auto_tune else self._alpha
230
+
231
+ def _extract_crop(self, pixel_action, observation):
232
+ # Pixel action will now be (B, 2)
233
+ observation = stack_on_channel(observation)
234
+ h = observation.shape[-1]
235
+ top_left_corner = torch.clamp(
236
+ pixel_action - self._crop_shape[0] // 2, 0, h - self._crop_shape[1]
237
+ )
238
+ grid = self._grid_for_crop + top_left_corner.unsqueeze(1).unsqueeze(1)
239
+ grid = ((grid / float(h)) * 2.0) - 1.0
240
+ grid = torch.cat((grid[:, :, :, 1:2], grid[:, :, :, 0:1]), dim=-1)
241
+ crop = F.grid_sample(observation, grid, mode="nearest", align_corners=True)
242
+ return crop
243
+
244
+ def _preprocess_inputs(self, replay_sample, pixel_action, pixel_action_tp1):
245
+ observations = [
246
+ self._extract_crop(
247
+ pixel_action, replay_sample["%s_rgb" % self._camera_name]
248
+ ),
249
+ self._extract_crop(
250
+ pixel_action, replay_sample["%s_point_cloud" % self._camera_name]
251
+ ),
252
+ ]
253
+ tp1_observations = [
254
+ self._extract_crop(
255
+ pixel_action_tp1, replay_sample["%s_rgb_tp1" % self._camera_name]
256
+ ),
257
+ self._extract_crop(
258
+ pixel_action_tp1,
259
+ replay_sample["%s_point_cloud_tp1" % self._camera_name],
260
+ ),
261
+ ]
262
+ return observations, tp1_observations
263
+
264
+ def _clip_action(self, a):
265
+ return torch.min(
266
+ torch.max(a, self._action_min_max_t[0:1]), self._action_min_max_t[1:2]
267
+ )
268
+
269
+ def _update_critic(self, replay_sample: dict) -> None:
270
+ action = replay_sample["action"]
271
+ reward = replay_sample["reward"]
272
+
273
+ robot_state = stack_on_channel(replay_sample["low_dim_state"][:, -1:])
274
+ robot_state_tp1 = stack_on_channel(replay_sample["low_dim_state_tp1"][:, -1:])
275
+
276
+ # Get last of time stack and first of plan stack
277
+ pixel_action = replay_sample["%s_pixel_coord" % self._camera_name][:, -1]
278
+ pixel_action_tp1 = replay_sample["%s_pixel_coord_tp1" % self._camera_name][
279
+ :, -1
280
+ ]
281
+
282
+ if self._crop_augmentation:
283
+ shifted = (
284
+ (torch.rand_like(pixel_action.float()) * self._crop_shape_t).int()
285
+ - self._crop_shape_t // 2
286
+ ) * replay_sample["demo"].int().unsqueeze(1)
287
+ pixel_action += shifted
288
+ pixel_action_tp1 += shifted
289
+
290
+ # Don't want timeouts to be classed as terminals
291
+ terminal = replay_sample["terminal"].float() - replay_sample["timeout"].float()
292
+
293
+ observations, tp1_observations = self._preprocess_inputs(
294
+ replay_sample, pixel_action, pixel_action_tp1
295
+ )
296
+
297
+ q1, q2, _, _ = self._q(observations, robot_state, action)
298
+
299
+ with torch.no_grad():
300
+ obs_feats = self._q.shared(tp1_observations)
301
+ _, pi_tp1, logp_pi_tp1, _ = self._actor(obs_feats, robot_state_tp1)
302
+
303
+ q1_pi_tp1_targ, q2_pi_tp1_targ, _, _ = self._q_target(
304
+ tp1_observations, robot_state_tp1, pi_tp1
305
+ )
306
+
307
+ min_q_pi_targ = torch.min(q1_pi_tp1_targ[:, 0], q2_pi_tp1_targ[:, 0])
308
+ next_value = min_q_pi_targ - self.alpha * logp_pi_tp1
309
+ q_backup = (
310
+ reward.unsqueeze(-1)
311
+ + (self._gamma**self._nstep)
312
+ * (1.0 - terminal.unsqueeze(-1))
313
+ * next_value
314
+ )
315
+
316
+ loss_weights = utils.loss_weights(replay_sample, REPLAY_BETA)
317
+
318
+ self._critic_summaries = {}
319
+ if self._q_conf:
320
+ w = 1.0
321
+ q1_delta = (
322
+ F.smooth_l1_loss(q1[:, 0], q_backup, reduction="none") * q1[:, 1]
323
+ - w * q1[:, 1].log()
324
+ )
325
+ q2_delta = (
326
+ F.smooth_l1_loss(q2[:, 0], q_backup, reduction="none") * q2[:, 1]
327
+ - w * q2[:, 1].log()
328
+ )
329
+ self._critic_summaries = {
330
+ "q_conf_loss": -(w * q1[:, 1].log()).mean(),
331
+ "q_conf_mean": q1[:, 1].mean(),
332
+ }
333
+ else:
334
+ q1_delta = F.smooth_l1_loss(q1[:, 0], q_backup, reduction="none")
335
+ q2_delta = F.smooth_l1_loss(q2[:, 0], q_backup, reduction="none")
336
+
337
+ q1_delta, q2_delta = q1_delta.mean(1), q2_delta.mean(1)
338
+ q1_bellman_loss = (q1_delta * loss_weights).mean()
339
+ q2_bellman_loss = (q2_delta * loss_weights).mean()
340
+
341
+ critic_loss = q1_bellman_loss + q2_bellman_loss
342
+
343
+ self._critic_summaries.update(
344
+ {
345
+ "q1_bellman_loss": q1_bellman_loss,
346
+ "q2_bellman_loss": q2_bellman_loss,
347
+ "q1_mean": q1[:, 0].mean().item(),
348
+ "q2_mean": q2[:, 0].mean().item(),
349
+ "alpha": self.alpha,
350
+ }
351
+ )
352
+ self._crop_summary = observations
353
+ self._crop_summary_tp1 = tp1_observations
354
+
355
+ new_pri = torch.sqrt((q1_delta + q2_delta) / 2.0 + 1e-10)
356
+ self._new_priority = (new_pri / torch.max(new_pri)).detach()
357
+ self._grad_step(
358
+ critic_loss,
359
+ self._critic_optimizer,
360
+ self._q.parameters(),
361
+ self._critic_grad_clip,
362
+ )
363
+
364
+ def _update_actor(self, replay_sample: dict) -> None:
365
+ robot_state = stack_on_channel(replay_sample["low_dim_state"][:, -1:])
366
+ pixel_action = replay_sample["%s_pixel_coord" % self._camera_name][:, -1]
367
+
368
+ if self._crop_augmentation:
369
+ shifted = (
370
+ (torch.rand_like(pixel_action.float()) * self._crop_shape_t).int()
371
+ - self._crop_shape_t // 2
372
+ ) * replay_sample["demo"].int().unsqueeze(1)
373
+ pixel_action += shifted
374
+
375
+ # Crop the observations
376
+ observations = [
377
+ self._extract_crop(
378
+ pixel_action, replay_sample["%s_rgb" % self._camera_name]
379
+ ),
380
+ self._extract_crop(
381
+ pixel_action, replay_sample["%s_point_cloud" % self._camera_name]
382
+ ),
383
+ ]
384
+
385
+ with torch.no_grad():
386
+ obs_feats = self._q.shared(observations)
387
+
388
+ mu, pi, self._logp_pi, log_scale_diag = self._actor(obs_feats, robot_state)
389
+
390
+ _, _, q1_pi, q2_pi = self._q(observations, robot_state, pi)
391
+
392
+ min_q_pi = torch.min(q1_pi, q2_pi)[:, 0]
393
+ pi_loss = self.alpha * self._logp_pi - min_q_pi
394
+
395
+ loss_weights = utils.loss_weights(replay_sample, REPLAY_BETA)
396
+ pi_loss = (pi_loss * loss_weights).mean()
397
+
398
+ self._actor_summaries = {
399
+ "pi/loss": pi_loss,
400
+ "pi/q1_pi_mean": q1_pi.mean(),
401
+ "pi/q2_pi_mean": q2_pi.mean(),
402
+ "pi/mu": mu.mean(),
403
+ "pi/pi": pi.mean(),
404
+ "pi/log_pi": self._logp_pi.mean(),
405
+ "pi/log_scale_diag": log_scale_diag.mean(),
406
+ }
407
+ self._grad_step(
408
+ pi_loss,
409
+ self._actor_optimizer,
410
+ self._actor.parameters(),
411
+ self._actor_grad_clip,
412
+ )
413
+
414
+ def _update_alpha(self):
415
+ alpha_loss = -(
416
+ self.alpha * (self._logp_pi + self._target_entropy).detach()
417
+ ).mean()
418
+ self._grad_step(alpha_loss, self._alpha_optimizer)
419
+
420
+ def _grad_step(self, loss, opt, model_params=None, clip=None):
421
+ opt.zero_grad()
422
+ loss.backward()
423
+ if clip is not None and model_params is not None:
424
+ nn.utils.clip_grad_value_(model_params, clip)
425
+ opt.step()
426
+
427
+ def update(self, step: int, replay_sample: dict) -> dict:
428
+ info = self._qattention_agent.update(step, replay_sample)
429
+
430
+ self._update_critic(replay_sample)
431
+
432
+ # Freeze critic so you don't waste computational effort
433
+ # computing gradients for them during the policy learning step.
434
+ for p in self._q.parameters():
435
+ p.requires_grad = False
436
+
437
+ self._update_actor(replay_sample)
438
+ if self._alpha_auto_tune:
439
+ self._update_alpha()
440
+
441
+ # UnFreeze critic.
442
+ for p in self._q.parameters():
443
+ p.requires_grad = True
444
+
445
+ utils.soft_updates(self._q, self._q_target, self._critic_tau)
446
+ pixel_agent_priority = info["priority"]
447
+ return {
448
+ "priority": ((self._new_priority + pixel_agent_priority) / 2.0)
449
+ ** REPLAY_ALPHA
450
+ }
451
+
452
+ def act(self, step: int, observation: dict, deterministic=False) -> ActResult:
453
+ with torch.no_grad():
454
+ act_res = self._qattention_agent.act(step, observation, deterministic)
455
+ observations = [
456
+ self._extract_crop(
457
+ act_res.action.unsqueeze(0),
458
+ observation["%s_rgb" % self._camera_name],
459
+ ),
460
+ self._extract_crop(
461
+ act_res.action.unsqueeze(0),
462
+ observation["%s_point_cloud" % self._camera_name],
463
+ ),
464
+ ]
465
+ self._act_crop_summaries = observations
466
+ robot_state = stack_on_channel(observation["low_dim_state"][:, -1:])
467
+ obs_feats = self._q.shared(observations)
468
+ mu, pi, _, _ = self._actor(obs_feats, robot_state)
469
+ act_res.action = (mu if deterministic else pi)[0]
470
+ act_res.info.update({"rgb_crop": observations[0]})
471
+ return act_res
472
+
473
+ def update_summaries(self) -> List[Summary]:
474
+ summaries = [
475
+ ImageSummary("%s/crops/rgb" % NAME, (self._crop_summary[0] + 1.0) / 2.0),
476
+ ImageSummary("%s/crops/point_cloud" % NAME, self._crop_summary[1]),
477
+ ImageSummary(
478
+ "%s/crops_tp1/rgb" % NAME, (self._crop_summary_tp1[0] + 1.0) / 2.0
479
+ ),
480
+ ImageSummary("%s/crops_tp1/point_cloud" % NAME, self._crop_summary_tp1[1]),
481
+ ]
482
+
483
+ for n, v in list(self._critic_summaries.items()) + list(
484
+ self._actor_summaries.items()
485
+ ):
486
+ summaries.append(ScalarSummary("%s/%s" % (NAME, n), v))
487
+
488
+ for tag, param in list(self._q.named_parameters()) + list(
489
+ self._actor.named_parameters()
490
+ ):
491
+ summaries.append(
492
+ HistogramSummary("%s/gradient/%s" % (NAME, tag), param.grad)
493
+ )
494
+ summaries.append(HistogramSummary("%s/weight/%s" % (NAME, tag), param.data))
495
+
496
+ pixel_summaries = self._qattention_agent.update_summaries()
497
+ return pixel_summaries + summaries
498
+
499
+ def act_summaries(self) -> List[Summary]:
500
+ summaries = [
501
+ ImageSummary(
502
+ "%s/crops/act/rgb" % NAME, (self._act_crop_summaries[0] + 1.0) / 2.0
503
+ ),
504
+ ImageSummary(
505
+ "%s/crops/act/point_cloud" % NAME, self._act_crop_summaries[1]
506
+ ),
507
+ ]
508
+ return summaries + self._qattention_agent.act_summaries()
509
+
510
+ def load_weights(self, savedir: str):
511
+ self._qattention_agent.load_weights(savedir)
512
+ self._actor.load_state_dict(
513
+ torch.load(
514
+ os.path.join(savedir, "pose_actor.pt"), map_location=torch.device("cpu")
515
+ )
516
+ )
517
+ self._q.load_state_dict(
518
+ torch.load(
519
+ os.path.join(savedir, "pose_q.pt"), map_location=torch.device("cpu")
520
+ )
521
+ )
522
+
523
+ def save_weights(self, savedir: str):
524
+ self._qattention_agent.save_weights(savedir)
525
+ torch.save(self._actor.state_dict(), os.path.join(savedir, "pose_actor.pt"))
526
+ torch.save(self._q.state_dict(), os.path.join(savedir, "pose_q.pt"))
external/peract_bimanual/agents/arm/qattention_agent.py ADDED
@@ -0,0 +1,247 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import copy
2
+ import logging
3
+ import os
4
+ from typing import List
5
+
6
+ import PIL
7
+ import torch
8
+ import torch.nn as nn
9
+ import torch.nn.functional as F
10
+ from torchvision import transforms
11
+
12
+ from yarr.agents.agent import (
13
+ Agent,
14
+ ActResult,
15
+ ScalarSummary,
16
+ HistogramSummary,
17
+ ImageSummary,
18
+ Summary,
19
+ )
20
+
21
+ from helpers import utils
22
+ from helpers.utils import stack_on_channel
23
+
24
+ NAME = "QAttentionAgent"
25
+ REPLAY_BETA = 1.0
26
+
27
+
28
+ class QFunction(nn.Module):
29
+ def __init__(self, unet: nn.Module):
30
+ super(QFunction, self).__init__()
31
+ self._qnet = copy.deepcopy(unet)
32
+ self._qnet2 = copy.deepcopy(unet)
33
+ self._qnet.build()
34
+ self._qnet2.build()
35
+
36
+ def _argmax_2d(self, tensor):
37
+ t_shape = tensor.shape
38
+ m = tensor.view(t_shape[0], -1).argmax(1).view(-1, 1)
39
+ indices = torch.cat((m // t_shape[-1], m % t_shape[-1]), dim=1)
40
+ return indices
41
+
42
+ def forward(self, x, robot_state):
43
+ q = self._qnet(x, robot_state)[:, 0]
44
+ q2 = self._qnet2(x, robot_state)[:, 0]
45
+ coords = self._argmax_2d(torch.min(q, q2))
46
+ return q, q2, coords
47
+
48
+
49
+ class QAttentionAgent(Agent):
50
+ def __init__(
51
+ self,
52
+ pixel_unet: nn.Module,
53
+ camera_name: str,
54
+ tau: float = 0.005,
55
+ gamma: float = 0.99,
56
+ nstep: int = 1,
57
+ lr: float = 0.0001,
58
+ weight_decay: float = 1e-5,
59
+ lambda_qreg: float = 1e-6,
60
+ grad_clip: float = 20.0,
61
+ include_low_dim_state: bool = False,
62
+ ):
63
+ self._pixel_unet = pixel_unet
64
+ self._camera_name = camera_name
65
+ self._tau = tau
66
+ self._gamma = gamma
67
+ self._nstep = nstep
68
+ self._lr = lr
69
+ self._weight_decay = weight_decay
70
+ self._lambda_qreg = lambda_qreg
71
+ self._grad_clip = grad_clip
72
+ self._include_low_dim_state = include_low_dim_state
73
+
74
+ def build(self, training: bool, device: torch.device = None):
75
+ if device is None:
76
+ device = torch.device("cpu")
77
+ self._q = QFunction(self._pixel_unet).to(device).train(training)
78
+ self._q_target = None
79
+ if training:
80
+ self._q_target = QFunction(self._pixel_unet).to(device).train(False)
81
+ for p in self._q_target.parameters():
82
+ p.requires_grad = False
83
+ utils.soft_updates(self._q, self._q_target, 1.0)
84
+ self._optimizer = torch.optim.Adam(
85
+ self._q.parameters(), lr=self._lr, weight_decay=self._weight_decay
86
+ )
87
+ logging.info(
88
+ "# Q-attention Params: %d"
89
+ % sum(p.numel() for p in self._q.parameters() if p.requires_grad)
90
+ )
91
+ else:
92
+ for p in self._q.parameters():
93
+ p.requires_grad = False
94
+ self._device = device
95
+
96
+ def _get_q_from_pixel_coord(self, q, coord):
97
+ b, h, w = q.shape
98
+ flat_indicies = (coord[:, 0] * w + coord[:, 1])[:, None].long()
99
+ return q.view(b, h * w).gather(1, flat_indicies)
100
+
101
+ def _preprocess_inputs(self, replay_sample):
102
+ observations = [
103
+ stack_on_channel(replay_sample["%s_rgb" % self._camera_name]),
104
+ stack_on_channel(replay_sample["%s_point_cloud" % self._camera_name]),
105
+ ]
106
+ tp1_observations = [
107
+ stack_on_channel(replay_sample["%s_rgb_tp1" % self._camera_name]),
108
+ stack_on_channel(replay_sample["%s_point_cloud_tp1" % self._camera_name]),
109
+ ]
110
+ return observations, tp1_observations
111
+
112
+ def update(self, step: int, replay_sample: dict) -> dict:
113
+ pixel_action = replay_sample["%s_pixel_coord" % self._camera_name][:, -1].int()
114
+ reward = replay_sample["reward"]
115
+ reward = torch.where(reward > 0, reward, torch.zeros_like(reward))
116
+
117
+ robot_state = robot_state_tp1 = None
118
+ if self._include_low_dim_state:
119
+ robot_state = stack_on_channel(replay_sample["low_dim_state"])
120
+ robot_state_tp1 = stack_on_channel(replay_sample["low_dim_state_tp1"])
121
+
122
+ # Don't want timeouts to be classed as terminals
123
+ terminal = replay_sample["terminal"].float() - replay_sample["timeout"].float()
124
+
125
+ obs, obs_tp1 = self._preprocess_inputs(replay_sample)
126
+ q, q2, coords = self._q(obs, robot_state)
127
+
128
+ with torch.no_grad():
129
+ # (B, h, w)
130
+ _, _, coords_tp1 = self._q(obs_tp1, robot_state_tp1)
131
+ q_tp1_targ, q2_tp1_targ, _ = self._q_target(obs_tp1, robot_state_tp1)
132
+ q_tp1_targ = torch.min(q_tp1_targ, q2_tp1_targ)
133
+ q_tp1_targ = self._get_q_from_pixel_coord(q_tp1_targ, coords_tp1)
134
+ target = (
135
+ reward.unsqueeze(1)
136
+ + (self._gamma**self._nstep)
137
+ * (1 - terminal.unsqueeze(1))
138
+ * q_tp1_targ
139
+ )
140
+ target = torch.clamp(target, 0.0, 100.0)
141
+
142
+ q_pred = self._get_q_from_pixel_coord(q, pixel_action)
143
+ delta = F.smooth_l1_loss(q_pred, target, reduction="none").mean(1)
144
+
145
+ delta += F.smooth_l1_loss(
146
+ self._get_q_from_pixel_coord(q2, pixel_action), target, reduction="none"
147
+ ).mean(1)
148
+ q_reg = (
149
+ (0.5 * torch.sum(q**2)) + (0.5 * torch.sum(q2**2))
150
+ ) * self._lambda_qreg
151
+
152
+ loss_weights = utils.loss_weights(replay_sample, REPLAY_BETA)
153
+ total_loss = ((delta) * loss_weights).mean() + q_reg
154
+ new_priority = ((delta) + 1e-10).sqrt()
155
+ new_priority /= new_priority.max()
156
+
157
+ self._summaries = {
158
+ "losses/bellman": delta.mean(),
159
+ "losses/qreg": q_reg.mean(),
160
+ "q/mean": q.mean(),
161
+ "q/action_q": q_pred.mean(),
162
+ }
163
+ self._qvalues = q[:1]
164
+ self._rgb_observation = replay_sample["front_rgb"][0, -1]
165
+ self._optimizer.zero_grad()
166
+ total_loss.backward()
167
+ if self._grad_clip is not None:
168
+ nn.utils.clip_grad_value_(self._q.parameters(), self._grad_clip)
169
+ self._optimizer.step()
170
+ utils.soft_updates(self._q, self._q_target, self._tau)
171
+
172
+ return {
173
+ "priority": new_priority,
174
+ }
175
+
176
+ def act(self, step: int, observation: dict, deterministic=False) -> ActResult:
177
+ with torch.no_grad():
178
+ observations = [
179
+ stack_on_channel(observation["%s_rgb" % self._camera_name]),
180
+ stack_on_channel(observation["%s_point_cloud" % self._camera_name]),
181
+ ]
182
+ robot_state = None
183
+ if self._include_low_dim_state:
184
+ robot_state = stack_on_channel(observation["low_dim_state"])
185
+ # Coords are stored as (y, x)
186
+ q, q2, coords = self._q(observations, robot_state)
187
+ self._act_qvalues = torch.min(q, q2)[:1]
188
+ self._rgb_observation = observation["front_rgb"][0, -1]
189
+ return ActResult(
190
+ coords[0],
191
+ observation_elements={
192
+ "%s_pixel_coord" % self._camera_name: coords[0],
193
+ },
194
+ info={"q_values": self._act_qvalues},
195
+ )
196
+
197
+ @staticmethod
198
+ def generate_heatmap(q_values, rgb_obs):
199
+ norm_q = torch.clamp(q_values / 100.0, 0, 1)
200
+ heatmap = torch.cat(
201
+ [norm_q, torch.zeros_like(norm_q), torch.zeros_like(norm_q)]
202
+ )
203
+ img = transforms.functional.to_pil_image(rgb_obs)
204
+ h_img = transforms.functional.to_pil_image(heatmap).convert("RGB")
205
+ ret = PIL.Image.blend(img, h_img, 0.75)
206
+ return transforms.ToTensor()(ret).unsqueeze_(0)
207
+
208
+ def update_summaries(self) -> List[Summary]:
209
+ summaries = [
210
+ ImageSummary(
211
+ "%s/Q" % NAME,
212
+ QAttentionAgent.generate_heatmap(
213
+ self._qvalues.cpu(), ((self._rgb_observation + 1) / 2.0).cpu()
214
+ ),
215
+ )
216
+ ]
217
+ for n, v in self._summaries.items():
218
+ summaries.append(ScalarSummary("%s/%s" % (NAME, n), v))
219
+
220
+ for tag, param in self._q.named_parameters():
221
+ assert not torch.isnan(param.grad.abs() <= 1.0).all()
222
+ summaries.append(
223
+ HistogramSummary("%s/gradient/%s" % (NAME, tag), param.grad)
224
+ )
225
+ summaries.append(HistogramSummary("%s/weight/%s" % (NAME, tag), param.data))
226
+ return summaries
227
+
228
+ def act_summaries(self) -> List[Summary]:
229
+ return [
230
+ ImageSummary(
231
+ "%s/Q_act" % NAME,
232
+ QAttentionAgent.generate_heatmap(
233
+ self._act_qvalues.cpu(), ((self._rgb_observation + 1) / 2.0).cpu()
234
+ ),
235
+ )
236
+ ]
237
+
238
+ def load_weights(self, savedir: str):
239
+ self._q.load_state_dict(
240
+ torch.load(
241
+ os.path.join(savedir, "pixel_agent_q.pt"),
242
+ map_location=torch.device("cpu"),
243
+ )
244
+ )
245
+
246
+ def save_weights(self, savedir: str):
247
+ torch.save(self._q.state_dict(), os.path.join(savedir, "pixel_agent_q.pt"))
external/peract_bimanual/agents/baselines/__init__.py ADDED
File without changes
external/peract_bimanual/agents/baselines/bc_lang/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ import agents.baselines.bc_lang.launch_utils
external/peract_bimanual/agents/baselines/bc_lang/bc_lang_agent.py ADDED
@@ -0,0 +1,148 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import copy
2
+ import logging
3
+ import os
4
+ from typing import List
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+ import torch.nn.functional as F
9
+ from yarr.agents.agent import Agent, Summary, ActResult, ScalarSummary, HistogramSummary
10
+
11
+ from helpers import utils
12
+ from helpers.utils import stack_on_channel
13
+
14
+ from helpers.clip.core.clip import build_model, load_clip
15
+
16
+ NAME = "BCLangAgent"
17
+ REPLAY_ALPHA = 0.7
18
+ REPLAY_BETA = 1.0
19
+
20
+
21
+ class Actor(nn.Module):
22
+ def __init__(self, actor_network: nn.Module):
23
+ super(Actor, self).__init__()
24
+ self._actor_network = copy.deepcopy(actor_network)
25
+ self._actor_network.build()
26
+
27
+ def forward(self, observations, robot_state, lang_goal_emb):
28
+ mu = self._actor_network(observations, robot_state, lang_goal_emb)
29
+ return mu
30
+
31
+
32
+ class BCLangAgent(Agent):
33
+ def __init__(
34
+ self,
35
+ actor_network: nn.Module,
36
+ camera_name: str,
37
+ lr: float = 0.01,
38
+ weight_decay: float = 1e-5,
39
+ grad_clip: float = 20.0,
40
+ ):
41
+ self._camera_name = camera_name
42
+ self._actor_network = actor_network
43
+ self._lr = lr
44
+ self._weight_decay = weight_decay
45
+ self._grad_clip = grad_clip
46
+
47
+ def build(self, training: bool, device: torch.device = None):
48
+ if device is None:
49
+ device = torch.device("cpu")
50
+ self._actor = Actor(self._actor_network).to(device).train(training)
51
+ if training:
52
+ self._actor_optimizer = torch.optim.Adam(
53
+ self._actor.parameters(), lr=self._lr, weight_decay=self._weight_decay
54
+ )
55
+ logging.info(
56
+ "# Actor Params: %d"
57
+ % sum(p.numel() for p in self._actor.parameters() if p.requires_grad)
58
+ )
59
+ else:
60
+ for p in self._actor.parameters():
61
+ p.requires_grad = False
62
+
63
+ model, _ = load_clip("RN50", jit=False)
64
+ self._clip_rn50 = build_model(model.state_dict())
65
+ self._clip_rn50 = self._clip_rn50.float().to(device)
66
+ self._clip_rn50.eval()
67
+ del model
68
+
69
+ self._device = device
70
+
71
+ def _grad_step(self, loss, opt, model_params=None, clip=None):
72
+ opt.zero_grad()
73
+ loss.backward()
74
+ if clip is not None and model_params is not None:
75
+ nn.utils.clip_grad_value_(model_params, clip)
76
+ opt.step()
77
+
78
+ def update(self, step: int, replay_sample: dict) -> dict:
79
+ lang_goal_emb = replay_sample["lang_goal_emb"]
80
+ robot_state = replay_sample["low_dim_state"]
81
+ observations = [
82
+ replay_sample["%s_rgb" % self._camera_name],
83
+ replay_sample["%s_point_cloud" % self._camera_name],
84
+ ]
85
+ mu = self._actor(observations, robot_state, lang_goal_emb)
86
+ loss_weights = utils.loss_weights(replay_sample, REPLAY_BETA)
87
+ delta = F.mse_loss(mu, replay_sample["action"], reduction="none").mean(1)
88
+ loss = (delta * loss_weights).mean()
89
+ self._grad_step(
90
+ loss, self._actor_optimizer, self._actor.parameters(), self._grad_clip
91
+ )
92
+ self._summaries = {
93
+ "pi/loss": loss,
94
+ "pi/mu": mu.mean(),
95
+ }
96
+ return {"total_losses": loss}
97
+
98
+ def _normalize_quat(self, x):
99
+ return x / x.square().sum(dim=1).sqrt().unsqueeze(-1)
100
+
101
+ def act(self, step: int, observation: dict, deterministic=False) -> ActResult:
102
+ lang_goal_tokens = observation.get("lang_goal_tokens", None).long()
103
+
104
+ with torch.no_grad():
105
+ lang_goal_tokens = lang_goal_tokens.to(device=self._device)
106
+ lang_goal_emb, _ = self._clip_rn50.encode_text_with_embeddings(
107
+ lang_goal_tokens[0]
108
+ )
109
+ lang_goal_emb = lang_goal_emb.to(device=self._device)
110
+
111
+ observations = [
112
+ observation["%s_rgb" % self._camera_name][0].to(self._device),
113
+ observation["%s_point_cloud" % self._camera_name][0].to(self._device),
114
+ ]
115
+ robot_state = observation["low_dim_state"][0].to(self._device)
116
+
117
+ mu = self._actor(observations, robot_state, lang_goal_emb)
118
+ mu = torch.cat([mu[:, :3], self._normalize_quat(mu[:, 3:7]), mu[:, 7:]], dim=-1)
119
+ ignore_collisions = torch.Tensor([1.0]).to(mu.device)
120
+ mu0 = torch.cat([mu[0], ignore_collisions])
121
+ return ActResult(mu0.detach().cpu())
122
+
123
+ def update_summaries(self) -> List[Summary]:
124
+ summaries = []
125
+ for n, v in self._summaries.items():
126
+ summaries.append(ScalarSummary("%s/%s" % (NAME, n), v))
127
+
128
+ for tag, param in self._actor.named_parameters():
129
+ summaries.append(
130
+ HistogramSummary("%s/gradient/%s" % (NAME, tag), param.grad)
131
+ )
132
+ summaries.append(HistogramSummary("%s/weight/%s" % (NAME, tag), param.data))
133
+
134
+ return summaries
135
+
136
+ def act_summaries(self) -> List[Summary]:
137
+ return []
138
+
139
+ def load_weights(self, savedir: str):
140
+ self._actor.load_state_dict(
141
+ torch.load(
142
+ os.path.join(savedir, "bc_actor.pt"), map_location=torch.device("cpu")
143
+ )
144
+ )
145
+ print("Loaded weights from %s" % savedir)
146
+
147
+ def save_weights(self, savedir: str):
148
+ torch.save(self._actor.state_dict(), os.path.join(savedir, "bc_actor.pt"))
external/peract_bimanual/agents/baselines/bc_lang/launch_utils.py ADDED
@@ -0,0 +1,368 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Adapted from ARM
2
+ # Source: https://github.com/stepjam/ARM
3
+ # License: https://github.com/stepjam/ARM/LICENSE
4
+
5
+ import logging
6
+ from typing import List
7
+
8
+ import numpy as np
9
+ from omegaconf import DictConfig
10
+ from rlbench.backend.observation import Observation
11
+ from rlbench.observation_config import ObservationConfig
12
+ import rlbench.utils as rlbench_utils
13
+ from rlbench.demo import Demo
14
+ from yarr.replay_buffer.prioritized_replay_buffer import (
15
+ PrioritizedReplayBuffer,
16
+ ObservationElement,
17
+ )
18
+ from yarr.replay_buffer.replay_buffer import ReplayElement, ReplayBuffer
19
+ from yarr.replay_buffer.uniform_replay_buffer import UniformReplayBuffer
20
+ from yarr.replay_buffer.task_uniform_replay_buffer import TaskUniformReplayBuffer
21
+
22
+ from helpers import demo_loading_utils, utils
23
+ from helpers import observation_utils
24
+ from agents.baselines.bc_lang.bc_lang_agent import BCLangAgent
25
+ from helpers.custom_rlbench_env import CustomRLBenchEnv
26
+ from helpers.network_utils import SiameseNet, CNNLangAndFcsNet
27
+ from helpers.preprocess_agent import PreprocessAgent
28
+
29
+ import torch
30
+ from torch.multiprocessing import Process, Value, Manager
31
+ from helpers.clip.core.clip import build_model, load_clip, tokenize
32
+
33
+ LOW_DIM_SIZE = 4
34
+
35
+
36
+ def create_replay(
37
+ batch_size: int,
38
+ timesteps: int,
39
+ prioritisation: bool,
40
+ task_uniform: bool,
41
+ save_dir: str,
42
+ cameras: list,
43
+ image_size=[128, 128],
44
+ replay_size=3e5,
45
+ ):
46
+ lang_feat_dim = 1024
47
+
48
+ # low_dim_state
49
+ observation_elements = []
50
+ observation_elements.append(
51
+ ObservationElement("low_dim_state", (LOW_DIM_SIZE,), np.float32)
52
+ )
53
+
54
+ # rgb, depth, point cloud, intrinsics, extrinsics
55
+ for cname in cameras:
56
+ observation_elements.append(
57
+ ObservationElement(
58
+ "%s_rgb" % cname,
59
+ (
60
+ 3,
61
+ *image_size,
62
+ ),
63
+ np.float32,
64
+ )
65
+ )
66
+ observation_elements.append(
67
+ ObservationElement("%s_point_cloud" % cname, (3, *image_size), np.float32)
68
+ ) # see pyrep/objects/vision_sensor.py on how pointclouds are extracted from depth frames
69
+ observation_elements.append(
70
+ ObservationElement(
71
+ "%s_camera_extrinsics" % cname,
72
+ (
73
+ 4,
74
+ 4,
75
+ ),
76
+ np.float32,
77
+ )
78
+ )
79
+ observation_elements.append(
80
+ ObservationElement(
81
+ "%s_camera_intrinsics" % cname,
82
+ (
83
+ 3,
84
+ 3,
85
+ ),
86
+ np.float32,
87
+ )
88
+ )
89
+
90
+ observation_elements.extend(
91
+ [
92
+ ReplayElement("lang_goal_emb", (lang_feat_dim,), np.float32),
93
+ ReplayElement("task", (), str),
94
+ ReplayElement(
95
+ "lang_goal", (1,), object
96
+ ), # language goal string for debugging and visualization
97
+ ]
98
+ )
99
+
100
+ extra_replay_elements = [
101
+ ReplayElement("demo", (), np.bool),
102
+ ]
103
+
104
+ replay_buffer = TaskUniformReplayBuffer(
105
+ save_dir=save_dir,
106
+ batch_size=batch_size,
107
+ timesteps=timesteps,
108
+ replay_capacity=int(replay_size),
109
+ action_shape=(8,),
110
+ action_dtype=np.float32,
111
+ reward_shape=(),
112
+ reward_dtype=np.float32,
113
+ update_horizon=1,
114
+ observation_elements=observation_elements,
115
+ extra_replay_elements=extra_replay_elements,
116
+ )
117
+ return replay_buffer
118
+
119
+
120
+ def _get_action(obs_tp1: Observation):
121
+ quat = utils.normalize_quaternion(obs_tp1.gripper_pose[3:])
122
+ if quat[-1] < 0:
123
+ quat = -quat
124
+ return np.concatenate(
125
+ [obs_tp1.gripper_pose[:3], quat, [float(obs_tp1.gripper_open)]]
126
+ )
127
+
128
+
129
+ def _add_keypoints_to_replay(
130
+ cfg: DictConfig,
131
+ task: str,
132
+ replay: ReplayBuffer,
133
+ inital_obs: Observation,
134
+ demo: Demo,
135
+ episode_keypoints: List[int],
136
+ cameras: List[str],
137
+ description: str = "",
138
+ clip_model=None,
139
+ device="cpu",
140
+ ):
141
+ prev_action = None
142
+ obs = inital_obs
143
+ all_actions = []
144
+ for k, keypoint in enumerate(episode_keypoints):
145
+ obs_tp1 = demo[keypoint]
146
+ action = _get_action(obs_tp1)
147
+ all_actions.append(action)
148
+ terminal = k == len(episode_keypoints) - 1
149
+ reward = float(terminal) if terminal else 0
150
+
151
+ obs_dict = observation_utils.extract_obs(
152
+ obs,
153
+ t=k,
154
+ prev_action=prev_action,
155
+ cameras=cameras,
156
+ episode_length=cfg.rlbench.episode_length,
157
+ robot_name=cfg.method.robot_name,
158
+ )
159
+ del obs_dict["ignore_collisions"]
160
+ tokens = tokenize([description]).numpy()
161
+ token_tensor = torch.from_numpy(tokens).to(device)
162
+ lang_feats, lang_embs = clip_model.encode_text_with_embeddings(token_tensor)
163
+ obs_dict["lang_goal_emb"] = lang_feats[0].float().detach().cpu().numpy()
164
+
165
+ final_obs = {
166
+ "task": task,
167
+ "lang_goal": np.array([description], dtype=object),
168
+ }
169
+
170
+ prev_action = np.copy(action)
171
+ others = {"demo": True}
172
+ others.update(final_obs)
173
+ others.update(obs_dict)
174
+ timeout = False
175
+ replay.add(action, reward, terminal, timeout, **others)
176
+ obs = obs_tp1 # Set the next obs
177
+ # Final step
178
+ obs_dict_tp1 = observation_utils.extract_obs(
179
+ obs_tp1,
180
+ t=k + 1,
181
+ prev_action=prev_action,
182
+ cameras=cameras,
183
+ episode_length=cfg.rlbench.episode_length,
184
+ robot_name=cfg.method.robot_name,
185
+ )
186
+ obs_dict_tp1["lang_goal_emb"] = lang_feats[0].float().detach().cpu().numpy()
187
+ # del obs_dict_tp1['lang_goal_tokens']
188
+ del obs_dict_tp1["ignore_collisions"]
189
+ # obs_dict_tp1['task'] = task
190
+ obs_dict_tp1.update(final_obs)
191
+ replay.add_final(**obs_dict_tp1)
192
+ return all_actions
193
+
194
+
195
+ def fill_replay(
196
+ cfg: DictConfig,
197
+ obs_config: ObservationConfig,
198
+ rank: int,
199
+ replay: ReplayBuffer,
200
+ task: str,
201
+ num_demos: int,
202
+ demo_augmentation: bool,
203
+ demo_augmentation_every_n: int,
204
+ cameras: List[str],
205
+ clip_model=None,
206
+ device="cpu",
207
+ ):
208
+ if clip_model is None:
209
+ model, _ = load_clip("RN50", jit=False, device=device)
210
+ clip_model = build_model(model.state_dict())
211
+ clip_model.to(device)
212
+ del model
213
+
214
+ logging.debug("Filling %s replay ..." % task)
215
+ all_actions = []
216
+ for d_idx in range(num_demos):
217
+ # load demo from disk
218
+ demo = rlbench_utils.get_stored_demos(
219
+ amount=1,
220
+ image_paths=False,
221
+ dataset_root=cfg.rlbench.demo_path,
222
+ variation_number=-1,
223
+ task_name=task,
224
+ obs_config=obs_config,
225
+ random_selection=False,
226
+ from_episode_number=d_idx,
227
+ )[0]
228
+
229
+ descs = demo._observations[0].misc["descriptions"]
230
+
231
+ # extract keypoints (a.k.a keyframes)
232
+ episode_keypoints = demo_loading_utils.keypoint_discovery(demo)
233
+
234
+ if rank == 0:
235
+ logging.info(
236
+ f"Loading Demo({d_idx}) - found {len(episode_keypoints)} keypoints - {task}"
237
+ )
238
+
239
+ for i in range(len(demo) - 1):
240
+ if not demo_augmentation and i > 0:
241
+ break
242
+ if i % demo_augmentation_every_n != 0:
243
+ continue
244
+
245
+ obs = demo[i]
246
+ desc = descs[0]
247
+ # if our starting point is past one of the keypoints, then remove it
248
+ while len(episode_keypoints) > 0 and i >= episode_keypoints[0]:
249
+ episode_keypoints = episode_keypoints[1:]
250
+ if len(episode_keypoints) == 0:
251
+ break
252
+ all_actions.extend(
253
+ _add_keypoints_to_replay(
254
+ cfg,
255
+ task,
256
+ replay,
257
+ obs,
258
+ demo,
259
+ episode_keypoints,
260
+ cameras,
261
+ description=desc,
262
+ clip_model=clip_model,
263
+ device=device,
264
+ )
265
+ )
266
+ logging.debug("Replay filled with demos.")
267
+ return all_actions
268
+
269
+
270
+ def fill_multi_task_replay(
271
+ cfg: DictConfig,
272
+ obs_config: ObservationConfig,
273
+ rank: int,
274
+ replay: ReplayBuffer,
275
+ tasks: List[str],
276
+ num_demos: int,
277
+ demo_augmentation: bool,
278
+ demo_augmentation_every_n: int,
279
+ cameras: List[str],
280
+ clip_model=None,
281
+ ):
282
+ manager = Manager()
283
+ store = manager.dict()
284
+
285
+ # create a MP dict for storing indicies
286
+ # TODO(mohit): this shouldn't be initialized here
287
+ del replay._task_idxs
288
+ task_idxs = manager.dict()
289
+ replay._task_idxs = task_idxs
290
+ replay._create_storage(store)
291
+ replay.add_count = Value("i", 0)
292
+
293
+ # fill replay buffer in parallel across tasks
294
+ max_parallel_processes = cfg.replay.max_parallel_processes
295
+ processes = []
296
+ n = np.arange(len(tasks))
297
+ split_n = utils.split_list(n, max_parallel_processes)
298
+ for split in split_n:
299
+ for e_idx, task_idx in enumerate(split):
300
+ task = tasks[int(task_idx)]
301
+ model_device = torch.device(
302
+ "cuda:%s" % (e_idx % torch.cuda.device_count())
303
+ if torch.cuda.is_available()
304
+ else "cpu"
305
+ )
306
+ p = Process(
307
+ target=fill_replay,
308
+ args=(
309
+ cfg,
310
+ obs_config,
311
+ rank,
312
+ replay,
313
+ task,
314
+ num_demos,
315
+ demo_augmentation,
316
+ demo_augmentation_every_n,
317
+ cameras,
318
+ clip_model,
319
+ model_device,
320
+ ),
321
+ )
322
+ p.start()
323
+ processes.append(p)
324
+
325
+ for p in processes:
326
+ p.join()
327
+
328
+ logging.debug("Replay filled with multi demos.")
329
+
330
+
331
+ def create_agent(cfg: DictConfig):
332
+ camera_name = cfg.rlbench.cameras
333
+ activation = cfg.method.activation
334
+ lr = cfg.method.lr
335
+ weight_decay = cfg.method.weight_decay
336
+ image_resolution = cfg.rlbench.camera_resolution
337
+ grad_clip = cfg.method.grad_clip
338
+
339
+ siamese_net = SiameseNet(
340
+ input_channels=[3, 3],
341
+ filters=[16],
342
+ kernel_sizes=[5],
343
+ strides=[1],
344
+ activation=activation,
345
+ norm=None,
346
+ )
347
+
348
+ actor_net = CNNLangAndFcsNet(
349
+ siamese_net=siamese_net,
350
+ input_resolution=image_resolution,
351
+ filters=[32, 64, 64],
352
+ kernel_sizes=[3, 3, 3],
353
+ strides=[2, 2, 2],
354
+ norm=None,
355
+ activation=activation,
356
+ fc_layers=[128, 64, 3 + 4 + 1],
357
+ low_dim_state_len=LOW_DIM_SIZE,
358
+ )
359
+
360
+ bc_agent = BCLangAgent(
361
+ actor_network=actor_net,
362
+ camera_name=camera_name,
363
+ lr=lr,
364
+ weight_decay=weight_decay,
365
+ grad_clip=grad_clip,
366
+ )
367
+
368
+ return PreprocessAgent(pose_agent=bc_agent)
external/peract_bimanual/agents/baselines/vit_bc_lang/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ import agents.baselines.vit_bc_lang.launch_utils
external/peract_bimanual/agents/baselines/vit_bc_lang/launch_utils.py ADDED
@@ -0,0 +1,372 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Adapted from ARM
2
+ # Source: https://github.com/stepjam/ARM
3
+ # License: https://github.com/stepjam/ARM/LICENSE
4
+
5
+ import logging
6
+ from typing import List
7
+
8
+ import numpy as np
9
+ from omegaconf import DictConfig
10
+ from rlbench.backend.observation import Observation
11
+ from rlbench.observation_config import ObservationConfig
12
+ import rlbench.utils as rlbench_utils
13
+ from rlbench.demo import Demo
14
+ from yarr.replay_buffer.prioritized_replay_buffer import (
15
+ PrioritizedReplayBuffer,
16
+ ObservationElement,
17
+ )
18
+ from yarr.replay_buffer.replay_buffer import ReplayElement, ReplayBuffer
19
+ from yarr.replay_buffer.uniform_replay_buffer import UniformReplayBuffer
20
+ from yarr.replay_buffer.task_uniform_replay_buffer import TaskUniformReplayBuffer
21
+
22
+ from helpers import demo_loading_utils, utils
23
+ from helpers import observation_utils
24
+ from agents.baselines.vit_bc_lang.vit_bc_lang_agent import ViTBCLangAgent
25
+ from helpers.custom_rlbench_env import CustomRLBenchEnv
26
+ from helpers.network_utils import ViTLangAndFcsNet, ViT
27
+ from helpers.preprocess_agent import PreprocessAgent
28
+
29
+ import torch
30
+ from torch.multiprocessing import Process, Value, Manager
31
+ from helpers.clip.core.clip import build_model, load_clip, tokenize
32
+
33
+ LOW_DIM_SIZE = 4
34
+
35
+
36
+ def create_replay(
37
+ batch_size: int,
38
+ timesteps: int,
39
+ prioritisation: bool,
40
+ task_uniform: bool,
41
+ save_dir: str,
42
+ cameras: list,
43
+ image_size=[128, 128],
44
+ replay_size=3e5,
45
+ ):
46
+ lang_feat_dim = 1024
47
+
48
+ # low_dim_state
49
+ observation_elements = []
50
+ observation_elements.append(
51
+ ObservationElement("low_dim_state", (LOW_DIM_SIZE,), np.float32)
52
+ )
53
+
54
+ # rgb, depth, point cloud, intrinsics, extrinsics
55
+ for cname in cameras:
56
+ observation_elements.append(
57
+ ObservationElement(
58
+ "%s_rgb" % cname,
59
+ (
60
+ 3,
61
+ *image_size,
62
+ ),
63
+ np.float32,
64
+ )
65
+ )
66
+ observation_elements.append(
67
+ ObservationElement("%s_point_cloud" % cname, (3, *image_size), np.float32)
68
+ ) # see pyrep/objects/vision_sensor.py on how pointclouds are extracted from depth frames
69
+ observation_elements.append(
70
+ ObservationElement(
71
+ "%s_camera_extrinsics" % cname,
72
+ (
73
+ 4,
74
+ 4,
75
+ ),
76
+ np.float32,
77
+ )
78
+ )
79
+ observation_elements.append(
80
+ ObservationElement(
81
+ "%s_camera_intrinsics" % cname,
82
+ (
83
+ 3,
84
+ 3,
85
+ ),
86
+ np.float32,
87
+ )
88
+ )
89
+
90
+ observation_elements.extend(
91
+ [
92
+ ReplayElement("lang_goal_emb", (lang_feat_dim,), np.float32),
93
+ ReplayElement("task", (), str),
94
+ ReplayElement(
95
+ "lang_goal", (1,), object
96
+ ), # language goal string for debugging and visualization
97
+ ]
98
+ )
99
+
100
+ extra_replay_elements = [
101
+ ReplayElement("demo", (), np.bool),
102
+ ]
103
+
104
+ replay_buffer = TaskUniformReplayBuffer(
105
+ save_dir=save_dir,
106
+ batch_size=batch_size,
107
+ timesteps=timesteps,
108
+ replay_capacity=int(replay_size),
109
+ action_shape=(8,),
110
+ action_dtype=np.float32,
111
+ reward_shape=(),
112
+ reward_dtype=np.float32,
113
+ update_horizon=1,
114
+ observation_elements=observation_elements,
115
+ extra_replay_elements=extra_replay_elements,
116
+ )
117
+ return replay_buffer
118
+
119
+
120
+ def _get_action(obs_tp1: Observation):
121
+ quat = utils.normalize_quaternion(obs_tp1.gripper_pose[3:])
122
+ if quat[-1] < 0:
123
+ quat = -quat
124
+ return np.concatenate(
125
+ [obs_tp1.gripper_pose[:3], quat, [float(obs_tp1.gripper_open)]]
126
+ )
127
+
128
+
129
+ def _add_keypoints_to_replay(
130
+ cfg: DictConfig,
131
+ task: str,
132
+ replay: ReplayBuffer,
133
+ inital_obs: Observation,
134
+ demo: Demo,
135
+ episode_keypoints: List[int],
136
+ cameras: List[str],
137
+ description: str = "",
138
+ clip_model=None,
139
+ device="cpu",
140
+ ):
141
+ prev_action = None
142
+ obs = inital_obs
143
+ all_actions = []
144
+ for k, keypoint in enumerate(episode_keypoints):
145
+ obs_tp1 = demo[keypoint]
146
+ action = _get_action(obs_tp1)
147
+ all_actions.append(action)
148
+ terminal = k == len(episode_keypoints) - 1
149
+ reward = float(terminal) if terminal else 0
150
+
151
+ obs_dict = observation_utils.extract_obs(
152
+ obs,
153
+ t=k,
154
+ prev_action=prev_action,
155
+ cameras=cameras,
156
+ episode_length=cfg.rlbench.episode_length,
157
+ robot_name=cfg.method.robot_name,
158
+ )
159
+ del obs_dict["ignore_collisions"]
160
+ tokens = tokenize([description]).numpy()
161
+ token_tensor = torch.from_numpy(tokens).to(device)
162
+ lang_feats, lang_embs = clip_model.encode_text_with_embeddings(token_tensor)
163
+ obs_dict["lang_goal_emb"] = lang_feats[0].float().detach().cpu().numpy()
164
+
165
+ final_obs = {
166
+ "task": task,
167
+ "lang_goal": np.array([description], dtype=object),
168
+ }
169
+
170
+ prev_action = np.copy(action)
171
+ others = {"demo": True}
172
+ others.update(final_obs)
173
+ others.update(obs_dict)
174
+ timeout = False
175
+ replay.add(action, reward, terminal, timeout, **others)
176
+ obs = obs_tp1 # Set the next obs
177
+ # Final step
178
+ obs_dict_tp1 = observation_utils.extract_obs(
179
+ obs_tp1,
180
+ t=k + 1,
181
+ prev_action=prev_action,
182
+ cameras=cameras,
183
+ episode_length=cfg.rlbench.episode_length,
184
+ robot_name=cfg.method.robot_name,
185
+ )
186
+ obs_dict_tp1["lang_goal_emb"] = lang_feats[0].float().detach().cpu().numpy()
187
+ # del obs_dict_tp1['lang_goal_tokens']
188
+ del obs_dict_tp1["ignore_collisions"]
189
+ # obs_dict_tp1['task'] = task
190
+ obs_dict_tp1.update(final_obs)
191
+ replay.add_final(**obs_dict_tp1)
192
+ return all_actions
193
+
194
+
195
+ def fill_replay(
196
+ cfg: DictConfig,
197
+ obs_config: ObservationConfig,
198
+ rank: int,
199
+ replay: ReplayBuffer,
200
+ task: str,
201
+ num_demos: int,
202
+ demo_augmentation: bool,
203
+ demo_augmentation_every_n: int,
204
+ cameras: List[str],
205
+ clip_model=None,
206
+ device="cpu",
207
+ ):
208
+ if clip_model is None:
209
+ model, _ = load_clip("RN50", jit=False, device=device)
210
+ clip_model = build_model(model.state_dict())
211
+ clip_model.to(device)
212
+ del model
213
+
214
+ logging.debug("Filling %s replay ..." % task)
215
+ all_actions = []
216
+ for d_idx in range(num_demos):
217
+ # load demo from disk
218
+ demo = rlbench_utils.get_stored_demos(
219
+ amount=1,
220
+ image_paths=False,
221
+ dataset_root=cfg.rlbench.demo_path,
222
+ variation_number=-1,
223
+ task_name=task,
224
+ obs_config=obs_config,
225
+ random_selection=False,
226
+ from_episode_number=d_idx,
227
+ )[0]
228
+
229
+ descs = demo._observations[0].misc["descriptions"]
230
+
231
+ # extract keypoints (a.k.a keyframes)
232
+ episode_keypoints = demo_loading_utils.keypoint_discovery(demo)
233
+
234
+ if rank == 0:
235
+ logging.info(
236
+ f"Loading Demo({d_idx}) - found {len(episode_keypoints)} keypoints - {task}"
237
+ )
238
+
239
+ for i in range(len(demo) - 1):
240
+ if not demo_augmentation and i > 0:
241
+ break
242
+ if i % demo_augmentation_every_n != 0:
243
+ continue
244
+
245
+ obs = demo[i]
246
+ desc = descs[0]
247
+ # if our starting point is past one of the keypoints, then remove it
248
+ while len(episode_keypoints) > 0 and i >= episode_keypoints[0]:
249
+ episode_keypoints = episode_keypoints[1:]
250
+ if len(episode_keypoints) == 0:
251
+ break
252
+ all_actions.extend(
253
+ _add_keypoints_to_replay(
254
+ cfg,
255
+ task,
256
+ replay,
257
+ obs,
258
+ demo,
259
+ episode_keypoints,
260
+ cameras,
261
+ description=desc,
262
+ clip_model=clip_model,
263
+ device=device,
264
+ )
265
+ )
266
+ logging.debug("Replay filled with demos.")
267
+ return all_actions
268
+
269
+
270
+ def fill_multi_task_replay(
271
+ cfg: DictConfig,
272
+ obs_config: ObservationConfig,
273
+ rank: int,
274
+ replay: ReplayBuffer,
275
+ tasks: List[str],
276
+ num_demos: int,
277
+ demo_augmentation: bool,
278
+ demo_augmentation_every_n: int,
279
+ cameras: List[str],
280
+ clip_model=None,
281
+ ):
282
+ manager = Manager()
283
+ store = manager.dict()
284
+
285
+ # create a MP dict for storing indicies
286
+ # TODO(mohit): this shouldn't be initialized here
287
+ del replay._task_idxs
288
+ task_idxs = manager.dict()
289
+ replay._task_idxs = task_idxs
290
+ replay._create_storage(store)
291
+ replay.add_count = Value("i", 0)
292
+
293
+ # fill replay buffer in parallel across tasks
294
+ max_parallel_processes = cfg.replay.max_parallel_processes
295
+ processes = []
296
+ n = np.arange(len(tasks))
297
+ split_n = utils.split_list(n, max_parallel_processes)
298
+ for split in split_n:
299
+ for e_idx, task_idx in enumerate(split):
300
+ task = tasks[int(task_idx)]
301
+ model_device = torch.device(
302
+ "cuda:%s" % (e_idx % torch.cuda.device_count())
303
+ if torch.cuda.is_available()
304
+ else "cpu"
305
+ )
306
+ p = Process(
307
+ target=fill_replay,
308
+ args=(
309
+ cfg,
310
+ obs_config,
311
+ rank,
312
+ replay,
313
+ task,
314
+ num_demos,
315
+ demo_augmentation,
316
+ demo_augmentation_every_n,
317
+ cameras,
318
+ clip_model,
319
+ model_device,
320
+ ),
321
+ )
322
+ p.start()
323
+ processes.append(p)
324
+
325
+ for p in processes:
326
+ p.join()
327
+
328
+ logging.debug("Replay filled with multi demos.")
329
+
330
+
331
+ def create_agent(cfg: DictConfig):
332
+ camera_name = cfg.rlbench.cameras
333
+ activation = cfg.method.activation
334
+ lr = cfg.method.lr
335
+ weight_decay = cfg.method.weight_decay
336
+ image_resolution = cfg.rlbench.camera_resolution
337
+ grad_clip = cfg.method.grad_clip
338
+
339
+ vit = ViT(
340
+ image_size=128,
341
+ patch_size=8,
342
+ num_classes=16,
343
+ dim=64,
344
+ depth=6,
345
+ heads=8,
346
+ mlp_dim=64,
347
+ dropout=0.1,
348
+ emb_dropout=0.1,
349
+ channels=6,
350
+ )
351
+
352
+ actor_net = ViTLangAndFcsNet(
353
+ vit=vit,
354
+ input_resolution=image_resolution,
355
+ filters=[64, 96, 128],
356
+ kernel_sizes=[1, 1, 1],
357
+ strides=[1, 1, 1],
358
+ norm=None,
359
+ activation=activation,
360
+ fc_layers=[128, 64, 3 + 4 + 1],
361
+ low_dim_state_len=LOW_DIM_SIZE,
362
+ )
363
+
364
+ bc_agent = ViTBCLangAgent(
365
+ actor_network=actor_net,
366
+ camera_name=camera_name,
367
+ lr=lr,
368
+ weight_decay=weight_decay,
369
+ grad_clip=grad_clip,
370
+ )
371
+
372
+ return PreprocessAgent(pose_agent=bc_agent)
external/peract_bimanual/agents/baselines/vit_bc_lang/vit_bc_lang_agent.py ADDED
@@ -0,0 +1,148 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import copy
2
+ import logging
3
+ import os
4
+ from typing import List
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+ import torch.nn.functional as F
9
+ from yarr.agents.agent import Agent, Summary, ActResult, ScalarSummary, HistogramSummary
10
+
11
+ from helpers import utils
12
+ from helpers.utils import stack_on_channel
13
+
14
+ from helpers.clip.core.clip import build_model, load_clip
15
+
16
+ NAME = "ViTBCLangAgent"
17
+ REPLAY_ALPHA = 0.7
18
+ REPLAY_BETA = 1.0
19
+
20
+
21
+ class Actor(nn.Module):
22
+ def __init__(self, actor_network: nn.Module):
23
+ super(Actor, self).__init__()
24
+ self._actor_network = copy.deepcopy(actor_network)
25
+ self._actor_network.build()
26
+
27
+ def forward(self, observations, robot_state, lang_goal_emb):
28
+ mu = self._actor_network(observations, robot_state, lang_goal_emb)
29
+ return mu
30
+
31
+
32
+ class ViTBCLangAgent(Agent):
33
+ def __init__(
34
+ self,
35
+ actor_network: nn.Module,
36
+ camera_name: str,
37
+ lr: float = 0.01,
38
+ weight_decay: float = 1e-5,
39
+ grad_clip: float = 20.0,
40
+ ):
41
+ self._camera_name = camera_name
42
+ self._actor_network = actor_network
43
+ self._lr = lr
44
+ self._weight_decay = weight_decay
45
+ self._grad_clip = grad_clip
46
+
47
+ def build(self, training: bool, device: torch.device = None):
48
+ if device is None:
49
+ device = torch.device("cpu")
50
+ self._actor = Actor(self._actor_network).to(device).train(training)
51
+ if training:
52
+ self._actor_optimizer = torch.optim.Adam(
53
+ self._actor.parameters(), lr=self._lr, weight_decay=self._weight_decay
54
+ )
55
+ logging.info(
56
+ "# Actor Params: %d"
57
+ % sum(p.numel() for p in self._actor.parameters() if p.requires_grad)
58
+ )
59
+ else:
60
+ for p in self._actor.parameters():
61
+ p.requires_grad = False
62
+
63
+ model, _ = load_clip("RN50", jit=False)
64
+ self._clip_rn50 = build_model(model.state_dict())
65
+ self._clip_rn50 = self._clip_rn50.float().to(device)
66
+ self._clip_rn50.eval()
67
+ del model
68
+
69
+ self._device = device
70
+
71
+ def _grad_step(self, loss, opt, model_params=None, clip=None):
72
+ opt.zero_grad()
73
+ loss.backward()
74
+ if clip is not None and model_params is not None:
75
+ nn.utils.clip_grad_value_(model_params, clip)
76
+ opt.step()
77
+
78
+ def update(self, step: int, replay_sample: dict) -> dict:
79
+ lang_goal_emb = replay_sample["lang_goal_emb"]
80
+ robot_state = replay_sample["low_dim_state"]
81
+ observations = [
82
+ replay_sample["%s_rgb" % self._camera_name],
83
+ replay_sample["%s_point_cloud" % self._camera_name],
84
+ ]
85
+ mu = self._actor(observations, robot_state, lang_goal_emb)
86
+ loss_weights = utils.loss_weights(replay_sample, REPLAY_BETA)
87
+ delta = F.mse_loss(mu, replay_sample["action"], reduction="none").mean(1)
88
+ loss = (delta * loss_weights).mean()
89
+ self._grad_step(
90
+ loss, self._actor_optimizer, self._actor.parameters(), self._grad_clip
91
+ )
92
+ self._summaries = {
93
+ "pi/loss": loss,
94
+ "pi/mu": mu.mean(),
95
+ }
96
+ return {"total_losses": loss}
97
+
98
+ def _normalize_quat(self, x):
99
+ return x / x.square().sum(dim=1).sqrt().unsqueeze(-1)
100
+
101
+ def act(self, step: int, observation: dict, deterministic=False) -> ActResult:
102
+ lang_goal_tokens = observation.get("lang_goal_tokens", None).long()
103
+
104
+ with torch.no_grad():
105
+ lang_goal_tokens = lang_goal_tokens.to(device=self._device)
106
+ lang_goal_emb, _ = self._clip_rn50.encode_text_with_embeddings(
107
+ lang_goal_tokens[0]
108
+ )
109
+ lang_goal_emb = lang_goal_emb.to(device=self._device)
110
+
111
+ observations = [
112
+ observation["%s_rgb" % self._camera_name][0].to(self._device),
113
+ observation["%s_point_cloud" % self._camera_name][0].to(self._device),
114
+ ]
115
+ robot_state = observation["low_dim_state"][0].to(self._device)
116
+
117
+ mu = self._actor(observations, robot_state, lang_goal_emb)
118
+ mu = torch.cat([mu[:, :3], self._normalize_quat(mu[:, 3:7]), mu[:, 7:]], dim=-1)
119
+ ignore_collisions = torch.Tensor([1.0]).to(mu.device)
120
+ mu0 = torch.cat([mu[0], ignore_collisions])
121
+ return ActResult(mu0.detach().cpu())
122
+
123
+ def update_summaries(self) -> List[Summary]:
124
+ summaries = []
125
+ for n, v in self._summaries.items():
126
+ summaries.append(ScalarSummary("%s/%s" % (NAME, n), v))
127
+
128
+ for tag, param in self._actor.named_parameters():
129
+ summaries.append(
130
+ HistogramSummary("%s/gradient/%s" % (NAME, tag), param.grad)
131
+ )
132
+ summaries.append(HistogramSummary("%s/weight/%s" % (NAME, tag), param.data))
133
+
134
+ return summaries
135
+
136
+ def act_summaries(self) -> List[Summary]:
137
+ return []
138
+
139
+ def load_weights(self, savedir: str):
140
+ self._actor.load_state_dict(
141
+ torch.load(
142
+ os.path.join(savedir, "bc_actor.pt"), map_location=torch.device("cpu")
143
+ )
144
+ )
145
+ print("Loaded weights from %s" % savedir)
146
+
147
+ def save_weights(self, savedir: str):
148
+ torch.save(self._actor.state_dict(), os.path.join(savedir, "bc_actor.pt"))
external/peract_bimanual/agents/bimanual_peract/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ import agents.bimanual_peract.launch_utils
external/peract_bimanual/agents/bimanual_peract/launch_utils.py ADDED
@@ -0,0 +1,93 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Adapted from ARM
2
+ # Source: https://github.com/stepjam/ARM
3
+ # License: https://github.com/stepjam/ARM/LICENSE
4
+
5
+
6
+ from helpers.preprocess_agent import PreprocessAgent
7
+
8
+ from agents.bimanual_peract.perceiver_lang_io import PerceiverVoxelLangEncoder
9
+ from agents.bimanual_peract.qattention_peract_bc_agent import QAttentionPerActBCAgent
10
+ from agents.bimanual_peract.qattention_stack_agent import QAttentionStackAgent
11
+
12
+ from omegaconf import DictConfig
13
+
14
+
15
+ def create_agent(cfg: DictConfig):
16
+ depth_0bounds = cfg.rlbench.scene_bounds
17
+ cam_resolution = cfg.rlbench.camera_resolution
18
+
19
+ num_rotation_classes = int(360.0 // cfg.method.rotation_resolution)
20
+ qattention_agents = []
21
+ for depth, vox_size in enumerate(cfg.method.voxel_sizes):
22
+ last = depth == len(cfg.method.voxel_sizes) - 1
23
+ perceiver_encoder = PerceiverVoxelLangEncoder(
24
+ depth=cfg.method.transformer_depth,
25
+ iterations=cfg.method.transformer_iterations,
26
+ voxel_size=vox_size,
27
+ initial_dim=3 + 3 + 1 + 3,
28
+ low_dim_size=cfg.method.low_dim_size,
29
+ layer=depth,
30
+ num_rotation_classes=num_rotation_classes if last else 0,
31
+ num_grip_classes=2 if last else 0,
32
+ num_collision_classes=2 if last else 0,
33
+ input_axis=3,
34
+ num_latents=cfg.method.num_latents,
35
+ latent_dim=cfg.method.latent_dim,
36
+ cross_heads=cfg.method.cross_heads,
37
+ latent_heads=cfg.method.latent_heads,
38
+ cross_dim_head=cfg.method.cross_dim_head,
39
+ latent_dim_head=cfg.method.latent_dim_head,
40
+ weight_tie_layers=False,
41
+ activation=cfg.method.activation,
42
+ pos_encoding_with_lang=cfg.method.pos_encoding_with_lang,
43
+ input_dropout=cfg.method.input_dropout,
44
+ attn_dropout=cfg.method.attn_dropout,
45
+ decoder_dropout=cfg.method.decoder_dropout,
46
+ lang_fusion_type=cfg.method.lang_fusion_type,
47
+ voxel_patch_size=cfg.method.voxel_patch_size,
48
+ voxel_patch_stride=cfg.method.voxel_patch_stride,
49
+ no_skip_connection=cfg.method.no_skip_connection,
50
+ no_perceiver=cfg.method.no_perceiver,
51
+ no_language=cfg.method.no_language,
52
+ final_dim=cfg.method.final_dim,
53
+ )
54
+
55
+ qattention_agent = QAttentionPerActBCAgent(
56
+ layer=depth,
57
+ coordinate_bounds=depth_0bounds,
58
+ perceiver_encoder=perceiver_encoder,
59
+ camera_names=cfg.rlbench.cameras,
60
+ voxel_size=vox_size,
61
+ bounds_offset=cfg.method.bounds_offset[depth - 1] if depth > 0 else None,
62
+ image_crop_size=cfg.method.image_crop_size,
63
+ lr=cfg.method.lr,
64
+ training_iterations=cfg.framework.training_iterations,
65
+ lr_scheduler=cfg.method.lr_scheduler,
66
+ num_warmup_steps=cfg.method.num_warmup_steps,
67
+ trans_loss_weight=cfg.method.trans_loss_weight,
68
+ rot_loss_weight=cfg.method.rot_loss_weight,
69
+ grip_loss_weight=cfg.method.grip_loss_weight,
70
+ collision_loss_weight=cfg.method.collision_loss_weight,
71
+ include_low_dim_state=True,
72
+ image_resolution=cam_resolution,
73
+ batch_size=cfg.replay.batch_size,
74
+ voxel_feature_size=3,
75
+ lambda_weight_l2=cfg.method.lambda_weight_l2,
76
+ num_rotation_classes=num_rotation_classes,
77
+ rotation_resolution=cfg.method.rotation_resolution,
78
+ transform_augmentation=cfg.method.transform_augmentation.apply_se3,
79
+ transform_augmentation_xyz=cfg.method.transform_augmentation.aug_xyz,
80
+ transform_augmentation_rpy=cfg.method.transform_augmentation.aug_rpy,
81
+ transform_augmentation_rot_resolution=cfg.method.transform_augmentation.aug_rot_resolution,
82
+ optimizer_type=cfg.method.optimizer,
83
+ num_devices=cfg.ddp.num_devices,
84
+ )
85
+ qattention_agents.append(qattention_agent)
86
+
87
+ rotation_agent = QAttentionStackAgent(
88
+ qattention_agents=qattention_agents,
89
+ rotation_resolution=cfg.method.rotation_resolution,
90
+ camera_names=cfg.rlbench.cameras,
91
+ )
92
+ preprocess_agent = PreprocessAgent(pose_agent=rotation_agent)
93
+ return preprocess_agent
external/peract_bimanual/agents/bimanual_peract/perceiver_lang_io.py ADDED
@@ -0,0 +1,549 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Perceiver IO implementation adpated for manipulation
2
+ # Source: https://github.com/lucidrains/perceiver-pytorch
3
+ # License: https://github.com/lucidrains/perceiver-pytorch/blob/main/LICENSE
4
+
5
+ import torch
6
+ from torch import nn
7
+
8
+ from einops import rearrange
9
+ from einops import repeat
10
+
11
+ from perceiver_pytorch.perceiver_pytorch import cache_fn
12
+ from perceiver_pytorch.perceiver_pytorch import PreNorm, FeedForward, Attention
13
+
14
+ from helpers.network_utils import (
15
+ DenseBlock,
16
+ SpatialSoftmax3D,
17
+ Conv3DBlock,
18
+ Conv3DUpsampleBlock,
19
+ )
20
+
21
+
22
+ # PerceiverIO adapted for 6-DoF manipulation
23
+ class PerceiverVoxelLangEncoder(nn.Module):
24
+ def __init__(
25
+ self,
26
+ depth, # number of self-attention layers
27
+ iterations, # number cross-attention iterations (PerceiverIO uses just 1)
28
+ voxel_size, # N voxels per side (size: N*N*N)
29
+ initial_dim, # 10 dimensions - dimension of the input sequence to be encoded
30
+ low_dim_size, # 4 dimensions - proprioception: {gripper_open, left_finger, right_finger, timestep}
31
+ layer=0,
32
+ num_rotation_classes=72, # 5 degree increments (5*72=360) for each of the 3-axis
33
+ num_grip_classes=2, # open or not open
34
+ num_collision_classes=2, # collisions allowed or not allowed
35
+ input_axis=3, # 3D tensors have 3 axes
36
+ num_latents=512, # number of latent vectors
37
+ im_channels=64, # intermediate channel size
38
+ latent_dim=512, # dimensions of latent vectors
39
+ cross_heads=1, # number of cross-attention heads
40
+ latent_heads=8, # number of latent heads
41
+ cross_dim_head=64,
42
+ latent_dim_head=64,
43
+ activation="relu",
44
+ weight_tie_layers=False,
45
+ pos_encoding_with_lang=True,
46
+ input_dropout=0.1,
47
+ attn_dropout=0.1,
48
+ decoder_dropout=0.0,
49
+ lang_fusion_type="seq",
50
+ voxel_patch_size=9,
51
+ voxel_patch_stride=8,
52
+ no_skip_connection=False,
53
+ no_perceiver=False,
54
+ no_language=False,
55
+ final_dim=64,
56
+ ):
57
+ super().__init__()
58
+ self.depth = depth
59
+ self.layer = layer
60
+ self.init_dim = int(initial_dim)
61
+ self.iterations = iterations
62
+ self.input_axis = input_axis
63
+ self.voxel_size = voxel_size
64
+ self.low_dim_size = low_dim_size
65
+ self.im_channels = im_channels
66
+ self.pos_encoding_with_lang = pos_encoding_with_lang
67
+ self.lang_fusion_type = lang_fusion_type
68
+ self.voxel_patch_size = voxel_patch_size
69
+ self.voxel_patch_stride = voxel_patch_stride
70
+ self.num_rotation_classes = num_rotation_classes
71
+ self.num_grip_classes = num_grip_classes
72
+ self.num_collision_classes = num_collision_classes
73
+ self.final_dim = final_dim
74
+ self.input_dropout = input_dropout
75
+ self.attn_dropout = attn_dropout
76
+ self.decoder_dropout = decoder_dropout
77
+ self.no_skip_connection = no_skip_connection
78
+ self.no_perceiver = no_perceiver
79
+ self.no_language = no_language
80
+
81
+ # patchified input dimensions
82
+ spatial_size = voxel_size // self.voxel_patch_stride # 100/5 = 20
83
+
84
+ # 64 voxel features + 64 proprio features (+ 64 lang goal features if concattenated)
85
+ self.input_dim_before_seq = (
86
+ self.im_channels * 3
87
+ if self.lang_fusion_type == "concat"
88
+ else self.im_channels * 2
89
+ )
90
+
91
+ # CLIP language feature dimensions
92
+ lang_feat_dim, lang_emb_dim, lang_max_seq_len = 1024, 512, 77
93
+
94
+ # learnable positional encoding
95
+ if self.pos_encoding_with_lang:
96
+ self.pos_encoding = nn.Parameter(
97
+ torch.randn(
98
+ 1, lang_max_seq_len + spatial_size**3, self.input_dim_before_seq
99
+ )
100
+ )
101
+ else:
102
+ # assert self.lang_fusion_type == 'concat', 'Only concat is supported for pos encoding without lang.'
103
+ self.pos_encoding = nn.Parameter(
104
+ torch.randn(
105
+ 1,
106
+ spatial_size,
107
+ spatial_size,
108
+ spatial_size,
109
+ self.input_dim_before_seq,
110
+ )
111
+ )
112
+
113
+ # voxel input preprocessing 1x1 conv encoder
114
+ self.input_preprocess = Conv3DBlock(
115
+ self.init_dim,
116
+ self.im_channels,
117
+ kernel_sizes=1,
118
+ strides=1,
119
+ norm=None,
120
+ activation=activation,
121
+ )
122
+
123
+ # patchify conv
124
+ self.patchify = Conv3DBlock(
125
+ self.input_preprocess.out_channels,
126
+ self.im_channels,
127
+ kernel_sizes=self.voxel_patch_size,
128
+ strides=self.voxel_patch_stride,
129
+ norm=None,
130
+ activation=activation,
131
+ )
132
+
133
+ # language preprocess
134
+ if self.lang_fusion_type == "concat":
135
+ self.lang_preprocess = nn.Linear(lang_feat_dim, self.im_channels)
136
+ elif self.lang_fusion_type == "seq":
137
+ self.lang_preprocess = nn.Linear(lang_emb_dim, self.im_channels * 2)
138
+
139
+ # proprioception
140
+ if self.low_dim_size > 0:
141
+ self.proprio_preprocess = DenseBlock(
142
+ self.low_dim_size,
143
+ self.im_channels,
144
+ norm=None,
145
+ activation=activation,
146
+ )
147
+
148
+ # pooling functions
149
+ self.local_maxp = nn.MaxPool3d(3, 2, padding=1)
150
+ self.global_maxp = nn.AdaptiveMaxPool3d(1)
151
+
152
+ # 1st 3D softmax
153
+ self.ss0 = SpatialSoftmax3D(
154
+ self.voxel_size, self.voxel_size, self.voxel_size, self.im_channels
155
+ )
156
+ flat_size = self.im_channels * 4
157
+
158
+ # latent vectors (that are randomly initialized)
159
+ self.latents = nn.Parameter(torch.randn(num_latents, latent_dim))
160
+
161
+ # encoder cross attention
162
+ self.cross_attend_blocks = nn.ModuleList(
163
+ [
164
+ PreNorm(
165
+ latent_dim,
166
+ Attention(
167
+ latent_dim,
168
+ self.input_dim_before_seq,
169
+ heads=cross_heads,
170
+ dim_head=cross_dim_head,
171
+ dropout=input_dropout,
172
+ ),
173
+ context_dim=self.input_dim_before_seq,
174
+ ),
175
+ PreNorm(latent_dim, FeedForward(latent_dim)),
176
+ PreNorm(latent_dim, FeedForward(latent_dim)),
177
+ ]
178
+ )
179
+
180
+ get_latent_attn = lambda: PreNorm(
181
+ latent_dim,
182
+ Attention(
183
+ latent_dim,
184
+ heads=latent_heads,
185
+ dim_head=latent_dim_head,
186
+ dropout=attn_dropout,
187
+ ),
188
+ )
189
+ get_latent_ff = lambda: PreNorm(latent_dim, FeedForward(latent_dim))
190
+ get_latent_attn, get_latent_ff = map(cache_fn, (get_latent_attn, get_latent_ff))
191
+
192
+ # self attention layers
193
+ self.layers = nn.ModuleList([])
194
+ cache_args = {"_cache": weight_tie_layers}
195
+
196
+ for i in range(depth):
197
+ self.layers.append(
198
+ nn.ModuleList(
199
+ [
200
+ get_latent_attn(**cache_args),
201
+ get_latent_ff(**cache_args),
202
+ get_latent_attn(**cache_args),
203
+ get_latent_ff(**cache_args),
204
+ ]
205
+ )
206
+ )
207
+
208
+ self.combined_latent_attn = get_latent_attn(**cache_args)
209
+ self.combined_latent_ff = get_latent_ff(**cache_args)
210
+
211
+ # decoder cross attention
212
+ self.decoder_cross_attn_right = PreNorm(
213
+ self.input_dim_before_seq,
214
+ Attention(
215
+ self.input_dim_before_seq,
216
+ latent_dim,
217
+ heads=cross_heads,
218
+ dim_head=cross_dim_head,
219
+ dropout=decoder_dropout,
220
+ ),
221
+ context_dim=latent_dim,
222
+ )
223
+
224
+ self.decoder_cross_attn_left = PreNorm(
225
+ self.input_dim_before_seq,
226
+ Attention(
227
+ self.input_dim_before_seq,
228
+ latent_dim,
229
+ heads=cross_heads,
230
+ dim_head=cross_dim_head,
231
+ dropout=decoder_dropout,
232
+ ),
233
+ context_dim=latent_dim,
234
+ )
235
+
236
+ # upsample conv
237
+ self.up0 = Conv3DUpsampleBlock(
238
+ self.input_dim_before_seq,
239
+ self.final_dim,
240
+ kernel_sizes=self.voxel_patch_size,
241
+ strides=self.voxel_patch_stride,
242
+ norm=None,
243
+ activation=activation,
244
+ )
245
+
246
+ # 2nd 3D softmax
247
+ self.ss1 = SpatialSoftmax3D(
248
+ spatial_size, spatial_size, spatial_size, self.input_dim_before_seq
249
+ )
250
+
251
+ flat_size += self.input_dim_before_seq * 4
252
+
253
+ # final 3D softmax
254
+ self.final = Conv3DBlock(
255
+ self.im_channels
256
+ if (self.no_perceiver or self.no_skip_connection)
257
+ else self.im_channels * 2,
258
+ self.im_channels,
259
+ kernel_sizes=3,
260
+ strides=1,
261
+ norm=None,
262
+ activation=activation,
263
+ )
264
+
265
+ self.right_trans_decoder = Conv3DBlock(
266
+ self.final_dim,
267
+ 1,
268
+ kernel_sizes=3,
269
+ strides=1,
270
+ norm=None,
271
+ activation=None,
272
+ )
273
+
274
+ self.left_trans_decoder = Conv3DBlock(
275
+ self.final_dim,
276
+ 1,
277
+ kernel_sizes=3,
278
+ strides=1,
279
+ norm=None,
280
+ activation=None,
281
+ )
282
+
283
+ # rotation, gripper, and collision MLP layers
284
+ if self.num_rotation_classes > 0:
285
+ self.ss_final = SpatialSoftmax3D(
286
+ self.voxel_size, self.voxel_size, self.voxel_size, self.im_channels
287
+ )
288
+
289
+ flat_size += self.im_channels * 4
290
+
291
+ self.right_dense0 = DenseBlock(flat_size, 256, None, activation)
292
+ self.right_dense1 = DenseBlock(256, self.final_dim, None, activation)
293
+
294
+ self.left_dense0 = DenseBlock(flat_size, 256, None, activation)
295
+ self.left_dense1 = DenseBlock(256, self.final_dim, None, activation)
296
+
297
+ self.right_rot_grip_collision_ff = DenseBlock(
298
+ self.final_dim,
299
+ self.num_rotation_classes * 3
300
+ + self.num_grip_classes
301
+ + self.num_collision_classes,
302
+ None,
303
+ None,
304
+ )
305
+
306
+ self.left_rot_grip_collision_ff = DenseBlock(
307
+ self.final_dim,
308
+ self.num_rotation_classes * 3
309
+ + self.num_grip_classes
310
+ + self.num_collision_classes,
311
+ None,
312
+ None,
313
+ )
314
+
315
+ def encode_text(self, x):
316
+ with torch.no_grad():
317
+ text_feat, text_emb = self._clip_rn50.encode_text_with_embeddings(x)
318
+
319
+ text_feat = text_feat.detach()
320
+ text_emb = text_emb.detach()
321
+ text_mask = torch.where(x == 0, x, 1) # [1, max_token_len]
322
+ return text_feat, text_emb
323
+
324
+ def forward(
325
+ self,
326
+ ins,
327
+ proprio,
328
+ lang_goal_emb,
329
+ lang_token_embs,
330
+ prev_layer_voxel_grid,
331
+ bounds,
332
+ prev_layer_bounds,
333
+ mask=None,
334
+ ):
335
+ # preprocess input
336
+ d0 = self.input_preprocess(ins) # [B,10,100,100,100] -> [B,64,100,100,100]
337
+
338
+ # aggregated features from 1st softmax and maxpool for MLP decoders
339
+ feats = [self.ss0(d0.contiguous()), self.global_maxp(d0).view(ins.shape[0], -1)]
340
+
341
+ # patchify input (5x5x5 patches)
342
+ ins = self.patchify(d0) # [B,64,100,100,100] -> [B,64,20,20,20]
343
+
344
+ b, c, d, h, w, device = *ins.shape, ins.device
345
+ axis = [d, h, w]
346
+ assert (
347
+ len(axis) == self.input_axis
348
+ ), "input must have the same number of axis as input_axis"
349
+
350
+ # concat proprio
351
+ if self.low_dim_size > 0:
352
+ p = self.proprio_preprocess(proprio) # [B,4] -> [B,64]
353
+ p = p.unsqueeze(-1).unsqueeze(-1).unsqueeze(-1).repeat(1, 1, d, h, w)
354
+ ins = torch.cat([ins, p], dim=1) # [B,128,20,20,20]
355
+
356
+ # language ablation
357
+ if self.no_language:
358
+ lang_goal_emb = torch.zeros_like(lang_goal_emb)
359
+ lang_token_embs = torch.zeros_like(lang_token_embs)
360
+
361
+ # option 1: tile and concat lang goal to input
362
+ if self.lang_fusion_type == "concat":
363
+ lang_emb = lang_goal_emb
364
+ lang_emb = lang_emb.to(dtype=ins.dtype)
365
+ l = self.lang_preprocess(lang_emb)
366
+ l = l.unsqueeze(-1).unsqueeze(-1).unsqueeze(-1).repeat(1, 1, d, h, w)
367
+ ins = torch.cat([ins, l], dim=1)
368
+
369
+ # channel last
370
+ ins = rearrange(ins, "b d ... -> b ... d") # [B,20,20,20,128]
371
+
372
+ # add pos encoding to grid
373
+ if not self.pos_encoding_with_lang:
374
+ ins = ins + self.pos_encoding
375
+
376
+ ######################## NOTE #############################
377
+ # NOTE: If you add positional encodings ^here the lang embs
378
+ # won't have positional encodings. I accidently forgot
379
+ # to turn this off for all the experiments in the paper.
380
+ # So I guess those models were using language embs
381
+ # as a bag of words :( But it doesn't matter much for
382
+ # RLBench tasks since we don't test for novel instructions
383
+ # at test time anyway. The recommend way is to add
384
+ # positional encodings to the final input sequence
385
+ # fed into the Perceiver Transformer, as done below
386
+ # (and also in the Colab tutorial).
387
+ ###########################################################
388
+
389
+ # concat to channels of and flatten axis
390
+ queries_orig_shape = ins.shape
391
+
392
+ # rearrange input to be channel last
393
+ ins = rearrange(ins, "b ... d -> b (...) d") # [B,8000,128]
394
+ ins_wo_prev_layers = ins
395
+
396
+ # option 2: add lang token embs as a sequence
397
+ if self.lang_fusion_type == "seq":
398
+ l = self.lang_preprocess(lang_token_embs) # [B,77,1024] -> [B,77,128]
399
+ ins = torch.cat((l, ins), dim=1) # [B,8077,128]
400
+
401
+ # add pos encoding to language + flattened grid (the recommended way)
402
+ if self.pos_encoding_with_lang:
403
+ ins = ins + self.pos_encoding
404
+
405
+ # batchify latents
406
+ x = repeat(self.latents, "n d -> b n d", b=b)
407
+
408
+ cross_attn, cross_ff_right, cross_ff_left = self.cross_attend_blocks
409
+
410
+ for it in range(self.iterations):
411
+ # encoder cross attention
412
+ x = cross_attn(x, context=ins, mask=mask) + x
413
+
414
+ # x.size() = [1, num_latents, latent_dim]
415
+ x_right, x_left = x.chunk(2, dim=1)
416
+
417
+ x_right = cross_ff_right(x_right) + x_right
418
+ x_left = cross_ff_left(x_left) + x_left
419
+
420
+ # self-attention layers
421
+ for (
422
+ self_attn_right,
423
+ self_ff_right,
424
+ self_attn_left,
425
+ self_ff_left,
426
+ ) in self.layers:
427
+ x_right = self_attn_right(x_right) + x_right
428
+ x_right = self_ff_right(x_right) + x_right
429
+
430
+ x_left = self_attn_left(x_left) + x_left
431
+ x_left = self_ff_left(x_left) + x_left
432
+
433
+ x = torch.concat([x_right, x_left], dim=1)
434
+ x = self.combined_latent_attn(x) + x
435
+ x = self.combined_latent_ff(x) + x
436
+
437
+ x_right, x_left = x.chunk(2, dim=1)
438
+
439
+ # decoder cross attention
440
+ latents_right = self.decoder_cross_attn_right(ins, context=x_right)
441
+ latents_left = self.decoder_cross_attn_left(ins, context=x_left)
442
+
443
+ # crop out the language part of the output sequence
444
+ if self.lang_fusion_type == "seq":
445
+ latents_right = latents_right[:, l.shape[1] :]
446
+ latents_left = latents_left[:, l.shape[1] :]
447
+
448
+ # reshape back to voxel grid
449
+ latents_right = latents_right.view(
450
+ b, *queries_orig_shape[1:-1], latents_right.shape[-1]
451
+ ) # [B,20,20,20,64]
452
+ latents_right = rearrange(
453
+ latents_right, "b ... d -> b d ..."
454
+ ) # [B,64,20,20,20]
455
+
456
+ # reshape back to voxel grid
457
+ latents_left = latents_left.view(
458
+ b, *queries_orig_shape[1:-1], latents_left.shape[-1]
459
+ ) # [B,20,20,20,64]
460
+ latents_left = rearrange(latents_left, "b ... d -> b d ...") # [B,64,20,20,20]
461
+
462
+ # aggregated features from 2nd softmax and maxpool for MLP decoders
463
+
464
+ feats_right = feats.copy()
465
+ feats_left = feats
466
+
467
+ feats_right.extend(
468
+ [
469
+ self.ss1(latents_right.contiguous()),
470
+ self.global_maxp(latents_right).view(b, -1),
471
+ ]
472
+ )
473
+ feats_left.extend(
474
+ [
475
+ self.ss1(latents_left.contiguous()),
476
+ self.global_maxp(latents_left).view(b, -1),
477
+ ]
478
+ )
479
+
480
+ # upsample
481
+ u0_right = self.up0(latents_right)
482
+ u0_left = self.up0(latents_left)
483
+
484
+ # ablations
485
+ if self.no_skip_connection:
486
+ u_right = self.final(u0_right)
487
+ u_left = self.final(u0_left)
488
+ elif self.no_perceiver:
489
+ u_right = self.final(d0)
490
+ u_left = self.final(d0)
491
+ else:
492
+ u_right = self.final(torch.cat([d0, u0_right], dim=1))
493
+ u_left = self.final(torch.cat([d0, u0_left], dim=1))
494
+
495
+ # translation decoder
496
+ right_trans = self.right_trans_decoder(u_right)
497
+ left_trans = self.left_trans_decoder(u_left)
498
+
499
+ # rotation, gripper, and collision MLPs
500
+ rot_and_grip_out = None
501
+ if self.num_rotation_classes > 0:
502
+ feats_right.extend(
503
+ [
504
+ self.ss_final(u_right.contiguous()),
505
+ self.global_maxp(u_right).view(b, -1),
506
+ ]
507
+ )
508
+
509
+ right_dense0 = self.right_dense0(torch.cat(feats_right, dim=1))
510
+ right_dense1 = self.right_dense1(right_dense0) # [B,72*3+2+2]
511
+
512
+ right_rot_and_grip_collision_out = self.right_rot_grip_collision_ff(
513
+ right_dense1
514
+ )
515
+ right_rot_and_grip_out = right_rot_and_grip_collision_out[
516
+ :, : -self.num_collision_classes
517
+ ]
518
+ right_collision_out = right_rot_and_grip_collision_out[
519
+ :, -self.num_collision_classes :
520
+ ]
521
+
522
+ feats_left.extend(
523
+ [
524
+ self.ss_final(u_left.contiguous()),
525
+ self.global_maxp(u_left).view(b, -1),
526
+ ]
527
+ )
528
+
529
+ left_dense0 = self.left_dense0(torch.cat(feats_left, dim=1))
530
+ left_dense1 = self.left_dense1(left_dense0) # [B,72*3+2+2]
531
+
532
+ left_rot_and_grip_collision_out = self.left_rot_grip_collision_ff(
533
+ left_dense1
534
+ )
535
+ left_rot_and_grip_out = left_rot_and_grip_collision_out[
536
+ :, : -self.num_collision_classes
537
+ ]
538
+ left_collision_out = left_rot_and_grip_collision_out[
539
+ :, -self.num_collision_classes :
540
+ ]
541
+
542
+ return (
543
+ right_trans,
544
+ right_rot_and_grip_out,
545
+ right_collision_out,
546
+ left_trans,
547
+ left_rot_and_grip_out,
548
+ left_collision_out,
549
+ )
external/peract_bimanual/agents/bimanual_peract/qattention_peract_bc_agent.py ADDED
@@ -0,0 +1,1063 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import copy
2
+ import logging
3
+ import os
4
+ from typing import List
5
+
6
+ import numpy as np
7
+ import torch
8
+ import torch.nn as nn
9
+ import torch.nn.functional as F
10
+ from torchvision import transforms
11
+ from pytorch3d import transforms as torch3d_tf
12
+ from yarr.agents.agent import (
13
+ Agent,
14
+ ActResult,
15
+ ScalarSummary,
16
+ HistogramSummary,
17
+ ImageSummary,
18
+ Summary,
19
+ )
20
+
21
+ from helpers import utils
22
+ from helpers.utils import visualise_voxel, stack_on_channel
23
+ from voxel.voxel_grid import VoxelGrid
24
+ from voxel.augmentation import apply_se3_augmentation
25
+ from einops import rearrange
26
+ from helpers.clip.core.clip import build_model, load_clip
27
+
28
+ import transformers
29
+ from helpers.optim.lamb import Lamb
30
+
31
+ from torch.nn.parallel import DistributedDataParallel as DDP
32
+
33
+ NAME = "QAttentionAgent"
34
+
35
+
36
+ class QFunction(nn.Module):
37
+ def __init__(
38
+ self,
39
+ perceiver_encoder: nn.Module,
40
+ voxelizer: VoxelGrid,
41
+ bounds_offset: float,
42
+ rotation_resolution: float,
43
+ device,
44
+ training,
45
+ ):
46
+ super(QFunction, self).__init__()
47
+ self._rotation_resolution = rotation_resolution
48
+ self._voxelizer = voxelizer
49
+ self._bounds_offset = bounds_offset
50
+ self._qnet = perceiver_encoder.to(device)
51
+
52
+ # distributed training
53
+ if training:
54
+ self._qnet = DDP(self._qnet, device_ids=[device])
55
+
56
+ def _argmax_3d(self, tensor_orig):
57
+ b, c, d, h, w = tensor_orig.shape # c will be one
58
+ idxs = tensor_orig.view(b, c, -1).argmax(-1)
59
+ indices = torch.cat([((idxs // h) // d), (idxs // h) % w, idxs % w], 1)
60
+ return indices
61
+
62
+ def choose_highest_action(self, q_trans, q_rot_grip, q_collision):
63
+ coords = self._argmax_3d(q_trans)
64
+ rot_and_grip_indicies = None
65
+ ignore_collision = None
66
+ if q_rot_grip is not None:
67
+ q_rot = torch.stack(
68
+ torch.split(
69
+ q_rot_grip[:, :-2], int(360 // self._rotation_resolution), dim=1
70
+ ),
71
+ dim=1,
72
+ )
73
+ rot_and_grip_indicies = torch.cat(
74
+ [
75
+ q_rot[:, 0:1].argmax(-1),
76
+ q_rot[:, 1:2].argmax(-1),
77
+ q_rot[:, 2:3].argmax(-1),
78
+ q_rot_grip[:, -2:].argmax(-1, keepdim=True),
79
+ ],
80
+ -1,
81
+ )
82
+ ignore_collision = q_collision[:, -2:].argmax(-1, keepdim=True)
83
+ return coords, rot_and_grip_indicies, ignore_collision
84
+
85
+ def forward(
86
+ self,
87
+ rgb_pcd,
88
+ proprio,
89
+ pcd,
90
+ lang_goal_emb,
91
+ lang_token_embs,
92
+ bounds=None,
93
+ prev_bounds=None,
94
+ prev_layer_voxel_grid=None,
95
+ ):
96
+ # rgb_pcd will be list of list (list of [rgb, pcd])
97
+ b = rgb_pcd[0][0].shape[0]
98
+ pcd_flat = torch.cat([p.permute(0, 2, 3, 1).reshape(b, -1, 3) for p in pcd], 1)
99
+
100
+ # flatten RGBs and Pointclouds
101
+ rgb = [rp[0] for rp in rgb_pcd]
102
+ feat_size = rgb[0].shape[1]
103
+ flat_imag_features = torch.cat(
104
+ [p.permute(0, 2, 3, 1).reshape(b, -1, feat_size) for p in rgb], 1
105
+ )
106
+
107
+ # construct voxel grid
108
+ voxel_grid = self._voxelizer.coords_to_bounding_voxel_grid(
109
+ pcd_flat, coord_features=flat_imag_features, coord_bounds=bounds
110
+ )
111
+
112
+ # swap to channels fist
113
+ voxel_grid = voxel_grid.permute(0, 4, 1, 2, 3).detach()
114
+
115
+ # batch bounds if necessary
116
+ if bounds.shape[0] != b:
117
+ bounds = bounds.repeat(b, 1)
118
+
119
+ # forward pass
120
+ split_pred = self._qnet(
121
+ voxel_grid,
122
+ proprio,
123
+ lang_goal_emb,
124
+ lang_token_embs,
125
+ prev_layer_voxel_grid,
126
+ bounds,
127
+ prev_bounds,
128
+ )
129
+
130
+ return split_pred, voxel_grid
131
+
132
+
133
+ class QAttentionPerActBCAgent(Agent):
134
+ def __init__(
135
+ self,
136
+ layer: int,
137
+ coordinate_bounds: list,
138
+ perceiver_encoder: nn.Module,
139
+ camera_names: list,
140
+ batch_size: int,
141
+ voxel_size: int,
142
+ bounds_offset: float,
143
+ voxel_feature_size: int,
144
+ image_crop_size: int,
145
+ num_rotation_classes: int,
146
+ rotation_resolution: float,
147
+ lr: float = 0.0001,
148
+ lr_scheduler: bool = False,
149
+ training_iterations: int = 100000,
150
+ num_warmup_steps: int = 20000,
151
+ trans_loss_weight: float = 1.0,
152
+ rot_loss_weight: float = 1.0,
153
+ grip_loss_weight: float = 1.0,
154
+ collision_loss_weight: float = 1.0,
155
+ include_low_dim_state: bool = False,
156
+ image_resolution: list = None,
157
+ lambda_weight_l2: float = 0.0,
158
+ transform_augmentation: bool = True,
159
+ transform_augmentation_xyz: list = [0.0, 0.0, 0.0],
160
+ transform_augmentation_rpy: list = [0.0, 0.0, 180.0],
161
+ transform_augmentation_rot_resolution: int = 5,
162
+ optimizer_type: str = "adam",
163
+ num_devices: int = 1,
164
+ ):
165
+ self._layer = layer
166
+ self._coordinate_bounds = coordinate_bounds
167
+ self._perceiver_encoder = perceiver_encoder
168
+ self._voxel_feature_size = voxel_feature_size
169
+ self._bounds_offset = bounds_offset
170
+ self._image_crop_size = image_crop_size
171
+ self._lr = lr
172
+ self._lr_scheduler = lr_scheduler
173
+ self._training_iterations = training_iterations
174
+ self._num_warmup_steps = num_warmup_steps
175
+ self._trans_loss_weight = trans_loss_weight
176
+ self._rot_loss_weight = rot_loss_weight
177
+ self._grip_loss_weight = grip_loss_weight
178
+ self._collision_loss_weight = collision_loss_weight
179
+ self._include_low_dim_state = include_low_dim_state
180
+ self._image_resolution = image_resolution or [128, 128]
181
+ self._voxel_size = voxel_size
182
+ self._camera_names = camera_names
183
+ self._num_cameras = len(camera_names)
184
+ self._batch_size = batch_size
185
+ self._lambda_weight_l2 = lambda_weight_l2
186
+ self._transform_augmentation = transform_augmentation
187
+ self._transform_augmentation_xyz = torch.from_numpy(
188
+ np.array(transform_augmentation_xyz)
189
+ )
190
+ self._transform_augmentation_rpy = transform_augmentation_rpy
191
+ self._transform_augmentation_rot_resolution = (
192
+ transform_augmentation_rot_resolution
193
+ )
194
+ self._optimizer_type = optimizer_type
195
+ self._num_devices = num_devices
196
+ self._num_rotation_classes = num_rotation_classes
197
+ self._rotation_resolution = rotation_resolution
198
+
199
+ self._cross_entropy_loss = nn.CrossEntropyLoss(reduction="none")
200
+ self._name = NAME + "_layer" + str(self._layer)
201
+
202
+ def build(self, training: bool, device: torch.device = None):
203
+ self._training = training
204
+
205
+ if device is None:
206
+ device = torch.device("cpu")
207
+
208
+ self._device = device
209
+
210
+ self._voxelizer = VoxelGrid(
211
+ coord_bounds=self._coordinate_bounds,
212
+ voxel_size=self._voxel_size,
213
+ device=device,
214
+ batch_size=self._batch_size if training else 1,
215
+ feature_size=self._voxel_feature_size,
216
+ max_num_coords=np.prod(self._image_resolution) * self._num_cameras,
217
+ )
218
+
219
+ self._q = (
220
+ QFunction(
221
+ self._perceiver_encoder,
222
+ self._voxelizer,
223
+ self._bounds_offset,
224
+ self._rotation_resolution,
225
+ device,
226
+ training,
227
+ )
228
+ .to(device)
229
+ .train(training)
230
+ )
231
+
232
+ grid_for_crop = (
233
+ torch.arange(0, self._image_crop_size, device=device)
234
+ .unsqueeze(0)
235
+ .repeat(self._image_crop_size, 1)
236
+ .unsqueeze(-1)
237
+ )
238
+ self._grid_for_crop = torch.cat(
239
+ [grid_for_crop.transpose(1, 0), grid_for_crop], dim=2
240
+ ).unsqueeze(0)
241
+
242
+ self._coordinate_bounds = torch.tensor(
243
+ self._coordinate_bounds, device=device
244
+ ).unsqueeze(0)
245
+
246
+ if self._training:
247
+ # optimizer
248
+ if self._optimizer_type == "lamb":
249
+ self._optimizer = Lamb(
250
+ self._q.parameters(),
251
+ lr=self._lr,
252
+ weight_decay=self._lambda_weight_l2,
253
+ betas=(0.9, 0.999),
254
+ adam=False,
255
+ )
256
+ elif self._optimizer_type == "adam":
257
+ self._optimizer = torch.optim.Adam(
258
+ self._q.parameters(),
259
+ lr=self._lr,
260
+ weight_decay=self._lambda_weight_l2,
261
+ )
262
+ else:
263
+ raise Exception("Unknown optimizer type")
264
+
265
+ # learning rate scheduler
266
+ if self._lr_scheduler:
267
+ self._scheduler = (
268
+ transformers.get_cosine_with_hard_restarts_schedule_with_warmup(
269
+ self._optimizer,
270
+ num_warmup_steps=self._num_warmup_steps,
271
+ num_training_steps=self._training_iterations,
272
+ num_cycles=self._training_iterations // 10000,
273
+ )
274
+ )
275
+
276
+ # one-hot zero tensors
277
+ self._action_trans_one_hot_zeros = torch.zeros(
278
+ (
279
+ self._batch_size,
280
+ 1,
281
+ self._voxel_size,
282
+ self._voxel_size,
283
+ self._voxel_size,
284
+ ),
285
+ dtype=int,
286
+ device=device,
287
+ )
288
+ self._action_rot_x_one_hot_zeros = torch.zeros(
289
+ (self._batch_size, self._num_rotation_classes), dtype=int, device=device
290
+ )
291
+ self._action_rot_y_one_hot_zeros = torch.zeros(
292
+ (self._batch_size, self._num_rotation_classes), dtype=int, device=device
293
+ )
294
+ self._action_rot_z_one_hot_zeros = torch.zeros(
295
+ (self._batch_size, self._num_rotation_classes), dtype=int, device=device
296
+ )
297
+ self._action_grip_one_hot_zeros = torch.zeros(
298
+ (self._batch_size, 2), dtype=int, device=device
299
+ )
300
+ self._action_ignore_collisions_one_hot_zeros = torch.zeros(
301
+ (self._batch_size, 2), dtype=int, device=device
302
+ )
303
+
304
+ # print total params
305
+ logging.info(
306
+ "# Q Params: %d"
307
+ % sum(
308
+ p.numel()
309
+ for name, p in self._q.named_parameters()
310
+ if p.requires_grad and "clip" not in name
311
+ )
312
+ )
313
+ else:
314
+ for param in self._q.parameters():
315
+ param.requires_grad = False
316
+
317
+ # load CLIP for encoding language goals during evaluation
318
+ model, _ = load_clip("RN50", jit=False)
319
+ self._clip_rn50 = build_model(model.state_dict())
320
+ self._clip_rn50 = self._clip_rn50.float().to(device)
321
+ self._clip_rn50.eval()
322
+ del model
323
+
324
+ self._voxelizer.to(device)
325
+ self._q.to(device)
326
+
327
+ def _extract_crop(self, pixel_action, observation):
328
+ # Pixel action will now be (B, 2)
329
+ # observation = stack_on_channel(observation)
330
+ h = observation.shape[-1]
331
+ top_left_corner = torch.clamp(
332
+ pixel_action - self._image_crop_size // 2, 0, h - self._image_crop_size
333
+ )
334
+ grid = self._grid_for_crop + top_left_corner.unsqueeze(1)
335
+ grid = ((grid / float(h)) * 2.0) - 1.0 # between -1 and 1
336
+ # Used for cropping the images across a batch
337
+ # swap fro y x, to x, y
338
+ grid = torch.cat((grid[:, :, :, 1:2], grid[:, :, :, 0:1]), dim=-1)
339
+ crop = F.grid_sample(observation, grid, mode="nearest", align_corners=True)
340
+ return crop
341
+
342
+ def _preprocess_inputs(self, replay_sample):
343
+ obs = []
344
+ pcds = []
345
+ self._crop_summary = []
346
+ for n in self._camera_names:
347
+ rgb = replay_sample["%s_rgb" % n]
348
+ pcd = replay_sample["%s_point_cloud" % n]
349
+
350
+ obs.append([rgb, pcd])
351
+ pcds.append(pcd)
352
+ return obs, pcds
353
+
354
+ def _act_preprocess_inputs(self, observation):
355
+ obs, pcds = [], []
356
+ for n in self._camera_names:
357
+ rgb = observation["%s_rgb" % n]
358
+ pcd = observation["%s_point_cloud" % n]
359
+
360
+ obs.append([rgb, pcd])
361
+ pcds.append(pcd)
362
+ return obs, pcds
363
+
364
+ def _get_value_from_voxel_index(self, q, voxel_idx):
365
+ b, c, d, h, w = q.shape
366
+ q_trans_flat = q.view(b, c, d * h * w)
367
+ flat_indicies = (
368
+ voxel_idx[:, 0] * d * h + voxel_idx[:, 1] * h + voxel_idx[:, 2]
369
+ )[:, None].int()
370
+ highest_idxs = flat_indicies.unsqueeze(-1).repeat(1, c, 1)
371
+ chosen_voxel_values = q_trans_flat.gather(2, highest_idxs)[
372
+ ..., 0
373
+ ] # (B, trans + rot + grip)
374
+ return chosen_voxel_values
375
+
376
+ def _get_value_from_rot_and_grip(self, rot_grip_q, rot_and_grip_idx):
377
+ q_rot = torch.stack(
378
+ torch.split(
379
+ rot_grip_q[:, :-2], int(360 // self._rotation_resolution), dim=1
380
+ ),
381
+ dim=1,
382
+ ) # B, 3, 72
383
+ q_grip = rot_grip_q[:, -2:]
384
+ rot_and_grip_values = torch.cat(
385
+ [
386
+ q_rot[:, 0].gather(1, rot_and_grip_idx[:, 0:1]),
387
+ q_rot[:, 1].gather(1, rot_and_grip_idx[:, 1:2]),
388
+ q_rot[:, 2].gather(1, rot_and_grip_idx[:, 2:3]),
389
+ q_grip.gather(1, rot_and_grip_idx[:, 3:4]),
390
+ ],
391
+ -1,
392
+ )
393
+ return rot_and_grip_values
394
+
395
+ def _celoss(self, pred, labels):
396
+ return self._cross_entropy_loss(pred, labels.argmax(-1))
397
+
398
+ def _softmax_q_trans(self, q):
399
+ q_shape = q.shape
400
+ return F.softmax(q.reshape(q_shape[0], -1), dim=1).reshape(q_shape)
401
+
402
+ def _softmax_q_rot_grip(self, q_rot_grip):
403
+ q_rot_x_flat = q_rot_grip[
404
+ :, 0 * self._num_rotation_classes : 1 * self._num_rotation_classes
405
+ ]
406
+ q_rot_y_flat = q_rot_grip[
407
+ :, 1 * self._num_rotation_classes : 2 * self._num_rotation_classes
408
+ ]
409
+ q_rot_z_flat = q_rot_grip[
410
+ :, 2 * self._num_rotation_classes : 3 * self._num_rotation_classes
411
+ ]
412
+ q_grip_flat = q_rot_grip[:, 3 * self._num_rotation_classes :]
413
+
414
+ q_rot_x_flat_softmax = F.softmax(q_rot_x_flat, dim=1)
415
+ q_rot_y_flat_softmax = F.softmax(q_rot_y_flat, dim=1)
416
+ q_rot_z_flat_softmax = F.softmax(q_rot_z_flat, dim=1)
417
+ q_grip_flat_softmax = F.softmax(q_grip_flat, dim=1)
418
+
419
+ return torch.cat(
420
+ [
421
+ q_rot_x_flat_softmax,
422
+ q_rot_y_flat_softmax,
423
+ q_rot_z_flat_softmax,
424
+ q_grip_flat_softmax,
425
+ ],
426
+ dim=1,
427
+ )
428
+
429
+ def _softmax_ignore_collision(self, q_collision):
430
+ q_collision_softmax = F.softmax(q_collision, dim=1)
431
+ return q_collision_softmax
432
+
433
+ def update(self, step: int, replay_sample: dict) -> dict:
434
+ right_action_trans = replay_sample["right_trans_action_indicies"][
435
+ :, self._layer * 3 : self._layer * 3 + 3
436
+ ].int()
437
+ right_action_rot_grip = replay_sample["right_rot_grip_action_indicies"].int()
438
+ right_action_gripper_pose = replay_sample["right_gripper_pose"]
439
+ right_action_ignore_collisions = replay_sample["right_ignore_collisions"].int()
440
+
441
+ left_action_trans = replay_sample["left_trans_action_indicies"][
442
+ :, self._layer * 3 : self._layer * 3 + 3
443
+ ].int()
444
+ left_action_rot_grip = replay_sample["left_rot_grip_action_indicies"].int()
445
+ left_action_gripper_pose = replay_sample["left_gripper_pose"]
446
+ left_action_ignore_collisions = replay_sample["left_ignore_collisions"].int()
447
+
448
+ lang_goal_emb = replay_sample["lang_goal_emb"].float()
449
+ lang_token_embs = replay_sample["lang_token_embs"].float()
450
+ prev_layer_voxel_grid = replay_sample.get("prev_layer_voxel_grid", None)
451
+ prev_layer_bounds = replay_sample.get("prev_layer_bounds", None)
452
+ device = self._device
453
+
454
+ bounds = self._coordinate_bounds.to(device)
455
+ if self._layer > 0:
456
+ right_cp = replay_sample[
457
+ "right_attention_coordinate_layer_%d" % (self._layer - 1)
458
+ ]
459
+
460
+ left_cp = replay_sample[
461
+ "left_attention_coordinate_layer_%d" % (self._layer - 1)
462
+ ]
463
+
464
+ right_bounds = torch.cat(
465
+ [right_cp - self._bounds_offset, right_cp + self._bounds_offset], dim=1
466
+ )
467
+ left_bounds = torch.cat(
468
+ [left_cp - self._bounds_offset, left_cp + self._bounds_offset], dim=1
469
+ )
470
+
471
+ else:
472
+ right_bounds = bounds
473
+ left_bounds = bounds
474
+
475
+ right_proprio = None
476
+ left_proprio = None
477
+ if self._include_low_dim_state:
478
+ right_proprio = replay_sample["right_low_dim_state"]
479
+ left_proprio = replay_sample["left_low_dim_state"]
480
+
481
+ # ..TODO::
482
+ # Can we add the coordinates of both robots?
483
+ #
484
+
485
+ obs, pcd = self._preprocess_inputs(replay_sample)
486
+
487
+ # batch size
488
+ bs = pcd[0].shape[0]
489
+
490
+ # We can move the point cloud w.r.t to the other robot's cooridinate system
491
+ # similar to apply_se3_augmentation
492
+ #
493
+
494
+ # SE(3) augmentation of point clouds and actions
495
+ if self._transform_augmentation:
496
+ from voxel import augmentation
497
+
498
+ (
499
+ right_action_trans,
500
+ right_action_rot_grip,
501
+ left_action_trans,
502
+ left_action_rot_grip,
503
+ pcd,
504
+ ) = augmentation.bimanual_apply_se3_augmentation(
505
+ pcd,
506
+ right_action_gripper_pose,
507
+ right_action_trans,
508
+ right_action_rot_grip,
509
+ left_action_gripper_pose,
510
+ left_action_trans,
511
+ left_action_rot_grip,
512
+ bounds,
513
+ self._layer,
514
+ self._transform_augmentation_xyz,
515
+ self._transform_augmentation_rpy,
516
+ self._transform_augmentation_rot_resolution,
517
+ self._voxel_size,
518
+ self._rotation_resolution,
519
+ self._device,
520
+ )
521
+ else:
522
+ right_action_trans = right_action_trans.int()
523
+ left_action_trans = left_action_trans.int()
524
+
525
+ proprio = torch.cat((right_proprio, left_proprio), dim=1)
526
+
527
+ right_action = (
528
+ right_action_trans,
529
+ right_action_rot_grip,
530
+ right_action_ignore_collisions,
531
+ )
532
+ left_action = (
533
+ left_action_trans,
534
+ left_action_rot_grip,
535
+ left_action_ignore_collisions,
536
+ )
537
+ # forward pass
538
+ q, voxel_grid = self._q(
539
+ obs,
540
+ proprio,
541
+ pcd,
542
+ lang_goal_emb,
543
+ lang_token_embs,
544
+ bounds,
545
+ prev_layer_bounds,
546
+ prev_layer_voxel_grid,
547
+ )
548
+
549
+ (
550
+ right_q_trans,
551
+ right_q_rot_grip,
552
+ right_q_collision,
553
+ left_q_trans,
554
+ left_q_rot_grip,
555
+ left_q_collision,
556
+ ) = q
557
+
558
+ # argmax to choose best action
559
+ (
560
+ right_coords,
561
+ right_rot_and_grip_indicies,
562
+ right_ignore_collision_indicies,
563
+ ) = self._q.choose_highest_action(
564
+ right_q_trans, right_q_rot_grip, right_q_collision
565
+ )
566
+
567
+ (
568
+ left_coords,
569
+ left_rot_and_grip_indicies,
570
+ left_ignore_collision_indicies,
571
+ ) = self._q.choose_highest_action(
572
+ left_q_trans, left_q_rot_grip, left_q_collision
573
+ )
574
+
575
+ (
576
+ right_q_trans_loss,
577
+ right_q_rot_loss,
578
+ right_q_grip_loss,
579
+ right_q_collision_loss,
580
+ ) = (0.0, 0.0, 0.0, 0.0)
581
+ left_q_trans_loss, left_q_rot_loss, left_q_grip_loss, left_q_collision_loss = (
582
+ 0.0,
583
+ 0.0,
584
+ 0.0,
585
+ 0.0,
586
+ )
587
+
588
+ # translation one-hot
589
+ right_action_trans_one_hot = self._action_trans_one_hot_zeros.clone().detach()
590
+ left_action_trans_one_hot = self._action_trans_one_hot_zeros.clone().detach()
591
+ for b in range(bs):
592
+ right_gt_coord = right_action_trans[b, :].int()
593
+ right_action_trans_one_hot[
594
+ b, :, right_gt_coord[0], right_gt_coord[1], right_gt_coord[2]
595
+ ] = 1
596
+ left_gt_coord = left_action_trans[b, :].int()
597
+ left_action_trans_one_hot[
598
+ b, :, left_gt_coord[0], left_gt_coord[1], left_gt_coord[2]
599
+ ] = 1
600
+
601
+ # translation loss
602
+ right_q_trans_flat = right_q_trans.view(bs, -1)
603
+ right_action_trans_one_hot_flat = right_action_trans_one_hot.view(bs, -1)
604
+ right_q_trans_loss = self._celoss(
605
+ right_q_trans_flat, right_action_trans_one_hot_flat
606
+ )
607
+ left_q_trans_flat = left_q_trans.view(bs, -1)
608
+ left_action_trans_one_hot_flat = left_action_trans_one_hot.view(bs, -1)
609
+ left_q_trans_loss = self._celoss(
610
+ left_q_trans_flat, left_action_trans_one_hot_flat
611
+ )
612
+
613
+ q_trans_loss = right_q_trans_loss + left_q_trans_loss
614
+
615
+ with_rot_and_grip = (
616
+ len(right_rot_and_grip_indicies) > 0 and len(left_rot_and_grip_indicies) > 0
617
+ )
618
+ if with_rot_and_grip:
619
+ # rotation, gripper, and collision one-hots
620
+ right_action_rot_x_one_hot = self._action_rot_x_one_hot_zeros.clone()
621
+ right_action_rot_y_one_hot = self._action_rot_y_one_hot_zeros.clone()
622
+ right_action_rot_z_one_hot = self._action_rot_z_one_hot_zeros.clone()
623
+ right_action_grip_one_hot = self._action_grip_one_hot_zeros.clone()
624
+ right_action_ignore_collisions_one_hot = (
625
+ self._action_ignore_collisions_one_hot_zeros.clone()
626
+ )
627
+
628
+ left_action_rot_x_one_hot = self._action_rot_x_one_hot_zeros.clone()
629
+ left_action_rot_y_one_hot = self._action_rot_y_one_hot_zeros.clone()
630
+ left_action_rot_z_one_hot = self._action_rot_z_one_hot_zeros.clone()
631
+ left_action_grip_one_hot = self._action_grip_one_hot_zeros.clone()
632
+ left_action_ignore_collisions_one_hot = (
633
+ self._action_ignore_collisions_one_hot_zeros.clone()
634
+ )
635
+
636
+ for b in range(bs):
637
+ right_gt_rot_grip = right_action_rot_grip[b, :].int()
638
+ right_action_rot_x_one_hot[b, right_gt_rot_grip[0]] = 1
639
+ right_action_rot_y_one_hot[b, right_gt_rot_grip[1]] = 1
640
+ right_action_rot_z_one_hot[b, right_gt_rot_grip[2]] = 1
641
+ right_action_grip_one_hot[b, right_gt_rot_grip[3]] = 1
642
+
643
+ right_gt_ignore_collisions = right_action_ignore_collisions[b, :].int()
644
+ right_action_ignore_collisions_one_hot[
645
+ b, right_gt_ignore_collisions[0]
646
+ ] = 1
647
+
648
+ left_gt_rot_grip = left_action_rot_grip[b, :].int()
649
+ left_action_rot_x_one_hot[b, left_gt_rot_grip[0]] = 1
650
+ left_action_rot_y_one_hot[b, left_gt_rot_grip[1]] = 1
651
+ left_action_rot_z_one_hot[b, left_gt_rot_grip[2]] = 1
652
+ left_action_grip_one_hot[b, left_gt_rot_grip[3]] = 1
653
+
654
+ left_gt_ignore_collisions = left_action_ignore_collisions[b, :].int()
655
+ left_action_ignore_collisions_one_hot[
656
+ b, left_gt_ignore_collisions[0]
657
+ ] = 1
658
+
659
+ # flatten predictions
660
+ right_q_rot_x_flat = right_q_rot_grip[
661
+ :, 0 * self._num_rotation_classes : 1 * self._num_rotation_classes
662
+ ]
663
+ right_q_rot_y_flat = right_q_rot_grip[
664
+ :, 1 * self._num_rotation_classes : 2 * self._num_rotation_classes
665
+ ]
666
+ right_q_rot_z_flat = right_q_rot_grip[
667
+ :, 2 * self._num_rotation_classes : 3 * self._num_rotation_classes
668
+ ]
669
+ right_q_grip_flat = right_q_rot_grip[:, 3 * self._num_rotation_classes :]
670
+ right_q_ignore_collisions_flat = right_q_collision
671
+
672
+ left_q_rot_x_flat = left_q_rot_grip[
673
+ :, 0 * self._num_rotation_classes : 1 * self._num_rotation_classes
674
+ ]
675
+ left_q_rot_y_flat = left_q_rot_grip[
676
+ :, 1 * self._num_rotation_classes : 2 * self._num_rotation_classes
677
+ ]
678
+ left_q_rot_z_flat = left_q_rot_grip[
679
+ :, 2 * self._num_rotation_classes : 3 * self._num_rotation_classes
680
+ ]
681
+ left_q_grip_flat = left_q_rot_grip[:, 3 * self._num_rotation_classes :]
682
+ left_q_ignore_collisions_flat = left_q_collision
683
+
684
+ # rotation loss
685
+ right_q_rot_loss += self._celoss(
686
+ right_q_rot_x_flat, right_action_rot_x_one_hot
687
+ )
688
+ right_q_rot_loss += self._celoss(
689
+ right_q_rot_y_flat, right_action_rot_y_one_hot
690
+ )
691
+ right_q_rot_loss += self._celoss(
692
+ right_q_rot_z_flat, right_action_rot_z_one_hot
693
+ )
694
+
695
+ left_q_rot_loss += self._celoss(
696
+ left_q_rot_x_flat, left_action_rot_x_one_hot
697
+ )
698
+ left_q_rot_loss += self._celoss(
699
+ left_q_rot_y_flat, left_action_rot_y_one_hot
700
+ )
701
+ left_q_rot_loss += self._celoss(
702
+ left_q_rot_z_flat, left_action_rot_z_one_hot
703
+ )
704
+
705
+ # gripper loss
706
+ right_q_grip_loss += self._celoss(
707
+ right_q_grip_flat, right_action_grip_one_hot
708
+ )
709
+ left_q_grip_loss += self._celoss(left_q_grip_flat, left_action_grip_one_hot)
710
+
711
+ # collision loss
712
+ right_q_collision_loss += self._celoss(
713
+ right_q_ignore_collisions_flat, right_action_ignore_collisions_one_hot
714
+ )
715
+ left_q_collision_loss += self._celoss(
716
+ left_q_ignore_collisions_flat, left_action_ignore_collisions_one_hot
717
+ )
718
+
719
+ q_trans_loss = right_q_trans_loss + left_q_trans_loss
720
+ q_rot_loss = right_q_rot_loss + left_q_rot_loss
721
+ q_grip_loss = right_q_grip_loss + left_q_grip_loss
722
+ q_collision_loss = right_q_collision_loss + left_q_collision_loss
723
+
724
+ combined_losses = (
725
+ (q_trans_loss * self._trans_loss_weight)
726
+ + (q_rot_loss * self._rot_loss_weight)
727
+ + (q_grip_loss * self._grip_loss_weight)
728
+ + (q_collision_loss * self._collision_loss_weight)
729
+ )
730
+ total_loss = combined_losses.mean()
731
+
732
+ self._optimizer.zero_grad()
733
+ total_loss.backward()
734
+ self._optimizer.step()
735
+
736
+ self._summaries = {
737
+ "losses/total_loss": total_loss,
738
+ "losses/trans_loss": q_trans_loss.mean(),
739
+ "losses/rot_loss": q_rot_loss.mean() if with_rot_and_grip else 0.0,
740
+ "losses/grip_loss": q_grip_loss.mean() if with_rot_and_grip else 0.0,
741
+ "losses/right/trans_loss": q_trans_loss.mean(),
742
+ "losses/right/rot_loss": q_rot_loss.mean() if with_rot_and_grip else 0.0,
743
+ "losses/right/grip_loss": q_grip_loss.mean() if with_rot_and_grip else 0.0,
744
+ "losses/right/collision_loss": q_collision_loss.mean()
745
+ if with_rot_and_grip
746
+ else 0.0,
747
+ "losses/left/trans_loss": q_trans_loss.mean(),
748
+ "losses/left/rot_loss": q_rot_loss.mean() if with_rot_and_grip else 0.0,
749
+ "losses/left/grip_loss": q_grip_loss.mean() if with_rot_and_grip else 0.0,
750
+ "losses/left/collision_loss": q_collision_loss.mean()
751
+ if with_rot_and_grip
752
+ else 0.0,
753
+ "losses/collision_loss": q_collision_loss.mean()
754
+ if with_rot_and_grip
755
+ else 0.0,
756
+ }
757
+
758
+ if self._lr_scheduler:
759
+ self._scheduler.step()
760
+ self._summaries["learning_rate"] = self._scheduler.get_last_lr()[0]
761
+
762
+ self._vis_voxel_grid = voxel_grid[0]
763
+ self._right_vis_translation_qvalue = self._softmax_q_trans(right_q_trans[0])
764
+ self._right_vis_max_coordinate = right_coords[0]
765
+ self._right_vis_gt_coordinate = right_action_trans[0]
766
+
767
+ self._left_vis_translation_qvalue = self._softmax_q_trans(left_q_trans[0])
768
+ self._left_vis_max_coordinate = left_coords[0]
769
+ self._left_vis_gt_coordinate = left_action_trans[0]
770
+
771
+ # Note: PerAct doesn't use multi-layer voxel grids like C2FARM
772
+ # stack prev_layer_voxel_grid(s) from previous layers into a list
773
+ if prev_layer_voxel_grid is None:
774
+ prev_layer_voxel_grid = [voxel_grid]
775
+ else:
776
+ prev_layer_voxel_grid = prev_layer_voxel_grid + [voxel_grid]
777
+
778
+ # stack prev_layer_bound(s) from previous layers into a list
779
+ if prev_layer_bounds is None:
780
+ prev_layer_bounds = [self._coordinate_bounds.repeat(bs, 1)]
781
+ else:
782
+ prev_layer_bounds = prev_layer_bounds + [bounds]
783
+
784
+ return {
785
+ "total_loss": total_loss,
786
+ "prev_layer_voxel_grid": prev_layer_voxel_grid,
787
+ "prev_layer_bounds": prev_layer_bounds,
788
+ }
789
+
790
+ def act(self, step: int, observation: dict, deterministic=False) -> ActResult:
791
+ deterministic = True
792
+ bounds = self._coordinate_bounds
793
+ prev_layer_voxel_grid = observation.get("prev_layer_voxel_grid", None)
794
+ prev_layer_bounds = observation.get("prev_layer_bounds", None)
795
+ lang_goal_tokens = observation.get("lang_goal_tokens", None).long()
796
+
797
+ # extract CLIP language embs
798
+ with torch.no_grad():
799
+ lang_goal_tokens = lang_goal_tokens.to(device=self._device)
800
+ (
801
+ lang_goal_emb,
802
+ lang_token_embs,
803
+ ) = self._clip_rn50.encode_text_with_embeddings(lang_goal_tokens[0])
804
+
805
+ # voxelization resolution
806
+ res = (bounds[:, 3:] - bounds[:, :3]) / self._voxel_size
807
+ max_rot_index = int(360 // self._rotation_resolution)
808
+ right_proprio = None
809
+ left_proprio = None
810
+
811
+ if self._include_low_dim_state:
812
+ right_proprio = observation["right_low_dim_state"]
813
+ left_proprio = observation["left_low_dim_state"]
814
+ right_proprio = right_proprio[0].to(self._device)
815
+ left_proprio = left_proprio[0].to(self._device)
816
+
817
+ obs, pcd = self._act_preprocess_inputs(observation)
818
+
819
+ # correct batch size and device
820
+ obs = [[o[0][0].to(self._device), o[1][0].to(self._device)] for o in obs]
821
+
822
+ pcd = [p[0].to(self._device) for p in pcd]
823
+ lang_goal_emb = lang_goal_emb.to(self._device)
824
+ lang_token_embs = lang_token_embs.to(self._device)
825
+ bounds = torch.as_tensor(bounds, device=self._device)
826
+ prev_layer_voxel_grid = (
827
+ prev_layer_voxel_grid.to(self._device)
828
+ if prev_layer_voxel_grid is not None
829
+ else None
830
+ )
831
+ prev_layer_bounds = (
832
+ prev_layer_bounds.to(self._device)
833
+ if prev_layer_bounds is not None
834
+ else None
835
+ )
836
+
837
+ proprio = torch.cat((right_proprio, left_proprio), dim=1)
838
+
839
+ # inference
840
+ (
841
+ right_q_trans,
842
+ right_q_rot_grip,
843
+ right_q_ignore_collisions,
844
+ left_q_trans,
845
+ left_q_rot_grip,
846
+ left_q_ignore_collisions,
847
+ ), vox_grid = self._q(
848
+ obs,
849
+ proprio,
850
+ pcd,
851
+ lang_goal_emb,
852
+ lang_token_embs,
853
+ bounds,
854
+ prev_layer_bounds,
855
+ prev_layer_voxel_grid,
856
+ )
857
+
858
+ # softmax Q predictions
859
+ right_q_trans = self._softmax_q_trans(right_q_trans)
860
+ left_q_trans = self._softmax_q_trans(left_q_trans)
861
+
862
+ if right_q_rot_grip is not None:
863
+ right_q_rot_grip = self._softmax_q_rot_grip(right_q_rot_grip)
864
+
865
+ if left_q_rot_grip is not None:
866
+ left_q_rot_grip = self._softmax_q_rot_grip(left_q_rot_grip)
867
+
868
+ if right_q_ignore_collisions is not None:
869
+ right_q_ignore_collisions = self._softmax_ignore_collision(
870
+ right_q_ignore_collisions
871
+ )
872
+
873
+ if left_q_ignore_collisions is not None:
874
+ left_q_ignore_collisions = self._softmax_ignore_collision(
875
+ left_q_ignore_collisions
876
+ )
877
+
878
+ # argmax Q predictions
879
+ (
880
+ right_coords,
881
+ right_rot_and_grip_indicies,
882
+ right_ignore_collisions,
883
+ ) = self._q.choose_highest_action(
884
+ right_q_trans, right_q_rot_grip, right_q_ignore_collisions
885
+ )
886
+ (
887
+ left_coords,
888
+ left_rot_and_grip_indicies,
889
+ left_ignore_collisions,
890
+ ) = self._q.choose_highest_action(
891
+ left_q_trans, left_q_rot_grip, left_q_ignore_collisions
892
+ )
893
+
894
+ if right_q_rot_grip is not None:
895
+ right_rot_grip_action = right_rot_and_grip_indicies
896
+ if right_q_ignore_collisions is not None:
897
+ right_ignore_collisions_action = right_ignore_collisions.int()
898
+
899
+ if left_q_rot_grip is not None:
900
+ left_rot_grip_action = left_rot_and_grip_indicies
901
+ if left_q_ignore_collisions is not None:
902
+ left_ignore_collisions_action = left_ignore_collisions.int()
903
+
904
+ right_coords = right_coords.int()
905
+ left_coords = left_coords.int()
906
+
907
+ right_attention_coordinate = bounds[:, :3] + res * right_coords + res / 2
908
+ left_attention_coordinate = bounds[:, :3] + res * left_coords + res / 2
909
+
910
+ # stack prev_layer_voxel_grid(s) into a list
911
+ # NOTE: PerAct doesn't used multi-layer voxel grids like C2FARM
912
+ if prev_layer_voxel_grid is None:
913
+ prev_layer_voxel_grid = [vox_grid]
914
+ else:
915
+ prev_layer_voxel_grid = prev_layer_voxel_grid + [vox_grid]
916
+
917
+ if prev_layer_bounds is None:
918
+ prev_layer_bounds = [bounds]
919
+ else:
920
+ prev_layer_bounds = prev_layer_bounds + [bounds]
921
+
922
+ observation_elements = {
923
+ "right_attention_coordinate": right_attention_coordinate,
924
+ "left_attention_coordinate": left_attention_coordinate,
925
+ "prev_layer_voxel_grid": prev_layer_voxel_grid,
926
+ "prev_layer_bounds": prev_layer_bounds,
927
+ }
928
+ info = {
929
+ "voxel_grid_depth%d" % self._layer: vox_grid,
930
+ "right_q_depth%d" % self._layer: right_q_trans,
931
+ "right_voxel_idx_depth%d" % self._layer: right_coords,
932
+ "left_q_depth%d" % self._layer: left_q_trans,
933
+ "left_voxel_idx_depth%d" % self._layer: left_coords,
934
+ }
935
+ self._act_voxel_grid = vox_grid[0]
936
+ self._right_act_max_coordinate = right_coords[0]
937
+ self._right_act_qvalues = right_q_trans[0].detach()
938
+ self._left_act_max_coordinate = left_coords[0]
939
+ self._left_act_qvalues = left_q_trans[0].detach()
940
+
941
+ action = (
942
+ right_coords,
943
+ right_rot_grip_action,
944
+ right_ignore_collisions,
945
+ left_coords,
946
+ left_rot_grip_action,
947
+ left_ignore_collisions,
948
+ )
949
+
950
+ return ActResult(action, observation_elements=observation_elements, info=info)
951
+
952
+ def update_summaries(self) -> List[Summary]:
953
+ voxel_grid = self._vis_voxel_grid.detach().cpu().numpy()
954
+ summaries = []
955
+ summaries.append(
956
+ ImageSummary(
957
+ "%s/right_update_qattention" % self._name,
958
+ transforms.ToTensor()(
959
+ visualise_voxel(
960
+ voxel_grid,
961
+ self._right_vis_translation_qvalue.detach().cpu().numpy(),
962
+ self._right_vis_max_coordinate.detach().cpu().numpy(),
963
+ self._right_vis_gt_coordinate.detach().cpu().numpy(),
964
+ )
965
+ ),
966
+ )
967
+ )
968
+ summaries.append(
969
+ ImageSummary(
970
+ "%s/left_update_qattention" % self._name,
971
+ transforms.ToTensor()(
972
+ visualise_voxel(
973
+ voxel_grid,
974
+ self._left_vis_translation_qvalue.detach().cpu().numpy(),
975
+ self._left_vis_max_coordinate.detach().cpu().numpy(),
976
+ self._left_vis_gt_coordinate.detach().cpu().numpy(),
977
+ )
978
+ ),
979
+ )
980
+ )
981
+ for n, v in self._summaries.items():
982
+ summaries.append(ScalarSummary("%s/%s" % (self._name, n), v))
983
+
984
+ for name, crop in self._crop_summary:
985
+ crops = (torch.cat(torch.split(crop, 3, dim=1), dim=3) + 1.0) / 2.0
986
+ summaries.extend([ImageSummary("%s/crops/%s" % (self._name, name), crops)])
987
+
988
+ for tag, param in self._q.named_parameters():
989
+ # assert not torch.isnan(param.grad.abs() <= 1.0).all()
990
+ summaries.append(
991
+ HistogramSummary("%s/gradient/%s" % (self._name, tag), param.grad)
992
+ )
993
+ summaries.append(
994
+ HistogramSummary("%s/weight/%s" % (self._name, tag), param.data)
995
+ )
996
+
997
+ return summaries
998
+
999
+ def act_summaries(self) -> List[Summary]:
1000
+ voxel_grid = self._act_voxel_grid.cpu().numpy()
1001
+ right_q_attention = self._right_act_qvalues.cpu().numpy()
1002
+ right_highlight_coordinate = self._right_act_max_coordinate.cpu().numpy()
1003
+ right_visualization = visualise_voxel(
1004
+ voxel_grid, right_q_attention, right_highlight_coordinate
1005
+ )
1006
+
1007
+ left_q_attention = self._left_act_qvalues.cpu().numpy()
1008
+ left_highlight_coordinate = self._left_act_max_coordinate.cpu().numpy()
1009
+ left_visualization = visualise_voxel(
1010
+ voxel_grid, left_q_attention, left_highlight_coordinate
1011
+ )
1012
+
1013
+ return [
1014
+ ImageSummary(
1015
+ f"{self._name}/right_act_Qattention",
1016
+ transforms.ToTensor()(right_visualization),
1017
+ ),
1018
+ ImageSummary(
1019
+ f"{self._name}/left_act_Qattention",
1020
+ transforms.ToTensor()(left_visualization),
1021
+ ),
1022
+ ]
1023
+
1024
+ def load_weights(self, savedir: str):
1025
+ device = (
1026
+ self._device
1027
+ if not self._training
1028
+ else torch.device("cuda:%d" % self._device)
1029
+ )
1030
+ weight_file = os.path.join(savedir, "%s.pt" % self._name)
1031
+ state_dict = torch.load(weight_file, map_location=device)
1032
+
1033
+ # load only keys that are in the current model
1034
+ merged_state_dict = self._q.state_dict()
1035
+ for k, v in state_dict.items():
1036
+ if not self._training:
1037
+ k = k.replace("_qnet.module", "_qnet")
1038
+ if k in merged_state_dict:
1039
+ merged_state_dict[k] = v
1040
+ else:
1041
+ if "_voxelizer" not in k:
1042
+ logging.warning("key %s not found in checkpoint" % k)
1043
+ if not self._training:
1044
+ # reshape voxelizer weights
1045
+ b = merged_state_dict["_voxelizer._ones_max_coords"].shape[0]
1046
+ merged_state_dict["_voxelizer._ones_max_coords"] = merged_state_dict[
1047
+ "_voxelizer._ones_max_coords"
1048
+ ][0:1]
1049
+ flat_shape = merged_state_dict["_voxelizer._flat_output"].shape[0]
1050
+ merged_state_dict["_voxelizer._flat_output"] = merged_state_dict[
1051
+ "_voxelizer._flat_output"
1052
+ ][0 : flat_shape // b]
1053
+ merged_state_dict["_voxelizer._tiled_batch_indices"] = merged_state_dict[
1054
+ "_voxelizer._tiled_batch_indices"
1055
+ ][0:1]
1056
+ merged_state_dict["_voxelizer._index_grid"] = merged_state_dict[
1057
+ "_voxelizer._index_grid"
1058
+ ][0:1]
1059
+ self._q.load_state_dict(merged_state_dict)
1060
+ print("loaded weights from %s" % weight_file)
1061
+
1062
+ def save_weights(self, savedir: str):
1063
+ torch.save(self._q.state_dict(), os.path.join(savedir, "%s.pt" % self._name))
external/peract_bimanual/agents/bimanual_peract/qattention_stack_agent.py ADDED
@@ -0,0 +1,202 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List
2
+
3
+ import torch
4
+ from yarr.agents.agent import Agent, ActResult, Summary
5
+
6
+ import numpy as np
7
+
8
+ from helpers import utils
9
+ from agents.bimanual_peract.qattention_peract_bc_agent import QAttentionPerActBCAgent
10
+
11
+ NAME = "QAttentionStackAgent"
12
+
13
+
14
+ class QAttentionStackAgent(Agent):
15
+ def __init__(
16
+ self,
17
+ qattention_agents: List[QAttentionPerActBCAgent],
18
+ rotation_resolution: float,
19
+ camera_names: List[str],
20
+ rotation_prediction_depth: int = 0,
21
+ ):
22
+ super(QAttentionStackAgent, self).__init__()
23
+ self._qattention_agents = qattention_agents
24
+ self._rotation_resolution = rotation_resolution
25
+ self._camera_names = camera_names
26
+ self._rotation_prediction_depth = rotation_prediction_depth
27
+
28
+ def build(self, training: bool, device=None) -> None:
29
+ self._device = device
30
+ if self._device is None:
31
+ self._device = torch.device("cpu")
32
+ for qa in self._qattention_agents:
33
+ qa.build(training, device)
34
+
35
+ def update(self, step: int, replay_sample: dict) -> dict:
36
+ priorities = 0
37
+ total_losses = 0.0
38
+ for qa in self._qattention_agents:
39
+ update_dict = qa.update(step, replay_sample)
40
+ replay_sample.update(update_dict)
41
+ total_losses += update_dict["total_loss"]
42
+ return {
43
+ "total_losses": total_losses,
44
+ }
45
+
46
+ def act(self, step: int, observation: dict, deterministic=False) -> ActResult:
47
+ observation_elements = {}
48
+ (
49
+ right_translation_results,
50
+ right_rot_grip_results,
51
+ right_ignore_collisions_results,
52
+ ) = ([], [], [])
53
+ (
54
+ left_translation_results,
55
+ left_rot_grip_results,
56
+ left_ignore_collisions_results,
57
+ ) = ([], [], [])
58
+
59
+ infos = {}
60
+ for depth, qagent in enumerate(self._qattention_agents):
61
+ act_results = qagent.act(step, observation, deterministic)
62
+ right_attention_coordinate = (
63
+ act_results.observation_elements["right_attention_coordinate"]
64
+ .cpu()
65
+ .numpy()
66
+ )
67
+ left_attention_coordinate = (
68
+ act_results.observation_elements["left_attention_coordinate"]
69
+ .cpu()
70
+ .numpy()
71
+ )
72
+ observation_elements[
73
+ "right_attention_coordinate_layer_%d" % depth
74
+ ] = right_attention_coordinate[0]
75
+ observation_elements[
76
+ "left_attention_coordinate_layer_%d" % depth
77
+ ] = left_attention_coordinate[0]
78
+
79
+ (
80
+ right_translation_idxs,
81
+ right_rot_grip_idxs,
82
+ right_ignore_collisions_idxs,
83
+ left_translation_idxs,
84
+ left_rot_grip_idxs,
85
+ left_ignore_collisions_idxs,
86
+ ) = act_results.action
87
+
88
+ right_translation_results.append(right_translation_idxs)
89
+ if right_rot_grip_idxs is not None:
90
+ right_rot_grip_results.append(right_rot_grip_idxs)
91
+ if right_ignore_collisions_idxs is not None:
92
+ right_ignore_collisions_results.append(right_ignore_collisions_idxs)
93
+
94
+ left_translation_results.append(left_translation_idxs)
95
+ if left_rot_grip_idxs is not None:
96
+ left_rot_grip_results.append(left_rot_grip_idxs)
97
+ if left_ignore_collisions_idxs is not None:
98
+ left_ignore_collisions_results.append(left_ignore_collisions_idxs)
99
+
100
+ observation[
101
+ "right_attention_coordinate"
102
+ ] = act_results.observation_elements["right_attention_coordinate"]
103
+ observation["left_attention_coordinate"] = act_results.observation_elements[
104
+ "left_attention_coordinate"
105
+ ]
106
+
107
+ observation["prev_layer_voxel_grid"] = act_results.observation_elements[
108
+ "prev_layer_voxel_grid"
109
+ ]
110
+ observation["prev_layer_bounds"] = act_results.observation_elements[
111
+ "prev_layer_bounds"
112
+ ]
113
+
114
+ for n in self._camera_names:
115
+ extrinsics = observation["%s_camera_extrinsics" % n][0, 0].cpu().numpy()
116
+ intrinsics = observation["%s_camera_intrinsics" % n][0, 0].cpu().numpy()
117
+ px, py = utils.point_to_pixel_index(
118
+ right_attention_coordinate[0], extrinsics, intrinsics
119
+ )
120
+ pc_t = torch.tensor(
121
+ [[[py, px]]], dtype=torch.float32, device=self._device
122
+ )
123
+ observation[f"right_{n}_pixel_coord"] = pc_t
124
+ observation_elements[f"right_{n}_pixel_coord"] = [py, px]
125
+
126
+ px, py = utils.point_to_pixel_index(
127
+ left_attention_coordinate[0], extrinsics, intrinsics
128
+ )
129
+ pc_t = torch.tensor(
130
+ [[[py, px]]], dtype=torch.float32, device=self._device
131
+ )
132
+ observation[f"left_{n}_pixel_coord"] = pc_t
133
+ observation_elements[f"left_{n}_pixel_coord"] = [py, px]
134
+ infos.update(act_results.info)
135
+
136
+ right_rgai = torch.cat(right_rot_grip_results, 1)[0].cpu().numpy()
137
+ # ..todo:: utils.correct_rotation_instability does nothing so we can ignore it
138
+ # right_rgai = utils.correct_rotation_instability(right_rgai, self._rotation_resolution)
139
+ right_ignore_collisions = (
140
+ torch.cat(right_ignore_collisions_results, 1)[0].cpu().numpy()
141
+ )
142
+ right_trans_action_indicies = (
143
+ torch.cat(right_translation_results, 1)[0].cpu().numpy()
144
+ )
145
+
146
+ observation_elements[
147
+ "right_trans_action_indicies"
148
+ ] = right_trans_action_indicies[:3]
149
+ observation_elements["right_rot_grip_action_indicies"] = right_rgai[:4]
150
+
151
+ left_rgai = torch.cat(left_rot_grip_results, 1)[0].cpu().numpy()
152
+ left_ignore_collisions = (
153
+ torch.cat(left_ignore_collisions_results, 1)[0].cpu().numpy()
154
+ )
155
+ left_trans_action_indicies = (
156
+ torch.cat(left_translation_results, 1)[0].cpu().numpy()
157
+ )
158
+
159
+ observation_elements["left_trans_action_indicies"] = left_trans_action_indicies[
160
+ 3:
161
+ ]
162
+ observation_elements["left_rot_grip_action_indicies"] = left_rgai[4:]
163
+
164
+ continuous_action = np.concatenate(
165
+ [
166
+ right_attention_coordinate[0],
167
+ utils.discrete_euler_to_quaternion(
168
+ right_rgai[-4:-1], self._rotation_resolution
169
+ ),
170
+ right_rgai[-1:],
171
+ right_ignore_collisions,
172
+ left_attention_coordinate[0],
173
+ utils.discrete_euler_to_quaternion(
174
+ left_rgai[-4:-1], self._rotation_resolution
175
+ ),
176
+ left_rgai[-1:],
177
+ left_ignore_collisions,
178
+ ]
179
+ )
180
+ return ActResult(
181
+ continuous_action, observation_elements=observation_elements, info=infos
182
+ )
183
+
184
+ def update_summaries(self) -> List[Summary]:
185
+ summaries = []
186
+ for qa in self._qattention_agents:
187
+ summaries.extend(qa.update_summaries())
188
+ return summaries
189
+
190
+ def act_summaries(self) -> List[Summary]:
191
+ s = []
192
+ for qa in self._qattention_agents:
193
+ s.extend(qa.act_summaries())
194
+ return s
195
+
196
+ def load_weights(self, savedir: str):
197
+ for qa in self._qattention_agents:
198
+ qa.load_weights(savedir)
199
+
200
+ def save_weights(self, savedir: str):
201
+ for qa in self._qattention_agents:
202
+ qa.save_weights(savedir)
external/peract_bimanual/agents/c2farm_lingunet_bc/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ import agents.c2farm_lingunet_bc.launch_utils
external/peract_bimanual/agents/c2farm_lingunet_bc/launch_utils.py ADDED
@@ -0,0 +1,519 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Adapted from ARM
2
+ # Source: https://github.com/stepjam/ARM
3
+ # License: https://github.com/stepjam/ARM/LICENSE
4
+
5
+ import logging
6
+ from typing import List
7
+
8
+ import numpy as np
9
+ from omegaconf import DictConfig
10
+ from rlbench.backend.observation import Observation
11
+ from rlbench.observation_config import ObservationConfig
12
+ import rlbench.utils as rlbench_utils
13
+ from rlbench.demo import Demo
14
+ from yarr.replay_buffer.prioritized_replay_buffer import ObservationElement
15
+ from yarr.replay_buffer.replay_buffer import ReplayElement, ReplayBuffer
16
+ from yarr.replay_buffer.uniform_replay_buffer import UniformReplayBuffer
17
+ from yarr.replay_buffer.task_uniform_replay_buffer import TaskUniformReplayBuffer
18
+
19
+ from helpers import demo_loading_utils, utils
20
+ from helpers import observation_utils
21
+ from helpers.preprocess_agent import PreprocessAgent
22
+ from helpers.clip.core.clip import tokenize
23
+ from agents.c2farm_lingunet_bc.networks import QattentionLingU3DNet
24
+ from agents.c2farm_lingunet_bc.qattention_lingunet_bc_agent import (
25
+ QAttentionLingUNetBCAgent,
26
+ )
27
+ from agents.c2farm_lingunet_bc.qattention_stack_agent import QAttentionStackAgent
28
+
29
+ import torch
30
+ from torch.multiprocessing import Process, Value, Manager
31
+ from helpers.clip.core.clip import build_model, load_clip, tokenize
32
+ from omegaconf import DictConfig
33
+
34
+ REWARD_SCALE = 100.0
35
+ LOW_DIM_SIZE = 4
36
+
37
+
38
+ def create_replay(
39
+ batch_size: int,
40
+ timesteps: int,
41
+ prioritisation: bool,
42
+ task_uniform: bool,
43
+ save_dir: str,
44
+ cameras: list,
45
+ voxel_sizes,
46
+ image_size=[128, 128],
47
+ replay_size=3e5,
48
+ ):
49
+ trans_indicies_size = 3 * len(voxel_sizes)
50
+ rot_and_grip_indicies_size = 3 + 1
51
+ gripper_pose_size = 7
52
+ ignore_collisions_size = 1
53
+ max_token_seq_len = 77
54
+ lang_feat_dim = 1024
55
+ lang_emb_dim = 512
56
+
57
+ # low_dim_state
58
+ observation_elements = []
59
+ observation_elements.append(
60
+ ObservationElement("low_dim_state", (LOW_DIM_SIZE,), np.float32)
61
+ )
62
+
63
+ # rgb, depth, point cloud, intrinsics, extrinsics
64
+ for cname in cameras:
65
+ observation_elements.append(
66
+ ObservationElement(
67
+ "%s_rgb" % cname,
68
+ (
69
+ 3,
70
+ *image_size,
71
+ ),
72
+ np.float32,
73
+ )
74
+ )
75
+ observation_elements.append(
76
+ ObservationElement("%s_point_cloud" % cname, (3, *image_size), np.float32)
77
+ ) # see pyrep/objects/vision_sensor.py on how pointclouds are extracted from depth frames
78
+ observation_elements.append(
79
+ ObservationElement(
80
+ "%s_camera_extrinsics" % cname,
81
+ (
82
+ 4,
83
+ 4,
84
+ ),
85
+ np.float32,
86
+ )
87
+ )
88
+ observation_elements.append(
89
+ ObservationElement(
90
+ "%s_camera_intrinsics" % cname,
91
+ (
92
+ 3,
93
+ 3,
94
+ ),
95
+ np.float32,
96
+ )
97
+ )
98
+ observation_elements.append(
99
+ ObservationElement("%s_pixel_coord" % cname, (2,), np.int32)
100
+ )
101
+
102
+ # discretized translation, discretized rotation, discrete ignore collision, 6-DoF gripper pose, and pre-trained language embeddings
103
+ observation_elements.extend(
104
+ [
105
+ ReplayElement("trans_action_indicies", (trans_indicies_size,), np.int32),
106
+ ReplayElement(
107
+ "rot_grip_action_indicies", (rot_and_grip_indicies_size,), np.int32
108
+ ),
109
+ ReplayElement("ignore_collisions", (ignore_collisions_size,), np.int32),
110
+ ReplayElement("gripper_pose", (gripper_pose_size,), np.float32),
111
+ ReplayElement("lang_goal_emb", (lang_feat_dim,), np.float32),
112
+ ReplayElement(
113
+ "lang_token_embs",
114
+ (
115
+ max_token_seq_len,
116
+ lang_emb_dim,
117
+ ),
118
+ np.float32,
119
+ ), # extracted from CLIP's language encoder
120
+ ReplayElement("task", (), str),
121
+ ReplayElement(
122
+ "lang_goal", (1,), object
123
+ ), # language goal string for debugging and visualization
124
+ ]
125
+ )
126
+
127
+ for depth in range(len(voxel_sizes)):
128
+ observation_elements.append(
129
+ ReplayElement("attention_coordinate_layer_%d" % depth, (3,), np.float32)
130
+ )
131
+
132
+ extra_replay_elements = [
133
+ ReplayElement("demo", (), np.bool),
134
+ ]
135
+
136
+ replay_buffer = TaskUniformReplayBuffer(
137
+ save_dir=save_dir,
138
+ batch_size=batch_size,
139
+ timesteps=timesteps,
140
+ replay_capacity=int(replay_size),
141
+ action_shape=(8,),
142
+ action_dtype=np.float32,
143
+ reward_shape=(),
144
+ reward_dtype=np.float32,
145
+ update_horizon=1,
146
+ observation_elements=observation_elements,
147
+ extra_replay_elements=extra_replay_elements,
148
+ )
149
+ return replay_buffer
150
+
151
+
152
+ def _get_action(
153
+ obs_tp1: Observation,
154
+ obs_tm1: Observation,
155
+ rlbench_scene_bounds: List[float], # metric 3D bounds of the scene
156
+ voxel_sizes: List[int],
157
+ bounds_offset: List[float],
158
+ rotation_resolution: int,
159
+ crop_augmentation: bool,
160
+ ):
161
+ quat = utils.normalize_quaternion(obs_tp1.gripper_pose[3:])
162
+ if quat[-1] < 0:
163
+ quat = -quat
164
+ disc_rot = utils.quaternion_to_discrete_euler(quat, rotation_resolution)
165
+ disc_rot = utils.correct_rotation_instability(disc_rot, rotation_resolution)
166
+
167
+ attention_coordinate = obs_tp1.gripper_pose[:3]
168
+ trans_indicies, attention_coordinates = [], []
169
+ bounds = np.array(rlbench_scene_bounds)
170
+ ignore_collisions = int(obs_tm1.ignore_collisions)
171
+ for depth, vox_size in enumerate(
172
+ voxel_sizes
173
+ ): # only single voxelization-level is used in PerAct
174
+ if depth > 0:
175
+ if crop_augmentation:
176
+ shift = bounds_offset[depth - 1] * 0.75
177
+ attention_coordinate += np.random.uniform(-shift, shift, size=(3,))
178
+ bounds = np.concatenate(
179
+ [
180
+ attention_coordinate - bounds_offset[depth - 1],
181
+ attention_coordinate + bounds_offset[depth - 1],
182
+ ]
183
+ )
184
+ index = utils.point_to_voxel_index(obs_tp1.gripper_pose[:3], vox_size, bounds)
185
+ trans_indicies.extend(index.tolist())
186
+ res = (bounds[3:] - bounds[:3]) / vox_size
187
+ attention_coordinate = bounds[:3] + res * index
188
+ attention_coordinates.append(attention_coordinate)
189
+
190
+ rot_and_grip_indicies = disc_rot.tolist()
191
+ grip = float(obs_tp1.gripper_open)
192
+ rot_and_grip_indicies.extend([int(obs_tp1.gripper_open)])
193
+ return (
194
+ trans_indicies,
195
+ rot_and_grip_indicies,
196
+ ignore_collisions,
197
+ np.concatenate([obs_tp1.gripper_pose, np.array([grip])]),
198
+ attention_coordinates,
199
+ )
200
+
201
+
202
+ def _add_keypoints_to_replay(
203
+ cfg: DictConfig,
204
+ task: str,
205
+ replay: ReplayBuffer,
206
+ inital_obs: Observation,
207
+ demo: Demo,
208
+ episode_keypoints: List[int],
209
+ cameras: List[str],
210
+ rlbench_scene_bounds: List[float],
211
+ voxel_sizes: List[int],
212
+ bounds_offset: List[float],
213
+ rotation_resolution: int,
214
+ crop_augmentation: bool,
215
+ description: str = "",
216
+ clip_model=None,
217
+ device="cpu",
218
+ ):
219
+ prev_action = None
220
+ obs = inital_obs
221
+ for k, keypoint in enumerate(episode_keypoints):
222
+ obs_tp1 = demo[keypoint]
223
+ obs_tm1 = demo[max(0, keypoint - 1)]
224
+ (
225
+ trans_indicies,
226
+ rot_grip_indicies,
227
+ ignore_collisions,
228
+ action,
229
+ attention_coordinates,
230
+ ) = _get_action(
231
+ obs_tp1,
232
+ obs_tm1,
233
+ rlbench_scene_bounds,
234
+ voxel_sizes,
235
+ bounds_offset,
236
+ rotation_resolution,
237
+ crop_augmentation,
238
+ )
239
+
240
+ terminal = k == len(episode_keypoints) - 1
241
+ reward = float(terminal) * REWARD_SCALE if terminal else 0
242
+
243
+ obs_dict = observation_utils.extract_obs(
244
+ obs,
245
+ t=k,
246
+ prev_action=prev_action,
247
+ cameras=cameras,
248
+ episode_length=cfg.rlbench.episode_length,
249
+ robot_name=cfg.method.robot_name,
250
+ )
251
+ tokens = tokenize([description]).numpy()
252
+ token_tensor = torch.from_numpy(tokens).to(device)
253
+ sentence_emb, token_embs = clip_model.encode_text_with_embeddings(token_tensor)
254
+ obs_dict["lang_goal_emb"] = sentence_emb[0].float().detach().cpu().numpy()
255
+ obs_dict["lang_token_embs"] = token_embs[0].float().detach().cpu().numpy()
256
+
257
+ prev_action = np.copy(action)
258
+
259
+ others = {"demo": True}
260
+ final_obs = {
261
+ "trans_action_indicies": trans_indicies,
262
+ "rot_grip_action_indicies": rot_grip_indicies,
263
+ "gripper_pose": obs_tp1.gripper_pose,
264
+ "task": task,
265
+ "lang_goal": np.array([description], dtype=object),
266
+ }
267
+
268
+ for depth in range(len(voxel_sizes)):
269
+ final_obs["attention_coordinate_layer_%d" % depth] = attention_coordinates[
270
+ depth
271
+ ]
272
+ for name in cameras:
273
+ px, py = utils.point_to_pixel_index(
274
+ obs_tp1.gripper_pose[:3],
275
+ obs_tp1.misc["%s_camera_extrinsics" % name],
276
+ obs_tp1.misc["%s_camera_intrinsics" % name],
277
+ )
278
+ final_obs["%s_pixel_coord" % name] = [py, px]
279
+
280
+ others.update(final_obs)
281
+ others.update(obs_dict)
282
+
283
+ timeout = False
284
+ replay.add(action, reward, terminal, timeout, **others)
285
+ obs = obs_tp1
286
+
287
+ # final step
288
+ obs_dict_tp1 = observation_utils.extract_obs(
289
+ obs_tp1,
290
+ t=k + 1,
291
+ prev_action=prev_action,
292
+ cameras=cameras,
293
+ episode_length=cfg.rlbench.episode_length,
294
+ robot_name=cfg.method.robot_name,
295
+ )
296
+ obs_dict_tp1["lang_goal_emb"] = sentence_emb[0].float().detach().cpu().numpy()
297
+ obs_dict_tp1["lang_token_embs"] = token_embs[0].float().detach().cpu().numpy()
298
+
299
+ obs_dict_tp1.pop("wrist_world_to_cam", None)
300
+ obs_dict_tp1.update(final_obs)
301
+ replay.add_final(**obs_dict_tp1)
302
+
303
+
304
+ def fill_replay(
305
+ cfg: DictConfig,
306
+ obs_config: ObservationConfig,
307
+ rank: int,
308
+ replay: ReplayBuffer,
309
+ task: str,
310
+ num_demos: int,
311
+ demo_augmentation: bool,
312
+ demo_augmentation_every_n: int,
313
+ cameras: List[str],
314
+ rlbench_scene_bounds: List[float], # AKA: DEPTH0_BOUNDS
315
+ voxel_sizes: List[int],
316
+ bounds_offset: List[float],
317
+ rotation_resolution: int,
318
+ crop_augmentation: bool,
319
+ clip_model=None,
320
+ device="cpu",
321
+ keypoint_method="heuristic",
322
+ ):
323
+ if clip_model is None:
324
+ model, _ = load_clip("RN50", jit=False, device=device)
325
+ clip_model = build_model(model.state_dict())
326
+ clip_model.to(device)
327
+ del model
328
+
329
+ logging.debug("Filling %s replay ..." % task)
330
+ for d_idx in range(num_demos):
331
+ # load demo from disk
332
+ demo = rlbench_utils.get_stored_demos(
333
+ amount=1,
334
+ image_paths=False,
335
+ dataset_root=cfg.rlbench.demo_path,
336
+ variation_number=-1,
337
+ task_name=task,
338
+ obs_config=obs_config,
339
+ random_selection=False,
340
+ from_episode_number=d_idx,
341
+ )[0]
342
+
343
+ descs = demo._observations[0].misc["descriptions"]
344
+
345
+ # extract keypoints (a.k.a keyframes)
346
+ episode_keypoints = demo_loading_utils.keypoint_discovery(
347
+ demo, method=keypoint_method
348
+ )
349
+
350
+ if rank == 0:
351
+ logging.info(
352
+ f"Loading Demo({d_idx}) - found {len(episode_keypoints)} keypoints - {task}"
353
+ )
354
+
355
+ for i in range(len(demo) - 1):
356
+ if not demo_augmentation and i > 0:
357
+ break
358
+ if i % demo_augmentation_every_n != 0:
359
+ continue
360
+
361
+ obs = demo[i]
362
+ desc = descs[0]
363
+ # if our starting point is past one of the keypoints, then remove it
364
+ while len(episode_keypoints) > 0 and i >= episode_keypoints[0]:
365
+ episode_keypoints = episode_keypoints[1:]
366
+ if len(episode_keypoints) == 0:
367
+ break
368
+ _add_keypoints_to_replay(
369
+ cfg,
370
+ task,
371
+ replay,
372
+ obs,
373
+ demo,
374
+ episode_keypoints,
375
+ cameras,
376
+ rlbench_scene_bounds,
377
+ voxel_sizes,
378
+ bounds_offset,
379
+ rotation_resolution,
380
+ crop_augmentation,
381
+ description=desc,
382
+ clip_model=clip_model,
383
+ device=device,
384
+ )
385
+ logging.debug("Replay %s filled with demos." % task)
386
+
387
+
388
+ def fill_multi_task_replay(
389
+ cfg: DictConfig,
390
+ obs_config: ObservationConfig,
391
+ rank: int,
392
+ replay: ReplayBuffer,
393
+ tasks: List[str],
394
+ num_demos: int,
395
+ demo_augmentation: bool,
396
+ demo_augmentation_every_n: int,
397
+ cameras: List[str],
398
+ rlbench_scene_bounds: List[float],
399
+ voxel_sizes: List[int],
400
+ bounds_offset: List[float],
401
+ rotation_resolution: int,
402
+ crop_augmentation: bool,
403
+ clip_model=None,
404
+ keypoint_method="heuristic",
405
+ ):
406
+ manager = Manager()
407
+ store = manager.dict()
408
+
409
+ # create a MP dict for storing indicies
410
+ # TODO(mohit): this shouldn't be initialized here
411
+ del replay._task_idxs
412
+ task_idxs = manager.dict()
413
+ replay._task_idxs = task_idxs
414
+ replay._create_storage(store)
415
+ replay.add_count = Value("i", 0)
416
+
417
+ # fill replay buffer in parallel across tasks
418
+ max_parallel_processes = cfg.replay.max_parallel_processes
419
+ processes = []
420
+ n = np.arange(len(tasks))
421
+ split_n = utils.split_list(n, max_parallel_processes)
422
+ for split in split_n:
423
+ for e_idx, task_idx in enumerate(split):
424
+ task = tasks[int(task_idx)]
425
+ model_device = torch.device(
426
+ "cuda:%s" % (e_idx % torch.cuda.device_count())
427
+ if torch.cuda.is_available()
428
+ else "cpu"
429
+ )
430
+ p = Process(
431
+ target=fill_replay,
432
+ args=(
433
+ cfg,
434
+ obs_config,
435
+ rank,
436
+ replay,
437
+ task,
438
+ num_demos,
439
+ demo_augmentation,
440
+ demo_augmentation_every_n,
441
+ cameras,
442
+ rlbench_scene_bounds,
443
+ voxel_sizes,
444
+ bounds_offset,
445
+ rotation_resolution,
446
+ crop_augmentation,
447
+ clip_model,
448
+ model_device,
449
+ keypoint_method,
450
+ ),
451
+ )
452
+ p.start()
453
+ processes.append(p)
454
+
455
+ for p in processes:
456
+ p.join()
457
+
458
+
459
+ def create_agent(cfg: DictConfig):
460
+ LATENT_SIZE = 64
461
+ depth_0bounds = cfg.rlbench.scene_bounds
462
+ cam_resolution = cfg.rlbench.camera_resolution
463
+
464
+ num_rotation_classes = int(360.0 // cfg.method.rotation_resolution)
465
+ qattention_agents = []
466
+ for depth, vox_size in enumerate(cfg.method.voxel_sizes):
467
+ last = depth == len(cfg.method.voxel_sizes) - 1
468
+ unet3d = QattentionLingU3DNet(
469
+ in_channels=3 + 3 + 1 + 3,
470
+ out_channels=1,
471
+ voxel_size=vox_size,
472
+ out_dense=((num_rotation_classes * 3) + 4) if last else 0,
473
+ kernels=LATENT_SIZE,
474
+ norm=None if "None" in cfg.method.norm else cfg.method.norm,
475
+ dense_feats=128,
476
+ activation=cfg.method.activation,
477
+ low_dim_size=4,
478
+ include_prev_layer=cfg.method.include_prev_layer and depth > 0,
479
+ depth=depth,
480
+ )
481
+
482
+ qattention_agent = QAttentionLingUNetBCAgent(
483
+ layer=depth,
484
+ coordinate_bounds=depth_0bounds,
485
+ unet3d=unet3d,
486
+ camera_names=cfg.rlbench.cameras,
487
+ batch_size=cfg.replay.batch_size,
488
+ voxel_size=vox_size,
489
+ bounds_offset=cfg.method.bounds_offset[depth - 1] if depth > 0 else None,
490
+ voxel_feature_size=3,
491
+ image_crop_size=cfg.method.image_crop_size,
492
+ lr=cfg.method.lr,
493
+ training_iterations=cfg.framework.training_iterations,
494
+ lr_scheduler=cfg.method.lr_scheduler,
495
+ num_warmup_steps=cfg.method.num_warmup_steps,
496
+ trans_loss_weight=cfg.method.trans_loss_weight,
497
+ rot_loss_weight=cfg.method.rot_loss_weight,
498
+ grip_loss_weight=cfg.method.grip_loss_weight,
499
+ collision_loss_weight=cfg.method.collision_loss_weight,
500
+ include_low_dim_state=True,
501
+ image_resolution=cam_resolution,
502
+ lambda_weight_l2=cfg.method.lambda_weight_l2,
503
+ num_rotation_classes=num_rotation_classes,
504
+ rotation_resolution=cfg.method.rotation_resolution,
505
+ transform_augmentation=cfg.method.transform_augmentation.apply_se3,
506
+ transform_augmentation_xyz=cfg.method.transform_augmentation.aug_xyz,
507
+ transform_augmentation_rpy=cfg.method.transform_augmentation.aug_rpy,
508
+ transform_augmentation_rot_resolution=cfg.method.transform_augmentation.aug_rot_resolution,
509
+ num_devices=cfg.ddp.num_devices,
510
+ )
511
+ qattention_agents.append(qattention_agent)
512
+
513
+ rotation_agent = QAttentionStackAgent(
514
+ qattention_agents=qattention_agents,
515
+ rotation_resolution=cfg.method.rotation_resolution,
516
+ camera_names=cfg.rlbench.cameras,
517
+ )
518
+ preprocess_agent = PreprocessAgent(pose_agent=rotation_agent)
519
+ return preprocess_agent
external/peract_bimanual/agents/c2farm_lingunet_bc/networks.py ADDED
@@ -0,0 +1,301 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+ from helpers.network_utils import (
5
+ Conv3DInceptionBlock,
6
+ DenseBlock,
7
+ SpatialSoftmax3D,
8
+ Conv3DInceptionBlockUpsampleBlock,
9
+ Conv3DBlock,
10
+ )
11
+
12
+
13
+ class QattentionLingU3DNet(nn.Module):
14
+ def __init__(
15
+ self,
16
+ in_channels: int,
17
+ out_channels: int,
18
+ out_dense: int,
19
+ voxel_size: int,
20
+ low_dim_size: int,
21
+ kernels: int,
22
+ norm: str = None,
23
+ activation: str = "relu",
24
+ dense_feats: int = 32,
25
+ include_prev_layer=False,
26
+ depth=0,
27
+ lingunet_dropout=0.0,
28
+ ):
29
+ super(QattentionLingU3DNet, self).__init__()
30
+ self._in_channels = in_channels
31
+ self._out_channels = out_channels
32
+ self._norm = norm
33
+ self._activation = activation
34
+ self._kernels = kernels
35
+ self._low_dim_size = low_dim_size
36
+ self._build_calls = 0
37
+ self._voxel_size = voxel_size
38
+ self._dense_feats = dense_feats
39
+ self._out_dense = out_dense
40
+ self._include_prev_layer = include_prev_layer
41
+ self._depth = depth
42
+
43
+ self._lingunet_dropout = lingunet_dropout
44
+ self._clip_lang_feat_dim = 1024
45
+
46
+ if self._voxel_size < 16:
47
+ raise Exception(
48
+ "Voxel size for C2FARM_LINGUNET_BC should be at least 16 or higher"
49
+ )
50
+
51
+ def build(self):
52
+ use_residual = False
53
+ self._build_calls += 1
54
+ if self._build_calls != 1:
55
+ raise RuntimeError("Build needs to be called once.")
56
+
57
+ spatial_size = self._voxel_size
58
+ self._input_preprocess = Conv3DInceptionBlock(
59
+ self._in_channels,
60
+ self._kernels,
61
+ norm=self._norm,
62
+ activation=self._activation,
63
+ )
64
+
65
+ d0_ins = self._input_preprocess.out_channels
66
+ if self._include_prev_layer:
67
+ PREV_VOXEL_CHANNELS = 0
68
+ d0_ins += self._input_preprocess.out_channels * self._depth
69
+
70
+ if self._low_dim_size > 0:
71
+ self._proprio_preprocess = DenseBlock(
72
+ self._low_dim_size, self._kernels, None, self._activation
73
+ )
74
+ d0_ins += self._kernels
75
+
76
+ self._down0 = Conv3DInceptionBlock(
77
+ d0_ins,
78
+ self._kernels,
79
+ norm=self._norm,
80
+ activation=self._activation,
81
+ residual=use_residual,
82
+ )
83
+ self._ss0 = SpatialSoftmax3D(
84
+ spatial_size, spatial_size, spatial_size, self._down0.out_channels
85
+ )
86
+ spatial_size //= 2
87
+ self._down1 = Conv3DInceptionBlock(
88
+ self._down0.out_channels,
89
+ self._kernels * 2,
90
+ norm=self._norm,
91
+ activation=self._activation,
92
+ residual=use_residual,
93
+ )
94
+ self._ss1 = SpatialSoftmax3D(
95
+ spatial_size, spatial_size, spatial_size, self._down1.out_channels
96
+ )
97
+ spatial_size //= 2
98
+
99
+ flat_size = self._down0.out_channels * 4 + self._down1.out_channels * 4
100
+
101
+ k1 = self._down1.out_channels
102
+ if self._voxel_size > 8:
103
+ k1 += self._kernels
104
+ self._down2 = Conv3DInceptionBlock(
105
+ self._down1.out_channels,
106
+ self._kernels * 4,
107
+ norm=self._norm,
108
+ activation=self._activation,
109
+ residual=use_residual,
110
+ )
111
+ self._lang_proj2 = DenseBlock(
112
+ self._clip_lang_feat_dim, self._down2.out_channels, None, None
113
+ )
114
+ self._dropout2 = nn.Dropout(self._lingunet_dropout)
115
+ flat_size += self._down2.out_channels * 4
116
+ self._ss2 = SpatialSoftmax3D(
117
+ spatial_size, spatial_size, spatial_size, self._down2.out_channels
118
+ )
119
+ spatial_size //= 2
120
+ k2 = self._down2.out_channels
121
+ if self._voxel_size > 16:
122
+ k2 *= 2
123
+ self._down3 = Conv3DInceptionBlock(
124
+ self._down2.out_channels,
125
+ self._kernels,
126
+ norm=self._norm,
127
+ activation=self._activation,
128
+ residual=use_residual,
129
+ )
130
+ self._lang_proj3 = DenseBlock(
131
+ self._clip_lang_feat_dim, self._down3.out_channels, None, None
132
+ )
133
+ self._dropout3 = nn.Dropout(self._lingunet_dropout)
134
+ flat_size += self._down3.out_channels * 4
135
+ self._ss3 = SpatialSoftmax3D(
136
+ spatial_size, spatial_size, spatial_size, self._down3.out_channels
137
+ )
138
+ self._up3 = Conv3DInceptionBlockUpsampleBlock(
139
+ self._kernels,
140
+ self._kernels * 4,
141
+ 2,
142
+ norm=self._norm,
143
+ activation=self._activation,
144
+ residual=use_residual,
145
+ )
146
+ self._up2 = Conv3DInceptionBlockUpsampleBlock(
147
+ k2,
148
+ self._kernels,
149
+ 2,
150
+ norm=self._norm,
151
+ activation=self._activation,
152
+ residual=use_residual,
153
+ )
154
+
155
+ self._up1 = Conv3DInceptionBlockUpsampleBlock(
156
+ k1,
157
+ self._kernels,
158
+ 2,
159
+ norm=self._norm,
160
+ activation=self._activation,
161
+ residual=use_residual,
162
+ )
163
+
164
+ self._global_maxp = nn.AdaptiveMaxPool3d(1)
165
+ self._local_maxp = nn.MaxPool3d(3, 2, padding=1)
166
+ self._final = Conv3DBlock(
167
+ self._kernels * 2,
168
+ self._kernels,
169
+ kernel_sizes=3,
170
+ strides=1,
171
+ norm=self._norm,
172
+ activation=self._activation,
173
+ )
174
+ self._final2 = Conv3DBlock(
175
+ self._kernels,
176
+ self._out_channels,
177
+ kernel_sizes=3,
178
+ strides=1,
179
+ norm=None,
180
+ activation=None,
181
+ )
182
+
183
+ self._ss_final = SpatialSoftmax3D(
184
+ self._voxel_size, self._voxel_size, self._voxel_size, self._kernels
185
+ )
186
+ flat_size += self._kernels * 4
187
+
188
+ if self._out_dense > 0:
189
+ self._dense0 = DenseBlock(
190
+ flat_size, self._dense_feats, None, self._activation
191
+ )
192
+ self._dense1 = DenseBlock(
193
+ self._dense_feats, self._dense_feats, None, self._activation
194
+ )
195
+ self._dense2 = DenseBlock(self._dense_feats, self._out_dense, None, None)
196
+
197
+ def _proj_feature(self, x, spatial_size, proj_fn):
198
+ x = proj_fn(x)
199
+ x = x.unsqueeze(2).unsqueeze(3).unsqueeze(4)
200
+ x = x.repeat(1, 1, spatial_size, spatial_size, spatial_size)
201
+ return x
202
+
203
+ def forward(
204
+ self,
205
+ ins,
206
+ proprio,
207
+ lang_goal_embs,
208
+ lang_token_embs,
209
+ bounds,
210
+ prev_bounds,
211
+ prev_layer_voxel_grid,
212
+ ):
213
+ b, _, d, h, w = ins.shape
214
+ x = self._input_preprocess(ins)
215
+
216
+ if self._include_prev_layer:
217
+ for voxel_grid in prev_layer_voxel_grid:
218
+ y = self._input_preprocess(voxel_grid)
219
+ x = torch.cat([x, y], dim=1)
220
+
221
+ if self._low_dim_size > 0:
222
+ p = self._proprio_preprocess(proprio)
223
+ p = p.unsqueeze(-1).unsqueeze(-1).unsqueeze(-1).repeat(1, 1, d, h, w)
224
+ x = torch.cat([x, p], dim=1)
225
+
226
+ l_feat = lang_goal_embs
227
+ l_feat = l_feat.to(dtype=x.dtype)
228
+
229
+ d0 = self._down0(x)
230
+ # l0 = self._proj_feature(l_feat, d0.shape[-1], self._lang_proj0)
231
+ # d0 = self._dropout0(d0 * l0)
232
+ ss0 = self._ss0(d0)
233
+ maxp0 = self._global_maxp(d0).view(b, -1)
234
+
235
+ d1 = u = self._down1(self._local_maxp(d0))
236
+ # l1 = self._proj_feature(l_feat, d1.shape[-1], self._lang_proj1)
237
+ # d1 = self._dropout1(d1 * l1)
238
+ ss1 = self._ss1(d1)
239
+ maxp1 = self._global_maxp(d1).view(b, -1)
240
+
241
+ feats = [ss0, maxp0, ss1, maxp1]
242
+
243
+ if self._voxel_size > 8:
244
+ d2 = u = self._down2(self._local_maxp(d1))
245
+ l2 = self._proj_feature(l_feat, d2.shape[-1], self._lang_proj2)
246
+ d2 = self._dropout2(d2 * l2)
247
+ feats.extend([self._ss2(d2), self._global_maxp(d2).view(b, -1)])
248
+ if self._voxel_size > 16:
249
+ d3 = self._down3(self._local_maxp(d2))
250
+ l3 = self._proj_feature(l_feat, d3.shape[-1], self._lang_proj3)
251
+ d3 = self._dropout3(d3 * l3)
252
+ feats.extend([self._ss3(d3), self._global_maxp(d3).view(b, -1)])
253
+ u3 = self._up3(d3)
254
+ u = torch.cat([d2, u3], dim=1)
255
+ u2 = self._up2(u)
256
+ u = torch.cat([d1, u2], dim=1)
257
+
258
+ u1 = self._up1(u)
259
+ f1 = self._final(torch.cat([d0, u1], dim=1))
260
+ trans = self._final2(f1)
261
+
262
+ feats.extend([self._ss_final(f1), self._global_maxp(f1).view(b, -1)])
263
+
264
+ self.latent_dict = {
265
+ "d0": d0.mean(-1).mean(-1).mean(-1),
266
+ "d1": d1.mean(-1).mean(-1).mean(-1),
267
+ "u1": u1.mean(-1).mean(-1).mean(-1),
268
+ "trans_out": trans,
269
+ }
270
+
271
+ rot_and_grip_out, collision_out = None, None
272
+ if self._out_dense > 0:
273
+ dense0 = self._dense0(torch.cat(feats, 1))
274
+ dense1 = self._dense1(dense0)
275
+ rot_and_grip_collision_out = self._dense2(dense1)
276
+ rot_and_grip_out = rot_and_grip_collision_out[:, :-2]
277
+ collision_out = rot_and_grip_collision_out[:, -2:]
278
+ self.latent_dict.update(
279
+ {
280
+ "dense0": dense0,
281
+ "dense1": dense1,
282
+ "dense2": rot_and_grip_collision_out,
283
+ }
284
+ )
285
+
286
+ if self._voxel_size > 8:
287
+ self.latent_dict.update(
288
+ {
289
+ "d2": d2.mean(-1).mean(-1).mean(-1),
290
+ "u2": u2.mean(-1).mean(-1).mean(-1),
291
+ }
292
+ )
293
+ if self._voxel_size > 16:
294
+ self.latent_dict.update(
295
+ {
296
+ "d3": d3.mean(-1).mean(-1).mean(-1),
297
+ "u3": u3.mean(-1).mean(-1).mean(-1),
298
+ }
299
+ )
300
+
301
+ return trans, rot_and_grip_out, collision_out
external/peract_bimanual/agents/c2farm_lingunet_bc/qattention_lingunet_bc_agent.py ADDED
@@ -0,0 +1,790 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import copy
2
+ import logging
3
+ import os
4
+ from typing import List
5
+
6
+ import numpy as np
7
+ import torch
8
+ import torch.nn as nn
9
+ import torch.nn.functional as F
10
+ from torchvision import transforms
11
+ from pytorch3d import transforms as torch3d_tf
12
+ from yarr.agents.agent import (
13
+ Agent,
14
+ ActResult,
15
+ ScalarSummary,
16
+ HistogramSummary,
17
+ ImageSummary,
18
+ Summary,
19
+ )
20
+
21
+ from helpers import utils
22
+ from helpers.utils import visualise_voxel, stack_on_channel
23
+ from voxel.voxel_grid import VoxelGrid
24
+ from voxel.augmentation import apply_se3_augmentation
25
+ from einops import rearrange
26
+ from helpers.clip.core.clip import build_model, load_clip
27
+
28
+ import transformers
29
+ from torch.nn.parallel import DistributedDataParallel as DDP
30
+
31
+ NAME = "QAttentionAgent"
32
+
33
+
34
+ class QFunction(nn.Module):
35
+ def __init__(
36
+ self,
37
+ unet_3d: nn.Module,
38
+ voxelizer: VoxelGrid,
39
+ bounds_offset: float,
40
+ rotation_resolution: float,
41
+ device,
42
+ training,
43
+ ):
44
+ super(QFunction, self).__init__()
45
+ self._rotation_resolution = rotation_resolution
46
+ self._voxelizer = voxelizer
47
+ self._bounds_offset = bounds_offset
48
+ self._qnet = unet_3d.to(device)
49
+
50
+ # distributed training
51
+ if training:
52
+ self._qnet = DDP(self._qnet, device_ids=[device])
53
+
54
+ def _argmax_3d(self, tensor_orig):
55
+ b, c, d, h, w = tensor_orig.shape # c will be one
56
+ idxs = tensor_orig.view(b, c, -1).argmax(-1)
57
+ indices = torch.cat([((idxs // h) // d), (idxs // h) % w, idxs % w], 1)
58
+ return indices
59
+
60
+ def choose_highest_action(self, q_trans, q_rot_grip, q_collision):
61
+ coords = self._argmax_3d(q_trans)
62
+ rot_and_grip_indicies = None
63
+ ignore_collision = None
64
+ if q_rot_grip is not None:
65
+ q_rot = torch.stack(
66
+ torch.split(
67
+ q_rot_grip[:, :-2], int(360 // self._rotation_resolution), dim=1
68
+ ),
69
+ dim=1,
70
+ )
71
+ rot_and_grip_indicies = torch.cat(
72
+ [
73
+ q_rot[:, 0:1].argmax(-1),
74
+ q_rot[:, 1:2].argmax(-1),
75
+ q_rot[:, 2:3].argmax(-1),
76
+ q_rot_grip[:, -2:].argmax(-1, keepdim=True),
77
+ ],
78
+ -1,
79
+ )
80
+ ignore_collision = q_collision[:, -2:].argmax(-1, keepdim=True)
81
+ return coords, rot_and_grip_indicies, ignore_collision
82
+
83
+ def forward(
84
+ self,
85
+ rgb_pcd,
86
+ proprio,
87
+ pcd,
88
+ lang_goal_emb,
89
+ lang_token_embs,
90
+ bounds=None,
91
+ prev_bounds=None,
92
+ prev_layer_voxel_grid=None,
93
+ ):
94
+ # rgb_pcd will be list of list (list of [rgb, pcd])
95
+ b = rgb_pcd[0][0].shape[0]
96
+ pcd_flat = torch.cat([p.permute(0, 2, 3, 1).reshape(b, -1, 3) for p in pcd], 1)
97
+
98
+ # flatten RGBs and Pointclouds
99
+ rgb = [rp[0] for rp in rgb_pcd]
100
+ feat_size = rgb[0].shape[1]
101
+ flat_imag_features = torch.cat(
102
+ [p.permute(0, 2, 3, 1).reshape(b, -1, feat_size) for p in rgb], 1
103
+ )
104
+
105
+ # construct voxel grid
106
+ voxel_grid = self._voxelizer.coords_to_bounding_voxel_grid(
107
+ pcd_flat, coord_features=flat_imag_features, coord_bounds=bounds
108
+ )
109
+
110
+ # swap to channels fist
111
+ voxel_grid = voxel_grid.permute(0, 4, 1, 2, 3).detach()
112
+
113
+ # batch bounds if necessary
114
+ if bounds.shape[0] != b:
115
+ bounds = bounds.repeat(b, 1)
116
+
117
+ # forward pass
118
+ q_trans, q_rot_and_grip, q_ignore_collisions = self._qnet(
119
+ voxel_grid,
120
+ proprio,
121
+ lang_goal_emb,
122
+ lang_token_embs,
123
+ prev_layer_voxel_grid,
124
+ bounds,
125
+ prev_bounds,
126
+ )
127
+
128
+ return q_trans, q_rot_and_grip, q_ignore_collisions, voxel_grid
129
+
130
+
131
+ class QAttentionLingUNetBCAgent(Agent):
132
+ def __init__(
133
+ self,
134
+ layer: int,
135
+ coordinate_bounds: list,
136
+ unet3d: nn.Module,
137
+ camera_names: list,
138
+ batch_size: int,
139
+ voxel_size: int,
140
+ bounds_offset: float,
141
+ voxel_feature_size: int,
142
+ image_crop_size: int,
143
+ num_rotation_classes: int,
144
+ rotation_resolution: float,
145
+ lr: float = 0.0001,
146
+ lr_scheduler: bool = False,
147
+ training_iterations: int = 100000,
148
+ num_warmup_steps: int = 20000,
149
+ trans_loss_weight: float = 1.0,
150
+ rot_loss_weight: float = 1.0,
151
+ grip_loss_weight: float = 1.0,
152
+ collision_loss_weight: float = 1.0,
153
+ include_low_dim_state: bool = False,
154
+ image_resolution: list = None,
155
+ lambda_weight_l2: float = 0.0,
156
+ transform_augmentation: bool = True,
157
+ transform_augmentation_xyz: list = [0.0, 0.0, 0.0],
158
+ transform_augmentation_rpy: list = [0.0, 0.0, 180.0],
159
+ transform_augmentation_rot_resolution: int = 5,
160
+ num_devices: int = 1,
161
+ ):
162
+ self._layer = layer
163
+ self._coordinate_bounds = coordinate_bounds
164
+ self._unet3d = unet3d
165
+ self._voxel_feature_size = voxel_feature_size
166
+ self._bounds_offset = bounds_offset
167
+ self._image_crop_size = image_crop_size
168
+ self._lr = lr
169
+ self._lr_scheduler = lr_scheduler
170
+ self._training_iterations = training_iterations
171
+ self._num_warmup_steps = num_warmup_steps
172
+ self._trans_loss_weight = trans_loss_weight
173
+ self._rot_loss_weight = rot_loss_weight
174
+ self._grip_loss_weight = grip_loss_weight
175
+ self._collision_loss_weight = collision_loss_weight
176
+ self._include_low_dim_state = include_low_dim_state
177
+ self._image_resolution = image_resolution or [128, 128]
178
+ self._voxel_size = voxel_size
179
+ self._camera_names = camera_names
180
+ self._num_cameras = len(camera_names)
181
+ self._batch_size = batch_size
182
+ self._lambda_weight_l2 = lambda_weight_l2
183
+ self._transform_augmentation = transform_augmentation
184
+ self._transform_augmentation_xyz = torch.from_numpy(
185
+ np.array(transform_augmentation_xyz)
186
+ )
187
+ self._transform_augmentation_rpy = transform_augmentation_rpy
188
+ self._transform_augmentation_rot_resolution = (
189
+ transform_augmentation_rot_resolution
190
+ )
191
+ self._num_devices = num_devices
192
+ self._num_rotation_classes = num_rotation_classes
193
+ self._rotation_resolution = rotation_resolution
194
+
195
+ self._cross_entropy_loss = nn.CrossEntropyLoss(reduction="none")
196
+ self._name = NAME + "_layer" + str(self._layer)
197
+
198
+ def build(self, training: bool, device: torch.device = None):
199
+ self._training = training
200
+ self._device = device
201
+
202
+ if device is None:
203
+ device = torch.device("cpu")
204
+
205
+ self._voxelizer = VoxelGrid(
206
+ coord_bounds=self._coordinate_bounds,
207
+ voxel_size=self._voxel_size,
208
+ device=device,
209
+ batch_size=self._batch_size if training else 1,
210
+ feature_size=self._voxel_feature_size,
211
+ max_num_coords=np.prod(self._image_resolution) * self._num_cameras,
212
+ )
213
+
214
+ self._unet3d.build()
215
+
216
+ self._q = (
217
+ QFunction(
218
+ self._unet3d,
219
+ self._voxelizer,
220
+ self._bounds_offset,
221
+ self._rotation_resolution,
222
+ device,
223
+ training,
224
+ )
225
+ .to(device)
226
+ .train(training)
227
+ )
228
+
229
+ grid_for_crop = (
230
+ torch.arange(0, self._image_crop_size, device=device)
231
+ .unsqueeze(0)
232
+ .repeat(self._image_crop_size, 1)
233
+ .unsqueeze(-1)
234
+ )
235
+ self._grid_for_crop = torch.cat(
236
+ [grid_for_crop.transpose(1, 0), grid_for_crop], dim=2
237
+ ).unsqueeze(0)
238
+
239
+ self._coordinate_bounds = torch.tensor(
240
+ self._coordinate_bounds, device=device
241
+ ).unsqueeze(0)
242
+
243
+ if self._training:
244
+ # optimizer
245
+ self._optimizer = torch.optim.Adam(
246
+ self._q.parameters(),
247
+ lr=self._lr,
248
+ weight_decay=self._lambda_weight_l2,
249
+ )
250
+
251
+ # learning rate scheduler
252
+ if self._lr_scheduler:
253
+ self._scheduler = (
254
+ transformers.get_cosine_with_hard_restarts_schedule_with_warmup(
255
+ self._optimizer,
256
+ num_warmup_steps=self._num_warmup_steps,
257
+ num_training_steps=self._training_iterations,
258
+ num_cycles=self._training_iterations // 10000,
259
+ )
260
+ )
261
+
262
+ # one-hot zero tensors
263
+ self._action_trans_one_hot_zeros = torch.zeros(
264
+ (
265
+ self._batch_size,
266
+ 1,
267
+ self._voxel_size,
268
+ self._voxel_size,
269
+ self._voxel_size,
270
+ ),
271
+ dtype=int,
272
+ device=device,
273
+ )
274
+ self._action_rot_x_one_hot_zeros = torch.zeros(
275
+ (self._batch_size, self._num_rotation_classes), dtype=int, device=device
276
+ )
277
+ self._action_rot_y_one_hot_zeros = torch.zeros(
278
+ (self._batch_size, self._num_rotation_classes), dtype=int, device=device
279
+ )
280
+ self._action_rot_z_one_hot_zeros = torch.zeros(
281
+ (self._batch_size, self._num_rotation_classes), dtype=int, device=device
282
+ )
283
+ self._action_grip_one_hot_zeros = torch.zeros(
284
+ (self._batch_size, 2), dtype=int, device=device
285
+ )
286
+ self._action_ignore_collisions_one_hot_zeros = torch.zeros(
287
+ (self._batch_size, 2), dtype=int, device=device
288
+ )
289
+
290
+ # print total params
291
+ logging.info(
292
+ "# Q Params: %d"
293
+ % sum(
294
+ p.numel()
295
+ for name, p in self._q.named_parameters()
296
+ if p.requires_grad and "clip" not in name
297
+ )
298
+ )
299
+ else:
300
+ for param in self._q.parameters():
301
+ param.requires_grad = False
302
+
303
+ # load CLIP for encoding language goals during evaluation
304
+ model, _ = load_clip("RN50", jit=False)
305
+ self._clip_rn50 = build_model(model.state_dict())
306
+ self._clip_rn50 = self._clip_rn50.float().to(device)
307
+ self._clip_rn50.eval()
308
+ del model
309
+
310
+ self._voxelizer.to(device)
311
+ self._q.to(device)
312
+
313
+ def _extract_crop(self, pixel_action, observation):
314
+ # Pixel action will now be (B, 2)
315
+ # observation = stack_on_channel(observation)
316
+ h = observation.shape[-1]
317
+ top_left_corner = torch.clamp(
318
+ pixel_action - self._image_crop_size // 2, 0, h - self._image_crop_size
319
+ )
320
+ grid = self._grid_for_crop + top_left_corner.unsqueeze(1).unsqueeze(1)
321
+ grid = ((grid / float(h)) * 2.0) - 1.0 # between -1 and 1
322
+ # Used for cropping the images across a batch
323
+ # swap fro y x, to x, y
324
+ grid = torch.cat((grid[:, :, :, 1:2], grid[:, :, :, 0:1]), dim=-1)
325
+ crop = F.grid_sample(observation, grid, mode="nearest", align_corners=True)
326
+ return crop
327
+
328
+ def _preprocess_inputs(self, replay_sample):
329
+ obs, pcds = [], []
330
+ self._crop_summary = []
331
+ for n in self._camera_names:
332
+ if self._layer > 0:
333
+ pc_t = replay_sample["%s_pixel_coord" % n]
334
+ rgb = self._extract_crop(pc_t, replay_sample["%s_rgb" % n])
335
+ pcd = self._extract_crop(pc_t, replay_sample["%s_point_cloud" % n])
336
+ self._crop_summary.append((n, rgb))
337
+ else:
338
+ rgb = replay_sample["%s_rgb" % n]
339
+ pcd = replay_sample["%s_point_cloud" % n]
340
+
341
+ obs.append([rgb, pcd])
342
+ pcds.append(pcd)
343
+ return obs, pcds
344
+
345
+ def _act_preprocess_inputs(self, observation):
346
+ obs, pcds = [], []
347
+ for n in self._camera_names:
348
+ if self._layer > 0:
349
+ pc_t = observation["%s_pixel_coord" % n][0]
350
+ rgb = self._extract_crop(pc_t, observation["%s_rgb" % n][0])
351
+ pcd = self._extract_crop(pc_t, observation["%s_point_cloud" % n][0])
352
+ else:
353
+ rgb = observation["%s_rgb" % n][0]
354
+ pcd = observation["%s_point_cloud" % n][0]
355
+
356
+ obs.append([rgb, pcd])
357
+ pcds.append(pcd)
358
+ return obs, pcds
359
+
360
+ def _get_value_from_voxel_index(self, q, voxel_idx):
361
+ b, c, d, h, w = q.shape
362
+ q_trans_flat = q.view(b, c, d * h * w)
363
+ flat_indicies = (
364
+ voxel_idx[:, 0] * d * h + voxel_idx[:, 1] * h + voxel_idx[:, 2]
365
+ )[:, None].int()
366
+ highest_idxs = flat_indicies.unsqueeze(-1).repeat(1, c, 1)
367
+ chosen_voxel_values = q_trans_flat.gather(2, highest_idxs)[
368
+ ..., 0
369
+ ] # (B, trans + rot + grip)
370
+ return chosen_voxel_values
371
+
372
+ def _get_value_from_rot_and_grip(self, rot_grip_q, rot_and_grip_idx):
373
+ q_rot = torch.stack(
374
+ torch.split(
375
+ rot_grip_q[:, :-2], int(360 // self._rotation_resolution), dim=1
376
+ ),
377
+ dim=1,
378
+ ) # B, 3, 72
379
+ q_grip = rot_grip_q[:, -2:]
380
+ rot_and_grip_values = torch.cat(
381
+ [
382
+ q_rot[:, 0].gather(1, rot_and_grip_idx[:, 0:1]),
383
+ q_rot[:, 1].gather(1, rot_and_grip_idx[:, 1:2]),
384
+ q_rot[:, 2].gather(1, rot_and_grip_idx[:, 2:3]),
385
+ q_grip.gather(1, rot_and_grip_idx[:, 3:4]),
386
+ ],
387
+ -1,
388
+ )
389
+ return rot_and_grip_values
390
+
391
+ def _celoss(self, pred, labels):
392
+ return self._cross_entropy_loss(pred, labels.argmax(-1))
393
+
394
+ def _softmax_q_trans(self, q):
395
+ q_shape = q.shape
396
+ return F.softmax(q.reshape(q_shape[0], -1), dim=1).reshape(q_shape)
397
+
398
+ def _softmax_q_rot_grip(self, q_rot_grip):
399
+ q_rot_x_flat = q_rot_grip[
400
+ :, 0 * self._num_rotation_classes : 1 * self._num_rotation_classes
401
+ ]
402
+ q_rot_y_flat = q_rot_grip[
403
+ :, 1 * self._num_rotation_classes : 2 * self._num_rotation_classes
404
+ ]
405
+ q_rot_z_flat = q_rot_grip[
406
+ :, 2 * self._num_rotation_classes : 3 * self._num_rotation_classes
407
+ ]
408
+ q_grip_flat = q_rot_grip[:, 3 * self._num_rotation_classes :]
409
+
410
+ q_rot_x_flat_softmax = F.softmax(q_rot_x_flat, dim=1)
411
+ q_rot_y_flat_softmax = F.softmax(q_rot_y_flat, dim=1)
412
+ q_rot_z_flat_softmax = F.softmax(q_rot_z_flat, dim=1)
413
+ q_grip_flat_softmax = F.softmax(q_grip_flat, dim=1)
414
+
415
+ return torch.cat(
416
+ [
417
+ q_rot_x_flat_softmax,
418
+ q_rot_y_flat_softmax,
419
+ q_rot_z_flat_softmax,
420
+ q_grip_flat_softmax,
421
+ ],
422
+ dim=1,
423
+ )
424
+
425
+ def _softmax_ignore_collision(self, q_collision):
426
+ q_collision_softmax = F.softmax(q_collision, dim=1)
427
+ return q_collision_softmax
428
+
429
+ def update(self, step: int, replay_sample: dict) -> dict:
430
+ action_trans = replay_sample["trans_action_indicies"][
431
+ :, self._layer * 3 : self._layer * 3 + 3
432
+ ].int()
433
+ action_rot_grip = replay_sample["rot_grip_action_indicies"].int()
434
+ action_gripper_pose = replay_sample["gripper_pose"]
435
+ action_ignore_collisions = replay_sample["ignore_collisions"].int()
436
+ lang_goal_emb = replay_sample["lang_goal_emb"].float()
437
+ lang_token_embs = replay_sample["lang_token_embs"].float()
438
+ prev_layer_voxel_grid = replay_sample.get("prev_layer_voxel_grid", None)
439
+ prev_layer_bounds = replay_sample.get("prev_layer_bounds", None)
440
+ device = self._device
441
+
442
+ bounds = bounds_tp1 = self._coordinate_bounds
443
+ if self._layer > 0:
444
+ cp = replay_sample["attention_coordinate_layer_%d" % (self._layer - 1)]
445
+ bounds = torch.cat(
446
+ [cp - self._bounds_offset, cp + self._bounds_offset], dim=1
447
+ )
448
+
449
+ proprio = None
450
+ if self._include_low_dim_state:
451
+ proprio = replay_sample["low_dim_state"]
452
+
453
+ obs, pcd = self._preprocess_inputs(replay_sample)
454
+
455
+ # batch size
456
+ bs = pcd[0].shape[0]
457
+
458
+ # SE(3) augmentation of point clouds and actions
459
+ if self._transform_augmentation:
460
+ action_trans, action_rot_grip, pcd = apply_se3_augmentation(
461
+ pcd,
462
+ action_gripper_pose,
463
+ action_trans,
464
+ action_rot_grip,
465
+ bounds,
466
+ self._layer,
467
+ self._transform_augmentation_xyz,
468
+ self._transform_augmentation_rpy,
469
+ self._transform_augmentation_rot_resolution,
470
+ self._voxel_size,
471
+ self._rotation_resolution,
472
+ self._device,
473
+ )
474
+
475
+ # forward pass
476
+ q_trans, q_rot_grip, q_collision, voxel_grid = self._q(
477
+ obs,
478
+ proprio,
479
+ pcd,
480
+ lang_goal_emb,
481
+ lang_token_embs,
482
+ bounds,
483
+ prev_layer_bounds,
484
+ prev_layer_voxel_grid,
485
+ )
486
+
487
+ # argmax to choose best action
488
+ (
489
+ coords,
490
+ rot_and_grip_indicies,
491
+ ignore_collision_indicies,
492
+ ) = self._q.choose_highest_action(q_trans, q_rot_grip, q_collision)
493
+
494
+ q_trans_loss, q_rot_loss, q_grip_loss, q_collision_loss = 0.0, 0.0, 0.0, 0.0
495
+
496
+ # translation one-hot
497
+ action_trans_one_hot = self._action_trans_one_hot_zeros.clone()
498
+ for b in range(bs):
499
+ gt_coord = action_trans[b, :].int()
500
+ action_trans_one_hot[b, :, gt_coord[0], gt_coord[1], gt_coord[2]] = 1
501
+
502
+ # translation loss
503
+ q_trans_flat = q_trans.view(bs, -1)
504
+ action_trans_one_hot_flat = action_trans_one_hot.view(bs, -1)
505
+ q_trans_loss = self._celoss(q_trans_flat, action_trans_one_hot_flat)
506
+
507
+ with_rot_and_grip = rot_and_grip_indicies is not None
508
+ if with_rot_and_grip:
509
+ # rotation, gripper, and collision one-hots
510
+ action_rot_x_one_hot = self._action_rot_x_one_hot_zeros.clone()
511
+ action_rot_y_one_hot = self._action_rot_y_one_hot_zeros.clone()
512
+ action_rot_z_one_hot = self._action_rot_z_one_hot_zeros.clone()
513
+ action_grip_one_hot = self._action_grip_one_hot_zeros.clone()
514
+ action_ignore_collisions_one_hot = (
515
+ self._action_ignore_collisions_one_hot_zeros.clone()
516
+ )
517
+
518
+ for b in range(bs):
519
+ gt_rot_grip = action_rot_grip[b, :].int()
520
+ action_rot_x_one_hot[b, gt_rot_grip[0]] = 1
521
+ action_rot_y_one_hot[b, gt_rot_grip[1]] = 1
522
+ action_rot_z_one_hot[b, gt_rot_grip[2]] = 1
523
+ action_grip_one_hot[b, gt_rot_grip[3]] = 1
524
+
525
+ gt_ignore_collisions = action_ignore_collisions[b, :].int()
526
+ action_ignore_collisions_one_hot[b, gt_ignore_collisions[0]] = 1
527
+
528
+ # flatten predictions
529
+ q_rot_x_flat = q_rot_grip[
530
+ :, 0 * self._num_rotation_classes : 1 * self._num_rotation_classes
531
+ ]
532
+ q_rot_y_flat = q_rot_grip[
533
+ :, 1 * self._num_rotation_classes : 2 * self._num_rotation_classes
534
+ ]
535
+ q_rot_z_flat = q_rot_grip[
536
+ :, 2 * self._num_rotation_classes : 3 * self._num_rotation_classes
537
+ ]
538
+ q_grip_flat = q_rot_grip[:, 3 * self._num_rotation_classes :]
539
+ q_ignore_collisions_flat = q_collision
540
+
541
+ # rotation loss
542
+ q_rot_loss += self._celoss(q_rot_x_flat, action_rot_x_one_hot)
543
+ q_rot_loss += self._celoss(q_rot_y_flat, action_rot_y_one_hot)
544
+ q_rot_loss += self._celoss(q_rot_z_flat, action_rot_z_one_hot)
545
+
546
+ # gripper loss
547
+ q_grip_loss += self._celoss(q_grip_flat, action_grip_one_hot)
548
+
549
+ # collision loss
550
+ q_collision_loss += self._celoss(
551
+ q_ignore_collisions_flat, action_ignore_collisions_one_hot
552
+ )
553
+
554
+ combined_losses = (
555
+ (q_trans_loss * self._trans_loss_weight)
556
+ + (q_rot_loss * self._rot_loss_weight)
557
+ + (q_grip_loss * self._grip_loss_weight)
558
+ + (q_collision_loss * self._collision_loss_weight)
559
+ )
560
+ total_loss = combined_losses.mean()
561
+
562
+ self._optimizer.zero_grad()
563
+ total_loss.backward()
564
+ self._optimizer.step()
565
+
566
+ self._summaries = {
567
+ "losses/total_loss": total_loss,
568
+ "losses/trans_loss": q_trans_loss.mean(),
569
+ "losses/rot_loss": q_rot_loss.mean() if with_rot_and_grip else 0.0,
570
+ "losses/grip_loss": q_grip_loss.mean() if with_rot_and_grip else 0.0,
571
+ "losses/collision_loss": q_collision_loss.mean()
572
+ if with_rot_and_grip
573
+ else 0.0,
574
+ }
575
+
576
+ if self._lr_scheduler:
577
+ self._scheduler.step()
578
+ self._summaries["learning_rate"] = self._scheduler.get_last_lr()[0]
579
+
580
+ self._vis_voxel_grid = voxel_grid[0]
581
+ self._vis_translation_qvalue = self._softmax_q_trans(q_trans[0])
582
+ self._vis_max_coordinate = coords[0]
583
+ self._vis_gt_coordinate = action_trans[0]
584
+
585
+ # Note: PerAct doesn't use multi-layer voxel grids like C2FARM
586
+ # stack prev_layer_voxel_grid(s) from previous layers into a list
587
+ if prev_layer_voxel_grid is None:
588
+ prev_layer_voxel_grid = [voxel_grid]
589
+ else:
590
+ prev_layer_voxel_grid = prev_layer_voxel_grid + [voxel_grid]
591
+
592
+ # stack prev_layer_bound(s) from previous layers into a list
593
+ if prev_layer_bounds is None:
594
+ prev_layer_bounds = [self._coordinate_bounds.repeat(bs, 1)]
595
+ else:
596
+ prev_layer_bounds = prev_layer_bounds + [bounds]
597
+
598
+ return {
599
+ "total_loss": total_loss,
600
+ "prev_layer_voxel_grid": prev_layer_voxel_grid,
601
+ "prev_layer_bounds": prev_layer_bounds,
602
+ }
603
+
604
+ def act(self, step: int, observation: dict, deterministic=False) -> ActResult:
605
+ deterministic = True
606
+ bounds = self._coordinate_bounds
607
+ prev_layer_voxel_grid = observation.get("prev_layer_voxel_grid", None)
608
+ prev_layer_bounds = observation.get("prev_layer_bounds", None)
609
+ lang_goal_tokens = observation.get("lang_goal_tokens", None).long()
610
+
611
+ # extract CLIP language embs
612
+ with torch.no_grad():
613
+ lang_goal_tokens = lang_goal_tokens.to(device=self._device)
614
+ (
615
+ lang_goal_emb,
616
+ lang_token_embs,
617
+ ) = self._clip_rn50.encode_text_with_embeddings(lang_goal_tokens[0])
618
+
619
+ if self._layer > 0:
620
+ cp = observation["attention_coordinate"]
621
+ bounds = torch.cat(
622
+ [cp - self._bounds_offset, cp + self._bounds_offset], dim=1
623
+ )
624
+
625
+ # voxelization resolution
626
+ res = (bounds[:, 3:] - bounds[:, :3]) / self._voxel_size
627
+ max_rot_index = int(360 // self._rotation_resolution)
628
+ proprio = None
629
+
630
+ if self._include_low_dim_state:
631
+ proprio = observation["low_dim_state"]
632
+
633
+ obs, pcd = self._act_preprocess_inputs(observation)
634
+
635
+ # correct batch size and device
636
+ obs = [[o[0].to(self._device), o[1].to(self._device)] for o in obs]
637
+ proprio = proprio[0].to(self._device)
638
+ pcd = [p.to(self._device) for p in pcd]
639
+ lang_goal_emb = lang_goal_emb.to(self._device)
640
+ lang_token_embs = lang_token_embs.to(self._device)
641
+ bounds = torch.as_tensor(bounds, device=self._device)
642
+ if prev_layer_voxel_grid is not None:
643
+ prev_layer_voxel_grid = [
644
+ pvg.to(self._device) for pvg in prev_layer_voxel_grid
645
+ ]
646
+ if prev_layer_bounds is not None:
647
+ prev_layer_bounds = [pb.to(self._device) for pb in prev_layer_bounds]
648
+
649
+ # inference
650
+ q_trans, q_rot_grip, q_ignore_collisions, vox_grid = self._q(
651
+ obs,
652
+ proprio,
653
+ pcd,
654
+ lang_goal_emb,
655
+ lang_token_embs,
656
+ bounds,
657
+ prev_layer_bounds,
658
+ prev_layer_voxel_grid,
659
+ )
660
+
661
+ # softmax Q predictions
662
+ q_trans = self._softmax_q_trans(q_trans)
663
+ q_rot_grip = (
664
+ self._softmax_q_rot_grip(q_rot_grip) if q_rot_grip is not None else None
665
+ )
666
+ q_ignore_collisions = (
667
+ self._softmax_ignore_collision(q_ignore_collisions)
668
+ if q_ignore_collisions is not None
669
+ else None
670
+ )
671
+
672
+ # argmax Q predictions
673
+ (
674
+ coords,
675
+ rot_and_grip_indicies,
676
+ ignore_collisions,
677
+ ) = self._q.choose_highest_action(q_trans, q_rot_grip, q_ignore_collisions)
678
+
679
+ rot_grip_action = rot_and_grip_indicies if q_rot_grip is not None else None
680
+ ignore_collisions_action = (
681
+ ignore_collisions.int() if ignore_collisions is not None else None
682
+ )
683
+
684
+ coords = coords.int()
685
+ attention_coordinate = bounds[:, :3] + res * coords + res / 2
686
+
687
+ # stack prev_layer_voxel_grid(s) into a list
688
+ # NOTE: PerAct doesn't used multi-layer voxel grids like C2FARM
689
+ if prev_layer_voxel_grid is None:
690
+ prev_layer_voxel_grid = [vox_grid]
691
+ else:
692
+ prev_layer_voxel_grid = prev_layer_voxel_grid + [vox_grid]
693
+
694
+ if prev_layer_bounds is None:
695
+ prev_layer_bounds = [bounds]
696
+ else:
697
+ prev_layer_bounds = prev_layer_bounds + [bounds]
698
+
699
+ observation_elements = {
700
+ "attention_coordinate": attention_coordinate,
701
+ "prev_layer_voxel_grid": prev_layer_voxel_grid,
702
+ "prev_layer_bounds": prev_layer_bounds,
703
+ }
704
+ info = {
705
+ "voxel_grid_depth%d" % self._layer: vox_grid,
706
+ "q_depth%d" % self._layer: q_trans,
707
+ "voxel_idx_depth%d" % self._layer: coords,
708
+ }
709
+ self._act_voxel_grid = vox_grid[0]
710
+ self._act_max_coordinate = coords[0]
711
+ self._act_qvalues = q_trans[0].detach()
712
+ return ActResult(
713
+ (coords, rot_grip_action, ignore_collisions_action),
714
+ observation_elements=observation_elements,
715
+ info=info,
716
+ )
717
+
718
+ def update_summaries(self) -> List[Summary]:
719
+ summaries = [
720
+ ImageSummary(
721
+ "%s/update_qattention" % self._name,
722
+ transforms.ToTensor()(
723
+ visualise_voxel(
724
+ self._vis_voxel_grid.detach().cpu().numpy(),
725
+ self._vis_translation_qvalue.detach().cpu().numpy(),
726
+ self._vis_max_coordinate.detach().cpu().numpy(),
727
+ self._vis_gt_coordinate.detach().cpu().numpy(),
728
+ )
729
+ ),
730
+ )
731
+ ]
732
+
733
+ for n, v in self._summaries.items():
734
+ summaries.append(ScalarSummary("%s/%s" % (self._name, n), v))
735
+
736
+ for name, crop in self._crop_summary:
737
+ crops = (torch.cat(torch.split(crop, 3, dim=1), dim=3) + 1.0) / 2.0
738
+ summaries.extend([ImageSummary("%s/crops/%s" % (self._name, name), crops)])
739
+
740
+ for tag, param in self._q.named_parameters():
741
+ # assert not torch.isnan(param.grad.abs() <= 1.0).all()
742
+ summaries.append(
743
+ HistogramSummary("%s/gradient/%s" % (self._name, tag), param.grad)
744
+ )
745
+ summaries.append(
746
+ HistogramSummary("%s/weight/%s" % (self._name, tag), param.data)
747
+ )
748
+
749
+ return summaries
750
+
751
+ def act_summaries(self) -> List[Summary]:
752
+ return [
753
+ ImageSummary(
754
+ "%s/act_Qattention" % self._name,
755
+ transforms.ToTensor()(
756
+ visualise_voxel(
757
+ self._act_voxel_grid.cpu().numpy(),
758
+ self._act_qvalues.cpu().numpy(),
759
+ self._act_max_coordinate.cpu().numpy(),
760
+ )
761
+ ),
762
+ )
763
+ ]
764
+
765
+ def load_weights(self, savedir: str):
766
+ device = (
767
+ self._device
768
+ if not self._training
769
+ else torch.device("cuda:%d" % self._device)
770
+ )
771
+ state_dict = torch.load(
772
+ os.path.join(savedir, "%s.pt" % self._name), map_location=device
773
+ )
774
+
775
+ # load only keys that are in the current model
776
+ merged_state_dict = self._q.state_dict()
777
+ for k, v in state_dict.items():
778
+ if "_voxelizer" not in k:
779
+ if not self._training:
780
+ k = k.replace("_qnet.module", "_qnet")
781
+
782
+ if k in merged_state_dict:
783
+ merged_state_dict[k] = v
784
+ else:
785
+ logging.warning("key %s not found in checkpoint" % k)
786
+ self._q.load_state_dict(merged_state_dict)
787
+ print("loaded weights from %s" % savedir)
788
+
789
+ def save_weights(self, savedir: str):
790
+ torch.save(self._q.state_dict(), os.path.join(savedir, "%s.pt" % self._name))
external/peract_bimanual/agents/c2farm_lingunet_bc/qattention_stack_agent.py ADDED
@@ -0,0 +1,136 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List
2
+
3
+ import torch
4
+ from yarr.agents.agent import Agent, ActResult, Summary
5
+
6
+ import numpy as np
7
+
8
+ from helpers import utils
9
+ from agents.c2farm_lingunet_bc.qattention_lingunet_bc_agent import (
10
+ QAttentionLingUNetBCAgent,
11
+ )
12
+
13
+ from scipy.spatial.transform import Rotation
14
+
15
+ NAME = "QAttentionStackAgent"
16
+
17
+
18
+ class QAttentionStackAgent(Agent):
19
+ def __init__(
20
+ self,
21
+ qattention_agents: List[QAttentionLingUNetBCAgent],
22
+ rotation_resolution: float,
23
+ camera_names: List[str],
24
+ rotation_prediction_depth: int = 0,
25
+ ):
26
+ super(QAttentionStackAgent, self).__init__()
27
+ self._qattention_agents = qattention_agents
28
+ self._rotation_resolution = rotation_resolution
29
+ self._camera_names = camera_names
30
+ self._rotation_prediction_depth = rotation_prediction_depth
31
+
32
+ def build(self, training: bool, device=None) -> None:
33
+ self._device = device
34
+ if self._device is None:
35
+ self._device = torch.device("cpu")
36
+ for qa in self._qattention_agents:
37
+ qa.build(training, device)
38
+
39
+ def update(self, step: int, replay_sample: dict) -> dict:
40
+ priorities = 0
41
+ total_losses = 0.0
42
+ for qa in self._qattention_agents:
43
+ update_dict = qa.update(step, replay_sample)
44
+ replay_sample.update(update_dict)
45
+ total_losses += update_dict["total_loss"]
46
+ return {
47
+ "total_losses": total_losses,
48
+ }
49
+
50
+ def act(self, step: int, observation: dict, deterministic=False) -> ActResult:
51
+ observation_elements = {}
52
+ translation_results, rot_grip_results, ignore_collisions_results = [], [], []
53
+ infos = {}
54
+ for depth, qagent in enumerate(self._qattention_agents):
55
+ act_results = qagent.act(step, observation, deterministic)
56
+ attention_coordinate = (
57
+ act_results.observation_elements["attention_coordinate"].cpu().numpy()
58
+ )
59
+ observation_elements[
60
+ "attention_coordinate_layer_%d" % depth
61
+ ] = attention_coordinate[0]
62
+
63
+ translation_idxs, rot_grip_idxs, ignore_collisions_idxs = act_results.action
64
+ translation_results.append(translation_idxs)
65
+ if rot_grip_idxs is not None:
66
+ rot_grip_results.append(rot_grip_idxs)
67
+ if ignore_collisions_idxs is not None:
68
+ ignore_collisions_results.append(ignore_collisions_idxs)
69
+
70
+ observation["attention_coordinate"] = act_results.observation_elements[
71
+ "attention_coordinate"
72
+ ]
73
+ observation["prev_layer_voxel_grid"] = act_results.observation_elements[
74
+ "prev_layer_voxel_grid"
75
+ ]
76
+ observation["prev_layer_bounds"] = act_results.observation_elements[
77
+ "prev_layer_bounds"
78
+ ]
79
+
80
+ for n in self._camera_names:
81
+ px, py = utils.point_to_pixel_index(
82
+ attention_coordinate[0],
83
+ observation["%s_camera_extrinsics" % n][0, 0].cpu().numpy(),
84
+ observation["%s_camera_intrinsics" % n][0, 0].cpu().numpy(),
85
+ )
86
+ pc_t = torch.tensor(
87
+ [[[py, px]]], dtype=torch.float32, device=self._device
88
+ )
89
+ observation["%s_pixel_coord" % n] = pc_t
90
+ observation_elements["%s_pixel_coord" % n] = [py, px]
91
+
92
+ infos.update(act_results.info)
93
+
94
+ rgai = torch.cat(rot_grip_results, 1)[0].cpu().numpy()
95
+ ignore_collisions = float(
96
+ torch.cat(ignore_collisions_results, 1)[0].cpu().numpy()
97
+ )
98
+ observation_elements["trans_action_indicies"] = (
99
+ torch.cat(translation_results, 1)[0].cpu().numpy()
100
+ )
101
+ observation_elements["rot_grip_action_indicies"] = rgai
102
+ continuous_action = np.concatenate(
103
+ [
104
+ act_results.observation_elements["attention_coordinate"]
105
+ .cpu()
106
+ .numpy()[0],
107
+ utils.discrete_euler_to_quaternion(
108
+ rgai[-4:-1], self._rotation_resolution
109
+ ),
110
+ rgai[-1:],
111
+ [ignore_collisions],
112
+ ]
113
+ )
114
+ return ActResult(
115
+ continuous_action, observation_elements=observation_elements, info=infos
116
+ )
117
+
118
+ def update_summaries(self) -> List[Summary]:
119
+ summaries = []
120
+ for qa in self._qattention_agents:
121
+ summaries.extend(qa.update_summaries())
122
+ return summaries
123
+
124
+ def act_summaries(self) -> List[Summary]:
125
+ s = []
126
+ for qa in self._qattention_agents:
127
+ s.extend(qa.act_summaries())
128
+ return s
129
+
130
+ def load_weights(self, savedir: str):
131
+ for qa in self._qattention_agents:
132
+ qa.load_weights(savedir)
133
+
134
+ def save_weights(self, savedir: str):
135
+ for qa in self._qattention_agents:
136
+ qa.save_weights(savedir)
external/peract_bimanual/agents/peract_bc/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ import agents.peract_bc.launch_utils
external/peract_bimanual/agents/peract_bc/launch_utils.py ADDED
@@ -0,0 +1,94 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Adapted from ARM
2
+ # Source: https://github.com/stepjam/ARM
3
+ # License: https://github.com/stepjam/ARM/LICENSE
4
+
5
+
6
+ from helpers.preprocess_agent import PreprocessAgent
7
+ from agents.peract_bc.perceiver_lang_io import PerceiverVoxelLangEncoder
8
+ from agents.peract_bc.qattention_peract_bc_agent import QAttentionPerActBCAgent
9
+ from agents.peract_bc.qattention_stack_agent import QAttentionStackAgent
10
+
11
+ from omegaconf import DictConfig
12
+
13
+
14
+ def create_agent(cfg: DictConfig):
15
+ LATENT_SIZE = 64
16
+ depth_0bounds = cfg.rlbench.scene_bounds
17
+ cam_resolution = cfg.rlbench.camera_resolution
18
+
19
+ num_rotation_classes = int(360.0 // cfg.method.rotation_resolution)
20
+ qattention_agents = []
21
+ for depth, vox_size in enumerate(cfg.method.voxel_sizes):
22
+ last = depth == len(cfg.method.voxel_sizes) - 1
23
+ perceiver_encoder = PerceiverVoxelLangEncoder(
24
+ depth=cfg.method.transformer_depth,
25
+ iterations=cfg.method.transformer_iterations,
26
+ voxel_size=vox_size,
27
+ initial_dim=3 + 3 + 1 + 3,
28
+ low_dim_size=cfg.method.low_dim_size,
29
+ layer=depth,
30
+ num_rotation_classes=num_rotation_classes if last else 0,
31
+ num_grip_classes=2 if last else 0,
32
+ num_collision_classes=2 if last else 0,
33
+ input_axis=3,
34
+ num_latents=cfg.method.num_latents,
35
+ latent_dim=cfg.method.latent_dim,
36
+ cross_heads=cfg.method.cross_heads,
37
+ latent_heads=cfg.method.latent_heads,
38
+ cross_dim_head=cfg.method.cross_dim_head,
39
+ latent_dim_head=cfg.method.latent_dim_head,
40
+ weight_tie_layers=False,
41
+ activation=cfg.method.activation,
42
+ pos_encoding_with_lang=cfg.method.pos_encoding_with_lang,
43
+ input_dropout=cfg.method.input_dropout,
44
+ attn_dropout=cfg.method.attn_dropout,
45
+ decoder_dropout=cfg.method.decoder_dropout,
46
+ lang_fusion_type=cfg.method.lang_fusion_type,
47
+ voxel_patch_size=cfg.method.voxel_patch_size,
48
+ voxel_patch_stride=cfg.method.voxel_patch_stride,
49
+ no_skip_connection=cfg.method.no_skip_connection,
50
+ no_perceiver=cfg.method.no_perceiver,
51
+ no_language=cfg.method.no_language,
52
+ final_dim=cfg.method.final_dim,
53
+ )
54
+
55
+ qattention_agent = QAttentionPerActBCAgent(
56
+ layer=depth,
57
+ coordinate_bounds=depth_0bounds,
58
+ perceiver_encoder=perceiver_encoder,
59
+ camera_names=cfg.rlbench.cameras,
60
+ voxel_size=vox_size,
61
+ bounds_offset=cfg.method.bounds_offset[depth - 1] if depth > 0 else None,
62
+ image_crop_size=cfg.method.image_crop_size,
63
+ lr=cfg.method.lr,
64
+ training_iterations=cfg.framework.training_iterations,
65
+ lr_scheduler=cfg.method.lr_scheduler,
66
+ num_warmup_steps=cfg.method.num_warmup_steps,
67
+ trans_loss_weight=cfg.method.trans_loss_weight,
68
+ rot_loss_weight=cfg.method.rot_loss_weight,
69
+ grip_loss_weight=cfg.method.grip_loss_weight,
70
+ collision_loss_weight=cfg.method.collision_loss_weight,
71
+ include_low_dim_state=True,
72
+ image_resolution=cam_resolution,
73
+ batch_size=cfg.replay.batch_size,
74
+ voxel_feature_size=3,
75
+ lambda_weight_l2=cfg.method.lambda_weight_l2,
76
+ num_rotation_classes=num_rotation_classes,
77
+ rotation_resolution=cfg.method.rotation_resolution,
78
+ transform_augmentation=cfg.method.transform_augmentation.apply_se3,
79
+ transform_augmentation_xyz=cfg.method.transform_augmentation.aug_xyz,
80
+ transform_augmentation_rpy=cfg.method.transform_augmentation.aug_rpy,
81
+ transform_augmentation_rot_resolution=cfg.method.transform_augmentation.aug_rot_resolution,
82
+ optimizer_type=cfg.method.optimizer,
83
+ num_devices=cfg.ddp.num_devices,
84
+ checkpoint_name_prefix=cfg.framework.checkpoint_name_prefix,
85
+ )
86
+ qattention_agents.append(qattention_agent)
87
+
88
+ rotation_agent = QAttentionStackAgent(
89
+ qattention_agents=qattention_agents,
90
+ rotation_resolution=cfg.method.rotation_resolution,
91
+ camera_names=cfg.rlbench.cameras,
92
+ )
93
+ preprocess_agent = PreprocessAgent(pose_agent=rotation_agent)
94
+ return preprocess_agent
external/peract_bimanual/agents/peract_bc/perceiver_lang_io.py ADDED
@@ -0,0 +1,426 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Perceiver IO implementation adpated for manipulation
2
+ # Source: https://github.com/lucidrains/perceiver-pytorch
3
+ # License: https://github.com/lucidrains/perceiver-pytorch/blob/main/LICENSE
4
+
5
+ import torch
6
+ from torch import nn
7
+
8
+ from einops import rearrange
9
+ from einops import repeat
10
+
11
+ from perceiver_pytorch.perceiver_pytorch import cache_fn
12
+ from perceiver_pytorch.perceiver_pytorch import PreNorm, FeedForward, Attention
13
+
14
+ from helpers.network_utils import (
15
+ DenseBlock,
16
+ SpatialSoftmax3D,
17
+ Conv3DBlock,
18
+ Conv3DUpsampleBlock,
19
+ )
20
+
21
+
22
+ # PerceiverIO adapted for 6-DoF manipulation
23
+ class PerceiverVoxelLangEncoder(nn.Module):
24
+ def __init__(
25
+ self,
26
+ depth, # number of self-attention layers
27
+ iterations, # number cross-attention iterations (PerceiverIO uses just 1)
28
+ voxel_size, # N voxels per side (size: N*N*N)
29
+ initial_dim, # 10 dimensions - dimension of the input sequence to be encoded
30
+ low_dim_size, # 4 dimensions - proprioception: {gripper_open, left_finger, right_finger, timestep}
31
+ layer=0,
32
+ num_rotation_classes=72, # 5 degree increments (5*72=360) for each of the 3-axis
33
+ num_grip_classes=2, # open or not open
34
+ num_collision_classes=2, # collisions allowed or not allowed
35
+ input_axis=3, # 3D tensors have 3 axes
36
+ num_latents=512, # number of latent vectors
37
+ im_channels=64, # intermediate channel size
38
+ latent_dim=512, # dimensions of latent vectors
39
+ cross_heads=1, # number of cross-attention heads
40
+ latent_heads=8, # number of latent heads
41
+ cross_dim_head=64,
42
+ latent_dim_head=64,
43
+ activation="relu",
44
+ weight_tie_layers=False,
45
+ pos_encoding_with_lang=True,
46
+ input_dropout=0.1,
47
+ attn_dropout=0.1,
48
+ decoder_dropout=0.0,
49
+ lang_fusion_type="seq",
50
+ voxel_patch_size=9,
51
+ voxel_patch_stride=8,
52
+ no_skip_connection=False,
53
+ no_perceiver=False,
54
+ no_language=False,
55
+ final_dim=64,
56
+ ):
57
+ super().__init__()
58
+ self.depth = depth
59
+ self.layer = layer
60
+ self.init_dim = int(initial_dim)
61
+ self.iterations = iterations
62
+ self.input_axis = input_axis
63
+ self.voxel_size = voxel_size
64
+ self.low_dim_size = low_dim_size
65
+ self.im_channels = im_channels
66
+ self.pos_encoding_with_lang = pos_encoding_with_lang
67
+ self.lang_fusion_type = lang_fusion_type
68
+ self.voxel_patch_size = voxel_patch_size
69
+ self.voxel_patch_stride = voxel_patch_stride
70
+ self.num_rotation_classes = num_rotation_classes
71
+ self.num_grip_classes = num_grip_classes
72
+ self.num_collision_classes = num_collision_classes
73
+ self.final_dim = final_dim
74
+ self.input_dropout = input_dropout
75
+ self.attn_dropout = attn_dropout
76
+ self.decoder_dropout = decoder_dropout
77
+ self.no_skip_connection = no_skip_connection
78
+ self.no_perceiver = no_perceiver
79
+ self.no_language = no_language
80
+
81
+ # patchified input dimensions
82
+ spatial_size = voxel_size // self.voxel_patch_stride # 100/5 = 20
83
+
84
+ # 64 voxel features + 64 proprio features (+ 64 lang goal features if concattenated)
85
+ self.input_dim_before_seq = (
86
+ self.im_channels * 3
87
+ if self.lang_fusion_type == "concat"
88
+ else self.im_channels * 2
89
+ )
90
+
91
+ # CLIP language feature dimensions
92
+ lang_feat_dim, lang_emb_dim, lang_max_seq_len = 1024, 512, 77
93
+
94
+ # learnable positional encoding
95
+ if self.pos_encoding_with_lang:
96
+ self.pos_encoding = nn.Parameter(
97
+ torch.randn(
98
+ 1, lang_max_seq_len + spatial_size**3, self.input_dim_before_seq
99
+ )
100
+ )
101
+ else:
102
+ # assert self.lang_fusion_type == 'concat', 'Only concat is supported for pos encoding without lang.'
103
+ self.pos_encoding = nn.Parameter(
104
+ torch.randn(
105
+ 1,
106
+ spatial_size,
107
+ spatial_size,
108
+ spatial_size,
109
+ self.input_dim_before_seq,
110
+ )
111
+ )
112
+
113
+ # voxel input preprocessing 1x1 conv encoder
114
+ self.input_preprocess = Conv3DBlock(
115
+ self.init_dim,
116
+ self.im_channels,
117
+ kernel_sizes=1,
118
+ strides=1,
119
+ norm=None,
120
+ activation=activation,
121
+ )
122
+
123
+ # patchify conv
124
+ self.patchify = Conv3DBlock(
125
+ self.input_preprocess.out_channels,
126
+ self.im_channels,
127
+ kernel_sizes=self.voxel_patch_size,
128
+ strides=self.voxel_patch_stride,
129
+ norm=None,
130
+ activation=activation,
131
+ )
132
+
133
+ # language preprocess
134
+ if self.lang_fusion_type == "concat":
135
+ self.lang_preprocess = nn.Linear(lang_feat_dim, self.im_channels)
136
+ elif self.lang_fusion_type == "seq":
137
+ self.lang_preprocess = nn.Linear(lang_emb_dim, self.im_channels * 2)
138
+
139
+ # proprioception
140
+ if self.low_dim_size > 0:
141
+ self.proprio_preprocess = DenseBlock(
142
+ self.low_dim_size,
143
+ self.im_channels,
144
+ norm=None,
145
+ activation=activation,
146
+ )
147
+
148
+ # pooling functions
149
+ self.local_maxp = nn.MaxPool3d(3, 2, padding=1)
150
+ self.global_maxp = nn.AdaptiveMaxPool3d(1)
151
+
152
+ # 1st 3D softmax
153
+ self.ss0 = SpatialSoftmax3D(
154
+ self.voxel_size, self.voxel_size, self.voxel_size, self.im_channels
155
+ )
156
+ flat_size = self.im_channels * 4
157
+
158
+ # latent vectors (that are randomly initialized)
159
+ self.latents = nn.Parameter(torch.randn(num_latents, latent_dim))
160
+
161
+ # encoder cross attention
162
+ self.cross_attend_blocks = nn.ModuleList(
163
+ [
164
+ PreNorm(
165
+ latent_dim,
166
+ Attention(
167
+ latent_dim,
168
+ self.input_dim_before_seq,
169
+ heads=cross_heads,
170
+ dim_head=cross_dim_head,
171
+ dropout=input_dropout,
172
+ ),
173
+ context_dim=self.input_dim_before_seq,
174
+ ),
175
+ PreNorm(latent_dim, FeedForward(latent_dim)),
176
+ ]
177
+ )
178
+
179
+ get_latent_attn = lambda: PreNorm(
180
+ latent_dim,
181
+ Attention(
182
+ latent_dim,
183
+ heads=latent_heads,
184
+ dim_head=latent_dim_head,
185
+ dropout=attn_dropout,
186
+ ),
187
+ )
188
+ get_latent_ff = lambda: PreNorm(latent_dim, FeedForward(latent_dim))
189
+ get_latent_attn, get_latent_ff = map(cache_fn, (get_latent_attn, get_latent_ff))
190
+
191
+ # self attention layers
192
+ self.layers = nn.ModuleList([])
193
+ cache_args = {"_cache": weight_tie_layers}
194
+
195
+ for i in range(depth):
196
+ self.layers.append(
197
+ nn.ModuleList(
198
+ [get_latent_attn(**cache_args), get_latent_ff(**cache_args)]
199
+ )
200
+ )
201
+
202
+ # decoder cross attention
203
+ self.decoder_cross_attn = PreNorm(
204
+ self.input_dim_before_seq,
205
+ Attention(
206
+ self.input_dim_before_seq,
207
+ latent_dim,
208
+ heads=cross_heads,
209
+ dim_head=cross_dim_head,
210
+ dropout=decoder_dropout,
211
+ ),
212
+ context_dim=latent_dim,
213
+ )
214
+
215
+ # upsample conv
216
+ self.up0 = Conv3DUpsampleBlock(
217
+ self.input_dim_before_seq,
218
+ self.final_dim,
219
+ kernel_sizes=self.voxel_patch_size,
220
+ strides=self.voxel_patch_stride,
221
+ norm=None,
222
+ activation=activation,
223
+ )
224
+
225
+ # 2nd 3D softmax
226
+ self.ss1 = SpatialSoftmax3D(
227
+ spatial_size, spatial_size, spatial_size, self.input_dim_before_seq
228
+ )
229
+
230
+ flat_size += self.input_dim_before_seq * 4
231
+
232
+ # final 3D softmax
233
+ self.final = Conv3DBlock(
234
+ self.im_channels
235
+ if (self.no_perceiver or self.no_skip_connection)
236
+ else self.im_channels * 2,
237
+ self.im_channels,
238
+ kernel_sizes=3,
239
+ strides=1,
240
+ norm=None,
241
+ activation=activation,
242
+ )
243
+
244
+ self.trans_decoder = Conv3DBlock(
245
+ self.final_dim,
246
+ 1,
247
+ kernel_sizes=3,
248
+ strides=1,
249
+ norm=None,
250
+ activation=None,
251
+ )
252
+
253
+ # rotation, gripper, and collision MLP layers
254
+ if self.num_rotation_classes > 0:
255
+ self.ss_final = SpatialSoftmax3D(
256
+ self.voxel_size, self.voxel_size, self.voxel_size, self.im_channels
257
+ )
258
+
259
+ flat_size += self.im_channels * 4
260
+
261
+ self.dense0 = DenseBlock(flat_size, 256, None, activation)
262
+ self.dense1 = DenseBlock(256, self.final_dim, None, activation)
263
+
264
+ self.rot_grip_collision_ff = DenseBlock(
265
+ self.final_dim,
266
+ self.num_rotation_classes * 3
267
+ + self.num_grip_classes
268
+ + self.num_collision_classes,
269
+ None,
270
+ None,
271
+ )
272
+
273
+ def encode_text(self, x):
274
+ with torch.no_grad():
275
+ text_feat, text_emb = self._clip_rn50.encode_text_with_embeddings(x)
276
+
277
+ text_feat = text_feat.detach()
278
+ text_emb = text_emb.detach()
279
+ text_mask = torch.where(x == 0, x, 1) # [1, max_token_len]
280
+ return text_feat, text_emb
281
+
282
+ def forward(
283
+ self,
284
+ ins,
285
+ proprio,
286
+ lang_goal_emb,
287
+ lang_token_embs,
288
+ prev_layer_voxel_grid,
289
+ bounds,
290
+ prev_layer_bounds,
291
+ mask=None,
292
+ ):
293
+ # preprocess input
294
+ d0 = self.input_preprocess(ins) # [B,10,100,100,100] -> [B,64,100,100,100]
295
+
296
+ # aggregated features from 1st softmax and maxpool for MLP decoders
297
+ feats = [self.ss0(d0.contiguous()), self.global_maxp(d0).view(ins.shape[0], -1)]
298
+
299
+ # patchify input (5x5x5 patches)
300
+ ins = self.patchify(d0) # [B,64,100,100,100] -> [B,64,20,20,20]
301
+
302
+ b, c, d, h, w, device = *ins.shape, ins.device
303
+ axis = [d, h, w]
304
+ assert (
305
+ len(axis) == self.input_axis
306
+ ), "input must have the same number of axis as input_axis"
307
+
308
+ # concat proprio
309
+ if self.low_dim_size > 0:
310
+ p = self.proprio_preprocess(proprio) # [B,4] -> [B,64]
311
+ p = p.unsqueeze(-1).unsqueeze(-1).unsqueeze(-1).repeat(1, 1, d, h, w)
312
+ ins = torch.cat([ins, p], dim=1) # [B,128,20,20,20]
313
+
314
+ # language ablation
315
+ if self.no_language:
316
+ lang_goal_emb = torch.zeros_like(lang_goal_emb)
317
+ lang_token_embs = torch.zeros_like(lang_token_embs)
318
+
319
+ # option 1: tile and concat lang goal to input
320
+ if self.lang_fusion_type == "concat":
321
+ lang_emb = lang_goal_emb
322
+ lang_emb = lang_emb.to(dtype=ins.dtype)
323
+ l = self.lang_preprocess(lang_emb)
324
+ l = l.unsqueeze(-1).unsqueeze(-1).unsqueeze(-1).repeat(1, 1, d, h, w)
325
+ ins = torch.cat([ins, l], dim=1)
326
+
327
+ # channel last
328
+ ins = rearrange(ins, "b d ... -> b ... d") # [B,20,20,20,128]
329
+
330
+ # add pos encoding to grid
331
+ if not self.pos_encoding_with_lang:
332
+ ins = ins + self.pos_encoding
333
+
334
+ ######################## NOTE #############################
335
+ # NOTE: If you add positional encodings ^here the lang embs
336
+ # won't have positional encodings. I accidently forgot
337
+ # to turn this off for all the experiments in the paper.
338
+ # So I guess those models were using language embs
339
+ # as a bag of words :( But it doesn't matter much for
340
+ # RLBench tasks since we don't test for novel instructions
341
+ # at test time anyway. The recommend way is to add
342
+ # positional encodings to the final input sequence
343
+ # fed into the Perceiver Transformer, as done below
344
+ # (and also in the Colab tutorial).
345
+ ###########################################################
346
+
347
+ # concat to channels of and flatten axis
348
+ queries_orig_shape = ins.shape
349
+
350
+ # rearrange input to be channel last
351
+ ins = rearrange(ins, "b ... d -> b (...) d") # [B,8000,128]
352
+ ins_wo_prev_layers = ins
353
+
354
+ # option 2: add lang token embs as a sequence
355
+ if self.lang_fusion_type == "seq":
356
+ l = self.lang_preprocess(lang_token_embs) # [B,77,1024] -> [B,77,128]
357
+ ins = torch.cat((l, ins), dim=1) # [B,8077,128]
358
+
359
+ # add pos encoding to language + flattened grid (the recommended way)
360
+ if self.pos_encoding_with_lang:
361
+ ins = ins + self.pos_encoding
362
+
363
+ # batchify latents
364
+ x = repeat(self.latents, "n d -> b n d", b=b)
365
+
366
+ cross_attn, cross_ff = self.cross_attend_blocks
367
+
368
+ for it in range(self.iterations):
369
+ # encoder cross attention
370
+ x = cross_attn(x, context=ins, mask=mask) + x
371
+ x = cross_ff(x) + x
372
+
373
+ # self-attention layers
374
+ for self_attn, self_ff in self.layers:
375
+ x = self_attn(x) + x
376
+ x = self_ff(x) + x
377
+
378
+ # decoder cross attention
379
+ latents = self.decoder_cross_attn(ins, context=x)
380
+
381
+ # crop out the language part of the output sequence
382
+ if self.lang_fusion_type == "seq":
383
+ latents = latents[:, l.shape[1] :]
384
+
385
+ # reshape back to voxel grid
386
+ latents = latents.view(
387
+ b, *queries_orig_shape[1:-1], latents.shape[-1]
388
+ ) # [B,20,20,20,64]
389
+ latents = rearrange(latents, "b ... d -> b d ...") # [B,64,20,20,20]
390
+
391
+ # aggregated features from 2nd softmax and maxpool for MLP decoders
392
+ feats.extend(
393
+ [self.ss1(latents.contiguous()), self.global_maxp(latents).view(b, -1)]
394
+ )
395
+
396
+ # upsample
397
+ u0 = self.up0(latents)
398
+
399
+ # ablations
400
+ if self.no_skip_connection:
401
+ u = self.final(u0)
402
+ elif self.no_perceiver:
403
+ u = self.final(d0)
404
+ else:
405
+ u = self.final(torch.cat([d0, u0], dim=1))
406
+
407
+ # translation decoder
408
+ trans = self.trans_decoder(u)
409
+
410
+ # rotation, gripper, and collision MLPs
411
+ rot_and_grip_out = None
412
+ if self.num_rotation_classes > 0:
413
+ feats.extend(
414
+ [self.ss_final(u.contiguous()), self.global_maxp(u).view(b, -1)]
415
+ )
416
+
417
+ dense0 = self.dense0(torch.cat(feats, dim=1))
418
+ dense1 = self.dense1(dense0) # [B,72*3+2+2]
419
+
420
+ rot_and_grip_collision_out = self.rot_grip_collision_ff(dense1)
421
+ rot_and_grip_out = rot_and_grip_collision_out[
422
+ :, : -self.num_collision_classes
423
+ ]
424
+ collision_out = rot_and_grip_collision_out[:, -self.num_collision_classes :]
425
+
426
+ return trans, rot_and_grip_out, collision_out
external/peract_bimanual/agents/peract_bc/qattention_peract_bc_agent.py ADDED
@@ -0,0 +1,808 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import copy
2
+ import logging
3
+ import os
4
+ from typing import List
5
+
6
+ import numpy as np
7
+ import torch
8
+ import torch.nn as nn
9
+ import torch.nn.functional as F
10
+ from torchvision import transforms
11
+ from pytorch3d import transforms as torch3d_tf
12
+ from yarr.agents.agent import (
13
+ Agent,
14
+ ActResult,
15
+ ScalarSummary,
16
+ HistogramSummary,
17
+ ImageSummary,
18
+ Summary,
19
+ )
20
+
21
+ from helpers import utils
22
+ from helpers.utils import visualise_voxel, stack_on_channel
23
+ from voxel.voxel_grid import VoxelGrid
24
+ from voxel.augmentation import apply_se3_augmentation
25
+ from einops import rearrange
26
+ from helpers.clip.core.clip import build_model, load_clip
27
+
28
+ import transformers
29
+ from helpers.optim.lamb import Lamb
30
+
31
+ from torch.nn.parallel import DistributedDataParallel as DDP
32
+
33
+
34
+ class QFunction(nn.Module):
35
+ def __init__(
36
+ self,
37
+ perceiver_encoder: nn.Module,
38
+ voxelizer: VoxelGrid,
39
+ bounds_offset: float,
40
+ rotation_resolution: float,
41
+ device,
42
+ training,
43
+ ):
44
+ super(QFunction, self).__init__()
45
+ self._rotation_resolution = rotation_resolution
46
+ self._voxelizer = voxelizer
47
+ self._bounds_offset = bounds_offset
48
+ self._qnet = perceiver_encoder.to(device)
49
+
50
+ # distributed training
51
+ if training:
52
+ self._qnet = DDP(self._qnet, device_ids=[device])
53
+
54
+ def _argmax_3d(self, tensor_orig):
55
+ b, c, d, h, w = tensor_orig.shape # c will be one
56
+ idxs = tensor_orig.view(b, c, -1).argmax(-1)
57
+ indices = torch.cat([((idxs // h) // d), (idxs // h) % w, idxs % w], 1)
58
+ return indices
59
+
60
+ def choose_highest_action(self, q_trans, q_rot_grip, q_collision):
61
+ coords = self._argmax_3d(q_trans)
62
+ rot_and_grip_indicies = None
63
+ ignore_collision = None
64
+ if q_rot_grip is not None:
65
+ q_rot = torch.stack(
66
+ torch.split(
67
+ q_rot_grip[:, :-2], int(360 // self._rotation_resolution), dim=1
68
+ ),
69
+ dim=1,
70
+ )
71
+ rot_and_grip_indicies = torch.cat(
72
+ [
73
+ q_rot[:, 0:1].argmax(-1),
74
+ q_rot[:, 1:2].argmax(-1),
75
+ q_rot[:, 2:3].argmax(-1),
76
+ q_rot_grip[:, -2:].argmax(-1, keepdim=True),
77
+ ],
78
+ -1,
79
+ )
80
+ ignore_collision = q_collision[:, -2:].argmax(-1, keepdim=True)
81
+ return coords, rot_and_grip_indicies, ignore_collision
82
+
83
+ def forward(
84
+ self,
85
+ rgb_pcd,
86
+ proprio,
87
+ pcd,
88
+ lang_goal_emb,
89
+ lang_token_embs,
90
+ bounds=None,
91
+ prev_bounds=None,
92
+ prev_layer_voxel_grid=None,
93
+ ):
94
+ # rgb_pcd will be list of list (list of [rgb, pcd])
95
+ b = rgb_pcd[0][0].shape[0]
96
+ pcd_flat = torch.cat([p.permute(0, 2, 3, 1).reshape(b, -1, 3) for p in pcd], 1)
97
+
98
+ # flatten RGBs and Pointclouds
99
+ rgb = [rp[0] for rp in rgb_pcd]
100
+ feat_size = rgb[0].shape[1]
101
+ flat_imag_features = torch.cat(
102
+ [p.permute(0, 2, 3, 1).reshape(b, -1, feat_size) for p in rgb], 1
103
+ )
104
+
105
+ # construct voxel grid
106
+ voxel_grid = self._voxelizer.coords_to_bounding_voxel_grid(
107
+ pcd_flat, coord_features=flat_imag_features, coord_bounds=bounds
108
+ )
109
+
110
+ # swap to channels fist
111
+ voxel_grid = voxel_grid.permute(0, 4, 1, 2, 3).detach()
112
+
113
+ # batch bounds if necessary
114
+ if bounds.shape[0] != b:
115
+ bounds = bounds.repeat(b, 1)
116
+
117
+ # forward pass
118
+ q_trans, q_rot_and_grip, q_ignore_collisions = self._qnet(
119
+ voxel_grid,
120
+ proprio,
121
+ lang_goal_emb,
122
+ lang_token_embs,
123
+ prev_layer_voxel_grid,
124
+ bounds,
125
+ prev_bounds,
126
+ )
127
+
128
+ return q_trans, q_rot_and_grip, q_ignore_collisions, voxel_grid
129
+
130
+
131
+ class QAttentionPerActBCAgent(Agent):
132
+ def __init__(
133
+ self,
134
+ layer: int,
135
+ coordinate_bounds: list,
136
+ perceiver_encoder: nn.Module,
137
+ camera_names: list,
138
+ batch_size: int,
139
+ voxel_size: int,
140
+ bounds_offset: float,
141
+ voxel_feature_size: int,
142
+ image_crop_size: int,
143
+ num_rotation_classes: int,
144
+ rotation_resolution: float,
145
+ lr: float = 0.0001,
146
+ lr_scheduler: bool = False,
147
+ training_iterations: int = 100000,
148
+ num_warmup_steps: int = 20000,
149
+ trans_loss_weight: float = 1.0,
150
+ rot_loss_weight: float = 1.0,
151
+ grip_loss_weight: float = 1.0,
152
+ collision_loss_weight: float = 1.0,
153
+ include_low_dim_state: bool = False,
154
+ image_resolution: list = None,
155
+ lambda_weight_l2: float = 0.0,
156
+ transform_augmentation: bool = True,
157
+ transform_augmentation_xyz: list = [0.0, 0.0, 0.0],
158
+ transform_augmentation_rpy: list = [0.0, 0.0, 180.0],
159
+ transform_augmentation_rot_resolution: int = 5,
160
+ optimizer_type: str = "adam",
161
+ num_devices: int = 1,
162
+ checkpoint_name_prefix=None,
163
+ ):
164
+ self._layer = layer
165
+ self._coordinate_bounds = coordinate_bounds
166
+ self._perceiver_encoder = perceiver_encoder
167
+ self._voxel_feature_size = voxel_feature_size
168
+ self._bounds_offset = bounds_offset
169
+ self._image_crop_size = image_crop_size
170
+ self._lr = lr
171
+ self._lr_scheduler = lr_scheduler
172
+ self._training_iterations = training_iterations
173
+ self._num_warmup_steps = num_warmup_steps
174
+ self._trans_loss_weight = trans_loss_weight
175
+ self._rot_loss_weight = rot_loss_weight
176
+ self._grip_loss_weight = grip_loss_weight
177
+ self._collision_loss_weight = collision_loss_weight
178
+ self._include_low_dim_state = include_low_dim_state
179
+ self._image_resolution = image_resolution or [128, 128]
180
+ self._voxel_size = voxel_size
181
+ self._camera_names = camera_names
182
+ self._num_cameras = len(camera_names)
183
+ self._batch_size = batch_size
184
+ self._lambda_weight_l2 = lambda_weight_l2
185
+ self._transform_augmentation = transform_augmentation
186
+ self._transform_augmentation_xyz = torch.from_numpy(
187
+ np.array(transform_augmentation_xyz)
188
+ )
189
+ self._transform_augmentation_rpy = transform_augmentation_rpy
190
+ self._transform_augmentation_rot_resolution = (
191
+ transform_augmentation_rot_resolution
192
+ )
193
+ self._optimizer_type = optimizer_type
194
+ self._num_devices = num_devices
195
+ self._num_rotation_classes = num_rotation_classes
196
+ self._rotation_resolution = rotation_resolution
197
+
198
+ self._cross_entropy_loss = nn.CrossEntropyLoss(reduction="none")
199
+ checkpoint_name_prefix = checkpoint_name_prefix or "QAttentionAgent"
200
+ self._name = f"{checkpoint_name_prefix}_layer_{self._layer}"
201
+
202
+ def build(self, training: bool, device: torch.device = None):
203
+ self._training = training
204
+
205
+ if device is None:
206
+ device = torch.device("cpu")
207
+
208
+ self._device = device
209
+
210
+ self._voxelizer = VoxelGrid(
211
+ coord_bounds=self._coordinate_bounds,
212
+ voxel_size=self._voxel_size,
213
+ device=device,
214
+ batch_size=self._batch_size if training else 1,
215
+ feature_size=self._voxel_feature_size,
216
+ max_num_coords=np.prod(self._image_resolution) * self._num_cameras,
217
+ )
218
+
219
+ self._q = (
220
+ QFunction(
221
+ self._perceiver_encoder,
222
+ self._voxelizer,
223
+ self._bounds_offset,
224
+ self._rotation_resolution,
225
+ device,
226
+ training,
227
+ )
228
+ .to(device)
229
+ .train(training)
230
+ )
231
+
232
+ grid_for_crop = (
233
+ torch.arange(0, self._image_crop_size, device=device)
234
+ .unsqueeze(0)
235
+ .repeat(self._image_crop_size, 1)
236
+ .unsqueeze(-1)
237
+ )
238
+ self._grid_for_crop = torch.cat(
239
+ [grid_for_crop.transpose(1, 0), grid_for_crop], dim=2
240
+ ).unsqueeze(0)
241
+
242
+ self._coordinate_bounds = torch.tensor(
243
+ self._coordinate_bounds, device=device
244
+ ).unsqueeze(0)
245
+
246
+ if self._training:
247
+ # optimizer
248
+ if self._optimizer_type == "lamb":
249
+ self._optimizer = Lamb(
250
+ self._q.parameters(),
251
+ lr=self._lr,
252
+ weight_decay=self._lambda_weight_l2,
253
+ betas=(0.9, 0.999),
254
+ adam=False,
255
+ )
256
+ elif self._optimizer_type == "adam":
257
+ self._optimizer = torch.optim.Adam(
258
+ self._q.parameters(),
259
+ lr=self._lr,
260
+ weight_decay=self._lambda_weight_l2,
261
+ )
262
+ else:
263
+ raise Exception("Unknown optimizer type")
264
+
265
+ # learning rate scheduler
266
+ if self._lr_scheduler:
267
+ self._scheduler = (
268
+ transformers.get_cosine_with_hard_restarts_schedule_with_warmup(
269
+ self._optimizer,
270
+ num_warmup_steps=self._num_warmup_steps,
271
+ num_training_steps=self._training_iterations,
272
+ num_cycles=self._training_iterations // 10000,
273
+ )
274
+ )
275
+
276
+ # one-hot zero tensors
277
+ self._action_trans_one_hot_zeros = torch.zeros(
278
+ (
279
+ self._batch_size,
280
+ 1,
281
+ self._voxel_size,
282
+ self._voxel_size,
283
+ self._voxel_size,
284
+ ),
285
+ dtype=int,
286
+ device=device,
287
+ )
288
+ self._action_rot_x_one_hot_zeros = torch.zeros(
289
+ (self._batch_size, self._num_rotation_classes), dtype=int, device=device
290
+ )
291
+ self._action_rot_y_one_hot_zeros = torch.zeros(
292
+ (self._batch_size, self._num_rotation_classes), dtype=int, device=device
293
+ )
294
+ self._action_rot_z_one_hot_zeros = torch.zeros(
295
+ (self._batch_size, self._num_rotation_classes), dtype=int, device=device
296
+ )
297
+ self._action_grip_one_hot_zeros = torch.zeros(
298
+ (self._batch_size, 2), dtype=int, device=device
299
+ )
300
+ self._action_ignore_collisions_one_hot_zeros = torch.zeros(
301
+ (self._batch_size, 2), dtype=int, device=device
302
+ )
303
+
304
+ # print total params
305
+ logging.info(
306
+ "# Q Params: %d"
307
+ % sum(
308
+ p.numel()
309
+ for name, p in self._q.named_parameters()
310
+ if p.requires_grad and "clip" not in name
311
+ )
312
+ )
313
+ else:
314
+ for param in self._q.parameters():
315
+ param.requires_grad = False
316
+
317
+ # load CLIP for encoding language goals during evaluation
318
+ model, _ = load_clip("RN50", jit=False)
319
+ self._clip_rn50 = build_model(model.state_dict())
320
+ self._clip_rn50 = self._clip_rn50.float().to(device)
321
+ self._clip_rn50.eval()
322
+ del model
323
+
324
+ self._voxelizer.to(device)
325
+ self._q.to(device)
326
+
327
+ def _extract_crop(self, pixel_action, observation):
328
+ # Pixel action will now be (B, 2)
329
+ # observation = stack_on_channel(observation)
330
+ h = observation.shape[-1]
331
+ top_left_corner = torch.clamp(
332
+ pixel_action - self._image_crop_size // 2, 0, h - self._image_crop_size
333
+ )
334
+ grid = self._grid_for_crop + top_left_corner.unsqueeze(1)
335
+ grid = ((grid / float(h)) * 2.0) - 1.0 # between -1 and 1
336
+ # Used for cropping the images across a batch
337
+ # swap fro y x, to x, y
338
+ grid = torch.cat((grid[:, :, :, 1:2], grid[:, :, :, 0:1]), dim=-1)
339
+ crop = F.grid_sample(observation, grid, mode="nearest", align_corners=True)
340
+ return crop
341
+
342
+ def _preprocess_inputs(self, replay_sample):
343
+ obs = []
344
+ pcds = []
345
+ self._crop_summary = []
346
+ for n in self._camera_names:
347
+ rgb = replay_sample["%s_rgb" % n]
348
+ pcd = replay_sample["%s_point_cloud" % n]
349
+
350
+ obs.append([rgb, pcd])
351
+ pcds.append(pcd)
352
+ return obs, pcds
353
+
354
+ def _act_preprocess_inputs(self, observation):
355
+ obs, pcds = [], []
356
+ for n in self._camera_names:
357
+ rgb = observation["%s_rgb" % n]
358
+ pcd = observation["%s_point_cloud" % n]
359
+
360
+ obs.append([rgb, pcd])
361
+ pcds.append(pcd)
362
+ return obs, pcds
363
+
364
+ def _get_value_from_voxel_index(self, q, voxel_idx):
365
+ b, c, d, h, w = q.shape
366
+ q_trans_flat = q.view(b, c, d * h * w)
367
+ flat_indicies = (
368
+ voxel_idx[:, 0] * d * h + voxel_idx[:, 1] * h + voxel_idx[:, 2]
369
+ )[:, None].int()
370
+ highest_idxs = flat_indicies.unsqueeze(-1).repeat(1, c, 1)
371
+ chosen_voxel_values = q_trans_flat.gather(2, highest_idxs)[
372
+ ..., 0
373
+ ] # (B, trans + rot + grip)
374
+ return chosen_voxel_values
375
+
376
+ def _get_value_from_rot_and_grip(self, rot_grip_q, rot_and_grip_idx):
377
+ q_rot = torch.stack(
378
+ torch.split(
379
+ rot_grip_q[:, :-2], int(360 // self._rotation_resolution), dim=1
380
+ ),
381
+ dim=1,
382
+ ) # B, 3, 72
383
+ q_grip = rot_grip_q[:, -2:]
384
+ rot_and_grip_values = torch.cat(
385
+ [
386
+ q_rot[:, 0].gather(1, rot_and_grip_idx[:, 0:1]),
387
+ q_rot[:, 1].gather(1, rot_and_grip_idx[:, 1:2]),
388
+ q_rot[:, 2].gather(1, rot_and_grip_idx[:, 2:3]),
389
+ q_grip.gather(1, rot_and_grip_idx[:, 3:4]),
390
+ ],
391
+ -1,
392
+ )
393
+ return rot_and_grip_values
394
+
395
+ def _celoss(self, pred, labels):
396
+ return self._cross_entropy_loss(pred, labels.argmax(-1))
397
+
398
+ def _softmax_q_trans(self, q):
399
+ q_shape = q.shape
400
+ return F.softmax(q.reshape(q_shape[0], -1), dim=1).reshape(q_shape)
401
+
402
+ def _softmax_q_rot_grip(self, q_rot_grip):
403
+ q_rot_x_flat = q_rot_grip[
404
+ :, 0 * self._num_rotation_classes : 1 * self._num_rotation_classes
405
+ ]
406
+ q_rot_y_flat = q_rot_grip[
407
+ :, 1 * self._num_rotation_classes : 2 * self._num_rotation_classes
408
+ ]
409
+ q_rot_z_flat = q_rot_grip[
410
+ :, 2 * self._num_rotation_classes : 3 * self._num_rotation_classes
411
+ ]
412
+ q_grip_flat = q_rot_grip[:, 3 * self._num_rotation_classes :]
413
+
414
+ q_rot_x_flat_softmax = F.softmax(q_rot_x_flat, dim=1)
415
+ q_rot_y_flat_softmax = F.softmax(q_rot_y_flat, dim=1)
416
+ q_rot_z_flat_softmax = F.softmax(q_rot_z_flat, dim=1)
417
+ q_grip_flat_softmax = F.softmax(q_grip_flat, dim=1)
418
+
419
+ return torch.cat(
420
+ [
421
+ q_rot_x_flat_softmax,
422
+ q_rot_y_flat_softmax,
423
+ q_rot_z_flat_softmax,
424
+ q_grip_flat_softmax,
425
+ ],
426
+ dim=1,
427
+ )
428
+
429
+ def _softmax_ignore_collision(self, q_collision):
430
+ q_collision_softmax = F.softmax(q_collision, dim=1)
431
+ return q_collision_softmax
432
+
433
+ def update(self, step: int, replay_sample: dict) -> dict:
434
+ action_trans = replay_sample["trans_action_indicies"][
435
+ :, self._layer * 3 : self._layer * 3 + 3
436
+ ].int()
437
+ action_rot_grip = replay_sample["rot_grip_action_indicies"].int()
438
+ action_gripper_pose = replay_sample["gripper_pose"]
439
+ action_ignore_collisions = replay_sample["ignore_collisions"].int()
440
+ lang_goal_emb = replay_sample["lang_goal_emb"].float()
441
+ lang_token_embs = replay_sample["lang_token_embs"].float()
442
+ prev_layer_voxel_grid = replay_sample.get("prev_layer_voxel_grid", None)
443
+ prev_layer_bounds = replay_sample.get("prev_layer_bounds", None)
444
+ device = self._device
445
+
446
+ bounds = self._coordinate_bounds.to(device)
447
+ if self._layer > 0:
448
+ cp = replay_sample["attention_coordinate_layer_%d" % (self._layer - 1)]
449
+ bounds = torch.cat(
450
+ [cp - self._bounds_offset, cp + self._bounds_offset], dim=1
451
+ )
452
+
453
+ proprio = None
454
+ if self._include_low_dim_state:
455
+ proprio = replay_sample["low_dim_state"]
456
+
457
+ obs, pcd = self._preprocess_inputs(replay_sample)
458
+
459
+ # batch size
460
+ bs = pcd[0].shape[0]
461
+
462
+ # SE(3) augmentation of point clouds and actions
463
+ if self._transform_augmentation:
464
+ action_trans, action_rot_grip, pcd = apply_se3_augmentation(
465
+ pcd,
466
+ action_gripper_pose,
467
+ action_trans,
468
+ action_rot_grip,
469
+ bounds,
470
+ self._layer,
471
+ self._transform_augmentation_xyz,
472
+ self._transform_augmentation_rpy,
473
+ self._transform_augmentation_rot_resolution,
474
+ self._voxel_size,
475
+ self._rotation_resolution,
476
+ self._device,
477
+ )
478
+
479
+ # forward pass
480
+ q_trans, q_rot_grip, q_collision, voxel_grid = self._q(
481
+ obs,
482
+ proprio,
483
+ pcd,
484
+ lang_goal_emb,
485
+ lang_token_embs,
486
+ bounds,
487
+ prev_layer_bounds,
488
+ prev_layer_voxel_grid,
489
+ )
490
+
491
+ # argmax to choose best action
492
+ (
493
+ coords,
494
+ rot_and_grip_indicies,
495
+ ignore_collision_indicies,
496
+ ) = self._q.choose_highest_action(q_trans, q_rot_grip, q_collision)
497
+
498
+ q_trans_loss, q_rot_loss, q_grip_loss, q_collision_loss = 0.0, 0.0, 0.0, 0.0
499
+
500
+ # translation one-hot
501
+ action_trans_one_hot = self._action_trans_one_hot_zeros.clone()
502
+ for b in range(bs):
503
+ gt_coord = action_trans[b, :].int()
504
+ action_trans_one_hot[b, :, gt_coord[0], gt_coord[1], gt_coord[2]] = 1
505
+
506
+ # translation loss
507
+ q_trans_flat = q_trans.view(bs, -1)
508
+ action_trans_one_hot_flat = action_trans_one_hot.view(bs, -1)
509
+ q_trans_loss = self._celoss(q_trans_flat, action_trans_one_hot_flat)
510
+
511
+ with_rot_and_grip = rot_and_grip_indicies is not None
512
+ if with_rot_and_grip:
513
+ # rotation, gripper, and collision one-hots
514
+ action_rot_x_one_hot = self._action_rot_x_one_hot_zeros.clone()
515
+ action_rot_y_one_hot = self._action_rot_y_one_hot_zeros.clone()
516
+ action_rot_z_one_hot = self._action_rot_z_one_hot_zeros.clone()
517
+ action_grip_one_hot = self._action_grip_one_hot_zeros.clone()
518
+ action_ignore_collisions_one_hot = (
519
+ self._action_ignore_collisions_one_hot_zeros.clone()
520
+ )
521
+
522
+ for b in range(bs):
523
+ gt_rot_grip = action_rot_grip[b, :].int()
524
+ action_rot_x_one_hot[b, gt_rot_grip[0]] = 1
525
+ action_rot_y_one_hot[b, gt_rot_grip[1]] = 1
526
+ action_rot_z_one_hot[b, gt_rot_grip[2]] = 1
527
+ action_grip_one_hot[b, gt_rot_grip[3]] = 1
528
+
529
+ gt_ignore_collisions = action_ignore_collisions[b, :].int()
530
+ action_ignore_collisions_one_hot[b, gt_ignore_collisions[0]] = 1
531
+
532
+ # flatten predictions
533
+ q_rot_x_flat = q_rot_grip[
534
+ :, 0 * self._num_rotation_classes : 1 * self._num_rotation_classes
535
+ ]
536
+ q_rot_y_flat = q_rot_grip[
537
+ :, 1 * self._num_rotation_classes : 2 * self._num_rotation_classes
538
+ ]
539
+ q_rot_z_flat = q_rot_grip[
540
+ :, 2 * self._num_rotation_classes : 3 * self._num_rotation_classes
541
+ ]
542
+ q_grip_flat = q_rot_grip[:, 3 * self._num_rotation_classes :]
543
+ q_ignore_collisions_flat = q_collision
544
+
545
+ # rotation loss
546
+ q_rot_loss += self._celoss(q_rot_x_flat, action_rot_x_one_hot)
547
+ q_rot_loss += self._celoss(q_rot_y_flat, action_rot_y_one_hot)
548
+ q_rot_loss += self._celoss(q_rot_z_flat, action_rot_z_one_hot)
549
+
550
+ # gripper loss
551
+ q_grip_loss += self._celoss(q_grip_flat, action_grip_one_hot)
552
+
553
+ # collision loss
554
+ q_collision_loss += self._celoss(
555
+ q_ignore_collisions_flat, action_ignore_collisions_one_hot
556
+ )
557
+
558
+ combined_losses = (
559
+ (q_trans_loss * self._trans_loss_weight)
560
+ + (q_rot_loss * self._rot_loss_weight)
561
+ + (q_grip_loss * self._grip_loss_weight)
562
+ + (q_collision_loss * self._collision_loss_weight)
563
+ )
564
+ total_loss = combined_losses.mean()
565
+
566
+ self._optimizer.zero_grad()
567
+ total_loss.backward()
568
+ self._optimizer.step()
569
+
570
+ self._summaries = {
571
+ "losses/total_loss": total_loss,
572
+ "losses/trans_loss": q_trans_loss.mean(),
573
+ "losses/rot_loss": q_rot_loss.mean() if with_rot_and_grip else 0.0,
574
+ "losses/grip_loss": q_grip_loss.mean() if with_rot_and_grip else 0.0,
575
+ "losses/collision_loss": q_collision_loss.mean()
576
+ if with_rot_and_grip
577
+ else 0.0,
578
+ }
579
+
580
+ if self._lr_scheduler:
581
+ self._scheduler.step()
582
+ self._summaries["learning_rate"] = self._scheduler.get_last_lr()[0]
583
+
584
+ self._vis_voxel_grid = voxel_grid[0]
585
+ self._vis_translation_qvalue = self._softmax_q_trans(q_trans[0])
586
+ self._vis_max_coordinate = coords[0]
587
+ self._vis_gt_coordinate = action_trans[0]
588
+
589
+ # Note: PerAct doesn't use multi-layer voxel grids like C2FARM
590
+ # stack prev_layer_voxel_grid(s) from previous layers into a list
591
+ if prev_layer_voxel_grid is None:
592
+ prev_layer_voxel_grid = [voxel_grid]
593
+ else:
594
+ prev_layer_voxel_grid = prev_layer_voxel_grid + [voxel_grid]
595
+
596
+ # stack prev_layer_bound(s) from previous layers into a list
597
+ if prev_layer_bounds is None:
598
+ prev_layer_bounds = [self._coordinate_bounds.repeat(bs, 1)]
599
+ else:
600
+ prev_layer_bounds = prev_layer_bounds + [bounds]
601
+
602
+ return {
603
+ "total_loss": total_loss,
604
+ "prev_layer_voxel_grid": prev_layer_voxel_grid,
605
+ "prev_layer_bounds": prev_layer_bounds,
606
+ }
607
+
608
+ def act(self, step: int, observation: dict, deterministic=False) -> ActResult:
609
+ deterministic = True
610
+ bounds = self._coordinate_bounds
611
+ prev_layer_voxel_grid = observation.get("prev_layer_voxel_grid", None)
612
+ prev_layer_bounds = observation.get("prev_layer_bounds", None)
613
+ lang_goal_tokens = observation.get("lang_goal_tokens", None).long()
614
+
615
+ # extract CLIP language embs
616
+ with torch.no_grad():
617
+ lang_goal_tokens = lang_goal_tokens.to(device=self._device)
618
+ (
619
+ lang_goal_emb,
620
+ lang_token_embs,
621
+ ) = self._clip_rn50.encode_text_with_embeddings(lang_goal_tokens[0])
622
+
623
+ # voxelization resolution
624
+ res = (bounds[:, 3:] - bounds[:, :3]) / self._voxel_size
625
+ max_rot_index = int(360 // self._rotation_resolution)
626
+ proprio = None
627
+
628
+ if self._include_low_dim_state:
629
+ proprio = observation["low_dim_state"]
630
+ proprio = proprio[0].to(self._device)
631
+
632
+ obs, pcd = self._act_preprocess_inputs(observation)
633
+
634
+ # correct batch size and device
635
+ obs = [[o[0][0].to(self._device), o[1][0].to(self._device)] for o in obs]
636
+ pcd = [p[0].to(self._device) for p in pcd]
637
+ lang_goal_emb = lang_goal_emb.to(self._device)
638
+ lang_token_embs = lang_token_embs.to(self._device)
639
+ bounds = torch.as_tensor(bounds, device=self._device)
640
+ prev_layer_voxel_grid = (
641
+ prev_layer_voxel_grid.to(self._device)
642
+ if prev_layer_voxel_grid is not None
643
+ else None
644
+ )
645
+ prev_layer_bounds = (
646
+ prev_layer_bounds.to(self._device)
647
+ if prev_layer_bounds is not None
648
+ else None
649
+ )
650
+
651
+ # inference
652
+ q_trans, q_rot_grip, q_ignore_collisions, vox_grid = self._q(
653
+ obs,
654
+ proprio,
655
+ pcd,
656
+ lang_goal_emb,
657
+ lang_token_embs,
658
+ bounds,
659
+ prev_layer_bounds,
660
+ prev_layer_voxel_grid,
661
+ )
662
+
663
+ # softmax Q predictions
664
+ q_trans = self._softmax_q_trans(q_trans)
665
+ q_rot_grip = (
666
+ self._softmax_q_rot_grip(q_rot_grip)
667
+ if q_rot_grip is not None
668
+ else q_rot_grip
669
+ )
670
+ q_ignore_collisions = (
671
+ self._softmax_ignore_collision(q_ignore_collisions)
672
+ if q_ignore_collisions is not None
673
+ else q_ignore_collisions
674
+ )
675
+
676
+ # argmax Q predictions
677
+ (
678
+ coords,
679
+ rot_and_grip_indicies,
680
+ ignore_collisions,
681
+ ) = self._q.choose_highest_action(q_trans, q_rot_grip, q_ignore_collisions)
682
+
683
+ rot_grip_action = rot_and_grip_indicies if q_rot_grip is not None else None
684
+ ignore_collisions_action = (
685
+ ignore_collisions.int() if ignore_collisions is not None else None
686
+ )
687
+
688
+ coords = coords.int()
689
+ attention_coordinate = bounds[:, :3] + res * coords + res / 2
690
+
691
+ # stack prev_layer_voxel_grid(s) into a list
692
+ # NOTE: PerAct doesn't used multi-layer voxel grids like C2FARM
693
+ if prev_layer_voxel_grid is None:
694
+ prev_layer_voxel_grid = [vox_grid]
695
+ else:
696
+ prev_layer_voxel_grid = prev_layer_voxel_grid + [vox_grid]
697
+
698
+ if prev_layer_bounds is None:
699
+ prev_layer_bounds = [bounds]
700
+ else:
701
+ prev_layer_bounds = prev_layer_bounds + [bounds]
702
+
703
+ observation_elements = {
704
+ "attention_coordinate": attention_coordinate,
705
+ "prev_layer_voxel_grid": prev_layer_voxel_grid,
706
+ "prev_layer_bounds": prev_layer_bounds,
707
+ }
708
+ info = {
709
+ "voxel_grid_depth%d" % self._layer: vox_grid,
710
+ "q_depth%d" % self._layer: q_trans,
711
+ "voxel_idx_depth%d" % self._layer: coords,
712
+ }
713
+ self._act_voxel_grid = vox_grid[0]
714
+ self._act_max_coordinate = coords[0]
715
+ self._act_qvalues = q_trans[0].detach()
716
+ return ActResult(
717
+ (coords, rot_grip_action, ignore_collisions_action),
718
+ observation_elements=observation_elements,
719
+ info=info,
720
+ )
721
+
722
+ def update_summaries(self) -> List[Summary]:
723
+ summaries = [
724
+ ImageSummary(
725
+ "%s/update_qattention" % self._name,
726
+ transforms.ToTensor()(
727
+ visualise_voxel(
728
+ self._vis_voxel_grid.detach().cpu().numpy(),
729
+ self._vis_translation_qvalue.detach().cpu().numpy(),
730
+ self._vis_max_coordinate.detach().cpu().numpy(),
731
+ self._vis_gt_coordinate.detach().cpu().numpy(),
732
+ )
733
+ ),
734
+ )
735
+ ]
736
+
737
+ for n, v in self._summaries.items():
738
+ summaries.append(ScalarSummary("%s/%s" % (self._name, n), v))
739
+
740
+ for name, crop in self._crop_summary:
741
+ crops = (torch.cat(torch.split(crop, 3, dim=1), dim=3) + 1.0) / 2.0
742
+ summaries.extend([ImageSummary("%s/crops/%s" % (self._name, name), crops)])
743
+
744
+ for tag, param in self._q.named_parameters():
745
+ # assert not torch.isnan(param.grad.abs() <= 1.0).all()
746
+ summaries.append(
747
+ HistogramSummary("%s/gradient/%s" % (self._name, tag), param.grad)
748
+ )
749
+ summaries.append(
750
+ HistogramSummary("%s/weight/%s" % (self._name, tag), param.data)
751
+ )
752
+
753
+ return summaries
754
+
755
+ def act_summaries(self) -> List[Summary]:
756
+ return [
757
+ ImageSummary(
758
+ "%s/act_Qattention" % self._name,
759
+ transforms.ToTensor()(
760
+ visualise_voxel(
761
+ self._act_voxel_grid.cpu().numpy(),
762
+ self._act_qvalues.cpu().numpy(),
763
+ self._act_max_coordinate.cpu().numpy(),
764
+ )
765
+ ),
766
+ )
767
+ ]
768
+
769
+ def load_weights(self, savedir: str):
770
+ device = (
771
+ self._device
772
+ if not self._training
773
+ else torch.device("cuda:%d" % self._device)
774
+ )
775
+ weight_file = os.path.join(savedir, "%s.pt" % self._name)
776
+ state_dict = torch.load(weight_file, map_location=device)
777
+
778
+ # load only keys that are in the current model
779
+ merged_state_dict = self._q.state_dict()
780
+ for k, v in state_dict.items():
781
+ if not self._training:
782
+ k = k.replace("_qnet.module", "_qnet")
783
+ if k in merged_state_dict:
784
+ merged_state_dict[k] = v
785
+ else:
786
+ if "_voxelizer" not in k:
787
+ logging.warning("key %s not found in checkpoint" % k)
788
+ if not self._training:
789
+ # reshape voxelizer weights
790
+ b = merged_state_dict["_voxelizer._ones_max_coords"].shape[0]
791
+ merged_state_dict["_voxelizer._ones_max_coords"] = merged_state_dict[
792
+ "_voxelizer._ones_max_coords"
793
+ ][0:1]
794
+ flat_shape = merged_state_dict["_voxelizer._flat_output"].shape[0]
795
+ merged_state_dict["_voxelizer._flat_output"] = merged_state_dict[
796
+ "_voxelizer._flat_output"
797
+ ][0 : flat_shape // b]
798
+ merged_state_dict["_voxelizer._tiled_batch_indices"] = merged_state_dict[
799
+ "_voxelizer._tiled_batch_indices"
800
+ ][0:1]
801
+ merged_state_dict["_voxelizer._index_grid"] = merged_state_dict[
802
+ "_voxelizer._index_grid"
803
+ ][0:1]
804
+ self._q.load_state_dict(merged_state_dict)
805
+ print("loaded weights from %s" % weight_file)
806
+
807
+ def save_weights(self, savedir: str):
808
+ torch.save(self._q.state_dict(), os.path.join(savedir, "%s.pt" % self._name))
external/peract_bimanual/agents/peract_bc/qattention_stack_agent.py ADDED
@@ -0,0 +1,132 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List
2
+
3
+ import torch
4
+ from yarr.agents.agent import Agent, ActResult, Summary
5
+
6
+ import numpy as np
7
+
8
+ from helpers import utils
9
+ from agents.peract_bc.qattention_peract_bc_agent import QAttentionPerActBCAgent
10
+
11
+ NAME = "QAttentionStackAgent"
12
+
13
+
14
+ class QAttentionStackAgent(Agent):
15
+ def __init__(
16
+ self,
17
+ qattention_agents: List[QAttentionPerActBCAgent],
18
+ rotation_resolution: float,
19
+ camera_names: List[str],
20
+ rotation_prediction_depth: int = 0,
21
+ ):
22
+ super(QAttentionStackAgent, self).__init__()
23
+ self._qattention_agents = qattention_agents
24
+ self._rotation_resolution = rotation_resolution
25
+ self._camera_names = camera_names
26
+ self._rotation_prediction_depth = rotation_prediction_depth
27
+
28
+ def build(self, training: bool, device=None) -> None:
29
+ self._device = device
30
+ if self._device is None:
31
+ self._device = torch.device("cpu")
32
+ for qa in self._qattention_agents:
33
+ qa.build(training, device)
34
+
35
+ def update(self, step: int, replay_sample: dict) -> dict:
36
+ priorities = 0
37
+ total_losses = 0.0
38
+ for qa in self._qattention_agents:
39
+ update_dict = qa.update(step, replay_sample)
40
+ replay_sample.update(update_dict)
41
+ total_losses += update_dict["total_loss"]
42
+ return {
43
+ "total_losses": total_losses,
44
+ }
45
+
46
+ def act(self, step: int, observation: dict, deterministic=False) -> ActResult:
47
+ observation_elements = {}
48
+ translation_results, rot_grip_results, ignore_collisions_results = [], [], []
49
+ infos = {}
50
+ for depth, qagent in enumerate(self._qattention_agents):
51
+ act_results = qagent.act(step, observation, deterministic)
52
+ attention_coordinate = (
53
+ act_results.observation_elements["attention_coordinate"].cpu().numpy()
54
+ )
55
+ observation_elements[
56
+ "attention_coordinate_layer_%d" % depth
57
+ ] = attention_coordinate[0]
58
+
59
+ translation_idxs, rot_grip_idxs, ignore_collisions_idxs = act_results.action
60
+ translation_results.append(translation_idxs)
61
+ if rot_grip_idxs is not None:
62
+ rot_grip_results.append(rot_grip_idxs)
63
+ if ignore_collisions_idxs is not None:
64
+ ignore_collisions_results.append(ignore_collisions_idxs)
65
+
66
+ observation["attention_coordinate"] = act_results.observation_elements[
67
+ "attention_coordinate"
68
+ ]
69
+ observation["prev_layer_voxel_grid"] = act_results.observation_elements[
70
+ "prev_layer_voxel_grid"
71
+ ]
72
+ observation["prev_layer_bounds"] = act_results.observation_elements[
73
+ "prev_layer_bounds"
74
+ ]
75
+
76
+ for n in self._camera_names:
77
+ px, py = utils.point_to_pixel_index(
78
+ attention_coordinate[0],
79
+ observation["%s_camera_extrinsics" % n][0, 0].cpu().numpy(),
80
+ observation["%s_camera_intrinsics" % n][0, 0].cpu().numpy(),
81
+ )
82
+ pc_t = torch.tensor(
83
+ [[[py, px]]], dtype=torch.float32, device=self._device
84
+ )
85
+ observation["%s_pixel_coord" % n] = pc_t
86
+ observation_elements["%s_pixel_coord" % n] = [py, px]
87
+
88
+ infos.update(act_results.info)
89
+
90
+ rgai = torch.cat(rot_grip_results, 1)[0].cpu().numpy()
91
+ ignore_collisions = float(
92
+ torch.cat(ignore_collisions_results, 1)[0].cpu().numpy()
93
+ )
94
+ observation_elements["trans_action_indicies"] = (
95
+ torch.cat(translation_results, 1)[0].cpu().numpy()
96
+ )
97
+ observation_elements["rot_grip_action_indicies"] = rgai
98
+ continuous_action = np.concatenate(
99
+ [
100
+ act_results.observation_elements["attention_coordinate"]
101
+ .cpu()
102
+ .numpy()[0],
103
+ utils.discrete_euler_to_quaternion(
104
+ rgai[-4:-1], self._rotation_resolution
105
+ ),
106
+ rgai[-1:],
107
+ [ignore_collisions],
108
+ ]
109
+ )
110
+ return ActResult(
111
+ continuous_action, observation_elements=observation_elements, info=infos
112
+ )
113
+
114
+ def update_summaries(self) -> List[Summary]:
115
+ summaries = []
116
+ for qa in self._qattention_agents:
117
+ summaries.extend(qa.update_summaries())
118
+ return summaries
119
+
120
+ def act_summaries(self) -> List[Summary]:
121
+ s = []
122
+ for qa in self._qattention_agents:
123
+ s.extend(qa.act_summaries())
124
+ return s
125
+
126
+ def load_weights(self, savedir: str):
127
+ for qa in self._qattention_agents:
128
+ qa.load_weights(savedir)
129
+
130
+ def save_weights(self, savedir: str):
131
+ for qa in self._qattention_agents:
132
+ qa.save_weights(savedir)
external/peract_bimanual/agents/replay_utils.py ADDED
@@ -0,0 +1,643 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ from typing import List
3
+
4
+ import numpy as np
5
+ from rlbench.backend.observation import Observation
6
+ from rlbench.observation_config import ObservationConfig
7
+ import rlbench.utils as rlbench_utils
8
+ from rlbench.demo import Demo
9
+ from yarr.replay_buffer.replay_buffer import ReplayBuffer
10
+
11
+ from helpers import demo_loading_utils, utils
12
+ from helpers import observation_utils
13
+ from helpers.clip.core.clip import tokenize
14
+
15
+
16
+ from yarr.replay_buffer.prioritized_replay_buffer import ObservationElement
17
+ from yarr.replay_buffer.replay_buffer import ReplayElement
18
+ from yarr.replay_buffer.task_uniform_replay_buffer import TaskUniformReplayBuffer
19
+
20
+
21
+ import torch
22
+ from torch.multiprocessing import Process, Value, Manager
23
+ from helpers.clip.core.clip import build_model, load_clip
24
+ from omegaconf import DictConfig
25
+
26
+
27
+ REWARD_SCALE = 100.0
28
+ LOW_DIM_SIZE = 4
29
+
30
+
31
+ def create_replay(cfg, replay_path):
32
+ if cfg.method.robot_name == "bimanual":
33
+ return create_bimanual_replay(
34
+ cfg.replay.batch_size,
35
+ cfg.replay.timesteps,
36
+ cfg.replay.prioritisation,
37
+ cfg.replay.task_uniform,
38
+ replay_path if cfg.replay.use_disk else None,
39
+ cfg.rlbench.cameras,
40
+ cfg.method.voxel_sizes,
41
+ cfg.rlbench.camera_resolution,
42
+ )
43
+ else:
44
+ return create_unimanual_replay(
45
+ cfg.replay.batch_size,
46
+ cfg.replay.timesteps,
47
+ cfg.replay.prioritisation,
48
+ cfg.replay.task_uniform,
49
+ replay_path if cfg.replay.use_disk else None,
50
+ cfg.rlbench.cameras,
51
+ cfg.method.voxel_sizes,
52
+ cfg.rlbench.camera_resolution,
53
+ )
54
+
55
+
56
+ def create_bimanual_replay(
57
+ batch_size: int,
58
+ timesteps: int,
59
+ prioritisation: bool,
60
+ task_uniform: bool,
61
+ save_dir: str,
62
+ cameras: list,
63
+ voxel_sizes,
64
+ image_size=[128, 128],
65
+ replay_size=3e5,
66
+ ):
67
+ trans_indicies_size = 3 * len(voxel_sizes)
68
+ rot_and_grip_indicies_size = 3 + 1
69
+ gripper_pose_size = 7
70
+ ignore_collisions_size = 1
71
+ max_token_seq_len = 77
72
+ lang_feat_dim = 1024
73
+ lang_emb_dim = 512
74
+
75
+ # low_dim_state
76
+ observation_elements = []
77
+ observation_elements.append(
78
+ ObservationElement("right_low_dim_state", (LOW_DIM_SIZE,), np.float32)
79
+ )
80
+ observation_elements.append(
81
+ ObservationElement("left_low_dim_state", (LOW_DIM_SIZE,), np.float32)
82
+ )
83
+
84
+ # rgb, depth, point cloud, intrinsics, extrinsics
85
+ for cname in cameras:
86
+ observation_elements.append(
87
+ # color, height, width
88
+ ObservationElement(
89
+ "%s_rgb" % cname,
90
+ (
91
+ 3,
92
+ image_size[1],
93
+ image_size[0],
94
+ ),
95
+ np.float32,
96
+ )
97
+ )
98
+ observation_elements.append(
99
+ ObservationElement(
100
+ "%s_point_cloud" % cname, (3, image_size[1], image_size[0]), np.float16
101
+ )
102
+ ) # see pyrep/objects/vision_sensor.py on how pointclouds are extracted from depth frames
103
+ observation_elements.append(
104
+ ObservationElement(
105
+ "%s_camera_extrinsics" % cname,
106
+ (
107
+ 4,
108
+ 4,
109
+ ),
110
+ np.float32,
111
+ )
112
+ )
113
+ observation_elements.append(
114
+ ObservationElement(
115
+ "%s_camera_intrinsics" % cname,
116
+ (
117
+ 3,
118
+ 3,
119
+ ),
120
+ np.float32,
121
+ )
122
+ )
123
+
124
+ # discretized translation, discretized rotation, discrete ignore collision, 6-DoF gripper pose, and pre-trained language embeddings
125
+ for robot_name in ["right", "left"]:
126
+ observation_elements.extend(
127
+ [
128
+ ReplayElement(
129
+ f"{robot_name}_trans_action_indicies",
130
+ (trans_indicies_size,),
131
+ np.int32,
132
+ ),
133
+ ReplayElement(
134
+ f"{robot_name}_rot_grip_action_indicies",
135
+ (rot_and_grip_indicies_size,),
136
+ np.int32,
137
+ ),
138
+ ReplayElement(
139
+ f"{robot_name}_ignore_collisions",
140
+ (ignore_collisions_size,),
141
+ np.int32,
142
+ ),
143
+ ReplayElement(
144
+ f"{robot_name}_gripper_pose", (gripper_pose_size,), np.float32
145
+ ),
146
+ ]
147
+ )
148
+
149
+ observation_elements.extend(
150
+ [
151
+ ReplayElement("lang_goal_emb", (lang_feat_dim,), np.float32),
152
+ ReplayElement(
153
+ "lang_token_embs",
154
+ (
155
+ max_token_seq_len,
156
+ lang_emb_dim,
157
+ ),
158
+ np.float32,
159
+ ), # extracted from CLIP's language encoder
160
+ ReplayElement("task", (), str),
161
+ ReplayElement(
162
+ "lang_goal", (1,), object
163
+ ), # language goal string for debugging and visualization
164
+ ]
165
+ )
166
+
167
+ extra_replay_elements = [
168
+ ReplayElement("demo", (), bool),
169
+ ]
170
+
171
+ replay_buffer = TaskUniformReplayBuffer(
172
+ save_dir=save_dir,
173
+ batch_size=batch_size,
174
+ timesteps=timesteps,
175
+ replay_capacity=int(replay_size),
176
+ action_shape=(8 * 2,),
177
+ action_dtype=np.float32,
178
+ reward_shape=(),
179
+ reward_dtype=np.float32,
180
+ update_horizon=1,
181
+ observation_elements=observation_elements,
182
+ extra_replay_elements=extra_replay_elements,
183
+ )
184
+ return replay_buffer
185
+
186
+
187
+ def create_unimanual_replay(
188
+ batch_size: int,
189
+ timesteps: int,
190
+ prioritisation: bool,
191
+ task_uniform: bool,
192
+ save_dir: str,
193
+ cameras: list,
194
+ voxel_sizes,
195
+ image_size=[128, 128],
196
+ replay_size=3e5,
197
+ ):
198
+ trans_indicies_size = 3 * len(voxel_sizes)
199
+ rot_and_grip_indicies_size = 3 + 1
200
+ gripper_pose_size = 7
201
+ ignore_collisions_size = 1
202
+ max_token_seq_len = 77
203
+ lang_feat_dim = 1024
204
+ lang_emb_dim = 512
205
+
206
+ # low_dim_state
207
+ observation_elements = []
208
+ observation_elements.append(
209
+ ObservationElement("low_dim_state", (LOW_DIM_SIZE,), np.float32)
210
+ )
211
+
212
+ # rgb, depth, point cloud, intrinsics, extrinsics
213
+ for cname in cameras:
214
+ observation_elements.append(
215
+ ObservationElement(
216
+ "%s_rgb" % cname,
217
+ (
218
+ 3,
219
+ *image_size,
220
+ ),
221
+ np.float32,
222
+ )
223
+ )
224
+ observation_elements.append(
225
+ ObservationElement("%s_point_cloud" % cname, (3, *image_size), np.float32)
226
+ ) # see pyrep/objects/vision_sensor.py on how pointclouds are extracted from depth frames
227
+ observation_elements.append(
228
+ ObservationElement(
229
+ "%s_camera_extrinsics" % cname,
230
+ (
231
+ 4,
232
+ 4,
233
+ ),
234
+ np.float32,
235
+ )
236
+ )
237
+ observation_elements.append(
238
+ ObservationElement(
239
+ "%s_camera_intrinsics" % cname,
240
+ (
241
+ 3,
242
+ 3,
243
+ ),
244
+ np.float32,
245
+ )
246
+ )
247
+
248
+ # discretized translation, discretized rotation, discrete ignore collision, 6-DoF gripper pose, and pre-trained language embeddings
249
+ observation_elements.extend(
250
+ [
251
+ ReplayElement("trans_action_indicies", (trans_indicies_size,), np.int32),
252
+ ReplayElement(
253
+ "rot_grip_action_indicies", (rot_and_grip_indicies_size,), np.int32
254
+ ),
255
+ ReplayElement("ignore_collisions", (ignore_collisions_size,), np.int32),
256
+ ReplayElement("gripper_pose", (gripper_pose_size,), np.float32),
257
+ ReplayElement("lang_goal_emb", (lang_feat_dim,), np.float32),
258
+ ReplayElement(
259
+ "lang_token_embs",
260
+ (
261
+ max_token_seq_len,
262
+ lang_emb_dim,
263
+ ),
264
+ np.float32,
265
+ ), # extracted from CLIP's language encoder
266
+ ReplayElement("task", (), str),
267
+ ReplayElement(
268
+ "lang_goal", (1,), object
269
+ ), # language goal string for debugging and visualization
270
+ ]
271
+ )
272
+
273
+ extra_replay_elements = [
274
+ ReplayElement("demo", (), bool),
275
+ ]
276
+
277
+ replay_buffer = TaskUniformReplayBuffer(
278
+ save_dir=save_dir,
279
+ batch_size=batch_size,
280
+ timesteps=timesteps,
281
+ replay_capacity=int(replay_size),
282
+ action_shape=(8,),
283
+ action_dtype=np.float32,
284
+ reward_shape=(),
285
+ reward_dtype=np.float32,
286
+ update_horizon=1,
287
+ observation_elements=observation_elements,
288
+ extra_replay_elements=extra_replay_elements,
289
+ )
290
+ return replay_buffer
291
+
292
+
293
+ def _get_action(
294
+ obs_tp1: Observation,
295
+ obs_tm1: Observation,
296
+ rlbench_scene_bounds: List[float], # metric 3D bounds of the scene
297
+ voxel_sizes: List[int],
298
+ bounds_offset: List[float],
299
+ rotation_resolution: int,
300
+ crop_augmentation: bool,
301
+ ):
302
+ quat = utils.normalize_quaternion(obs_tp1.gripper_pose[3:])
303
+ if quat[-1] < 0:
304
+ quat = -quat
305
+ disc_rot = utils.quaternion_to_discrete_euler(quat, rotation_resolution)
306
+ disc_rot = utils.correct_rotation_instability(disc_rot, rotation_resolution)
307
+
308
+ attention_coordinate = obs_tp1.gripper_pose[:3]
309
+ trans_indicies, attention_coordinates = [], []
310
+ bounds = np.array(rlbench_scene_bounds)
311
+ ignore_collisions = int(obs_tm1.ignore_collisions)
312
+ for depth, vox_size in enumerate(
313
+ voxel_sizes
314
+ ): # only single voxelization-level is used in PerAct
315
+ if depth > 0:
316
+ if crop_augmentation:
317
+ shift = bounds_offset[depth - 1] * 0.75
318
+ attention_coordinate += np.random.uniform(-shift, shift, size=(3,))
319
+ bounds = np.concatenate(
320
+ [
321
+ attention_coordinate - bounds_offset[depth - 1],
322
+ attention_coordinate + bounds_offset[depth - 1],
323
+ ]
324
+ )
325
+ index = utils.point_to_voxel_index(obs_tp1.gripper_pose[:3], vox_size, bounds)
326
+ trans_indicies.extend(index.tolist())
327
+ res = (bounds[3:] - bounds[:3]) / vox_size
328
+ attention_coordinate = bounds[:3] + res * index
329
+ attention_coordinates.append(attention_coordinate)
330
+
331
+ rot_and_grip_indicies = disc_rot.tolist()
332
+ grip = float(obs_tp1.gripper_open)
333
+ rot_and_grip_indicies.extend([int(obs_tp1.gripper_open)])
334
+ return (
335
+ trans_indicies,
336
+ rot_and_grip_indicies,
337
+ ignore_collisions,
338
+ np.concatenate([obs_tp1.gripper_pose, np.array([grip])]),
339
+ attention_coordinates,
340
+ )
341
+
342
+
343
+ def _add_keypoints_to_replay(
344
+ cfg: DictConfig,
345
+ task: str,
346
+ replay: ReplayBuffer,
347
+ inital_obs: Observation,
348
+ demo: Demo,
349
+ episode_keypoints: List[int],
350
+ description: str = "",
351
+ clip_model=None,
352
+ device="cpu",
353
+ ):
354
+ cameras = cfg.rlbench.cameras
355
+ rlbench_scene_bounds = cfg.rlbench.scene_bounds
356
+ voxel_sizes = cfg.method.voxel_sizes
357
+ bounds_offset = cfg.method.bounds_offset
358
+ rotation_resolution = cfg.method.rotation_resolution
359
+ crop_augmentation = cfg.method.crop_augmentation
360
+ robot_name = cfg.method.robot_name
361
+
362
+ prev_action = None
363
+ obs = inital_obs
364
+
365
+ for k, keypoint in enumerate(episode_keypoints):
366
+ obs_tp1 = demo[keypoint]
367
+ obs_tm1 = demo[max(0, keypoint - 1)]
368
+
369
+ if obs_tp1.is_bimanual and robot_name == "bimanual":
370
+ # assert isinstance(obs_tp1, BimanualObservation)
371
+ (
372
+ right_trans_indicies,
373
+ right_rot_grip_indicies,
374
+ right_ignore_collisions,
375
+ right_action,
376
+ right_attention_coordinates,
377
+ ) = _get_action(
378
+ obs_tp1.right,
379
+ obs_tm1.right,
380
+ rlbench_scene_bounds,
381
+ voxel_sizes,
382
+ bounds_offset,
383
+ rotation_resolution,
384
+ crop_augmentation,
385
+ )
386
+
387
+ (
388
+ left_trans_indicies,
389
+ left_rot_grip_indicies,
390
+ left_ignore_collisions,
391
+ left_action,
392
+ left_attention_coordinates,
393
+ ) = _get_action(
394
+ obs_tp1.left,
395
+ obs_tm1.left,
396
+ rlbench_scene_bounds,
397
+ voxel_sizes,
398
+ bounds_offset,
399
+ rotation_resolution,
400
+ crop_augmentation,
401
+ )
402
+
403
+ action = np.append(right_action, left_action)
404
+
405
+ right_ignore_collisions = np.array([right_ignore_collisions])
406
+ left_ignore_collisions = np.array([left_ignore_collisions])
407
+
408
+ elif robot_name == "unimanual":
409
+ (
410
+ trans_indicies,
411
+ rot_grip_indicies,
412
+ ignore_collisions,
413
+ action,
414
+ attention_coordinates,
415
+ ) = _get_action(
416
+ obs_tp1,
417
+ obs_tm1,
418
+ rlbench_scene_bounds,
419
+ voxel_sizes,
420
+ bounds_offset,
421
+ rotation_resolution,
422
+ crop_augmentation,
423
+ )
424
+ gripper_pose = obs_tp1.gripper_pose
425
+ elif obs_tp1.is_bimanual and robot_name == "right":
426
+ (
427
+ trans_indicies,
428
+ rot_grip_indicies,
429
+ ignore_collisions,
430
+ action,
431
+ attention_coordinates,
432
+ ) = _get_action(
433
+ obs_tp1.right,
434
+ obs_tm1.right,
435
+ rlbench_scene_bounds,
436
+ voxel_sizes,
437
+ bounds_offset,
438
+ rotation_resolution,
439
+ crop_augmentation,
440
+ )
441
+ gripper_pose = obs_tp1.right.gripper_pose
442
+ elif obs_tp1.is_bimanual and robot_name == "left":
443
+ (
444
+ trans_indicies,
445
+ rot_grip_indicies,
446
+ ignore_collisions,
447
+ action,
448
+ attention_coordinates,
449
+ ) = _get_action(
450
+ obs_tp1.left,
451
+ obs_tm1.left,
452
+ rlbench_scene_bounds,
453
+ voxel_sizes,
454
+ bounds_offset,
455
+ rotation_resolution,
456
+ crop_augmentation,
457
+ )
458
+ gripper_pose = obs_tp1.left.gripper_pose
459
+ else:
460
+ logging.error("Invalid robot name %s", cfg.method.robot_name)
461
+ raise Exception("Invalid robot name.")
462
+
463
+ terminal = k == len(episode_keypoints) - 1
464
+ reward = float(terminal) * REWARD_SCALE if terminal else 0
465
+
466
+ obs_dict = observation_utils.extract_obs(
467
+ obs,
468
+ t=k,
469
+ prev_action=prev_action,
470
+ cameras=cameras,
471
+ episode_length=cfg.rlbench.episode_length,
472
+ robot_name=robot_name,
473
+ )
474
+ tokens = tokenize([description]).numpy()
475
+ token_tensor = torch.from_numpy(tokens).to(device)
476
+ sentence_emb, token_embs = clip_model.encode_text_with_embeddings(token_tensor)
477
+ obs_dict["lang_goal_emb"] = sentence_emb[0].float().detach().cpu().numpy()
478
+ obs_dict["lang_token_embs"] = token_embs[0].float().detach().cpu().numpy()
479
+
480
+ prev_action = np.copy(action)
481
+
482
+ others = {"demo": True}
483
+ if robot_name == "bimanual":
484
+ final_obs = {
485
+ "right_trans_action_indicies": right_trans_indicies,
486
+ "right_rot_grip_action_indicies": right_rot_grip_indicies,
487
+ "right_gripper_pose": obs_tp1.right.gripper_pose,
488
+ "left_trans_action_indicies": left_trans_indicies,
489
+ "left_rot_grip_action_indicies": left_rot_grip_indicies,
490
+ "left_gripper_pose": obs_tp1.left.gripper_pose,
491
+ "task": task,
492
+ "lang_goal": np.array([description], dtype=object),
493
+ }
494
+ else:
495
+ final_obs = {
496
+ "trans_action_indicies": trans_indicies,
497
+ "rot_grip_action_indicies": rot_grip_indicies,
498
+ "gripper_pose": gripper_pose,
499
+ "task": task,
500
+ "lang_goal": np.array([description], dtype=object),
501
+ }
502
+
503
+ others.update(final_obs)
504
+ others.update(obs_dict)
505
+
506
+ timeout = False
507
+ replay.add(action, reward, terminal, timeout, **others)
508
+ obs = obs_tp1
509
+
510
+ # final step
511
+ obs_dict_tp1 = observation_utils.extract_obs(
512
+ obs_tp1,
513
+ t=k + 1,
514
+ prev_action=prev_action,
515
+ cameras=cameras,
516
+ episode_length=cfg.rlbench.episode_length,
517
+ robot_name=cfg.method.robot_name,
518
+ )
519
+ obs_dict_tp1["lang_goal_emb"] = sentence_emb[0].float().detach().cpu().numpy()
520
+ obs_dict_tp1["lang_token_embs"] = token_embs[0].float().detach().cpu().numpy()
521
+
522
+ obs_dict_tp1.pop("wrist_world_to_cam", None)
523
+ obs_dict_tp1.update(final_obs)
524
+ replay.add_final(**obs_dict_tp1)
525
+
526
+
527
+ def fill_replay(
528
+ cfg: DictConfig,
529
+ obs_config: ObservationConfig,
530
+ rank: int,
531
+ replay: ReplayBuffer,
532
+ task: str,
533
+ clip_model=None,
534
+ device="cpu",
535
+ ):
536
+ num_demos = cfg.rlbench.demos
537
+ demo_augmentation = cfg.method.demo_augmentation
538
+ demo_augmentation_every_n = cfg.method.demo_augmentation_every_n
539
+ keypoint_method = cfg.method.keypoint_method
540
+
541
+ if clip_model is None:
542
+ model, _ = load_clip("RN50", jit=False, device=device)
543
+ clip_model = build_model(model.state_dict())
544
+ clip_model.to(device)
545
+ del model
546
+
547
+ logging.debug("Filling %s replay ..." % task)
548
+ for d_idx in range(num_demos):
549
+ # load demo from disk
550
+ demo = rlbench_utils.get_stored_demos(
551
+ amount=1,
552
+ image_paths=False,
553
+ dataset_root=cfg.rlbench.demo_path,
554
+ variation_number=-1,
555
+ task_name=task,
556
+ obs_config=obs_config,
557
+ random_selection=False,
558
+ from_episode_number=d_idx,
559
+ )[0]
560
+
561
+ descs = demo._observations[0].misc["descriptions"]
562
+
563
+ # extract keypoints (a.k.a keyframes)
564
+ episode_keypoints = demo_loading_utils.keypoint_discovery(
565
+ demo, method=keypoint_method
566
+ )
567
+
568
+ if rank == 0:
569
+ logging.info(
570
+ f"Loading Demo({d_idx}) - found {len(episode_keypoints)} keypoints - {task}"
571
+ )
572
+
573
+ for i in range(len(demo) - 1):
574
+ if not demo_augmentation and i > 0:
575
+ break
576
+ if i % demo_augmentation_every_n != 0:
577
+ continue
578
+
579
+ obs = demo[i]
580
+ desc = descs[0]
581
+ # if our starting point is past one of the keypoints, then remove it
582
+ while len(episode_keypoints) > 0 and i >= episode_keypoints[0]:
583
+ episode_keypoints = episode_keypoints[1:]
584
+ if len(episode_keypoints) == 0:
585
+ break
586
+ _add_keypoints_to_replay(
587
+ cfg,
588
+ task,
589
+ replay,
590
+ obs,
591
+ demo,
592
+ episode_keypoints,
593
+ description=desc,
594
+ clip_model=clip_model,
595
+ device=device,
596
+ )
597
+ logging.debug("Replay %s filled with demos." % task)
598
+
599
+
600
+ def fill_multi_task_replay(
601
+ cfg: DictConfig,
602
+ obs_config: ObservationConfig,
603
+ rank: int,
604
+ replay: ReplayBuffer,
605
+ tasks: List[str],
606
+ clip_model=None,
607
+ ):
608
+ tasks = cfg.rlbench.tasks
609
+
610
+ manager = Manager()
611
+ store = manager.dict()
612
+
613
+ # create a MP dict for storing indicies
614
+ # TODO(mohit): this shouldn't be initialized here
615
+ del replay._task_idxs
616
+ task_idxs = manager.dict()
617
+ replay._task_idxs = task_idxs
618
+ replay._create_storage(store)
619
+ replay.add_count = Value("i", 0)
620
+
621
+ # fill replay buffer in parallel across tasks
622
+ max_parallel_processes = cfg.replay.max_parallel_processes
623
+ processes = []
624
+ n = np.arange(len(tasks))
625
+ split_n = utils.split_list(n, max_parallel_processes)
626
+ for split in split_n:
627
+ for e_idx, task_idx in enumerate(split):
628
+ task = tasks[int(task_idx)]
629
+ model_device = torch.device(
630
+ "cuda:%s" % (e_idx % torch.cuda.device_count())
631
+ if torch.cuda.is_available()
632
+ else "cpu"
633
+ )
634
+ p = Process(
635
+ target=fill_replay,
636
+ args=(cfg, obs_config, rank, replay, task, clip_model, model_device),
637
+ )
638
+
639
+ p.start()
640
+ processes.append(p)
641
+
642
+ for p in processes:
643
+ p.join()
external/peract_bimanual/agents/rvt/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ import agents.rvt.launch_utils
external/peract_bimanual/agents/rvt/launch_utils.py ADDED
@@ -0,0 +1,168 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from typing import List
3
+ import torch
4
+ import numpy as np
5
+
6
+ from omegaconf import DictConfig
7
+
8
+ from yarr.agents.agent import Agent
9
+ from yarr.agents.agent import ActResult
10
+ from yarr.agents.agent import Summary
11
+ from yarr.agents.agent import ScalarSummary
12
+
13
+
14
+ from torch.nn.parallel import DistributedDataParallel as DDP
15
+
16
+ from helpers.preprocess_agent import PreprocessAgent
17
+
18
+
19
+ from rvt.mvt.mvt import MVT
20
+ from rvt.models import rvt_agent
21
+ from rvt.utils.peract_utils import (
22
+ CAMERAS,
23
+ SCENE_BOUNDS,
24
+ IMAGE_SIZE,
25
+ DATA_FOLDER,
26
+ )
27
+
28
+
29
+ import rvt.config as exp_cfg_mod
30
+ import rvt.models.rvt_agent as rvt_agent
31
+ import rvt.mvt.config as mvt_cfg_mod
32
+
33
+
34
+ def create_agent(cfg: DictConfig):
35
+ exp_cfg = exp_cfg_mod.get_cfg_defaults()
36
+ exp_cfg.bs = cfg.replay.batch_size
37
+ exp_cfg.tasks = ",".join(cfg.rlbench.tasks)
38
+
39
+ exp_cfg.freeze()
40
+
41
+ mvt_cfg = mvt_cfg_mod.get_cfg_defaults()
42
+ mvt_cfg.proprio_dim = cfg.method.low_dim_size
43
+ mvt_cfg.freeze()
44
+
45
+ agent = RVTAgentWrapper(
46
+ cfg.framework.checkpoint_name_prefix, cfg.rlbench, mvt_cfg, exp_cfg
47
+ )
48
+
49
+ preprocess_agent = PreprocessAgent(pose_agent=agent)
50
+ return preprocess_agent
51
+
52
+
53
+ class RVTAgentWrapper(Agent):
54
+ def __init__(self, checkpoint_name_prefix, rlbench_cfg, mvt_cfg, exp_cfg):
55
+ self._checkpoint_filename = f"{checkpoint_name_prefix}.pt"
56
+ self.rvt_agent = None
57
+ self.rlbench_cfg = rlbench_cfg
58
+ self.mvt_cfg = mvt_cfg
59
+ self.exp_cfg = exp_cfg
60
+ self._summaries = {}
61
+
62
+ def build(self, training: bool, device=None) -> None:
63
+ import torch
64
+
65
+ torch.cuda.set_device(device)
66
+ torch.cuda.empty_cache()
67
+
68
+ if isinstance(device, int):
69
+ device = f"cuda:{device}"
70
+
71
+ rvt = MVT(
72
+ renderer_device=device,
73
+ **self.mvt_cfg,
74
+ )
75
+ rvt = rvt.to(device)
76
+
77
+ if training:
78
+ rvt = DDP(rvt, device_ids=[device])
79
+
80
+ self.rvt_agent = rvt_agent.RVTAgent(
81
+ network=rvt,
82
+ # image_resolution=self.rlbench_cfg.camera_resolution,
83
+ add_lang=self.mvt_cfg.add_lang,
84
+ scene_bounds=self.rlbench_cfg.scene_bounds,
85
+ cameras=self.rlbench_cfg.cameras,
86
+ log_dir="/tmp/eval_run",
87
+ **self.exp_cfg.peract,
88
+ **self.exp_cfg.rvt,
89
+ )
90
+
91
+ self.rvt_agent.build(training, device)
92
+
93
+ def update(self, step: int, replay_sample: dict) -> dict:
94
+ for k, v in replay_sample.items():
95
+ replay_sample[k] = v.unsqueeze(1)
96
+ # RVT is based on the PerAct's Colab version.
97
+ replay_sample["lang_goal_embs"] = replay_sample["lang_token_embs"]
98
+ replay_sample["tasks"] = self.exp_cfg.tasks.split(",")
99
+
100
+ update_dict = self.rvt_agent.update(step, replay_sample)
101
+
102
+ for key, val in self.rvt_agent.loss_log.items():
103
+ self._summaries[key] = np.mean(np.array(val))
104
+
105
+ return {
106
+ "total_losses": update_dict["total_loss"],
107
+ }
108
+
109
+ return result
110
+
111
+ def act(self, step: int, observation: dict, deterministic: bool) -> ActResult:
112
+ return self.rvt_agent.act(step, observation, deterministic)
113
+
114
+ def reset(self) -> None:
115
+ self.rvt_agent.reset()
116
+
117
+ def update_summaries(self) -> List[Summary]:
118
+ summaries = []
119
+ for k, v in self._summaries.items():
120
+ summaries.append(ScalarSummary(f"RVT/{k}", v))
121
+ return summaries
122
+
123
+ def act_summaries(self) -> List[Summary]:
124
+ return []
125
+
126
+ def load_weights(self, savedir: str) -> None:
127
+ """
128
+ copied from RVT
129
+ """
130
+ device = torch.device("cuda:0")
131
+ weight_file = os.path.join(savedir, self._checkpoint_filename)
132
+ state_dict = torch.load(weight_file, map_location=device)
133
+
134
+ model = self.rvt_agent._network
135
+ optimizer = self.rvt_agent._optimizer
136
+ lr_sched = self.rvt_agent._lr_sched
137
+
138
+ if isinstance(model, DDP):
139
+ model = model.module
140
+
141
+ model.load_state_dict(state_dict["model_state"])
142
+ optimizer.load_state_dict(state_dict["optimizer_state"])
143
+ lr_sched.load_state_dict(state_dict["lr_sched_state"])
144
+
145
+ return self.rvt_agent.load_clip()
146
+
147
+ def save_weights(self, savedir: str) -> None:
148
+ os.makedirs(savedir, exist_ok=True)
149
+
150
+ weight_file = os.path.join(savedir, self._checkpoint_filename)
151
+
152
+ model = self.rvt_agent._network
153
+ optimizer = self.rvt_agent._optimizer
154
+ lr_sched = self.rvt_agent._lr_sched
155
+
156
+ if isinstance(model, DDP):
157
+ model = model.module
158
+
159
+ model_state = model.state_dict()
160
+
161
+ torch.save(
162
+ {
163
+ "model_state": model_state,
164
+ "optimizer_state": optimizer.state_dict(),
165
+ "lr_sched_state": lr_sched.state_dict(),
166
+ },
167
+ weight_file,
168
+ )
external/peract_bimanual/conf/config.yaml ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ddp:
2
+ master_addr: "localhost"
3
+ master_port: "0"
4
+ num_devices: 1
5
+
6
+ rlbench:
7
+ task_name: "multi"
8
+ tasks: [open_drawer,slide_block_to_color_target]
9
+ demos: 100
10
+ demo_path: /my/demo/path
11
+ episode_length: 25
12
+ cameras: ["over_shoulder_left", "over_shoulder_right", "overhead", "wrist_right", "wrist_left", "front"]
13
+ camera_resolution: [128, 128]
14
+ scene_bounds: [-0.3, -0.5, 0.6, 0.7, 0.5, 1.6]
15
+ include_lang_goal_in_obs: True
16
+
17
+ replay:
18
+ batch_size: 8
19
+ timesteps: 1
20
+ prioritisation: False
21
+ task_uniform: True # uniform sampling of tasks for multi-task buffers
22
+ use_disk: True
23
+ path: '/tmp/arm/replay' # only used when use_disk is True.
24
+ max_parallel_processes: 32
25
+
26
+ framework:
27
+ log_freq: 100
28
+ save_freq: 100
29
+ train_envs: 1
30
+ replay_ratio: ${replay.batch_size}
31
+ transitions_before_train: 200
32
+ tensorboard_logging: True
33
+ csv_logging: True
34
+ training_iterations: 40000
35
+ gpu: 0
36
+ env_gpu: 0
37
+ logdir: '/tmp/arm_test/'
38
+ logging_level: 20 # https://docs.python.org/3/library/logging.html#levels
39
+ seeds: 1
40
+ start_seed: 0
41
+ load_existing_weights: True
42
+ num_weights_to_keep: 60 # older checkpoints will be deleted chronologically
43
+ num_workers: 0
44
+ record_every_n: 5
45
+ checkpoint_name_prefix: "checkpoint"
46
+
47
+ defaults:
48
+ - method: PERACT_BC
49
+
50
+ hydra:
51
+ run:
52
+ dir: ${framework.logdir}/${rlbench.task_name}/${method.name}
external/peract_bimanual/conf/eval.yaml ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ defaults:
2
+ - method: PERACT_BC
3
+
4
+
5
+ rlbench:
6
+ task_name: "multi"
7
+ tasks: [open_drawer,slide_block_to_color_target]
8
+ demo_path: /my/demo/path
9
+ episode_length: 25
10
+ cameras: ["over_shoulder_left", "over_shoulder_right", "overhead", "wrist_right", "wrist_left", "front"]
11
+ camera_resolution: [128, 128]
12
+ scene_bounds: [-0.3, -0.5, 0.6, 0.7, 0.5, 1.6]
13
+ include_lang_goal_in_obs: True
14
+ time_in_state: True
15
+ headless: True
16
+ gripper_mode: 'Discrete'
17
+ arm_action_mode: 'EndEffectorPoseViaPlanning'
18
+ action_mode: 'MoveArmThenGripper'
19
+
20
+ framework:
21
+ tensorboard_logging: True
22
+ csv_logging: True
23
+ gpu: 0
24
+ logdir: '/tmp/arm_test/'
25
+ start_seed: 0
26
+ record_every_n: 5
27
+
28
+ eval_envs: 1
29
+ eval_from_eps_number: 0
30
+ eval_episodes: 5
31
+ eval_type: 'last' # or 'best', 'missing', or 'last'
32
+ eval_save_metrics: True
33
+
34
+ cinematic_recorder:
35
+ enabled: False
36
+ camera_resolution: [1280, 720]
37
+ fps: 30
38
+ rotate_speed: 0.005
39
+ save_path: '/tmp/videos/'
external/peract_bimanual/conf/hydra/job_logging/custom.yaml ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ version: 1
2
+ formatters:
3
+ simple:
4
+ format: '[%(levelname)s] - %(message)s'
5
+ handlers:
6
+ rich_console:
7
+ class: rich.logging.RichHandler
8
+ root:
9
+ handlers: [rich_console]
10
+
11
+
12
+ disable_existing_loggers: false
external/peract_bimanual/conf/method/ACT_BC_LANG.yaml ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # @package _group_
2
+
3
+ name: 'ACT_BC_LANG'
4
+
5
+ # Agent
6
+ robot_name: 'bimanual'
7
+ agent_type: 'bimanual'
8
+
9
+
10
+ train_demo_path: "/home/markus/rlbench_data_v2_128/train/"
11
+
12
+ activation: lrelu
13
+ lr: 1e-4
14
+ weight_decay: 0.000001
15
+ grad_clip: 0.1
16
+ demo_augmentation: True
17
+ demo_augmentation_every_n: 10
18
+
19
+ prev_action_horizon: 1
20
+ next_action_horizon: 10
21
+
22
+ # hyperparameters
23
+ lr_backbone: 1e-5
24
+ backbone: resnet18
25
+ dilation: False
26
+ position_embedding: sine
27
+ kl_weight: 100
28
+ chunk_size: ${method.next_action_horizon}
29
+
30
+ # transformer
31
+ input_dim: 16 # 7 revolute joints + 1 gripper joints
32
+ enc_layers: 4
33
+ dec_layers: 7
34
+ dim_feedforward: 3200
35
+ hidden_dim: 512
36
+ dropout: 0.1
37
+ nheads: 8
38
+ num_queries: ${method.next_action_horizon}
39
+ pre_norm: False
40
+
41
+ # unused
42
+ masks: False
43
+
44
+ # legacy
45
+ camera_names: ${rlbench.cameras}
46
+
47
+ # ..todo:: also set the following
48
+
49
+ +rlbench.episode_length: 400
50
+ +rlbench.arm_action_mode: JointPosition
51
+ +rlbench.action_mode: JointPositionActionMode
external/peract_bimanual/conf/method/ARM.yaml ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # @package _group_
2
+
3
+ name: 'ARM'
4
+ activation: lrelu
5
+ q_conf: True
6
+ alpha: 0.05
7
+ alpha_lr: 0.0001
8
+ alpha_auto_tune: False
9
+ next_best_pose_critic_lr: 0.0025
10
+ next_best_pose_actor_lr: 0.001
11
+ next_best_pose_critic_weight_decay: 0.00001
12
+ next_best_pose_actor_weight_decay: 0.00001
13
+ crop_shape: [16, 16]
14
+ next_best_pose_tau: 0.005
15
+ next_best_pose_critic_grad_clip: 5
16
+ next_best_pose_actor_grad_clip: 5
17
+ qattention_grad_clip: 5
18
+ qattention_tau: 0.005
19
+ qattention_lr: 0.0005
20
+ qattention_weight_decay: 0.00001
21
+ qattention_lambda_qreg: 0.0000001
22
+
23
+ demo_augmentation: True
24
+ demo_augmentation_every_n: 10
external/peract_bimanual/conf/method/BC_LANG.yaml ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ # @package _group_
2
+
3
+ name: 'BC_LANG'
4
+ activation: lrelu
5
+ lr: 0.0005
6
+ weight_decay: 0.000001
7
+ grad_clip: 0.1
8
+ demo_augmentation: True
9
+ demo_augmentation_every_n: 10
external/peract_bimanual/conf/method/BIMANUAL_PERACT.yaml ADDED
@@ -0,0 +1,70 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # @package _group_
2
+
3
+ name: 'BIMANUAL_PERACT'
4
+
5
+ # Agent
6
+ robot_name: 'bimanual'
7
+ agent_type: 'bimanual'
8
+
9
+
10
+ # Voxelization
11
+ image_crop_size: 64
12
+ bounds_offset: [0.15]
13
+ voxel_sizes: [100]
14
+ include_prev_layer: False
15
+
16
+ # Perceiver
17
+ num_latents: 2048
18
+ latent_dim: 512
19
+ transformer_depth: 6
20
+ transformer_iterations: 1
21
+ cross_heads: 1
22
+ cross_dim_head: 64
23
+ latent_heads: 8
24
+ latent_dim_head: 64
25
+ pos_encoding_with_lang: True
26
+ conv_downsample: True
27
+ lang_fusion_type: 'seq' # or 'concat'
28
+ voxel_patch_size: 5
29
+ voxel_patch_stride: 5
30
+ final_dim: 64
31
+ low_dim_size: 8
32
+
33
+
34
+ # Training
35
+ input_dropout: 0.1
36
+ attn_dropout: 0.1
37
+ decoder_dropout: 0.0
38
+
39
+ lr: 0.0005
40
+ lr_scheduler: False
41
+ num_warmup_steps: 3000
42
+ optimizer: 'lamb' # or 'adam'
43
+
44
+ lambda_weight_l2: 0.000001
45
+ trans_loss_weight: 1.0
46
+ rot_loss_weight: 1.0
47
+ grip_loss_weight: 1.0
48
+ collision_loss_weight: 1.0
49
+ rotation_resolution: 5
50
+
51
+ # Network
52
+ activation: lrelu
53
+ norm: None
54
+
55
+ # Augmentation
56
+ crop_augmentation: True
57
+ transform_augmentation:
58
+ apply_se3: True
59
+ aug_xyz: [0.125, 0.125, 0.125]
60
+ aug_rpy: [0.0, 0.0, 45.0]
61
+ aug_rot_resolution: ${method.rotation_resolution}
62
+
63
+ demo_augmentation: True
64
+ demo_augmentation_every_n: 10
65
+
66
+ # Ablations
67
+ no_skip_connection: False
68
+ no_perceiver: False
69
+ no_language: False
70
+ keypoint_method: 'heuristic'
external/peract_bimanual/conf/method/C2FARM_LINGUNET_BC.yaml ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # @package _group_
2
+
3
+ name: 'C2FARM_LINGUNET_BC'
4
+
5
+ # Voxelization
6
+ image_crop_size: 64
7
+ bounds_offset: [0.15]
8
+ voxel_sizes: [32, 32]
9
+ include_prev_layer: False
10
+
11
+ # Training
12
+ lr: 0.0005
13
+ lr_scheduler: False
14
+ num_warmup_steps: 10000
15
+
16
+ lambda_weight_l2: 0.000001
17
+ trans_loss_weight: 1.0
18
+ rot_loss_weight: 1.0
19
+ grip_loss_weight: 1.0
20
+ collision_loss_weight: 1.0
21
+ rotation_resolution: 5
22
+
23
+ # Network
24
+ activation: lrelu
25
+ norm: None
26
+
27
+ # Augmentation
28
+ crop_augmentation: True
29
+ transform_augmentation:
30
+ apply_se3: True
31
+ aug_xyz: [0.125, 0.125, 0.125]
32
+ aug_rpy: [0.0, 0.0, 45.0]
33
+ aug_rot_resolution: ${method.rotation_resolution}
34
+
35
+ demo_augmentation: True
36
+ demo_augmentation_every_n: 10
37
+ exploration_strategy: gaussian
38
+
39
+ # Ablations
40
+ keypoint_method: 'heuristic'
external/peract_bimanual/conf/method/PERACT_BC.yaml ADDED
@@ -0,0 +1,68 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # @package _group_
2
+
3
+ name: 'PERACT_BC'
4
+
5
+ # Agent
6
+ agent_type: 'leader_follower'
7
+ robot_name: 'bimanual'
8
+
9
+ # Voxelization
10
+ image_crop_size: 64
11
+ bounds_offset: [0.15]
12
+ voxel_sizes: [100]
13
+ include_prev_layer: False
14
+
15
+ # Perceiver
16
+ num_latents: 2048
17
+ latent_dim: 512
18
+ transformer_depth: 6
19
+ transformer_iterations: 1
20
+ cross_heads: 1
21
+ cross_dim_head: 64
22
+ latent_heads: 8
23
+ latent_dim_head: 64
24
+ pos_encoding_with_lang: True
25
+ conv_downsample: True
26
+ lang_fusion_type: 'seq' # or 'concat'
27
+ voxel_patch_size: 5
28
+ voxel_patch_stride: 5
29
+ final_dim: 64
30
+ low_dim_size: 4
31
+
32
+ # Training
33
+ input_dropout: 0.1
34
+ attn_dropout: 0.1
35
+ decoder_dropout: 0.0
36
+
37
+ lr: 0.0005
38
+ lr_scheduler: False
39
+ num_warmup_steps: 3000
40
+ optimizer: 'lamb' # or 'adam'
41
+
42
+ lambda_weight_l2: 0.000001
43
+ trans_loss_weight: 1.0
44
+ rot_loss_weight: 1.0
45
+ grip_loss_weight: 1.0
46
+ collision_loss_weight: 1.0
47
+ rotation_resolution: 5
48
+
49
+ # Network
50
+ activation: lrelu
51
+ norm: None
52
+
53
+ # Augmentation
54
+ crop_augmentation: True
55
+ transform_augmentation:
56
+ apply_se3: True
57
+ aug_xyz: [0.125, 0.125, 0.125]
58
+ aug_rpy: [0.0, 0.0, 45.0]
59
+ aug_rot_resolution: ${method.rotation_resolution}
60
+
61
+ demo_augmentation: True
62
+ demo_augmentation_every_n: 10
63
+
64
+ # Ablations
65
+ no_skip_connection: False
66
+ no_perceiver: False
67
+ no_language: False
68
+ keypoint_method: 'heuristic'