diff --git a/.gitattributes b/.gitattributes
index a6344aac8c09253b3b630fb776ae94478aa0275b..a1756f5b220600b99b6b2ea7fa92a4a8acba46a4 100644
--- a/.gitattributes
+++ b/.gitattributes
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
*.zip filter=lfs diff=lfs merge=lfs -text
*.zst filter=lfs diff=lfs merge=lfs -text
*tfevents* filter=lfs diff=lfs merge=lfs -text
+ACT_DP_multitask/detr/models/mr_mg/media/model.gif filter=lfs diff=lfs merge=lfs -text
diff --git a/ACT_DP_multitask/README.md b/ACT_DP_multitask/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..4e37f4582871f4b531e82eb67f6dddac6b93749e
--- /dev/null
+++ b/ACT_DP_multitask/README.md
@@ -0,0 +1,16 @@
+### Install
+```
+cd policy/ACT-DP-TP
+cd detr
+pip install -e .
+cd ..
+cd Cosmos-Tokenizer
+pip install -e .
+#upload policy/ACT-DP-TP/Cosmos-Tokenizer/pretrained_ckpts
+```
+### Command
+```
+#data_dir: policy/ACT-DP-TP/data_zarr
+cd policy/ACT-DP-TP
+bash scripts/act_dp_tp/train.sh bottle_adjust 300 20 20 0
+```
diff --git a/ACT_DP_multitask/base.yaml b/ACT_DP_multitask/base.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..01b47f3ae774ad8fd58008b60fe3134be004dd05
--- /dev/null
+++ b/ACT_DP_multitask/base.yaml
@@ -0,0 +1,71 @@
+common:
+ # The number of historical images
+ img_history_size: 2
+ # The number of future actions to predict
+ action_chunk_size: 64
+ # The number of cameras to be used in the model
+ num_cameras: 3
+ # Dimension for state/action, we use the same space for both state and action
+ # This MUST be equal to configs/state_vec.py
+ state_dim: 128
+
+
+dataset:
+ # We will extract the data from raw dataset
+ # and store them in the disk buffer by producer
+ # When training, we will read the data
+ # randomly from the buffer by consumer
+ # The producer will replace the data which has been
+ # read by the consumer with new data
+
+ # The path to the buffer (at least 400GB)
+ buf_path: /path/to/buffer
+ # The number of chunks in the buffer
+ buf_num_chunks: 512
+ # The number of samples (step rather than episode) in each chunk
+ buf_chunk_size: 512
+
+ # We will filter the episodes with length less than `epsd_len_thresh_low`
+ epsd_len_thresh_low: 32
+ # For those more than `epsd_len_thresh_high`,
+ # we will randomly sample `epsd_len_thresh_high` steps each time we load the episode
+ # to better balance the training datasets
+ epsd_len_thresh_high: 2048
+ # How to fit the image size
+ image_aspect_ratio: pad
+ # Maximum number of language tokens
+ tokenizer_max_length: 1024
+
+model:
+ # Config for condition adpators
+ lang_adaptor: mlp2x_gelu
+ img_adaptor: mlp2x_gelu
+ state_adaptor: mlp3x_gelu
+ lang_token_dim: 4096
+ img_token_dim: 1152
+ # Dim of action or proprioception vector
+ # A `state` refers to an action or a proprioception vector
+ state_token_dim: 128
+ # Config for RDT structure
+ rdt:
+ # 1B: num_head 32 hidden_size 2048
+ hidden_size: 2048
+ depth: 28
+ num_heads: 32
+ cond_pos_embed_type: multimodal
+ # For noise scheduler
+ noise_scheduler:
+ type: ddpm
+ num_train_timesteps: 1000
+ num_inference_timesteps: 5
+ beta_schedule: squaredcos_cap_v2 # Critical choice
+ prediction_type: sample
+ clip_sample: False
+ # For EMA (params averaging)
+ # We do not use EMA currently
+ ema:
+ update_after_step: 0
+ inv_gamma: 1.0
+ power: 0.75
+ min_value: 0.0
+ max_value: 0.9999
diff --git a/ACT_DP_multitask/detr/LICENSE b/ACT_DP_multitask/detr/LICENSE
new file mode 100644
index 0000000000000000000000000000000000000000..b1395e94b016dd1b95b4c7e3ed493e1d0b342917
--- /dev/null
+++ b/ACT_DP_multitask/detr/LICENSE
@@ -0,0 +1,201 @@
+ Apache License
+ Version 2.0, January 2004
+ http://www.apache.org/licenses/
+
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
+
+ 1. Definitions.
+
+ "License" shall mean the terms and conditions for use, reproduction,
+ and distribution as defined by Sections 1 through 9 of this document.
+
+ "Licensor" shall mean the copyright owner or entity authorized by
+ the copyright owner that is granting the License.
+
+ "Legal Entity" shall mean the union of the acting entity and all
+ other entities that control, are controlled by, or are under common
+ control with that entity. For the purposes of this definition,
+ "control" means (i) the power, direct or indirect, to cause the
+ direction or management of such entity, whether by contract or
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
+ outstanding shares, or (iii) beneficial ownership of such entity.
+
+ "You" (or "Your") shall mean an individual or Legal Entity
+ exercising permissions granted by this License.
+
+ "Source" form shall mean the preferred form for making modifications,
+ including but not limited to software source code, documentation
+ source, and configuration files.
+
+ "Object" form shall mean any form resulting from mechanical
+ transformation or translation of a Source form, including but
+ not limited to compiled object code, generated documentation,
+ and conversions to other media types.
+
+ "Work" shall mean the work of authorship, whether in Source or
+ Object form, made available under the License, as indicated by a
+ copyright notice that is included in or attached to the work
+ (an example is provided in the Appendix below).
+
+ "Derivative Works" shall mean any work, whether in Source or Object
+ form, that is based on (or derived from) the Work and for which the
+ editorial revisions, annotations, elaborations, or other modifications
+ represent, as a whole, an original work of authorship. For the purposes
+ of this License, Derivative Works shall not include works that remain
+ separable from, or merely link (or bind by name) to the interfaces of,
+ the Work and Derivative Works thereof.
+
+ "Contribution" shall mean any work of authorship, including
+ the original version of the Work and any modifications or additions
+ to that Work or Derivative Works thereof, that is intentionally
+ submitted to Licensor for inclusion in the Work by the copyright owner
+ or by an individual or Legal Entity authorized to submit on behalf of
+ the copyright owner. For the purposes of this definition, "submitted"
+ means any form of electronic, verbal, or written communication sent
+ to the Licensor or its representatives, including but not limited to
+ communication on electronic mailing lists, source code control systems,
+ and issue tracking systems that are managed by, or on behalf of, the
+ Licensor for the purpose of discussing and improving the Work, but
+ excluding communication that is conspicuously marked or otherwise
+ designated in writing by the copyright owner as "Not a Contribution."
+
+ "Contributor" shall mean Licensor and any individual or Legal Entity
+ on behalf of whom a Contribution has been received by Licensor and
+ subsequently incorporated within the Work.
+
+ 2. Grant of Copyright License. Subject to the terms and conditions of
+ this License, each Contributor hereby grants to You a perpetual,
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
+ copyright license to reproduce, prepare Derivative Works of,
+ publicly display, publicly perform, sublicense, and distribute the
+ Work and such Derivative Works in Source or Object form.
+
+ 3. Grant of Patent License. Subject to the terms and conditions of
+ this License, each Contributor hereby grants to You a perpetual,
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
+ (except as stated in this section) patent license to make, have made,
+ use, offer to sell, sell, import, and otherwise transfer the Work,
+ where such license applies only to those patent claims licensable
+ by such Contributor that are necessarily infringed by their
+ Contribution(s) alone or by combination of their Contribution(s)
+ with the Work to which such Contribution(s) was submitted. If You
+ institute patent litigation against any entity (including a
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
+ or a Contribution incorporated within the Work constitutes direct
+ or contributory patent infringement, then any patent licenses
+ granted to You under this License for that Work shall terminate
+ as of the date such litigation is filed.
+
+ 4. Redistribution. You may reproduce and distribute copies of the
+ Work or Derivative Works thereof in any medium, with or without
+ modifications, and in Source or Object form, provided that You
+ meet the following conditions:
+
+ (a) You must give any other recipients of the Work or
+ Derivative Works a copy of this License; and
+
+ (b) You must cause any modified files to carry prominent notices
+ stating that You changed the files; and
+
+ (c) You must retain, in the Source form of any Derivative Works
+ that You distribute, all copyright, patent, trademark, and
+ attribution notices from the Source form of the Work,
+ excluding those notices that do not pertain to any part of
+ the Derivative Works; and
+
+ (d) If the Work includes a "NOTICE" text file as part of its
+ distribution, then any Derivative Works that You distribute must
+ include a readable copy of the attribution notices contained
+ within such NOTICE file, excluding those notices that do not
+ pertain to any part of the Derivative Works, in at least one
+ of the following places: within a NOTICE text file distributed
+ as part of the Derivative Works; within the Source form or
+ documentation, if provided along with the Derivative Works; or,
+ within a display generated by the Derivative Works, if and
+ wherever such third-party notices normally appear. The contents
+ of the NOTICE file are for informational purposes only and
+ do not modify the License. You may add Your own attribution
+ notices within Derivative Works that You distribute, alongside
+ or as an addendum to the NOTICE text from the Work, provided
+ that such additional attribution notices cannot be construed
+ as modifying the License.
+
+ You may add Your own copyright statement to Your modifications and
+ may provide additional or different license terms and conditions
+ for use, reproduction, or distribution of Your modifications, or
+ for any such Derivative Works as a whole, provided Your use,
+ reproduction, and distribution of the Work otherwise complies with
+ the conditions stated in this License.
+
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
+ any Contribution intentionally submitted for inclusion in the Work
+ by You to the Licensor shall be under the terms and conditions of
+ this License, without any additional terms or conditions.
+ Notwithstanding the above, nothing herein shall supersede or modify
+ the terms of any separate license agreement you may have executed
+ with Licensor regarding such Contributions.
+
+ 6. Trademarks. This License does not grant permission to use the trade
+ names, trademarks, service marks, or product names of the Licensor,
+ except as required for reasonable and customary use in describing the
+ origin of the Work and reproducing the content of the NOTICE file.
+
+ 7. Disclaimer of Warranty. Unless required by applicable law or
+ agreed to in writing, Licensor provides the Work (and each
+ Contributor provides its Contributions) on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
+ implied, including, without limitation, any warranties or conditions
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
+ PARTICULAR PURPOSE. You are solely responsible for determining the
+ appropriateness of using or redistributing the Work and assume any
+ risks associated with Your exercise of permissions under this License.
+
+ 8. Limitation of Liability. In no event and under no legal theory,
+ whether in tort (including negligence), contract, or otherwise,
+ unless required by applicable law (such as deliberate and grossly
+ negligent acts) or agreed to in writing, shall any Contributor be
+ liable to You for damages, including any direct, indirect, special,
+ incidental, or consequential damages of any character arising as a
+ result of this License or out of the use or inability to use the
+ Work (including but not limited to damages for loss of goodwill,
+ work stoppage, computer failure or malfunction, or any and all
+ other commercial damages or losses), even if such Contributor
+ has been advised of the possibility of such damages.
+
+ 9. Accepting Warranty or Additional Liability. While redistributing
+ the Work or Derivative Works thereof, You may choose to offer,
+ and charge a fee for, acceptance of support, warranty, indemnity,
+ or other liability obligations and/or rights consistent with this
+ License. However, in accepting such obligations, You may act only
+ on Your own behalf and on Your sole responsibility, not on behalf
+ of any other Contributor, and only if You agree to indemnify,
+ defend, and hold each Contributor harmless for any liability
+ incurred by, or claims asserted against, such Contributor by reason
+ of your accepting any such warranty or additional liability.
+
+ END OF TERMS AND CONDITIONS
+
+ APPENDIX: How to apply the Apache License to your work.
+
+ To apply the Apache License to your work, attach the following
+ boilerplate notice, with the fields enclosed by brackets "[]"
+ replaced with your own identifying information. (Don't include
+ the brackets!) The text should be enclosed in the appropriate
+ comment syntax for the file format. We also recommend that a
+ file or class name and description of purpose be included on the
+ same "printed page" as the copyright notice for easier
+ identification within third-party archives.
+
+ Copyright 2020 - present, Facebook, Inc
+
+ Licensed under the Apache License, Version 2.0 (the "License");
+ you may not use this file except in compliance with the License.
+ You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
diff --git a/ACT_DP_multitask/detr/README.md b/ACT_DP_multitask/detr/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..500b1b8d01108f8ff99b2c505a58cdd43a546fee
--- /dev/null
+++ b/ACT_DP_multitask/detr/README.md
@@ -0,0 +1,9 @@
+This part of the codebase is modified from DETR https://github.com/facebookresearch/detr under APACHE 2.0.
+
+ @article{Carion2020EndtoEndOD,
+ title={End-to-End Object Detection with Transformers},
+ author={Nicolas Carion and Francisco Massa and Gabriel Synnaeve and Nicolas Usunier and Alexander Kirillov and Sergey Zagoruyko},
+ journal={ArXiv},
+ year={2020},
+ volume={abs/2005.12872}
+ }
\ No newline at end of file
diff --git a/ACT_DP_multitask/detr/__pycache__/main.cpython-310.pyc b/ACT_DP_multitask/detr/__pycache__/main.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..3c35be3fc0b1fd6b108e52648f6223818f7959dc
Binary files /dev/null and b/ACT_DP_multitask/detr/__pycache__/main.cpython-310.pyc differ
diff --git a/ACT_DP_multitask/detr/__pycache__/main.cpython-37.pyc b/ACT_DP_multitask/detr/__pycache__/main.cpython-37.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..598253e19babf5a1f2807b842afe1fc88f3332bc
Binary files /dev/null and b/ACT_DP_multitask/detr/__pycache__/main.cpython-37.pyc differ
diff --git a/ACT_DP_multitask/detr/detr.egg-info/PKG-INFO b/ACT_DP_multitask/detr/detr.egg-info/PKG-INFO
new file mode 100644
index 0000000000000000000000000000000000000000..595dd87f3aa7838e4a3dc6787e6b130830496b7c
--- /dev/null
+++ b/ACT_DP_multitask/detr/detr.egg-info/PKG-INFO
@@ -0,0 +1,17 @@
+Metadata-Version: 2.2
+Name: detr
+Version: 0.0.0
+License: MIT License
+License-File: LICENSE
+Dynamic: description
+Dynamic: license
+
+This part of the codebase is modified from DETR https://github.com/facebookresearch/detr under APACHE 2.0.
+
+ @article{Carion2020EndtoEndOD,
+ title={End-to-End Object Detection with Transformers},
+ author={Nicolas Carion and Francisco Massa and Gabriel Synnaeve and Nicolas Usunier and Alexander Kirillov and Sergey Zagoruyko},
+ journal={ArXiv},
+ year={2020},
+ volume={abs/2005.12872}
+ }
diff --git a/ACT_DP_multitask/detr/detr.egg-info/SOURCES.txt b/ACT_DP_multitask/detr/detr.egg-info/SOURCES.txt
new file mode 100644
index 0000000000000000000000000000000000000000..effae1d8c053e3533179a4afa121adc37953335b
--- /dev/null
+++ b/ACT_DP_multitask/detr/detr.egg-info/SOURCES.txt
@@ -0,0 +1,37 @@
+LICENSE
+README.md
+setup.py
+detr.egg-info/PKG-INFO
+detr.egg-info/SOURCES.txt
+detr.egg-info/dependency_links.txt
+detr.egg-info/top_level.txt
+models/__init__.py
+models/backbone.py
+models/detr_vae.py
+models/detr_vae_nfp.py
+models/position_encoding.py
+models/transformer.py
+models/vision_transformer.py
+models/mask_former/__init__.py
+models/mask_former/config.py
+models/mask_former/mask_former_model.py
+models/mask_former/test_time_augmentation.py
+models/mask_former/modeling/__init__.py
+models/mask_former/modeling/criterion.py
+models/mask_former/modeling/matcher.py
+models/mask_former/modeling/backbone/__init__.py
+models/mask_former/modeling/backbone/swin.py
+models/mask_former/modeling/heads/__init__.py
+models/mask_former/modeling/heads/mask_former_head.py
+models/mask_former/modeling/heads/per_pixel_baseline.py
+models/mask_former/modeling/heads/pixel_decoder.py
+models/mask_former/modeling/transformer/__init__.py
+models/mask_former/modeling/transformer/position_encoding.py
+models/mask_former/modeling/transformer/transformer.py
+models/mask_former/modeling/transformer/transformer_predictor.py
+models/mask_former/utils/__init__.py
+models/mask_former/utils/misc.py
+util/__init__.py
+util/box_ops.py
+util/misc.py
+util/plot_utils.py
\ No newline at end of file
diff --git a/ACT_DP_multitask/detr/detr.egg-info/dependency_links.txt b/ACT_DP_multitask/detr/detr.egg-info/dependency_links.txt
new file mode 100644
index 0000000000000000000000000000000000000000..8b137891791fe96927ad78e64b0aad7bded08bdc
--- /dev/null
+++ b/ACT_DP_multitask/detr/detr.egg-info/dependency_links.txt
@@ -0,0 +1 @@
+
diff --git a/ACT_DP_multitask/detr/detr.egg-info/top_level.txt b/ACT_DP_multitask/detr/detr.egg-info/top_level.txt
new file mode 100644
index 0000000000000000000000000000000000000000..c406055464420fa96c35835aa35153fb064acf68
--- /dev/null
+++ b/ACT_DP_multitask/detr/detr.egg-info/top_level.txt
@@ -0,0 +1,2 @@
+models
+util
diff --git a/ACT_DP_multitask/detr/main.py b/ACT_DP_multitask/detr/main.py
new file mode 100644
index 0000000000000000000000000000000000000000..c34ae360374046d5bb0958bf6f87c259a5c7ce51
--- /dev/null
+++ b/ACT_DP_multitask/detr/main.py
@@ -0,0 +1,763 @@
+# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
+import argparse
+from pathlib import Path
+import os
+import numpy as np
+import torch
+from .models import *
+
+import IPython
+
+e = IPython.embed
+
+
+def get_args_parser():
+ parser = argparse.ArgumentParser("Set transformer detector", add_help=False)
+ parser.add_argument("--ckpt_path", type=str, default='policy/ACT_DP_multitask/checkpoints/real_fintune_50_2000/act_dp')
+ parser.add_argument("--eval_ckpts", default=0, type=int, help="eval_ckpts")
+ parser.add_argument("--eval_video_log", action="store_true")
+ parser.add_argument("--action_interval", default=1, type=int)
+ parser.add_argument("--lr", default=1e-4, type=float) # will be overridden
+ parser.add_argument("--lr_backbone", default=1e-5, type=float) # will be overridden
+ parser.add_argument(
+ "--lr_schedule_type", default="constant", type=str, help="lr_schedule_type"
+ )
+ parser.add_argument(
+ "--num_episodes", type=int, help="num_epochs", default=0, required=False
+ )
+ parser.add_argument("--batch_size", default=2, type=int) # not used
+ parser.add_argument(
+ "--samples_per_epoch",
+ default=1,
+ type=int,
+ help="samples_per_epoch",
+ required=False,
+ )
+ parser.add_argument("--weight_decay", default=1e-4, type=float)
+ parser.add_argument("--epochs", default=300, type=int) # not used
+ parser.add_argument("--lr_drop", default=200, type=int) # not used
+ parser.add_argument(
+ "--clip_max_norm",
+ default=0.1,
+ type=float, # not used
+ help="gradient clipping max norm",
+ )
+ parser.add_argument("--norm_type", default="meanstd", type=str, help="norm_type")
+ parser.add_argument(
+ "--num_train_steps", default=50, type=int, help="num_train_steps"
+ )
+ parser.add_argument(
+ "--num_inference_steps", default=10, type=int, help="num_inference_steps"
+ )
+ parser.add_argument(
+ "--schedule_type", default="DDIM", type=str, help="scheduler_type"
+ )
+ parser.add_argument(
+ "--imitate_weight", default=1, type=int, help="imitate Weight", required=False
+ )
+ parser.add_argument(
+ "--prediction_type", default="sample", type=str, help="prediction_type"
+ )
+ parser.add_argument(
+ "--beta_schedule", default="squaredcos_cap_v2", type=str, help="prediction_type"
+ )
+ parser.add_argument(
+ "--diffusion_timestep_type",
+ default="cat",
+ type=str,
+ help="diffusion_timestep_type, cat or add, how to combine timestep",
+ )
+ parser.add_argument(
+ "--condition_type",
+ default="cross_attention",
+ type=str,
+ help="diffusion_condition_type, cross_attention or adaLN, how to combine observation condition",
+ )
+ parser.add_argument("--attention_type", default="v0", help="decoder attention type")
+ parser.add_argument(
+ "--causal_mask", action="store_true", help="use causal mask for diffusion"
+ )
+ parser.add_argument("--loss_type", default="l2", type=str, help="loss_type")
+ parser.add_argument(
+ "--disable_vae_latent",
+ action="store_true",
+ help="Use VAE latent space by default",
+ )
+ parser.add_argument(
+ "--disable_resnet",
+ action="store_true",
+ help="Use resnet to encode obs image by default",
+ )
+ parser.add_argument(
+ "--disable_scale",
+ action="store_true",
+ help="scale model up",
+ )
+ parser.add_argument(
+ "--inference_num_queries",
+ default=0,
+ type=int,
+ help="inference_num_queries",
+ required=False,
+ ) # predict_frame
+ parser.add_argument(
+ "--disable_resize", action="store_true", help="if resize jpeg image"
+ )
+ parser.add_argument(
+ "--share_decoder", action="store_true", help="jpeg and action share decoder"
+ )
+ parser.add_argument(
+ "--resize_rate",
+ default=1,
+ type=int,
+ help="resize rate for pixel prediction",
+ required=False,
+ )
+ parser.add_argument(
+ "--image_downsample_rate",
+ default=1,
+ type=int,
+ help="image_downsample_rate",
+ required=False,
+ )
+ parser.add_argument(
+ "--temporal_downsample_rate",
+ default=1,
+ type=int,
+ help="temporal_downsample_rate",
+ required=False,
+ )
+ # Model parameters external
+ parser.add_argument("--test_num", default=50, type=int, help="test_num")
+ parser.add_argument("--save_episode", action="store_true")
+ parser.add_argument(
+ "--depth_mode",
+ default="None",
+ type=str,
+ help="use depth/depth+coordinate/None. ALL/Single/None",
+ )
+ parser.add_argument(
+ "--pc_mode", default="pc_camera", type=str, help="pc_world/pc_camera"
+ )
+ parser.add_argument(
+ "--disable_multi_view", action="store_true", help="Use multi-view rgb images"
+ )
+ # * Backbone
+ parser.add_argument(
+ "--backbone",
+ default="resnet18",
+ type=str, # will be overridden
+ help="Name of the convolutional backbone to use",
+ )
+ parser.add_argument(
+ "--dilation",
+ action="store_true",
+ help="If true, we replace stride with dilation in the last convolutional block (DC5)",
+ )
+ parser.add_argument(
+ "--position_embedding",
+ default="sine",
+ type=str,
+ choices=("sine", "learned"),
+ help="Type of positional embedding to use on top of the image features",
+ )
+ parser.add_argument(
+ "--camera_names",
+ default=[],
+ type=list, # will be overridden
+ help="A list of camera names",
+ )
+
+ # * Transformer
+ parser.add_argument(
+ "--enc_layers",
+ default=4,
+ type=int, # will be overridden
+ help="Number of encoding layers in the transformer",
+ )
+ parser.add_argument(
+ "--dec_layers",
+ default=6,
+ type=int, # will be overridden
+ help="Number of decoding layers in the transformer",
+ )
+ parser.add_argument(
+ "--dim_feedforward",
+ default=2048,
+ type=int, # will be overridden
+ help="Intermediate size of the feedforward layers in the transformer blocks",
+ )
+ parser.add_argument(
+ "--hidden_dim",
+ default=256,
+ type=int, # will be overridden
+ help="Size of the embeddings (dimension of the transformer)",
+ )
+ parser.add_argument(
+ "--dropout", default=0.1, type=float, help="Dropout applied in the transformer"
+ )
+ parser.add_argument(
+ "--nheads",
+ default=8,
+ type=int, # will be overridden
+ help="Number of attention heads inside the transformer's attentions",
+ )
+ parser.add_argument(
+ "--num_queries",
+ default=400,
+ type=int, # will be overridden
+ help="Number of query slots",
+ )
+ parser.add_argument("--pre_norm", action="store_true")
+
+ # # * Segmentation
+ parser.add_argument(
+ "--masks",
+ action="store_true",
+ help="Train segmentation head if the flag is provided",
+ )
+
+ # repeat args in imitate_episodes just to avoid error. Will not be used
+ parser.add_argument("--eval", action="store_true")
+ parser.add_argument("--onscreen_render", action="store_true")
+ parser.add_argument(
+ "--ckpt_dir", action="store", type=str, help="ckpt_dir", required=False
+ )
+ parser.add_argument(
+ "--policy_class",
+ action="store",
+ type=str,
+ help="policy_class, capitalize",
+ required=False,
+ )
+ parser.add_argument(
+ "--task_name", action="store", type=str, help="task_name", required=False
+ )
+ parser.add_argument("--seed", action="store", type=int, help="seed", required=False)
+ parser.add_argument(
+ "--num_epochs", action="store", type=int, help="num_epochs", required=False
+ )
+ parser.add_argument(
+ "--kl_weight", action="store", type=int, help="KL Weight", required=False
+ )
+ parser.add_argument(
+ "--save_epoch",
+ action="store",
+ type=int,
+ help="save_epoch",
+ default=500,
+ required=False,
+ )
+ parser.add_argument(
+ "--chunk_size", action="store", type=int, help="chunk_size", required=False
+ )
+ parser.add_argument(
+ "--history_step", default=0, type=int, help="history_step", required=False
+ )
+ parser.add_argument(
+ "--predict_frame", default=0, type=int, help="predict_frame", required=False
+ )
+ # add image_width and image_height
+ parser.add_argument(
+ "--image_width", default=320, type=int, help="image_width", required=False
+ )
+ parser.add_argument(
+ "--image_height", default=240, type=int, help="image_height", required=False
+ )
+ parser.add_argument(
+ "--predict_only_last", action="store_true"
+ ) # only predict the last #predict_frame frame
+ parser.add_argument("--temporal_agg", action="store_true")
+ # visual tokenizer
+ parser.add_argument(
+ "--tokenizer_model_type",
+ default="DV",
+ type=str,
+ help="tokenizer_model_type, DV,CV,DI,CI",
+ )
+ parser.add_argument(
+ "--tokenizer_model_temporal_rate",
+ default=8,
+ type=int,
+ help="tokenizer_model_temporal_rate, 4,8",
+ )
+ parser.add_argument(
+ "--tokenizer_model_spatial_rate",
+ default=16,
+ type=int,
+ help="tokenizer_model_spatial_rate, 8,16",
+ )
+ parser.add_argument(
+ "--tokenizer_model_name",
+ default="Cosmos-Tokenizer-DV4x8x8",
+ type=str,
+ help="tokenizer_model_name",
+ )
+ parser.add_argument(
+ "--prediction_weight",
+ default=1,
+ type=float,
+ help="pred token Weight",
+ required=False,
+ )
+ parser.add_argument(
+ "--token_dim", default=6, type=int, help="token_dim", required=False
+ ) # token_pe_type
+ parser.add_argument(
+ "--patch_size", default=5, type=int, help="patch_size", required=False
+ ) # token_pe_type
+ parser.add_argument(
+ "--token_pe_type",
+ default="learned",
+ type=str,
+ help="token_pe_type",
+ required=False,
+ )
+ parser.add_argument("--nf", action="store_true")
+ parser.add_argument("--pretrain", action="store_true", required=False)
+ parser.add_argument("--is_wandb", action="store_true")
+ parser.add_argument("--mae", action="store_true")
+ # parser.add_argument('--seg', action='store_true')
+ # parser.add_argument('--seg_next', action='store_true')
+
+ # parameters for distributed training
+ parser.add_argument(
+ "--resume",
+ default="",
+ type=str,
+ metavar="PATH",
+ help="path to latest checkpoint (default: none)",
+ )
+ parser.add_argument(
+ "--world-size",
+ default=-1,
+ type=int,
+ help="number of nodes for distributed training",
+ )
+ parser.add_argument(
+ "--rank", default=-1, type=int, help="node rank for distributed training"
+ )
+ parser.add_argument(
+ "--dist-url",
+ default="tcp://224.66.41.62:23456",
+ type=str,
+ help="url used to set up distributed training",
+ )
+ parser.add_argument(
+ "--dist-backend", default="nccl", type=str, help="distributed backend"
+ )
+ # parser.add_argument(
+ # "--seed", default=None, type=int, help="seed for initializing training. "
+ # )
+ parser.add_argument("--gpu", default=None, type=int, help="GPU id to use.")
+ parser.add_argument(
+ "--multiprocessing-distributed",
+ action="store_true",
+ help="Use multi-processing distributed training to launch "
+ "N processes per node, which has N GPUs. This is the "
+ "fastest way to use PyTorch for either single node or "
+ "multi node data parallel training",
+ )
+ parser.add_argument(
+ "-j",
+ "--workers",
+ default=32,
+ type=int,
+ metavar="N",
+ help="number of data loading workers (default: 32)",
+ )
+
+ return parser
+
+
+def build_ACT_model_and_optimizer(args_override):
+ parser = argparse.ArgumentParser(
+ "DETR training and evaluation script", parents=[get_args_parser()]
+ )
+ args = parser.parse_args()
+
+ for k, v in args_override.items():
+ setattr(args, k, v)
+
+ if args_override["segmentation"]:
+ model = build_ACT_Seg_model(args)
+ else:
+ model = build_ACT_model(args)
+ model.cuda()
+
+ param_dicts = [
+ {
+ "params": [
+ p
+ for n, p in model.named_parameters()
+ if "backbone" not in n and p.requires_grad
+ ]
+ },
+ {
+ "params": [
+ p
+ for n, p in model.named_parameters()
+ if "backbone" in n and p.requires_grad
+ ],
+ "lr": args.lr_backbone,
+ },
+ ]
+ optimizer = torch.optim.AdamW(
+ param_dicts, lr=args.lr, weight_decay=args.weight_decay
+ )
+
+ return model, optimizer
+
+
+def build_ACTDiffusion_model_and_optimizer(args_override):
+ parser = argparse.ArgumentParser(
+ "DETR training and evaluation script", parents=[get_args_parser()]
+ )
+ args = parser.parse_args()
+ for k, v in args_override.items():
+ setattr(args, k, v)
+ # print('args',args) # get
+ model = build_ACTDiffusion_model(args)
+ model.cuda()
+
+ param_dicts = [
+ {
+ "params": [
+ p
+ for n, p in model.named_parameters()
+ if "backbone" not in n and p.requires_grad
+ ]
+ },
+ {
+ "params": [
+ p
+ for n, p in model.named_parameters()
+ if "backbone" in n and p.requires_grad
+ ],
+ "lr": args.lr_backbone,
+ },
+ ]
+ optimizer = torch.optim.AdamW(
+ param_dicts, lr=args.lr, weight_decay=args.weight_decay
+ )
+
+ return model, optimizer
+
+
+def build_ACTDiffusion_tactile_model_and_optimizer(args_override):
+ parser = argparse.ArgumentParser(
+ "DETR training and evaluation script", parents=[get_args_parser()]
+ )
+ args = parser.parse_args()
+ for k, v in args_override.items():
+ setattr(args, k, v)
+ # print('args',args) # get
+ model = build_ACTDiffusion_tactile_model(args)
+ model.cuda()
+
+ param_dicts = [
+ {
+ "params": [
+ p
+ for n, p in model.named_parameters()
+ if "backbone" not in n and p.requires_grad
+ ]
+ },
+ {
+ "params": [
+ p
+ for n, p in model.named_parameters()
+ if "backbone" in n and p.requires_grad
+ ],
+ "lr": args.lr_backbone,
+ },
+ ]
+ optimizer = torch.optim.AdamW(
+ param_dicts, lr=args.lr, weight_decay=args.weight_decay
+ )
+
+ return model, optimizer
+
+
+def build_diffusion_tp_model_and_optimizer(args_override):
+ parser = argparse.ArgumentParser(
+ "DETR training and evaluation script", parents=[get_args_parser()]
+ )
+ args = parser.parse_args()
+ for k, v in args_override.items():
+ setattr(args, k, v)
+ # print('args',args) # get
+ model = build_ACTDiffusion_tp_model(args)
+ model.cuda()
+
+ return model # , optimizer
+
+
+def build_diffusion_pp_model_and_optimizer(args_override):
+ parser = argparse.ArgumentParser(
+ "DETR training and evaluation script", parents=[get_args_parser()]
+ )
+ args = parser.parse_args()
+ for k, v in args_override.items():
+ setattr(args, k, v)
+ # print('args',args) # get
+ model = build_ACTDiffusion_pp_model(args)
+ model.cuda()
+
+ return model
+
+
+# discard
+
+
+def build_ACT_NF_model_and_optimizer(args_override):
+ parser = argparse.ArgumentParser(
+ "DETR training and evaluation script", parents=[get_args_parser()]
+ )
+ args = parser.parse_args()
+
+ for k, v in args_override.items():
+ setattr(args, k, v)
+
+ model = build_ACT_NF_model(args)
+ model.cuda()
+
+ param_dicts = [
+ {
+ "params": [
+ p
+ for n, p in model.named_parameters()
+ if "backbone" not in n and p.requires_grad
+ ]
+ },
+ {
+ "params": [
+ p
+ for n, p in model.named_parameters()
+ if "backbone" in n and p.requires_grad
+ ],
+ "lr": args.lr_backbone,
+ },
+ ]
+ optimizer = torch.optim.AdamW(
+ param_dicts, lr=args.lr, weight_decay=args.weight_decay
+ )
+
+ return model, optimizer
+
+
+def build_ACT_Dino_model_and_optimizer(args_override):
+ parser = argparse.ArgumentParser(
+ "DETR training and evaluation script", parents=[get_args_parser()]
+ )
+ args = parser.parse_args()
+
+ for k, v in args_override.items():
+ setattr(args, k, v)
+
+ model = build_ACT_dino_model(args)
+ model.cuda()
+
+ param_dicts = [
+ {
+ "params": [
+ p
+ for n, p in model.named_parameters()
+ if "backbone" not in n and p.requires_grad
+ ]
+ },
+ {
+ "params": [
+ p
+ for n, p in model.named_parameters()
+ if "backbone" in n and p.requires_grad
+ ],
+ "lr": args.lr_backbone,
+ },
+ ]
+ optimizer = torch.optim.AdamW(
+ param_dicts, lr=args.lr, weight_decay=args.weight_decay
+ )
+
+ return model, optimizer
+
+
+def build_ACT_jpeg_model_and_optimizer(args_override):
+ parser = argparse.ArgumentParser(
+ "DETR training and evaluation script", parents=[get_args_parser()]
+ )
+ args = parser.parse_args()
+
+ for k, v in args_override.items():
+ setattr(args, k, v)
+
+ model = build_ACT_jpeg_model(args)
+ model.cuda()
+
+ param_dicts = [
+ {
+ "params": [
+ p
+ for n, p in model.named_parameters()
+ if "backbone" not in n and p.requires_grad
+ ]
+ },
+ {
+ "params": [
+ p
+ for n, p in model.named_parameters()
+ if "backbone" in n and p.requires_grad
+ ],
+ "lr": args.lr_backbone,
+ },
+ ]
+ optimizer = torch.optim.AdamW(
+ param_dicts, lr=args.lr, weight_decay=args.weight_decay
+ )
+
+ return model, optimizer
+
+
+def build_ACT_jpeg_diffusion_model_and_optimizer(args_override):
+ parser = argparse.ArgumentParser(
+ "DETR training and evaluation script", parents=[get_args_parser()]
+ )
+ args = parser.parse_args()
+
+ for k, v in args_override.items():
+ setattr(args, k, v)
+
+ model = build_ACT_jpeg_diffusion_model(args)
+ model.cuda()
+
+ param_dicts = [
+ {
+ "params": [
+ p
+ for n, p in model.named_parameters()
+ if "backbone" not in n and p.requires_grad
+ ]
+ },
+ {
+ "params": [
+ p
+ for n, p in model.named_parameters()
+ if "backbone" in n and p.requires_grad
+ ],
+ "lr": args.lr_backbone,
+ },
+ ]
+ optimizer = torch.optim.AdamW(
+ param_dicts, lr=args.lr, weight_decay=args.weight_decay
+ )
+
+ return model, optimizer
+
+
+def build_ACT_jpeg_diffusion_seperate_model_and_optimizer(args_override):
+ parser = argparse.ArgumentParser(
+ "DETR training and evaluation script", parents=[get_args_parser()]
+ )
+ args = parser.parse_args()
+
+ for k, v in args_override.items():
+ setattr(args, k, v)
+
+ model = build_ACT_jpeg_diffusion_seperate_model(args)
+ model.cuda()
+
+ param_dicts = [
+ {
+ "params": [
+ p
+ for n, p in model.named_parameters()
+ if "backbone" not in n and p.requires_grad
+ ]
+ },
+ {
+ "params": [
+ p
+ for n, p in model.named_parameters()
+ if "backbone" in n and p.requires_grad
+ ],
+ "lr": args.lr_backbone,
+ },
+ ]
+ optimizer = torch.optim.AdamW(
+ param_dicts, lr=args.lr, weight_decay=args.weight_decay
+ )
+
+ return model, optimizer
+
+
+def build_nf_diffusion_seperate_model_and_optimizer(args_override):
+ parser = argparse.ArgumentParser(
+ "DETR training and evaluation script", parents=[get_args_parser()]
+ )
+ args = parser.parse_args()
+
+ for k, v in args_override.items():
+ setattr(args, k, v)
+
+ model = build_nf_diffusion_seperate_model(args)
+ model.cuda()
+
+ param_dicts = [
+ {
+ "params": [
+ p
+ for n, p in model.named_parameters()
+ if "backbone" not in n and p.requires_grad
+ ]
+ },
+ {
+ "params": [
+ p
+ for n, p in model.named_parameters()
+ if "backbone" in n and p.requires_grad
+ ],
+ "lr": args.lr_backbone,
+ },
+ ]
+ optimizer = torch.optim.AdamW(
+ param_dicts, lr=args.lr, weight_decay=args.weight_decay
+ )
+
+ return model, optimizer
+
+
+def build_CNNMLP_model_and_optimizer(args_override):
+ parser = argparse.ArgumentParser(
+ "DETR training and evaluation script", parents=[get_args_parser()]
+ )
+ args = parser.parse_args()
+
+ for k, v in args_override.items():
+ setattr(args, k, v)
+
+ model = build_CNNMLP_model(args)
+ model.cuda()
+
+ param_dicts = [
+ {
+ "params": [
+ p
+ for n, p in model.named_parameters()
+ if "backbone" not in n and p.requires_grad
+ ]
+ },
+ {
+ "params": [
+ p
+ for n, p in model.named_parameters()
+ if "backbone" in n and p.requires_grad
+ ],
+ "lr": args.lr_backbone,
+ },
+ ]
+ optimizer = torch.optim.AdamW(
+ param_dicts, lr=args.lr, weight_decay=args.weight_decay
+ )
+
+ return model, optimizer
+
diff --git a/ACT_DP_multitask/detr/models/__init__.py b/ACT_DP_multitask/detr/models/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..ecb30ad65567274fba7157b568ac0ff042d6a2cf
--- /dev/null
+++ b/ACT_DP_multitask/detr/models/__init__.py
@@ -0,0 +1,60 @@
+# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
+from .detr_vae import build as build_vae
+from .detr_vae import build_seg as build_vae_seg
+from .detr_vae_nfp import build as build_vae_nfp
+from .detr_vae import build_cnnmlp as build_cnnmlp
+from .detr_vae import build_dino as build_dino
+from .detr_vae import build_jpeg as build_jpeg
+from .detr_vae import build_jpeg_diffusion as build_jpeg_diffusion
+from .detr_vae import build_jpeg_diffusion_seperate as build_jpeg_diffusion_seperate
+from .detr_vae import build_nf_diffusion_seperate as build_nf_diffusion_seperate
+from .detr_vae import build_diffusion as build_diffusion
+from .detr_vae import build_diffusion_tp as build_diffusion_tp
+from .detr_vae import build_diffusion_tp_with_dual_visual_token as build_diffusion_tp_with_dual_visual_token
+from .detr_vae import build_diffusion_pp as build_diffusion_pp
+from .detr_vae import build_diffusion_tactile as build_diffusion_tactile
+
+def build_ACT_model(args):
+ return build_vae(args)
+
+def build_CNNMLP_model(args):
+ return build_cnnmlp(args)
+
+def build_ACTDiffusion_model(args):
+ return build_diffusion(args)
+
+def build_ACTDiffusion_tactile_model(args):
+ return build_diffusion_tactile(args)
+
+def build_ACTDiffusion_tp_model(args):
+ if args.diffusion_timestep_type == 'vis_cat': # HARDCODE whether use tokenizer feature for decoder & action prediction
+ print('Using dual visual token for decoder and action prediction')
+ return build_diffusion_tp_with_dual_visual_token(args)
+ else:
+ return build_diffusion_tp(args)
+
+def build_ACTDiffusion_pp_model(args):
+ return build_diffusion_pp(args)
+
+# discard
+def build_ACT_NF_model(args):
+ return build_vae_nfp(args)
+
+def build_ACT_Seg_model(args):
+ return build_vae_seg(args)
+
+def build_ACT_dino_model(args):
+ return build_dino(args)
+
+def build_ACT_jpeg_model(args):
+ return build_jpeg(args)
+
+def build_ACT_jpeg_diffusion_model(args):
+ return build_jpeg_diffusion(args)
+
+def build_ACT_jpeg_diffusion_seperate_model(args):
+ return build_jpeg_diffusion_seperate(args)
+
+def build_nf_diffusion_seperate_model(args):
+ return build_nf_diffusion_seperate(args)
+
diff --git a/ACT_DP_multitask/detr/models/__pycache__/__init__.cpython-310.pyc b/ACT_DP_multitask/detr/models/__pycache__/__init__.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..8ff241e6845d924770d91e296912b754e9bb1037
Binary files /dev/null and b/ACT_DP_multitask/detr/models/__pycache__/__init__.cpython-310.pyc differ
diff --git a/ACT_DP_multitask/detr/models/__pycache__/__init__.cpython-37.pyc b/ACT_DP_multitask/detr/models/__pycache__/__init__.cpython-37.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..f15f86d73e1cfd2c04218d8ab24b5388bac2890f
Binary files /dev/null and b/ACT_DP_multitask/detr/models/__pycache__/__init__.cpython-37.pyc differ
diff --git a/ACT_DP_multitask/detr/models/__pycache__/__init__.cpython-38.pyc b/ACT_DP_multitask/detr/models/__pycache__/__init__.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..75f1fbcf5fb261eb909e31debcfb0500be3f0857
Binary files /dev/null and b/ACT_DP_multitask/detr/models/__pycache__/__init__.cpython-38.pyc differ
diff --git a/ACT_DP_multitask/detr/models/__pycache__/backbone.cpython-310.pyc b/ACT_DP_multitask/detr/models/__pycache__/backbone.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..3057fddd7d741d12c184edcd2ee025c9b5f420f2
Binary files /dev/null and b/ACT_DP_multitask/detr/models/__pycache__/backbone.cpython-310.pyc differ
diff --git a/ACT_DP_multitask/detr/models/__pycache__/backbone.cpython-37.pyc b/ACT_DP_multitask/detr/models/__pycache__/backbone.cpython-37.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..def5d907456559aed4e23290080c1afaace64fb3
Binary files /dev/null and b/ACT_DP_multitask/detr/models/__pycache__/backbone.cpython-37.pyc differ
diff --git a/ACT_DP_multitask/detr/models/__pycache__/backbone.cpython-38.pyc b/ACT_DP_multitask/detr/models/__pycache__/backbone.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..d9fd912bbdd116f42baa238c0ce16f455a093efc
Binary files /dev/null and b/ACT_DP_multitask/detr/models/__pycache__/backbone.cpython-38.pyc differ
diff --git a/ACT_DP_multitask/detr/models/__pycache__/detr_vae.cpython-310.pyc b/ACT_DP_multitask/detr/models/__pycache__/detr_vae.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..3d163d23c7998a92d59b27ad4ff8dd5106b232f8
Binary files /dev/null and b/ACT_DP_multitask/detr/models/__pycache__/detr_vae.cpython-310.pyc differ
diff --git a/ACT_DP_multitask/detr/models/__pycache__/detr_vae.cpython-37.pyc b/ACT_DP_multitask/detr/models/__pycache__/detr_vae.cpython-37.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..e8788a652254702f82e6b06e34ac510ddee44bc7
Binary files /dev/null and b/ACT_DP_multitask/detr/models/__pycache__/detr_vae.cpython-37.pyc differ
diff --git a/ACT_DP_multitask/detr/models/__pycache__/detr_vae.cpython-38.pyc b/ACT_DP_multitask/detr/models/__pycache__/detr_vae.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..1d30aa9eb4e5fa98a78e531e94b1dc39b08073ef
Binary files /dev/null and b/ACT_DP_multitask/detr/models/__pycache__/detr_vae.cpython-38.pyc differ
diff --git a/ACT_DP_multitask/detr/models/__pycache__/detr_vae_nfp.cpython-310.pyc b/ACT_DP_multitask/detr/models/__pycache__/detr_vae_nfp.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..144a41ad3f0275e9817bc17cf6f00e5fc58706d9
Binary files /dev/null and b/ACT_DP_multitask/detr/models/__pycache__/detr_vae_nfp.cpython-310.pyc differ
diff --git a/ACT_DP_multitask/detr/models/__pycache__/detr_vae_nfp.cpython-37.pyc b/ACT_DP_multitask/detr/models/__pycache__/detr_vae_nfp.cpython-37.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..f1d9982b9c3a5358b5933b69137a0fcffc65d192
Binary files /dev/null and b/ACT_DP_multitask/detr/models/__pycache__/detr_vae_nfp.cpython-37.pyc differ
diff --git a/ACT_DP_multitask/detr/models/__pycache__/detr_vae_nfp.cpython-38.pyc b/ACT_DP_multitask/detr/models/__pycache__/detr_vae_nfp.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..aad89d2ddf1fcaea00dbb7db2fd026a0685cfb45
Binary files /dev/null and b/ACT_DP_multitask/detr/models/__pycache__/detr_vae_nfp.cpython-38.pyc differ
diff --git a/ACT_DP_multitask/detr/models/__pycache__/position_encoding.cpython-310.pyc b/ACT_DP_multitask/detr/models/__pycache__/position_encoding.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..31228f52a00e0f77bf929b0d0edaa3c87da75d12
Binary files /dev/null and b/ACT_DP_multitask/detr/models/__pycache__/position_encoding.cpython-310.pyc differ
diff --git a/ACT_DP_multitask/detr/models/__pycache__/position_encoding.cpython-37.pyc b/ACT_DP_multitask/detr/models/__pycache__/position_encoding.cpython-37.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..b0a4a6e6597403eadb101c3d570b8f24f4eebcef
Binary files /dev/null and b/ACT_DP_multitask/detr/models/__pycache__/position_encoding.cpython-37.pyc differ
diff --git a/ACT_DP_multitask/detr/models/__pycache__/position_encoding.cpython-38.pyc b/ACT_DP_multitask/detr/models/__pycache__/position_encoding.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..4ed830ec82cfc73f6bd2061da7689dbc780f9d06
Binary files /dev/null and b/ACT_DP_multitask/detr/models/__pycache__/position_encoding.cpython-38.pyc differ
diff --git a/ACT_DP_multitask/detr/models/__pycache__/resnet_film.cpython-310.pyc b/ACT_DP_multitask/detr/models/__pycache__/resnet_film.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..ebb3e72a11b2f267f60b4307ff087167dd1f2ad4
Binary files /dev/null and b/ACT_DP_multitask/detr/models/__pycache__/resnet_film.cpython-310.pyc differ
diff --git a/ACT_DP_multitask/detr/models/__pycache__/transformer.cpython-310.pyc b/ACT_DP_multitask/detr/models/__pycache__/transformer.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..444f11624b6fe0783733bec98d810f1627f996aa
Binary files /dev/null and b/ACT_DP_multitask/detr/models/__pycache__/transformer.cpython-310.pyc differ
diff --git a/ACT_DP_multitask/detr/models/__pycache__/transformer.cpython-37.pyc b/ACT_DP_multitask/detr/models/__pycache__/transformer.cpython-37.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..b3b7fd65409dbb49db23166019d58475f20f9518
Binary files /dev/null and b/ACT_DP_multitask/detr/models/__pycache__/transformer.cpython-37.pyc differ
diff --git a/ACT_DP_multitask/detr/models/__pycache__/transformer.cpython-38.pyc b/ACT_DP_multitask/detr/models/__pycache__/transformer.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..515275a7f9e090bf6407253ae2f7183b6831e1c0
Binary files /dev/null and b/ACT_DP_multitask/detr/models/__pycache__/transformer.cpython-38.pyc differ
diff --git a/ACT_DP_multitask/detr/models/__pycache__/vision_transformer.cpython-310.pyc b/ACT_DP_multitask/detr/models/__pycache__/vision_transformer.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..761b345f1415a83be930b2a1c14c974d1016921f
Binary files /dev/null and b/ACT_DP_multitask/detr/models/__pycache__/vision_transformer.cpython-310.pyc differ
diff --git a/ACT_DP_multitask/detr/models/__pycache__/vision_transformer.cpython-37.pyc b/ACT_DP_multitask/detr/models/__pycache__/vision_transformer.cpython-37.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..2eead987d8c11019e247f723f3e176b98340b352
Binary files /dev/null and b/ACT_DP_multitask/detr/models/__pycache__/vision_transformer.cpython-37.pyc differ
diff --git a/ACT_DP_multitask/detr/models/__pycache__/vision_transformer.cpython-38.pyc b/ACT_DP_multitask/detr/models/__pycache__/vision_transformer.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..0342901b46494e03dde2ec0a1425f321ca02ac42
Binary files /dev/null and b/ACT_DP_multitask/detr/models/__pycache__/vision_transformer.cpython-38.pyc differ
diff --git a/ACT_DP_multitask/detr/models/backbone.py b/ACT_DP_multitask/detr/models/backbone.py
new file mode 100644
index 0000000000000000000000000000000000000000..82fb532b5db2e03b93c83b7bea0b608a5f4d57eb
--- /dev/null
+++ b/ACT_DP_multitask/detr/models/backbone.py
@@ -0,0 +1,209 @@
+# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
+"""
+Backbone modules.
+"""
+from collections import OrderedDict
+
+import torch
+import torch.nn.functional as F
+import torchvision
+from torch import nn
+from torchvision.models._utils import IntermediateLayerGetter
+from typing import Dict, List
+from typing import Any, Dict, List, Mapping, Optional
+from ..util.misc import NestedTensor, is_main_process
+
+from .position_encoding import build_position_encoding
+from .resnet_film import resnet18 as resnet18_film
+from .resnet_film import resnet34 as resnet34_film
+import IPython
+e = IPython.embed
+
+class FrozenBatchNorm2d(torch.nn.Module):
+ """
+ BatchNorm2d where the batch statistics and the affine parameters are fixed.
+
+ Copy-paste from torchvision.misc.ops with added eps before rqsrt,
+ without which any other policy_models than torchvision.policy_models.resnet[18,34,50,101]
+ produce nans.
+ """
+
+ def __init__(self, n):
+ super(FrozenBatchNorm2d, self).__init__()
+ self.register_buffer("weight", torch.ones(n))
+ self.register_buffer("bias", torch.zeros(n))
+ self.register_buffer("running_mean", torch.zeros(n))
+ self.register_buffer("running_var", torch.ones(n))
+
+ def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict,
+ missing_keys, unexpected_keys, error_msgs):
+ num_batches_tracked_key = prefix + 'num_batches_tracked'
+ if num_batches_tracked_key in state_dict:
+ del state_dict[num_batches_tracked_key]
+
+ super(FrozenBatchNorm2d, self)._load_from_state_dict(
+ state_dict, prefix, local_metadata, strict,
+ missing_keys, unexpected_keys, error_msgs)
+
+ def forward(self, x):
+ # move reshapes to the beginning
+ # to make it fuser-friendly
+ w = self.weight.reshape(1, -1, 1, 1)
+ b = self.bias.reshape(1, -1, 1, 1)
+ rv = self.running_var.reshape(1, -1, 1, 1)
+ rm = self.running_mean.reshape(1, -1, 1, 1)
+ eps = 1e-5
+ scale = w * (rv + eps).rsqrt()
+ bias = b - rm * scale
+ return x * scale + bias
+
+
+class BackboneBase(nn.Module):
+
+ def __init__(self, backbone: nn.Module, train_backbone: bool, num_channels: int, return_interm_layers: bool):
+ super().__init__()
+ # for name, parameter in backbone.named_parameters(): # only train later layers # TODO do we want this?
+ # if not train_backbone or 'layer2' not in name and 'layer3' not in name and 'layer4' not in name:
+ # parameter.requires_grad_(False)
+ if return_interm_layers:
+ return_layers = {"layer1": "0", "layer2": "1", "layer3": "2", "layer4": "3"}
+ else:
+ return_layers = {'layer4': "0"}
+ self.body = IntermediateLayerGetter(backbone, return_layers=return_layers)
+ self.num_channels = num_channels
+
+ def forward(self, tensor):
+ xs = self.body(tensor)
+ return xs
+ # out: Dict[str, NestedTensor] = {}
+ # for name, x in xs.items():
+ # m = tensor_list.mask
+ # assert m is not None
+ # mask = F.interpolate(m[None].float(), size=x.shape[-2:]).to(torch.bool)[0]
+ # out[name] = NestedTensor(x, mask)
+ # return out
+
+
+class Backbone(BackboneBase):
+ """ResNet backbone with frozen BatchNorm."""
+ def __init__(self, name: str,
+ train_backbone: bool,
+ return_interm_layers: bool,
+ dilation: bool):
+ backbone = getattr(torchvision.models, name)(
+ replace_stride_with_dilation=[False, False, dilation],
+ pretrained=is_main_process(), norm_layer=FrozenBatchNorm2d) # pretrained # TODO do we want frozen batch_norm??
+ num_channels = 512 if name in ('resnet18', 'resnet34') else 2048
+ super().__init__(backbone, train_backbone, num_channels, return_interm_layers)
+
+# ==== ResNet Backbone ====
+class ResNetFilmBackbone(nn.Module):
+ def __init__(self, embedding_name: str, pretrained: bool = False,
+ film_config: Optional[Mapping[str, Any]] = None):
+ super().__init__()
+ self._pretrained = pretrained
+ weights = 'IMAGENET1K_V1' if pretrained else None
+ if embedding_name in ('resnet34_film', 'resnet34'):
+ backbone = resnet34_film(weights=weights, film_config=film_config, pretrained=pretrained)
+ embedding_dim = 512
+ elif embedding_name in ('resnet18_film', 'resnet18'):
+ backbone = resnet18_film(weights=weights, film_config=film_config, pretrained=pretrained)
+ embedding_dim = 512
+ else:
+ raise NotImplementedError
+
+ self.resnet_film_model = backbone
+ self._embedding_dim = embedding_dim
+ self.resnet_film_model.fc = nn.Identity()
+ self.resnet_film_model.avgpool = nn.Identity()
+
+ self.num_channels = self._embedding_dim
+
+ # FiLM config
+ self.film_config = film_config
+ if film_config is not None and film_config['use']:
+ film_models = []
+ for layer_idx, num_blocks in enumerate(self.resnet_film_model.layers):
+ if layer_idx in film_config['use_in_layers']:
+ num_planes = self.resnet_film_model.film_planes[layer_idx]
+ film_model_layer = nn.Linear(
+ film_config['task_embedding_dim'], num_blocks * 2 * num_planes)
+ else:
+ film_model_layer = None
+ film_models.append(film_model_layer)
+
+ self.film_models = nn.ModuleList(film_models)
+
+ def forward(self, x, texts: Optional[List[str]] = None, task_emb: Optional[torch.Tensor] = None, **kwargs):
+ film_outputs = None
+ if self.film_config is not None and self.film_config['use']:
+ film_outputs = []
+ for layer_idx, num_blocks in enumerate(self.resnet_film_model.layers):
+ if self.film_config['use'] and self.film_models[layer_idx] is not None:
+ film_features = self.film_models[layer_idx](task_emb)
+ else:
+ film_features = None
+ film_outputs.append(film_features)
+ return self.resnet_film_model(x, film_features=film_outputs, flatten=False)
+
+ @property
+ def embed_dim(self):
+ return self._embedding_dim
+
+
+# class Joiner(nn.Sequential):
+# def __init__(self, backbone, position_embedding):
+# super().__init__(backbone, position_embedding)
+
+# def forward(self, tensor_list: NestedTensor, task_emb:NestedTensor):
+# xs = self[0](tensor_list)
+# out: List[NestedTensor] = []
+# pos = []
+# for name, x in xs.items():
+# out.append(x)
+# # position encoding
+# pos.append(self[1](x).to(x.dtype))
+
+# return out, pos
+
+class Joiner(nn.Sequential):
+ def __init__(self, backbone, position_embedding):
+ super().__init__(backbone, position_embedding)
+
+ def forward(self, tensor_list: NestedTensor, task_emb: Optional[Any] = None):
+ if task_emb is not None:
+ xs = self[0](tensor_list, task_emb=task_emb)
+ # Make a dictionary out of the last layer outputs since we don't have IntermediateLayerGetter
+ xs = {'0': xs}
+ else:
+ xs = self[0](tensor_list)
+ out: List[NestedTensor] = []
+ pos = []
+ for name, x in xs.items():
+ out.append(x)
+ # position encoding
+ pos.append(self[1](x).to(x.dtype))
+
+ return out, pos
+
+def build_backbone(args):
+ position_embedding = build_position_encoding(args)
+ train_backbone = args.lr_backbone > 0
+ return_interm_layers = args.masks
+ backbone = Backbone(args.backbone, train_backbone, return_interm_layers, args.dilation)
+ model = Joiner(backbone, position_embedding)
+ model.num_channels = backbone.num_channels
+ return model
+
+def build_film_backbone(args):
+ position_embedding = build_position_encoding(args)
+ film_config = {
+ 'use': True,
+ 'use_in_layers': [1, 2, 3],
+ 'task_embedding_dim': 512,
+ 'film_planes': [64, 128, 256, 512],
+ }
+ backbone = ResNetFilmBackbone(args.backbone, film_config=film_config)
+ model = Joiner(backbone, position_embedding)
+ model.num_channels = backbone.num_channels
+ return model
\ No newline at end of file
diff --git a/ACT_DP_multitask/detr/models/detr_vae.py b/ACT_DP_multitask/detr/models/detr_vae.py
new file mode 100644
index 0000000000000000000000000000000000000000..78fbbe3913bbed57b9a29f9df9b3490a578476c2
--- /dev/null
+++ b/ACT_DP_multitask/detr/models/detr_vae.py
@@ -0,0 +1,3193 @@
+# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
+"""
+DETR model and criterion classes.
+"""
+import torch
+from torch import nn
+from torch.autograd import Variable
+from .backbone import build_backbone, build_film_backbone
+from .transformer import *
+from .vision_transformer import (
+ Block,
+ get_2d_sincos_pos_embed,
+ get_2d_sincos_pos_embed_v2,
+)
+from einops import rearrange
+import numpy as np
+
+import IPython
+
+e = IPython.embed
+
+
+class FourierFeatureMapping(nn.Module):
+ def __init__(self, input_dim, mapping_size, scale=10.0):
+ """
+ Args:
+ input_dim (int): input dimension.
+ mapping_size (int): Fourier Features output dimension.
+ scale (float): scale factor for frequencies.
+ """
+ super(FourierFeatureMapping, self).__init__()
+ self.B = torch.randn((mapping_size, input_dim)) * scale
+
+ def forward(self, x):
+ """
+ Args:
+ x (Tensor): [batch_size, input_dim]
+ Returns:
+ Tensor: Fourier Features [batch_size, mapping_size * 2]
+ """
+ x_proj = 2 * torch.pi * x @ self.B.T # [batch_size, mapping_size]
+ return torch.cat(
+ [torch.sin(x_proj), torch.cos(x_proj)], dim=-1
+ ) # Concatenate sin and cos
+
+
+class MLPWithFourierFeatures(nn.Module):
+ def __init__(self, input_dim, mapping_size, hidden_dim, output_dim):
+ """
+ Args:
+ input_dim (int): input dimension.
+ mapping_size (int): Fourier Features output dimension.
+ hidden_dim (int): MLP layer layer dimension.
+ output_dim (int): MLP output dimension.
+ """
+ super(MLPWithFourierFeatures, self).__init__()
+ self.fourier_mapping = FourierFeatureMapping(input_dim, mapping_size)
+ self.mlp = nn.Sequential(
+ nn.Linear(mapping_size * 2, hidden_dim),
+ nn.ReLU(),
+ nn.Linear(hidden_dim, hidden_dim),
+ nn.ReLU(),
+ nn.Linear(hidden_dim, output_dim),
+ )
+
+ def forward(self, x):
+ """
+ Args:
+ x (Tensor): [batch_size, input_dim]
+ Returns:
+ Tensor: MLP
+ """
+ x_mapped = self.fourier_mapping(x) # Fourier Features Mapping
+ output = self.mlp(x_mapped) # MLP
+ return output
+
+
+def get_spatial_temporal_positional_encoding(height, width, seq_len, dim):
+ """
+ Generate spatial-temporal positional encoding for video data.
+
+ Args:
+ height (int): Height of the video frames.
+ width (int): Width of the video frames.
+ seq_len (int): Number of frames in the video.
+ dim (int): Embedding dimension.
+ device (str): Device to store the tensor.
+
+ Returns:
+ torch.Tensor: Positional encoding with shape (seq_len, height, width, dim).
+ """
+ assert dim % 2 == 0, "Embedding dimension (dim) must be even."
+
+ spatial_dim = dim // 2
+ temporal_dim = dim - spatial_dim
+
+ # Temporal encoding
+ temporal_encoding = get_sinusoid_encoding_table(seq_len, temporal_dim)[
+ 0
+ ] # Shape: (seq_len, temporal_dim)
+
+ # Spatial encoding
+ position_h = torch.arange(height, dtype=torch.float32).unsqueeze(1) # (height, 1)
+ position_w = torch.arange(width, dtype=torch.float32).unsqueeze(1) # (1, width)
+ div_term_h = torch.exp(
+ -torch.arange(0, spatial_dim, 2, dtype=torch.float32)
+ * (math.log(10000.0) / spatial_dim)
+ ).unsqueeze(0)
+ div_term_w = torch.exp(
+ -torch.arange(0, spatial_dim, 2, dtype=torch.float32)
+ * (math.log(10000.0) / spatial_dim)
+ ).unsqueeze(0)
+
+ spatial_encoding_h = torch.sin(
+ position_h * div_term_h
+ ) # (height, spatial_dim // 2)
+ spatial_encoding_w = torch.cos(position_w * div_term_w) # (width, spatial_dim // 2)
+
+ # print(spatial_encoding_h.shape, spatial_encoding_w.shape)
+ # Combine H and W spatial encodings
+ spatial_encoding_h = spatial_encoding_h.unsqueeze(1).expand(
+ -1, width, -1
+ ) # (height, width£¬ spatial_dim // 2)
+ spatial_encoding_w = spatial_encoding_w.unsqueeze(0).expand(
+ height, -1, -1
+ ) # (height, width , spatial_dim // 2)
+ spatial_encoding = torch.cat(
+ [spatial_encoding_h, spatial_encoding_w], dim=-1
+ ) # (height, width , spatial_dim)
+ spatial_encoding = spatial_encoding.unsqueeze(0).repeat(
+ seq_len, 1, 1, 1
+ ) # (seq_len, height, width , spatial_dim)
+
+ # Combine spatial and temporal
+ temporal_encoding = (
+ temporal_encoding.unsqueeze(1).unsqueeze(1).repeat(1, height, width, 1)
+ ) # (seq_len, height, width, temporal_dim)
+ # print(spatial_encoding.shape, temporal_encoding.shape)
+ pos_encoding = torch.cat(
+ [spatial_encoding, temporal_encoding], dim=-1
+ ) # Combine spatial and temporal
+
+ return pos_encoding # (seq_len, height, width, dim)
+
+
+def get_sinusoid_encoding_table(n_position, d_hid):
+ def get_position_angle_vec(position):
+ return [
+ position / np.power(10000, 2 * (hid_j // 2) / d_hid)
+ for hid_j in range(d_hid)
+ ]
+
+ sinusoid_table = np.array(
+ [get_position_angle_vec(pos_i) for pos_i in range(n_position)]
+ )
+ # print(sinusoid_table.shape)
+ sinusoid_table[:, 0::2] = np.sin(sinusoid_table[:, 0::2]) # dim 2i
+ sinusoid_table[:, 1::2] = np.cos(sinusoid_table[:, 1::2]) # dim 2i+1
+
+ return torch.FloatTensor(sinusoid_table).unsqueeze(0)
+
+
+def reparametrize(mu, logvar):
+ std = logvar.div(2).exp()
+ eps = Variable(std.data.new(std.size()).normal_())
+ return mu + std * eps
+
+
+def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
+ """
+ embed_dim: output dimension for each position
+ pos: a list of positions to be encoded: size (M,)
+ out: (M, D)
+ """
+ assert embed_dim % 2 == 0
+ omega = np.arange(embed_dim // 2, dtype=np.float64)
+ omega /= embed_dim / 2.0
+ omega = 1.0 / 10000**omega # (D/2,)
+
+ if not isinstance(pos, np.ndarray):
+ pos = np.array(pos, dtype=np.float64)
+ pos = pos.reshape(-1) # (M,)
+ out = np.einsum("m,d->md", pos, omega) # (M, D/2), outer product
+
+ emb_sin = np.sin(out) # (M, D/2)
+ emb_cos = np.cos(out) # (M, D/2)
+
+ emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D)
+ return emb
+
+
+def get_nd_sincos_pos_embed_from_grid(embed_dim, grid_sizes): # useful
+ """
+ embed_dim: output dimension for each position
+ grid_sizes: the grids sizes in each dimension (K,). for example [T, C, P] n_his, n_view,n_patch
+ out: (grid_sizes[0], ..., grid_sizes[K-1], D) for example (T, C, P, D)
+ """
+ num_sizes = len(grid_sizes)
+ # For grid size of 1, we do not need to add any positional embedding
+ num_valid_sizes = len([x for x in grid_sizes if x > 1])
+ emb = np.zeros(grid_sizes + (embed_dim,))
+ # Uniformly divide the embedding dimension for each grid size
+ dim_for_each_grid = embed_dim // num_valid_sizes
+ # To make it even
+ if dim_for_each_grid % 2 != 0:
+ dim_for_each_grid -= 1
+ valid_size_idx = 0
+ for size_idx in range(num_sizes):
+ grid_size = grid_sizes[size_idx]
+ if grid_size <= 1:
+ continue
+ pos = np.arange(grid_size)
+ posemb_shape = [1] * len(grid_sizes) + [dim_for_each_grid]
+ posemb_shape[size_idx] = -1
+ emb[
+ ...,
+ valid_size_idx
+ * dim_for_each_grid : (valid_size_idx + 1)
+ * dim_for_each_grid,
+ ] += get_1d_sincos_pos_embed_from_grid(dim_for_each_grid, pos).reshape(
+ posemb_shape
+ )
+ valid_size_idx += 1
+ return emb
+
+
+class DETRVAE(nn.Module):
+ """This is the DETR module that performs object detection"""
+
+ def __init__(
+ self, backbones, transformer, encoder, state_dim, num_queries, camera_names
+ ):
+ """Initializes the model.
+ Parameters:
+ backbones: torch module of the backbone to be used. See backbone.py
+ transformer: torch module of the transformer architecture. See transformer.py
+ state_dim: robot state dimension of the environment
+ num_queries: number of object queries, ie detection slot. This is the maximal number of objects
+ DETR can detect in a single image. For COCO, we recommend 100 queries.
+ aux_loss: True if auxiliary decoding losses (loss at each decoder layer) are to be used.
+ """
+ super().__init__()
+ self.num_queries = num_queries
+ self.camera_names = camera_names
+ self.transformer = transformer
+ self.encoder = encoder
+ hidden_dim = transformer.d_model
+ self.action_head = nn.Linear(hidden_dim, state_dim)
+ self.is_pad_head = nn.Linear(hidden_dim, 1)
+ self.query_embed = nn.Embedding(num_queries, hidden_dim)
+ if backbones is not None:
+ self.input_proj = nn.Conv2d(
+ backbones[0].num_channels, hidden_dim, kernel_size=1
+ )
+ self.backbones = nn.ModuleList(backbones)
+ self.input_proj_robot_state = nn.Linear(14, hidden_dim)
+ else:
+ # input_dim = 14 + 7 # robot_state + env_state
+ self.input_proj_robot_state = nn.Linear(14, hidden_dim)
+ self.input_proj_env_state = nn.Linear(7, hidden_dim)
+ self.pos = torch.nn.Embedding(2, hidden_dim)
+ self.backbones = None
+
+ # encoder extra parameters
+ self.latent_dim = 32 # final size of latent z # TODO tune
+ self.cls_embed = nn.Embedding(1, hidden_dim) # extra cls token embedding
+ self.encoder_action_proj = nn.Linear(
+ 14, hidden_dim
+ ) # project action to embedding
+ self.encoder_joint_proj = nn.Linear(14, hidden_dim) # project qpos to embedding
+ self.latent_proj = nn.Linear(
+ hidden_dim, self.latent_dim * 2
+ ) # project hidden state to latent std, var
+ self.register_buffer(
+ "pos_table", get_sinusoid_encoding_table(1 + 1 + num_queries, hidden_dim)
+ ) # [CLS], qpos, a_seq
+
+ # decoder extra parameters
+ self.latent_out_proj = nn.Linear(
+ self.latent_dim, hidden_dim
+ ) # project latent sample to embedding
+ self.additional_pos_embed = nn.Embedding(
+ 2, hidden_dim
+ ) # learned position embedding for proprio and latent
+
+ def forward(self, qpos, image, env_state, actions=None, is_pad=None):
+ """
+ qpos: batch, qpos_dim
+ image: batch, num_cam, channel, height, width
+ env_state: None
+ actions: batch, seq, action_dim
+ """
+ is_training = actions is not None # train or val
+ bs, _ = qpos.shape
+ ### Obtain latent z from action sequence
+ if is_training:
+ # project action sequence to embedding dim, and concat with a CLS token
+ action_embed = self.encoder_action_proj(actions) # (bs, seq, hidden_dim)
+ qpos_embed = self.encoder_joint_proj(qpos) # (bs, hidden_dim)
+ qpos_embed = torch.unsqueeze(qpos_embed, axis=1) # (bs, 1, hidden_dim)
+ cls_embed = self.cls_embed.weight # (1, hidden_dim)
+ cls_embed = torch.unsqueeze(cls_embed, axis=0).repeat(
+ bs, 1, 1
+ ) # (bs, 1, hidden_dim)
+ encoder_input = torch.cat(
+ [cls_embed, qpos_embed, action_embed], axis=1
+ ) # (bs, seq+1, hidden_dim)
+ encoder_input = encoder_input.permute(1, 0, 2) # (seq+1, bs, hidden_dim)
+ # do not mask cls token
+ cls_joint_is_pad = torch.full((bs, 2), False).to(
+ qpos.device
+ ) # False: not a padding
+ is_pad = torch.cat([cls_joint_is_pad, is_pad], axis=1) # (bs, seq+1)
+ # obtain position embedding
+ pos_embed = self.pos_table.clone().detach()
+ pos_embed = pos_embed.permute(1, 0, 2) # (seq+1, 1, hidden_dim)
+ # query model
+ encoder_output = self.encoder(
+ encoder_input, pos=pos_embed, src_key_padding_mask=is_pad
+ )
+ encoder_output = encoder_output[0] # take cls output only
+ latent_info = self.latent_proj(encoder_output)
+ mu = latent_info[:, : self.latent_dim]
+ logvar = latent_info[:, self.latent_dim :]
+ latent_sample = reparametrize(mu, logvar)
+ latent_input = self.latent_out_proj(latent_sample)
+ else:
+ mu = logvar = None
+ latent_sample = torch.zeros([bs, self.latent_dim], dtype=torch.float32).to(
+ qpos.device
+ )
+ latent_input = self.latent_out_proj(latent_sample)
+
+ if self.backbones is not None:
+ # Image observation features and position embeddings
+ all_cam_features = []
+ all_cam_pos = []
+ for cam_id, cam_name in enumerate(self.camera_names):
+ features, pos = self.backbones[0](image[:, cam_id]) # HARDCODED
+ features = features[0] # take the last layer feature
+ pos = pos[0]
+ all_cam_features.append(self.input_proj(features))
+ all_cam_pos.append(pos)
+ # proprioception features
+ proprio_input = self.input_proj_robot_state(qpos)
+ # fold camera dimension into width dimension
+ src = torch.cat(all_cam_features, axis=3)
+ pos = torch.cat(all_cam_pos, axis=3)
+ hs = self.transformer(
+ src,
+ None,
+ self.query_embed.weight,
+ pos,
+ latent_input,
+ proprio_input,
+ self.additional_pos_embed.weight,
+ )[0]
+ else:
+ qpos = self.input_proj_robot_state(qpos)
+ env_state = self.input_proj_env_state(env_state)
+ transformer_input = torch.cat([qpos, env_state], axis=1) # seq length = 2
+ hs = self.transformer(
+ transformer_input, None, self.query_embed.weight, self.pos.weight
+ )[0]
+ a_hat = self.action_head(hs)
+ is_pad_hat = self.is_pad_head(hs)
+ return a_hat, is_pad_hat, [mu, logvar]
+
+
+class DETRVAE_Denoise(nn.Module):
+ """This is the DETR module that performs object detection"""
+
+ def __init__(
+ self,
+ backbones,
+ transformer,
+ encoder,
+ state_dim,
+ num_queries,
+ camera_names,
+ history_step,
+ disable_vae_latent,
+ is_multi_task=False,
+ use_film=False
+ ):
+ """Initializes the model.
+ Parameters:
+ backbones: torch module of the backbone to be used. See backbone.py
+ transformer: torch module of the transformer architecture. See transformer.py
+ state_dim: robot state dimension of the environment
+ num_queries: number of object queries, ie detection slot. This is the maximal number of objects
+ DETR can detect in a single image. For COCO, we recommend 100 queries.
+ aux_loss: True if auxiliary decoding losses (loss at each decoder layer) are to be used.
+ """
+ super().__init__()
+ self.num_queries = num_queries
+ self.camera_names = camera_names
+ self.transformer = transformer
+ self.encoder = encoder
+ self.disable_vae_latent = disable_vae_latent
+ hidden_dim = transformer.d_model # 512
+ self.hidden_dim = hidden_dim
+ # self.action_head = nn.Linear(hidden_dim, state_dim)
+ self.action_head = nn.Sequential(
+ nn.Linear(hidden_dim, hidden_dim),
+ nn.SiLU(),
+ nn.Linear(hidden_dim, state_dim), # TODO add layers
+ )
+ self.is_pad_head = nn.Linear(hidden_dim, 1)
+ self.query_embed = nn.Embedding(num_queries, hidden_dim)
+
+ self.multi_task = is_multi_task
+ self.use_film = use_film
+
+ print('use_film', use_film)
+ print('is_multi_task', is_multi_task)
+
+ if backbones is not None:
+ self.input_proj = nn.Conv2d(
+ backbones[0].num_channels * (history_step + 1) , hidden_dim, kernel_size=1
+ )
+ self.backbones = nn.ModuleList(backbones)
+ self.input_proj_robot_state = nn.Linear(14, hidden_dim) # proprioception
+ else:
+ # input_dim = 14 + 7 # robot_state + env_state
+ self.input_proj_robot_state = nn.Linear(14, hidden_dim)
+ self.input_proj_env_state = nn.Linear(7, hidden_dim)
+ self.pos = torch.nn.Embedding(2, hidden_dim)
+ self.backbones = None
+
+ # encoder extra parameters
+ self.latent_dim = 32 # final size of latent z # TODO tune
+ self.cls_embed = nn.Embedding(1, hidden_dim) # extra cls token embedding
+ self.encoder_action_proj = nn.Linear(
+ 14, hidden_dim
+ ) # project action to embedding
+ self.encoder_joint_proj = nn.Linear(14, hidden_dim) # project qpos to embedding
+ self.latent_proj = nn.Linear(
+ hidden_dim, self.latent_dim * 2
+ ) # project hidden state to latent std, var
+ self.register_buffer(
+ "pos_table", get_sinusoid_encoding_table(1 + 1 + history_step + num_queries, hidden_dim)
+ ) # [CLS], qpos, a_seq
+ self.history_step = history_step
+ # decoder extra parameters vae latent to decoder token space
+ self.latent_out_proj = nn.Linear(
+ self.latent_dim, hidden_dim
+ ) # project latent sample to embedding
+ if self.multi_task: # include proprio and latent and language token
+ self.additional_pos_embed = nn.Embedding(
+ 2 + history_step + 1, hidden_dim
+ )
+ else:
+ self.additional_pos_embed = nn.Embedding(
+ 2 + history_step, hidden_dim
+ ) # learned position embedding for proprio and latent
+
+ self.proj_text_emb = nn.Linear(4096, hidden_dim) # TODO hard code text embedding dim
+ # src, mask, query_embed, pos_embed, latent_input=None, proprio_input=None, additional_pos_embed=None, noisy_actions = None, denoise_steps=None
+ def forward(
+ self,
+ qpos,
+ image,
+ env_state,
+ actions=None,
+ is_pad=None,
+ denoise_steps=None,
+ is_training=True,
+ task_emb=None
+ ):
+ """
+ qpos: batch, 1+history, qpos_dim
+ image: batch, 1+history, num_cam, channel, height, width
+ env_state: None
+ actions: batch, seq, action_dim noisy action
+ denoise_step: int, the step of denoise
+ """
+ # is_training = actions is not None # train or val
+ bs = qpos.shape[0]
+ ### Obtain latent z from action sequence add a paramers = use_latent = True
+ if is_training and self.disable_vae_latent == False:
+ # project action sequence to embedding dim, and concat with a CLS token
+ action_embed = self.encoder_action_proj(actions) # (bs, seq, hidden_dim)
+ qpos_embed = self.encoder_joint_proj(qpos) # (bs, hidden_dim)
+ qpos_embed = torch.unsqueeze(qpos_embed, axis=1) # (bs, 1, hidden_dim)
+ cls_embed = self.cls_embed.weight # (1, hidden_dim)
+ cls_embed = torch.unsqueeze(cls_embed, axis=0).repeat(
+ bs, 1, 1
+ ) # (bs, 1, hidden_dim)
+ # vae encoder
+ # print(cls_embed.shape, qpos_embed.shape, action_embed.shape)
+ encoder_input = torch.cat(
+ [cls_embed, qpos_embed[:,0], action_embed], axis=1
+ ) # (bs, seq+1, hidden_dim)
+ encoder_input = encoder_input.permute(1, 0, 2) # (seq+1, bs, hidden_dim)
+ # do not mask cls token
+ cls_joint_is_pad = torch.full((bs, 2), False).to(
+ qpos.device
+ ) # False: not a padding
+ is_pad = torch.cat([cls_joint_is_pad, is_pad], axis=1) # (bs, seq+1)
+ # obtain position embedding
+ pos_embed = self.pos_table.clone().detach()
+ pos_embed = pos_embed.permute(1, 0, 2) # (seq+1, 1, hidden_dim)
+ # query model
+ encoder_output = self.encoder(
+ encoder_input, pos=pos_embed, src_key_padding_mask=is_pad
+ )
+ encoder_output = encoder_output[0] # take cls output only
+ latent_info = self.latent_proj(encoder_output)
+ mu = latent_info[:, : self.latent_dim]
+ logvar = latent_info[:, self.latent_dim :]
+ latent_sample = reparametrize(mu, logvar)
+ latent_input = self.latent_out_proj(
+ latent_sample
+ ) # get latent input for decoder TODO do we need this?
+ task_emb = self.proj_text_emb(task_emb) if task_emb is not None else None
+ else: # if dismiss latent ,just set to zero when training or val
+ mu = logvar = None
+ latent_sample = torch.zeros([bs, self.latent_dim], dtype=torch.float32).to(
+ qpos.device
+ )
+ # latent_sample = torch.randn([bs, self.latent_dim], dtype=torch.float32).to(qpos.device)
+ latent_input = self.latent_out_proj(latent_sample)
+ task_emb = self.proj_text_emb(task_emb) if task_emb is not None else None
+
+ if self.backbones is not None:
+ # proprioception features
+ proprio_input = self.input_proj_robot_state(qpos)
+ # Image observation features and position embeddings B T N_view C H W
+ image = image.view( -1, *image.shape[2:]) # frame stack shape: batch*T', num_cam, 3, height, width
+ all_cam_features = [] # no multiview
+ for cam_id, cam_name in enumerate(self.camera_names):
+ if self.use_film:
+ features, pos = self.backbones[0](image[:, cam_id], task_emb=task_emb) # HARDCODED
+ else:
+ features, pos = self.backbones[cam_id](
+ image[:, cam_id]
+ ) # shape: batch*T', C, H, W
+ features = features[0] # take the last layer feature
+ features = features.view(
+ bs, -1, *features.shape[-2:]
+ ) # shape: batch, T'*C, H, W
+ all_cam_features.append(self.input_proj(features))
+ src = torch.stack(all_cam_features, axis=-3) # shape: batch,D, N_view, H, W
+ pos = get_nd_sincos_pos_embed_from_grid(
+ self.hidden_dim, src.shape[2:]
+ ) # N_view, H, W, D numpy for hw no view
+ pos = torch.from_numpy(pos).to(src.device).unsqueeze(0).float()
+ # 1 N_view, H, W, D
+ src = rearrange(src,"b d n_view h w -> b d h (w n_view)",) #
+ pos = rearrange(pos,"b n_view h w d -> b d h (w n_view)",) # will b d h*w*n_view in the following transformer
+
+ # print('proprio_input shape', proprio_input.shape)
+ hs = self.transformer(
+ src,
+ None,
+ self.query_embed.weight,
+ pos,
+ latent_input,
+ proprio_input,
+ self.additional_pos_embed.weight,
+ actions,
+ denoise_steps,
+ task_emb=task_emb,
+ )[0]
+ else:
+ qpos = self.input_proj_robot_state(qpos)
+ env_state = self.input_proj_env_state(env_state)
+ transformer_input = torch.cat([qpos, env_state], axis=1) # seq length = 2
+ hs = self.transformer(
+ transformer_input,
+ None,
+ self.query_embed.weight,
+ self.pos.weight,
+ denoise_steps,
+ )[0]
+ a_hat = self.action_head(hs) # predict action or noise output of module
+ is_pad_hat = self.is_pad_head(hs)
+ return a_hat, is_pad_hat, [mu, logvar]
+
+
+class DETRVAE_Denoise_Token_Prediction(nn.Module):
+ """This is the DETR module that performs object detection"""
+
+ def __init__(
+ self,
+ backbones,
+ transformer,
+ encoder,
+ state_dim,
+ num_queries,
+ camera_names,
+ history_step,
+ predict_frame,
+ image_downsample_rate,
+ temporal_downsample_rate,
+ disable_vae_latent,
+ disable_resnet,
+ patch_size=5,
+ token_pe_type="learned",
+ image_height=480,
+ image_width=640,
+ tokenizer_model_temporal_rate=8,
+ tokenizer_model_spatial_rate=16,
+ resize_rate=1,
+ ):
+ """Initializes the model.
+ Parameters:
+ backbones: torch module of the backbone to be used. See backbone.py
+ transformer: torch module of the transformer architecture. See transformer.py
+ state_dim: robot state dimension of the environment
+ num_queries: number of object queries, ie detection slot. This is the maximal number of objects
+ DETR can detect in a single image. For COCO, we recommend 100 queries.
+ aux_loss: True if auxiliary decoding losses (loss at each decoder layer) are to be used.
+ """
+ super().__init__()
+ self.num_queries = num_queries
+ self.camera_names = camera_names
+ self.future_camera_names = [camera_names[0]] # TODO attention
+ self.transformer = transformer
+ self.transformer_share_decoder = transformer.share_decoder
+ self.predict_only_last = transformer.predict_only_last
+ self.encoder = encoder
+ self.disable_vae_latent = disable_vae_latent
+ self.disable_resnet = disable_resnet
+ hidden_dim = transformer.d_model # 512
+ self.hidden_dim = hidden_dim
+ token_dim = transformer.token_dim
+ # Action head state_dim = action_dim
+ # self.action_head = nn.Linear(hidden_dim, state_dim) # replace MLP?
+ self.action_head = nn.Sequential(
+ nn.Linear(hidden_dim, hidden_dim),
+ nn.SiLU(),
+ nn.Linear(hidden_dim, state_dim),
+ )
+
+ self.token_head = nn.Sequential(
+ nn.Linear(
+ hidden_dim, hidden_dim
+ ), # Hardcode patch size * path size * patch dim
+ nn.SiLU(),
+ nn.Linear(hidden_dim, token_dim * patch_size * patch_size),
+ )
+
+ self.is_pad_head = nn.Linear(hidden_dim, 1)
+ self.query_embed = nn.Embedding(num_queries, hidden_dim) # for decoder as PE
+ self.diffusion_timestep_type = transformer.diffusion_timestep_type
+ if backbones is not None and disable_resnet == False:
+ self.input_proj = nn.Conv2d(
+ backbones[0].num_channels * (history_step + 1),
+ hidden_dim,
+ kernel_size=1,
+ ) # = MLP c h w -> c' h' w' frame stack
+ self.backbones = nn.ModuleList(backbones) # N encoders for N view
+ self.input_proj_robot_state = nn.Linear(14, hidden_dim) # proprioception
+ elif backbones is not None and disable_resnet == True:
+ self.pos_encoder = backbones[0][1] # for 2D PositionEmbedding
+ self.backbones = None # HardCDcode
+ # Hardcode divide the latent feature representation into non-overlapping patches
+ self.input_proj_token = nn.Conv2d(
+ token_dim, hidden_dim, kernel_size=5, stride=5, bias=False
+ ) # Hardcode deal with image token pa
+ self.input_proj_robot_state = nn.Linear(14, hidden_dim) # proprioception
+ else:
+ # input_dim = 14 + 7 # robot_state + env_state
+ self.input_proj_robot_state = nn.Linear(14, hidden_dim)
+ self.input_proj_env_state = nn.Linear(7, hidden_dim)
+ self.pos = torch.nn.Embedding(2, hidden_dim)
+ self.backbones = None
+
+ # encoder extra parameters
+ self.latent_dim = 32 # final size of latent z # TODO tune
+ self.cls_embed = nn.Embedding(1, hidden_dim) # extra cls token embedding
+ self.encoder_action_proj = nn.Linear(
+ 14, hidden_dim
+ ) # project action to embedding TODO action dim = 14/16
+ self.encoder_joint_proj = nn.Linear(14, hidden_dim) # project qpos to embedding
+ self.latent_proj = nn.Linear(
+ hidden_dim, self.latent_dim * 2
+ ) # project hidden state to latent std, var
+ self.register_buffer(
+ "pos_table",
+ get_sinusoid_encoding_table(1 + 1 + history_step + num_queries, hidden_dim),
+ ) # [CLS], qpos, a_seq
+
+ # decoder extra parameters vae latent to decoder token space
+ self.history_step = history_step
+ self.latent_out_proj = nn.Linear(
+ self.latent_dim, hidden_dim
+ ) # project latent sample to embedding
+ self.additional_pos_embed = nn.Embedding(
+ 2 + history_step, hidden_dim
+ ) # learned position embedding for proprio and latent and denpoe
+ self.denoise_step_pos_embed = nn.Embedding(
+ 1, hidden_dim
+ ) # learned position embedding for denoise step
+ # setting for token prediction
+ if self.predict_only_last:
+ self.num_temporal_token = 1
+ else:
+ # print('predict frame', predict_frame, 'temporal_downsample_rate', temporal_downsample_rate, 'tokenizer_model_temporal_rate', tokenizer_model_temporal_rate)
+ self.num_temporal_token = math.ceil(
+ predict_frame
+ // temporal_downsample_rate
+ / tokenizer_model_temporal_rate
+ )
+ self.patch_size = patch_size
+ self.image_h = (
+ image_height
+ // image_downsample_rate
+ // tokenizer_model_spatial_rate
+ // resize_rate
+ // self.patch_size
+ ) #
+ self.image_w = (
+ image_width
+ // image_downsample_rate
+ // tokenizer_model_spatial_rate
+ // resize_rate
+ // self.patch_size
+ ) #
+ self.num_pred_token_per_timestep = (
+ self.image_h * self.image_w * len(self.future_camera_names)
+ ) #
+ self.token_shape = (
+ self.num_temporal_token,
+ self.image_h,
+ self.image_w,
+ self.patch_size,
+ )
+ if token_pe_type == "learned":
+ self.query_embed_token = nn.Embedding(
+ self.num_temporal_token * self.num_pred_token_per_timestep, hidden_dim
+ ) # for decoder as PE TODO replace with temporal spatial PE
+ print(
+ "predict token shape", # B T' N_view D*P*P H' W'
+ (
+ self.num_pred_token_per_timestep, # H' W' N_view
+ self.num_temporal_token, # T'
+ self.image_h,
+ self.image_w,
+ self.patch_size,
+ ),
+ )
+ query_embed_token_fixed = get_nd_sincos_pos_embed_from_grid(
+ hidden_dim,
+ (self.num_temporal_token, len(self.future_camera_names), self.image_h, self.image_w),
+ )
+ self.query_embed_token_fixed = (
+ torch.from_numpy(query_embed_token_fixed).view(-1, hidden_dim).float()
+ )
+ self.token_pe_type = token_pe_type
+
+ def forward(
+ self,
+ qpos,
+ current_image,
+ env_state,
+ actions,
+ is_pad,
+ noisy_actions,
+ noise_tokens,
+ is_tokens_pad,
+ denoise_steps,
+ ):
+ # qpos: batch, 1+ history, qpos_dim
+ # current_image: 1. batch, T', num_cam, 3, height, width - image_norm T' = 1+ history
+ # 2. batch, T', num_cam*6/16, height', width' - image_token
+ # env_state: None
+ # actions: batch, seq, action_dim clean action for vae encoder
+ # is_pad: batch, seq, for vae encoder
+ # noisy_actions: batch, seq, action_dim, noisy action for denoise
+ # noise_tokens: batch, seq, num_cam, 6/16, height', width', noise token for denoise
+ # is_tokens_pad: batch, seq
+ # denoise_steps: int, the step of denoise batch
+ is_training = actions is not None # train or val
+ bs = qpos.shape[0]
+ is_actions_pad = is_pad
+
+ ### Obtain latent z from action sequence
+ mu = logvar = None
+ latent_sample = torch.zeros([bs, self.latent_dim], dtype=torch.float32).to(
+ qpos.device
+ )
+ latent_input = self.latent_out_proj(
+ latent_sample
+ ) # TODO maybe add tacile input
+
+ # proprioception features
+ proprio_input = self.input_proj_robot_state(qpos)
+ # Image observation features and position embeddings
+ image = current_image[
+ 0
+ ] # shape: batch, T', num_cam, 3, height, width image_norm
+ image = image.view(
+ -1, *image.shape[2:]
+ ) # frame stack shape: batch*T', num_cam, 3, height, width
+ all_cam_features = []
+ for cam_id, cam_name in enumerate(self.camera_names):
+ features, pos = self.backbones[cam_id](
+ image[:, cam_id]
+ ) # shape: batch*T', C, H, W
+ features = features[0] # take the last layer feature
+ features = features.view(
+ bs, -1, *features.shape[-2:]
+ ) # shape: batch, T'*C, H, W
+ all_cam_features.append(self.input_proj(features))
+ src = torch.stack(all_cam_features, axis=-3) # shape: batch,D,N_view, H, W
+ pos = get_nd_sincos_pos_embed_from_grid(
+ self.hidden_dim, src.shape[2:]
+ ) # N_view, H, W, D numpy
+ pos = (
+ torch.from_numpy(pos).to(src.device).unsqueeze(0).float()
+ ) # 1 N_view, H, W, D
+ src = rearrange(
+ src,
+ "b d n_view h w -> b d h (w n_view)",
+ )
+ pos = rearrange(
+ pos,
+ "b n_view h w d -> b d h (w n_view)",
+ ) # will b d h*w*n_view in the following transformer
+
+ # deal with token patchfy
+ noise_tokens = (
+ rearrange(
+ noise_tokens,
+ "b s n d (ph p1) (pw p2) -> b s n (d p1 p2) ph pw",
+ p1=self.patch_size,
+ p2=self.patch_size,
+ )
+ if noise_tokens is not None
+ else None
+ )
+
+ if is_tokens_pad is not None: # B T -> B T' N_view*H'*W' -> B T'*N_view*H'*W'
+ is_tokens_pad = (
+ is_tokens_pad.unsqueeze(2)
+ .repeat(1, 1, self.num_pred_token_per_timestep)
+ .reshape(bs, -1)
+ )
+ else:
+ is_tokens_pad = torch.zeros(
+ bs,
+ self.num_pred_token_per_timestep * self.num_temporal_token,
+ dtype=torch.bool,
+ ).to(qpos.device)
+ is_pad = torch.zeros(bs, self.num_queries, dtype=torch.bool).to(qpos.device)
+
+ if self.token_pe_type == "learned":
+ query_embed_token = self.query_embed_token.weight
+ else:
+ query_embed_token = self.query_embed_token_fixed.to(qpos.device)
+ # print('detr query_embed_token', query_embed_token.shape)
+ hs_action, hs_token = self.transformer(
+ src, # obeserved image token
+ None,
+ self.query_embed.weight, # for action token pe
+ query_embed_token, # for future token pe
+ pos, # obeserved image token pe
+ latent_input,
+ proprio_input,
+ self.additional_pos_embed.weight, # latent & proprio token pe
+ noisy_actions,
+ noise_tokens,
+ denoise_steps,
+ self.denoise_step_pos_embed.weight, # denoise step token pe
+ is_actions_pad,
+ is_tokens_pad,
+ )
+
+ a_hat = self.action_head(hs_action)
+ is_pad_a_hat = self.is_pad_head(hs_action)
+ pred_token = (
+ self.token_head(hs_token) if hs_token is not None else None
+ ) # B T'*N*H'*W' 6*self.patch_size*self.patch_size
+ is_pad_token_hat = self.is_pad_head(hs_token) if hs_token is not None else None
+ is_pad_hat = (
+ torch.cat([is_pad_a_hat, is_pad_token_hat], axis=1)
+ if is_pad_token_hat is not None
+ else is_pad_a_hat
+ )
+
+ pred_token = (
+ rearrange(
+ pred_token,
+ "b (t n hp wp) (c ph pw) -> b t n c (hp ph) (wp pw)",
+ t=self.num_temporal_token,
+ n=len(self.future_camera_names),
+ hp=self.image_h,
+ wp=self.image_w,
+ ph=self.patch_size,
+ pw=self.patch_size,
+ )
+ if pred_token is not None
+ else None
+ )
+
+ return a_hat, is_pad_hat, pred_token, [mu, logvar]
+
+
+class DETRVAE_Denoise_Tactile(nn.Module):
+ def __init__(self,
+ backbones,
+ transformer,
+ tactile_encoder,
+ state_dim,
+ num_queries,
+ camera_names,):
+ super().__init__()
+ self.num_queries = num_queries
+ self.camera_names = camera_names
+ hidden_dim = transformer.d_model # 512
+ self.hidden_dim = hidden_dim
+ self.tactile_dim = tactile_encoder.tactile_dim
+
+ # tokenize input obeservation
+ self.input_proj_robot_state = nn.Linear(14, hidden_dim) # proprioception
+ self.input_proj = nn.Conv2d(
+ backbones[0].num_channels,
+ hidden_dim,
+ kernel_size=1,
+ ) # = MLP c h w -> c' h' w' frame stack
+ self.backbones = nn.ModuleList(backbones) # N encoders for N view
+ self.tactile_encoder = tactile_encoder
+ self.transformer = transformer
+
+ # ouput head
+ self.action_head = nn.Sequential(
+ nn.Linear(hidden_dim, hidden_dim),
+ nn.SiLU(),
+ nn.Linear(hidden_dim, state_dim),
+ )
+ self.is_pad_head = nn.Linear(hidden_dim, 1)
+ self.query_embed = nn.Embedding(num_queries, hidden_dim) # for decoder as PE
+ self.additional_pos_embed = nn.Embedding(
+ 1, hidden_dim
+ ) # learned position embedding for proprio
+
+ def forward(
+ self,
+ qpos,
+ image,
+ tactile,
+ env_state,
+ actions=None,
+ is_pad=None,
+ denoise_steps=None,
+ is_training=True,
+ ):
+ """
+ qpos: batch, 1, qpos_dim
+ image: batch, 1, num_cam, channel, height, width
+ tactile: batch, 1, 4, 3, 960, 960
+ env_state: None
+ actions: batch, seq, action_dim noisy action
+ denoise_step: int, the step of denoise
+ """
+ bs = qpos.shape[0]
+
+ # proprioception features
+ proprio_input = self.input_proj_robot_state(qpos) # B 1 D
+ # Image observation features and position embeddings
+ image = image.view( -1, *image.shape[2:]) # frame stack shape: batch*T', num_cam, 3, height, width
+ all_cam_features = []
+ for cam_id, cam_name in enumerate(self.camera_names):
+ features, pos = self.backbones[cam_id](
+ image[:, cam_id]
+ ) # shape: batch*T', C, H, W
+ features = features[0] # take the last layer feature
+ features = features.view(
+ bs, -1, *features.shape[-2:]
+ ) # shape: batch, T'*C, H, W
+ all_cam_features.append(self.input_proj(features))
+ src = torch.stack(all_cam_features, axis=-3) # shape: batch,D,N_view, H, W
+ pos = get_nd_sincos_pos_embed_from_grid(
+ self.hidden_dim, src.shape[2:]
+ )
+ pos = torch.from_numpy(pos).to(src.device).unsqueeze(0).float()
+ src = rearrange(src,"b d n_view h w -> b d h (w n_view)",)
+ pos = rearrange(pos,"b n_view h w d -> b d h (w n_view)",) # will b d h*w*n_view in the following transformer
+ # tactile features
+ tactile_input = self.tactile_encoder(tactile) # batch, d, 4 h w
+ # print('detr tactile_input', tactile_input.shape)
+ tactile_pos = get_nd_sincos_pos_embed_from_grid(
+ self.hidden_dim, tactile_input.shape[2:]
+ ) # 4,H,W,d -> 1,4,h,w,d
+ tactile_pos = torch.from_numpy(tactile_pos).to(src.device).unsqueeze(0).float()
+ tactile_input = rearrange(tactile_input,"b d n_view h w -> b d h (w n_view)",)
+ tactile_pos = rearrange(tactile_pos,"b n_view h w d -> b d h (w n_view)",) # will b d h*w*n_view in the following transformer
+ # tactile_pos = tactile_pos.permute(0, 3, 1, 2) # batch, d, 4, h, w
+ hs = self.transformer(
+ src, # B D h w
+ None,
+ self.query_embed.weight, # H D
+ pos, # B D h w
+ tactile_input, # B D h 4*w
+ tactile_pos, # 1 D h 4*w
+ proprio_input, # B 1 D
+ self.additional_pos_embed.weight, # D
+ actions, # B H D
+ denoise_steps, # B
+ )[0]
+ a_hat = self.action_head(hs) # predict action or noise output of module
+ is_pad_hat = self.is_pad_head(hs)
+ return a_hat, is_pad_hat
+
+class DETRVAE_Denoise_Token_Prediction_Dual_Visual_Token(nn.Module):
+ """This is the DETR module that performs object detection"""
+
+ def __init__(
+ self,
+ backbones,
+ transformer,
+ encoder,
+ state_dim,
+ num_queries,
+ camera_names,
+ history_step,
+ predict_frame,
+ image_downsample_rate,
+ temporal_downsample_rate,
+ disable_vae_latent,
+ disable_resnet,
+ patch_size=5,
+ token_pe_type="learned",
+ image_height=480,
+ image_width=640,
+ tokenizer_model_temporal_rate=8,
+ tokenizer_model_spatial_rate=16,
+ ):
+ """Initializes the model.
+ both resnet feature and token feature are used for token prediction & action prediction
+ Parameters:
+ backbones: torch module of the backbone to be used. See backbone.py
+ transformer: torch module of the transformer architecture. See transformer.py
+ state_dim: robot state dimension of the environment
+ num_queries: number of object queries, ie detection slot. This is the maximal number of objects
+ DETR can detect in a single image. For COCO, we recommend 100 queries.
+ aux_loss: True if auxiliary decoding losses (loss at each decoder layer) are to be used.
+ """
+ super().__init__()
+ self.num_queries = num_queries
+ self.camera_names = camera_names
+ self.transformer = transformer
+ self.transformer_share_decoder = transformer.share_decoder
+ self.predict_only_last = transformer.predict_only_last
+ self.encoder = encoder
+ self.disable_vae_latent = disable_vae_latent
+ self.disable_resnet = disable_resnet
+ hidden_dim = transformer.d_model # 512
+ self.hidden_dim = hidden_dim
+ self.action_head = nn.Linear(hidden_dim, state_dim) # TODO replace MLP
+ if self.transformer_share_decoder == False:
+ self.token_head = nn.Linear(
+ hidden_dim, 6 * patch_size * patch_size
+ ) # HardCode TODO replace MLP
+ else: # default share decoder
+ self.token_head = nn.Sequential(
+ nn.Linear(
+ hidden_dim, hidden_dim
+ ), # Hardcode patch size * path size * patch dim
+ nn.SiLU(),
+ nn.Linear(hidden_dim, 6 * patch_size * patch_size),
+ )
+
+ self.is_pad_head = nn.Linear(hidden_dim, 1)
+ self.query_embed = nn.Embedding(num_queries, hidden_dim) # for decoder as PE
+ self.diffusion_timestep_type = transformer.diffusion_timestep_type
+ if backbones is not None: # resnet encoder to extract visual feature
+ self.backbones = nn.ModuleList(backbones)
+ self.input_proj = nn.Conv2d(
+ backbones[0].num_channels, hidden_dim, kernel_size=1
+ )
+ self.input_proj_token = nn.Conv2d(
+ 6, hidden_dim, kernel_size=patch_size, stride=patch_size, bias=False
+ ) # Hardcode deal with image token patch size = 5
+ self.input_proj_robot_state = nn.Linear(14, hidden_dim) # proprioception
+ else:
+ # input_dim = 14 + 7 # robot_state + env_state
+ self.input_proj_robot_state = nn.Linear(14, hidden_dim)
+ self.input_proj_env_state = nn.Linear(7, hidden_dim)
+ self.pos = torch.nn.Embedding(2, hidden_dim)
+ self.backbones = None
+
+ # encoder extra parameters
+ self.latent_dim = 32 # final size of latent z # TODO tune
+ self.cls_embed = nn.Embedding(1, hidden_dim) # extra cls token embedding
+ self.encoder_action_proj = nn.Linear(
+ 14, hidden_dim
+ ) # project action to embedding
+ self.encoder_joint_proj = nn.Linear(14, hidden_dim) # project qpos to embedding
+ self.latent_proj = nn.Linear(
+ hidden_dim, self.latent_dim * 2
+ ) # project hidden state to latent std, var
+ self.register_buffer(
+ "pos_table",
+ get_sinusoid_encoding_table(1 + 1 + history_step + num_queries, hidden_dim),
+ ) # [CLS], qpos, a_seq
+
+ # decoder extra parameters vae latent to decoder token space
+ self.history_step = history_step
+ self.latent_out_proj = nn.Linear(
+ self.latent_dim, hidden_dim
+ ) # project latent sample to embedding
+ self.additional_pos_embed = nn.Embedding(
+ 2 + history_step, hidden_dim
+ ) # learned position embedding for proprio and latent and denpoe
+ self.denoise_step_pos_embed = nn.Embedding(
+ 1, hidden_dim
+ ) # learned position embedding for denoise step
+ # setting for token prediction hyper
+ if self.predict_only_last:
+ self.num_temporal_token = 1
+ else:
+ self.num_temporal_token = math.ceil(
+ predict_frame
+ // temporal_downsample_rate
+ / tokenizer_model_temporal_rate
+ ) # align with tokenizer output
+ self.patch_size = patch_size
+ self.image_h = (
+ image_height
+ // image_downsample_rate
+ // tokenizer_model_spatial_rate
+ // self.patch_size
+ ) #
+ self.image_w = (
+ image_width
+ // image_downsample_rate
+ // tokenizer_model_spatial_rate
+ // self.patch_size
+ ) #
+ self.num_pred_token_per_timestep = (
+ self.image_h * self.image_w * len(camera_names)
+ )
+ # setting for token prediction position embedding
+ if token_pe_type == "learned":
+ self.current_embed_token = nn.Embedding(
+ 1 * self.num_pred_token_per_timestep, hidden_dim
+ ) # for current image token encoder
+ self.query_embed_token = nn.Embedding(
+ self.num_temporal_token * self.num_pred_token_per_timestep, hidden_dim
+ ) # for decoder as PE TODO replace with temporal spatial PE
+
+ spatial_temporal_pe = get_spatial_temporal_positional_encoding(
+ self.image_h, self.image_w, self.num_temporal_token + 1, hidden_dim
+ )
+ self.current_embed_token_fixed = spatial_temporal_pe[0].view(
+ -1, hidden_dim
+ ) # for current image token encoder
+ self.query_embed_token_fixed = spatial_temporal_pe[1:].view(
+ -1, hidden_dim
+ ) # for decoder as PE # (seq_len, height, width, dim) -> (seq_len*height*width, dim)
+
+ self.token_pe_type = token_pe_type
+ print(
+ "predict token shape",
+ (
+ self.num_pred_token_per_timestep,
+ self.num_temporal_token,
+ self.image_h,
+ self.image_w,
+ self.patch_size,
+ ),
+ )
+
+ def forward(
+ self,
+ qpos,
+ current_image,
+ env_state,
+ actions,
+ is_pad,
+ noisy_actions,
+ noise_tokens,
+ is_tokens_pad,
+ denoise_steps,
+ ):
+ # qpos: batch, 1+ history, qpos_dim
+ # current_image: 1. batch, T', num_cam, 3, height, width - image_norm T' = 1+ history defualt = 1
+ # 2. batch, T', num_cam, 6, height', width' - image_token (current_image_norm, current_image_tokens)
+ # env_state: None
+ # actions: batch, seq, action_dim clean action for vae encoder
+ # is_pad: batch, seq, for vae encoder
+ # noisy_actions: batch, seq, action_dim, noisy action for denoise
+ # noise_tokens: batch, seq, num_cam, 6, height', width', noise token for denoise
+ # is_tokens_pad: batch, seq
+ # denoise_steps: int, the step of denoise
+ is_training = actions is not None # train or val
+ bs = qpos.shape[0]
+ is_actions_pad = is_pad
+
+ ### Obtain latent z from action sequence
+ if is_training and not self.disable_vae_latent:
+ # project action sequence to embedding dim, and concat with a CLS token
+ action_embed = self.encoder_action_proj(actions) # (bs, seq, hidden_dim)
+ qpos_embed = self.encoder_joint_proj(qpos) # (bs, hidden_dim)
+ qpos_embed = (
+ torch.unsqueeze(qpos_embed, axis=1)
+ if len(qpos_embed.shape) == 2
+ else qpos_embed
+ ) # (bs, 1, hidden_dim)
+ cls_embed = self.cls_embed.weight # (1, hidden_dim)
+ cls_embed = torch.unsqueeze(cls_embed, axis=0).repeat(
+ bs, 1, 1
+ ) # (bs, 1, hidden_dim)
+ encoder_input = torch.cat(
+ [cls_embed, qpos_embed, action_embed], axis=1
+ ) # (bs, seq+1, hidden_dim)
+ encoder_input = encoder_input.permute(1, 0, 2) # (seq+1, bs, hidden_dim)
+ # do not mask cls token
+ cls_joint_is_pad = torch.full((bs, 2 + self.history_step), False).to(
+ qpos.device
+ ) # False: not a padding
+
+ is_pad = torch.cat([cls_joint_is_pad, is_pad], axis=1) # (bs, seq+1)
+ # obtain position embedding
+ pos_embed = self.pos_table.clone().detach()
+ pos_embed = pos_embed.permute(1, 0, 2) # (seq+1, 1, hidden_dim)
+ # query model
+ encoder_output = self.encoder(
+ encoder_input, pos=pos_embed, src_key_padding_mask=is_pad
+ )
+ encoder_output = encoder_output[0] # take cls output only
+ latent_info = self.latent_proj(encoder_output) # predict mu and logvar
+ mu = latent_info[:, : self.latent_dim]
+ logvar = latent_info[:, self.latent_dim :]
+ latent_sample = reparametrize(mu, logvar)
+ latent_input = self.latent_out_proj(latent_sample)
+ else:
+ mu = logvar = None
+ latent_sample = torch.zeros([bs, self.latent_dim], dtype=torch.float32).to(
+ qpos.device
+ )
+ latent_input = self.latent_out_proj(latent_sample)
+
+ # Image observation features and position embeddings for resnet visual encoder
+ all_cam_features = []
+ all_cam_pos = []
+ if self.backbones is not None:
+ image = current_image[0] # shape: batch, T', num_cam, 3, height, width
+ for t in range(self.history_step + 1): # only current frame
+ for cam_id, cam_name in enumerate(self.camera_names):
+ features, pos = self.backbones[0](
+ image[:, t, cam_id]
+ ) # resnet feature and pos
+ features = features[0]
+ pos = pos[0]
+ all_cam_features.append(self.input_proj(features))
+ all_cam_pos.append(pos) # resnet feature and pos
+ src = torch.cat(
+ all_cam_features, axis=3
+ ) # B C H W*num_view*T # fold camera dimension into width dimension
+ pos = torch.cat(all_cam_pos, axis=3) # B C H W*num_view*T
+
+ # Image observation features and position embeddings for visual tokenizer encoder, only 1 frame
+ all_token_features = []
+ all_token_pos = []
+ current_visual_token = current_image[
+ 1
+ ] # shape: batch, T', num_cam, 6, height', width' Depth - image_token
+ for cam_id, cam_name in enumerate(self.camera_names):
+ token_features = current_visual_token[:, -1, cam_id] # B C H W
+ token_features = self.input_proj_token(token_features) # B C H W
+ token_pos = self.current_embed_token_fixed.to(
+ qpos.device
+ ) # height*width, C
+ all_token_features.append(token_features)
+ all_token_pos.append(token_pos)
+ addition_visual_token = torch.cat(
+ all_token_features, axis=3
+ ) # B C H W*num_view
+ addition_visual_token_pos = torch.cat(all_token_pos, axis=0) # H*W*num_view C
+
+ proprio_input = self.input_proj_robot_state(qpos)
+
+ # deal with token patchfy
+ noise_tokens = rearrange(
+ noise_tokens,
+ "b s n d (ph p1) (pw p2) -> b s n (d p1 p2) ph pw",
+ p1=self.patch_size,
+ p2=self.patch_size,
+ ) # b seq_len num_cam 6*patch_size*atch_size height width
+
+ if (
+ is_tokens_pad is not None
+ ): # B seq_len -> B seq_len*self.num_pred_token_per_timestep useless for decoder
+ is_tokens_pad = (
+ is_tokens_pad.unsqueeze(2)
+ .repeat(1, 1, self.num_pred_token_per_timestep)
+ .reshape(bs, -1)
+ )
+ else:
+ is_tokens_pad = torch.zeros(
+ bs,
+ self.num_pred_token_per_timestep * self.num_temporal_token,
+ dtype=torch.bool,
+ ).to(qpos.device)
+
+ if self.token_pe_type == "learned":
+ query_embed_token = self.query_embed_token.weight
+ else:
+ query_embed_token = self.query_embed_token_fixed.to(
+ qpos.device
+ ) # consider the visual token from current frame
+ hs_action, hs_token = self.transformer(
+ src,
+ None,
+ self.query_embed.weight,
+ query_embed_token,
+ pos,
+ latent_input,
+ proprio_input,
+ self.additional_pos_embed.weight,
+ noisy_actions,
+ noise_tokens,
+ denoise_steps,
+ self.denoise_step_pos_embed.weight,
+ is_actions_pad,
+ is_tokens_pad,
+ addition_visual_token,
+ addition_visual_token_pos,
+ )
+
+ a_hat = self.action_head(hs_action)
+ is_pad_a_hat = self.is_pad_head(hs_action)
+ pred_token = self.token_head(
+ hs_token
+ ) # B T'*N*H'*W' 6*self.patch_size*self.patch_size
+ is_pad_token_hat = self.is_pad_head(hs_token)
+ is_pad_hat = torch.cat([is_pad_a_hat, is_pad_token_hat], axis=1)
+
+ pred_token = rearrange(
+ pred_token,
+ "b (t n hp wp) (c ph pw) -> b t n c (hp ph) (wp pw)",
+ t=self.num_temporal_token,
+ n=len(self.camera_names),
+ hp=self.image_h,
+ wp=self.image_w,
+ ph=self.patch_size,
+ pw=self.patch_size,
+ )
+
+ ## if no visal observation,
+ # qpos = self.input_proj_robot_state(qpos)
+ # env_state = self.input_proj_env_state(env_state)
+ # transformer_input = torch.cat([qpos, env_state], axis=1) # seq length = 2
+ # hs = self.transformer(transformer_input, None, self.query_embed.weight, self.pos.weight)[0]
+ return a_hat, is_pad_hat, pred_token, [mu, logvar]
+
+
+class DETRVAE_Denoise_Pixel_Prediction(nn.Module):
+ def __init__(
+ self,
+ backbones,
+ transformer,
+ encoder,
+ state_dim,
+ num_queries,
+ camera_names,
+ history_step,
+ predict_frame,
+ image_downsample_rate,
+ temporal_downsample_rate,
+ disable_vae_latent,
+ disable_resnet,
+ patch_size=5,
+ token_pe_type="learned",
+ image_height=480,
+ image_width=640,
+ resize_rate=8,
+ ):
+ """Initializes the model.
+ Parameters:
+ backbones: torch module of the backbone to be used. See backbone.py
+ transformer: torch module of the transformer architecture. See transformer.py
+ state_dim: robot state dimension of the environment
+ num_queries: number of object queries, ie detection slot. This is the maximal number of objects
+ DETR can detect in a single image. For COCO, we recommend 100 queries.
+ aux_loss: True if auxiliary decoding losses (loss at each decoder layer) are to be used.
+ """
+ super().__init__()
+ self.num_queries = num_queries
+ self.camera_names = camera_names
+ self.transformer = transformer
+ self.transformer_share_decoder = transformer.share_decoder
+ self.predict_only_last = transformer.predict_only_last
+ self.encoder = encoder
+ self.disable_vae_latent = disable_vae_latent
+ self.disable_resnet = disable_resnet
+ hidden_dim = transformer.d_model # 512
+ self.hidden_dim = hidden_dim
+ self.action_head = nn.Linear(hidden_dim, state_dim) # TODO replace MLP
+ if self.transformer_share_decoder == False:
+ self.pixel_head = nn.Linear(
+ hidden_dim, 3 * patch_size * patch_size
+ ) # HardCode TODO replace MLP
+ else:
+ self.pixel_head = nn.Sequential(
+ nn.Linear(
+ hidden_dim, hidden_dim
+ ), # Hardcode patch size * path size * patch dim
+ nn.SiLU(),
+ nn.Linear(hidden_dim, 3 * patch_size * patch_size),
+ )
+
+ self.is_pad_head = nn.Linear(hidden_dim, 1)
+ self.query_embed = nn.Embedding(num_queries, hidden_dim) # for decoder as PE
+ self.diffusion_timestep_type = transformer.diffusion_timestep_type
+ if backbones is not None and disable_resnet == False:
+ self.input_proj = nn.Conv2d(
+ backbones[0].num_channels, hidden_dim, kernel_size=1
+ )
+ self.backbones = nn.ModuleList(backbones)
+ self.input_proj_robot_state = nn.Linear(14, hidden_dim) # proprioception
+ elif backbones is not None and disable_resnet == True:
+ self.pos_encoder = backbones[0][1] # for 2D PositionEmbedding
+ self.backbones = None # HardCDcode
+ # Hardcode divide the latent feature representation into non-overlapping patches
+ self.input_proj_token = nn.Conv2d(
+ 6, hidden_dim, kernel_size=5, stride=5, bias=False
+ ) # Hardcode deal with image token pa
+ self.input_proj_robot_state = nn.Linear(14, hidden_dim) # proprioception
+ else:
+ # input_dim = 14 + 7 # robot_state + env_state
+ self.input_proj_robot_state = nn.Linear(14, hidden_dim)
+ self.input_proj_env_state = nn.Linear(7, hidden_dim)
+ self.pos = torch.nn.Embedding(2, hidden_dim)
+ self.backbones = None
+
+ # encoder extra parameters
+ self.latent_dim = 32 # final size of latent z # TODO tune
+ self.cls_embed = nn.Embedding(1, hidden_dim) # extra cls token embedding
+ self.encoder_action_proj = nn.Linear(
+ 14, hidden_dim
+ ) # project action to embedding
+ self.encoder_joint_proj = nn.Linear(14, hidden_dim) # project qpos to embedding
+ self.latent_proj = nn.Linear(
+ hidden_dim, self.latent_dim * 2
+ ) # project hidden state to latent std, var
+ self.register_buffer(
+ "pos_table",
+ get_sinusoid_encoding_table(1 + 1 + history_step + num_queries, hidden_dim),
+ ) # [CLS], qpos, a_seq
+
+ # decoder extra parameters vae latent to decoder token space
+ self.history_step = history_step
+ self.latent_out_proj = nn.Linear(
+ self.latent_dim, hidden_dim
+ ) # project latent sample to embedding
+ self.additional_pos_embed = nn.Embedding(
+ 2 + history_step, hidden_dim
+ ) # learned position embedding for proprio and latent and denpoe
+ self.denoise_step_pos_embed = nn.Embedding(
+ 1, hidden_dim
+ ) # learned position embedding for denoise step
+ # setting for token prediction
+ # self.num_temporal_token = predict_frame // temporal_downsample_rate // tokenizer_model_temporal_rate
+ if self.predict_only_last:
+ self.num_frames = 1
+ else:
+ self.num_frames = math.ceil(predict_frame // temporal_downsample_rate)
+ self.patch_size = patch_size
+ self.image_h = (
+ image_height // image_downsample_rate // resize_rate // self.patch_size
+ ) #
+ self.image_w = (
+ image_width // image_downsample_rate // resize_rate // self.patch_size
+ ) #
+ self.num_pred_pixel_per_timestep = (
+ self.image_h * self.image_w * len(camera_names)
+ )
+
+ if token_pe_type == "learned":
+ self.query_embed_token = nn.Embedding(
+ self.num_frames * self.num_pred_pixel_per_timestep, hidden_dim
+ ) # for decoder as PE TODO replace with temporal spatial PE
+ self.query_embed_token_fixed = get_spatial_temporal_positional_encoding(
+ self.image_h, self.image_w, self.num_frames, hidden_dim
+ ).view(
+ -1, hidden_dim
+ ) # for decoder as PE # (seq_len, height, width, dim) -> (seq_len*height*width, dim)
+ self.token_pe_type = token_pe_type
+ print(
+ "predict pixel shape",
+ (
+ self.num_frames,
+ self.num_pred_pixel_per_timestep,
+ self.image_h,
+ self.image_w,
+ self.patch_size,
+ ),
+ )
+
+ def forward(
+ self,
+ qpos,
+ current_image,
+ env_state,
+ actions,
+ is_pad,
+ noisy_actions,
+ noise_tokens,
+ is_tokens_pad,
+ denoise_steps,
+ ):
+ # qpos: batch, 1+ history, qpos_dim
+ # current_image: 1. batch, T', num_cam, 3, height, width - image_norm T' = 1+ history
+ # 2. batch, T', num_cam*6, height', width' - image_token
+ # env_state: None
+ # actions: batch, seq, action_dim clean action for vae encoder
+ # is_pad: batch, seq, for vae encoder
+ # noisy_actions: batch, seq, action_dim, noisy action for denoise
+ # noise_tokens: batch, seq, num_cam, 6, height', width', noise token for denoise
+ # is_tokens_pad: batch, seq
+ # denoise_steps: int, the step of denoise
+
+ # print('qpos',qpos.shape)
+ # print('current_image',current_image[0].shape)
+ # print('actions',actions.shape)
+ # print('is_pad',is_pad.shape)
+ # print('noisy_actions',noisy_actions.shape)
+ # print('noise_tokens',noise_tokens.shape)
+ # print('is_tokens_pad',is_tokens_pad.shape)
+ # print('denoise_steps',denoise_steps)
+
+ is_training = actions is not None # train or val
+ bs = qpos.shape[0]
+ is_actions_pad = is_pad
+
+ ### Obtain latent z from action sequence
+ if is_training and not self.disable_vae_latent:
+ # project action sequence to embedding dim, and concat with a CLS token
+ action_embed = self.encoder_action_proj(actions) # (bs, seq, hidden_dim)
+ qpos_embed = self.encoder_joint_proj(qpos) # (bs, hidden_dim)
+ qpos_embed = (
+ torch.unsqueeze(qpos_embed, axis=1)
+ if len(qpos_embed.shape) == 2
+ else qpos_embed
+ ) # (bs, 1, hidden_dim)
+ cls_embed = self.cls_embed.weight # (1, hidden_dim)
+ cls_embed = torch.unsqueeze(cls_embed, axis=0).repeat(
+ bs, 1, 1
+ ) # (bs, 1, hidden_dim)
+ encoder_input = torch.cat(
+ [cls_embed, qpos_embed, action_embed], axis=1
+ ) # (bs, seq+1, hidden_dim)
+ encoder_input = encoder_input.permute(1, 0, 2) # (seq+1, bs, hidden_dim)
+ # do not mask cls token
+ cls_joint_is_pad = torch.full((bs, 2 + self.history_step), False).to(
+ qpos.device
+ ) # False: not a padding
+
+ is_pad = torch.cat([cls_joint_is_pad, is_pad], axis=1) # (bs, seq+1)
+ # obtain position embedding
+ pos_embed = self.pos_table.clone().detach()
+ pos_embed = pos_embed.permute(1, 0, 2) # (seq+1, 1, hidden_dim)
+ # query model
+ encoder_output = self.encoder(
+ encoder_input, pos=pos_embed, src_key_padding_mask=is_pad
+ )
+ encoder_output = encoder_output[0] # take cls output only
+ latent_info = self.latent_proj(encoder_output) # predict mu and logvar
+ mu = latent_info[:, : self.latent_dim]
+ logvar = latent_info[:, self.latent_dim :]
+ latent_sample = reparametrize(mu, logvar)
+ latent_input = self.latent_out_proj(latent_sample)
+ else:
+ mu = logvar = None
+ latent_sample = torch.zeros([bs, self.latent_dim], dtype=torch.float32).to(
+ qpos.device
+ )
+ latent_input = self.latent_out_proj(latent_sample)
+
+ # Image observation features and position embeddings
+ all_cam_features = []
+ all_cam_pos = []
+ if self.backbones is not None and self.disable_resnet == False:
+ image = current_image[0] # shape: batch, T', num_cam, 3, height, width
+ for t in range(self.history_step + 1):
+ for cam_id, cam_name in enumerate(self.camera_names):
+ features, pos = self.backbones[0](image[:, t, cam_id])
+ features = features[0]
+ pos = pos[0]
+ all_cam_features.append(self.input_proj(features))
+ all_cam_pos.append(pos)
+ else:
+ image = current_image[
+ 1
+ ] # [:,:self.num_temporal_token] # shape: batch, T', num_cam, 6, height', width' hardcode keep consistent with token prediction
+ for t in range(self.history_step + 1):
+ for cam_id, cam_name in enumerate(self.camera_names):
+ features = image[:, t, cam_id] # B C H W
+ features = self.input_proj(features)
+ pos = self.pos_encoder(features).to(features.dtype)
+ all_cam_features.append(features)
+ all_cam_pos.append(pos)
+ # proprioception features
+ proprio_input = self.input_proj_robot_state(qpos)
+ # fold camera dimension into width dimension
+ src = torch.cat(all_cam_features, axis=3) # B C H W*num_view*T
+ pos = torch.cat(all_cam_pos, axis=3) # B C H W*num_view*T
+ # deal with token patchfy
+ noise_tokens = rearrange(
+ noise_tokens,
+ "b s n d (ph p1) (pw p2) -> b s n (d p1 p2) ph pw",
+ p1=self.patch_size,
+ p2=self.patch_size,
+ )
+ # print('noise_tokens after rearrange',noise_tokens.shape)
+ if is_tokens_pad is not None:
+ is_tokens_pad = (
+ is_tokens_pad.unsqueeze(2)
+ .repeat(1, 1, self.num_pred_pixel_per_timestep)
+ .reshape(bs, -1)
+ )
+ else:
+ is_tokens_pad = torch.zeros(
+ bs, self.num_frames * self.num_pred_pixel_per_timestep, dtype=torch.bool
+ ).to(qpos.device)
+ is_pad = torch.zeros(bs, self.num_queries, dtype=torch.bool).to(qpos.device)
+
+ if self.token_pe_type == "learned":
+ query_embed_token = self.query_embed_token.weight
+ else:
+ query_embed_token = self.query_embed_token_fixed.to(qpos.device)
+
+ hs_action, hs_token = self.transformer(
+ src,
+ None,
+ self.query_embed.weight,
+ query_embed_token,
+ pos,
+ latent_input,
+ proprio_input,
+ self.additional_pos_embed.weight,
+ noisy_actions,
+ noise_tokens,
+ denoise_steps,
+ self.denoise_step_pos_embed.weight,
+ is_actions_pad,
+ is_tokens_pad,
+ )
+
+ a_hat = self.action_head(hs_action)
+ is_pad_a_hat = self.is_pad_head(hs_action)
+ pred_images = self.pixel_head(
+ hs_token
+ ) # B T'*N*H'*W' 6*self.patch_size*self.patch_size
+ is_pad_image_hat = self.is_pad_head(hs_token)
+ is_pad_hat = torch.cat([is_pad_a_hat, is_pad_image_hat], axis=1)
+
+ pred_images = rearrange(
+ pred_images,
+ "b (t n hp wp) (c ph pw) -> b t n c (hp ph) (wp pw)",
+ t=self.num_frames,
+ n=len(self.camera_names),
+ hp=self.image_h,
+ wp=self.image_w,
+ ph=self.patch_size,
+ pw=self.patch_size,
+ )
+
+ return a_hat, is_pad_hat, pred_images, [mu, logvar]
+
+
+class DETRVAEDino(nn.Module):
+ """This is the DETR module that performs object detection"""
+
+ def __init__(
+ self, backbones, transformer, encoder, state_dim, num_queries, camera_names
+ ):
+ """Initializes the model.
+ Parameters:
+ backbones: torch module of the backbone to be used. See backbone.py
+ transformer: torch module of the transformer architecture. See transformer.py
+ state_dim: robot state dimension of the environment
+ num_queries: number of object queries, ie detection slot. This is the maximal number of objects
+ DETR can detect in a single image. For COCO, we recommend 100 queries.
+ aux_loss: True if auxiliary decoding losses (loss at each decoder layer) are to be used.
+ """
+ super().__init__()
+ self.num_queries = num_queries
+ self.camera_names = camera_names
+ self.transformer = transformer
+ self.encoder = encoder
+ hidden_dim = transformer.d_model
+ self.action_head = nn.Linear(hidden_dim, state_dim)
+ self.is_pad_head = nn.Linear(hidden_dim, 1)
+ # self.cls_token_head = nn.Linear(hidden_dim, 384)
+ self.num_cls_tokens = 50
+ self.query_embed = nn.Embedding(num_queries + self.num_cls_tokens, hidden_dim)
+ if backbones is not None:
+ self.input_proj = nn.Conv2d(
+ backbones[0].num_channels, hidden_dim, kernel_size=1
+ )
+ self.backbones = nn.ModuleList(backbones)
+ self.input_proj_robot_state = nn.Linear(14, hidden_dim)
+ else:
+ # input_dim = 14 + 7 # robot_state + env_state
+ self.input_proj_robot_state = nn.Linear(14, hidden_dim)
+ self.input_proj_env_state = nn.Linear(7, hidden_dim)
+ self.pos = torch.nn.Embedding(2, hidden_dim)
+ self.backbones = None
+
+ # encoder extra parameters
+ self.latent_dim = 32 # final size of latent z # TODO tune
+ self.cls_embed = nn.Embedding(1, hidden_dim) # extra cls token embedding
+ self.encoder_action_proj = nn.Linear(
+ 14, hidden_dim
+ ) # project action to embedding
+ self.encoder_joint_proj = nn.Linear(14, hidden_dim) # project qpos to embedding
+ self.latent_proj = nn.Linear(
+ hidden_dim, self.latent_dim * 2
+ ) # project hidden state to latent std, var
+ self.register_buffer(
+ "pos_table", get_sinusoid_encoding_table(1 + 1 + num_queries, hidden_dim)
+ ) # [CLS], qpos, a_seq
+
+ # decoder extra parameters
+ self.latent_out_proj = nn.Linear(
+ self.latent_dim, hidden_dim
+ ) # project latent sample to embedding
+ self.additional_pos_embed = nn.Embedding(
+ 2, hidden_dim
+ ) # learned position embedding for proprio and latent
+
+ # settings for cls token prediction
+ # self.mask_token = nn.Parameter(torch.zeros(1, 1, hidden_dim))
+ # self.n_patch = (self.image_size//self.patch_size)**2
+ # self.k = 1 # number of next frames
+ # self.n_patch = (self.img_h//self.patch_size)*(self.img_w//self.patch_size)*(self.k)
+ # self.decoder_pos_embed = nn.Parameter(torch.zeros(1, self.n_patch, hidden_dim), requires_grad=False) # (1, n_patch, h)
+ # self.patch_embed = nn.Embedding(self.n_patch, hidden_dim)
+ # self.decoder_embed = nn.Linear(hidden_dim, hidden_dim, bias=True)
+
+ decoder_depth = 2 # hardcode
+ self.decoder_blocks = nn.ModuleList(
+ [
+ Block(
+ hidden_dim,
+ 16,
+ 4,
+ qkv_bias=True,
+ qk_scale=None,
+ norm_layer=nn.LayerNorm,
+ )
+ for i in range(decoder_depth)
+ ]
+ )
+
+ self.decoder_norm = nn.LayerNorm(hidden_dim)
+ self.decoder_pred = nn.Linear(hidden_dim, 384, bias=True) # decoder to patch
+
+ # decoder_pos_embed = get_2d_sincos_pos_embed(self.decoder_pos_embed.shape[-1], (self.image_size//self.patch_size), cls_token=False)
+ # decoder_pos_embed = get_2d_sincos_pos_embed_v2(self.decoder_pos_embed.shape[-1], (self.img_h//self.patch_size, self.img_w//self.patch_size))
+ # self.decoder_pos_embed.data.copy_(torch.from_numpy(decoder_pos_embed).float().unsqueeze(0).repeat(1,self.k,1))
+
+ def forward(self, qpos, image, env_state, actions=None, is_pad=None):
+ """
+ qpos: batch, qpos_dim
+ image: batch, num_cam, channel, height, width
+ env_state: None
+ actions: batch, seq, action_dim
+ """
+ is_training = actions is not None # train or val
+ bs, _ = qpos.shape
+ ### Obtain latent z from action sequence
+ if is_training:
+ # project action sequence to embedding dim, and concat with a CLS token
+ action_embed = self.encoder_action_proj(actions) # (bs, seq, hidden_dim)
+ qpos_embed = self.encoder_joint_proj(qpos) # (bs, hidden_dim)
+ qpos_embed = torch.unsqueeze(qpos_embed, axis=1) # (bs, 1, hidden_dim)
+ cls_embed = self.cls_embed.weight # (1, hidden_dim)
+ cls_embed = torch.unsqueeze(cls_embed, axis=0).repeat(
+ bs, 1, 1
+ ) # (bs, 1, hidden_dim)
+ encoder_input = torch.cat(
+ [cls_embed, qpos_embed, action_embed], axis=1
+ ) # (bs, seq+1, hidden_dim)
+ encoder_input = encoder_input.permute(1, 0, 2) # (seq+1, bs, hidden_dim)
+ # do not mask cls token
+ cls_joint_is_pad = torch.full((bs, 2), False).to(
+ qpos.device
+ ) # False: not a padding
+ is_pad = torch.cat([cls_joint_is_pad, is_pad], axis=1) # (bs, seq+1)
+ # obtain position embedding
+ pos_embed = self.pos_table.clone().detach()
+ pos_embed = pos_embed.permute(1, 0, 2) # (seq+1, 1, hidden_dim)
+ # query model
+ encoder_output = self.encoder(
+ encoder_input, pos=pos_embed, src_key_padding_mask=is_pad
+ )
+ encoder_output = encoder_output[0] # take cls output only
+ latent_info = self.latent_proj(encoder_output)
+ mu = latent_info[:, : self.latent_dim]
+ logvar = latent_info[:, self.latent_dim :]
+ latent_sample = reparametrize(mu, logvar)
+ latent_input = self.latent_out_proj(latent_sample)
+ else:
+ mu = logvar = None
+ latent_sample = torch.zeros([bs, self.latent_dim], dtype=torch.float32).to(
+ qpos.device
+ )
+ latent_input = self.latent_out_proj(latent_sample)
+
+ if self.backbones is not None:
+ # Image observation features and position embeddings
+ all_cam_features = []
+ all_cam_pos = []
+ for cam_id, cam_name in enumerate(self.camera_names):
+ features, pos = self.backbones[0](image[:, cam_id]) # HARDCODED
+ features = features[0] # take the last layer feature
+ pos = pos[0]
+ all_cam_features.append(self.input_proj(features))
+ all_cam_pos.append(pos)
+ # proprioception features
+ proprio_input = self.input_proj_robot_state(qpos)
+ # fold camera dimension into width dimension
+ src = torch.cat(all_cam_features, axis=3)
+ pos = torch.cat(all_cam_pos, axis=3)
+ hs = self.transformer(
+ src,
+ None,
+ self.query_embed.weight,
+ pos,
+ latent_input,
+ proprio_input,
+ self.additional_pos_embed.weight,
+ )[0]
+ else:
+ qpos = self.input_proj_robot_state(qpos)
+ env_state = self.input_proj_env_state(env_state)
+ transformer_input = torch.cat([qpos, env_state], axis=1) # seq length = 2
+ hs = self.transformer(
+ transformer_input, None, self.query_embed.weight, self.pos.weight
+ )[0]
+ a_hat = self.action_head(hs[:, : self.num_queries])
+ is_pad_hat = self.is_pad_head(hs[:, : self.num_queries])
+
+ for blk in self.decoder_blocks:
+ cls_token = blk(hs[:, self.num_queries :])
+ cls_token = self.decoder_norm(cls_token)
+ cls_token_hat = self.decoder_pred(cls_token)
+ return a_hat, is_pad_hat, [mu, logvar], cls_token_hat
+
+
+class DETRVAEjpeg(nn.Module):
+ """This is the DETR module that performs object detection"""
+
+ def __init__(
+ self, backbones, transformer, encoder, state_dim, num_queries, camera_names
+ ):
+ """Initializes the model.
+ Parameters:
+ backbones: torch module of the backbone to be used. See backbone.py
+ transformer: torch module of the transformer architecture. See transformer.py
+ state_dim: robot state dimension of the environment
+ num_queries: number of object queries, ie detection slot. This is the maximal number of objects
+ DETR can detect in a single image. For COCO, we recommend 100 queries.
+ aux_loss: True if auxiliary decoding losses (loss at each decoder layer) are to be used.
+ """
+ super().__init__()
+ self.num_queries = num_queries
+ self.camera_names = camera_names
+ self.transformer = transformer
+ self.encoder = encoder
+ hidden_dim = transformer.d_model
+ self.action_head = nn.Linear(hidden_dim, state_dim)
+ self.is_pad_head = nn.Linear(hidden_dim, 1)
+ self.num_jpeg = (
+ 20 * 4
+ ) # 80 tokens for the 20 next frame jpegs, 4 token will represent 1 jpeg
+ self.query_embed = nn.Embedding(num_queries + self.num_jpeg, hidden_dim)
+ if backbones is not None:
+ self.input_proj = nn.Conv2d(
+ backbones[0].num_channels, hidden_dim, kernel_size=1
+ )
+ self.backbones = nn.ModuleList(backbones)
+ self.input_proj_robot_state = nn.Linear(14, hidden_dim)
+ else:
+ # input_dim = 14 + 7 # robot_state + env_state
+ self.input_proj_robot_state = nn.Linear(14, hidden_dim)
+ self.input_proj_env_state = nn.Linear(7, hidden_dim)
+ self.pos = torch.nn.Embedding(2, hidden_dim)
+ self.backbones = None
+
+ # encoder extra parameters
+ self.latent_dim = 32 # final size of latent z # TODO tune
+ self.cls_embed = nn.Embedding(1, hidden_dim) # extra cls token embedding
+ self.encoder_action_proj = nn.Linear(
+ 14, hidden_dim
+ ) # project action to embedding
+ self.encoder_joint_proj = nn.Linear(14, hidden_dim) # project qpos to embedding
+ self.latent_proj = nn.Linear(
+ hidden_dim, self.latent_dim * 2
+ ) # project hidden state to latent std, var
+ self.register_buffer(
+ "pos_table", get_sinusoid_encoding_table(1 + 1 + num_queries, hidden_dim)
+ ) # [CLS], qpos, a_seq
+
+ # decoder extra parameters
+ self.latent_out_proj = nn.Linear(
+ self.latent_dim, hidden_dim
+ ) # project latent sample to embedding
+ self.additional_pos_embed = nn.Embedding(
+ 2, hidden_dim
+ ) # learned position embedding for proprio and latent
+
+ # settings for cls token prediction
+ # self.mask_token = nn.Parameter(torch.zeros(1, 1, hidden_dim))
+ # self.n_patch = (self.image_size//self.patch_size)**2
+ # self.k = 1 # number of next frames
+ # self.n_patch = (self.img_h//self.patch_size)*(self.img_w//self.patch_size)*(self.k)
+ # self.decoder_pos_embed = nn.Parameter(torch.zeros(1, self.n_patch, hidden_dim), requires_grad=False) # (1, n_patch, h)
+ # self.patch_embed = nn.Embedding(self.n_patch, hidden_dim)
+ # self.decoder_embed = nn.Linear(hidden_dim, hidden_dim, bias=True)
+
+ self.jpeg_head = nn.Sequential(
+ nn.Linear(hidden_dim, hidden_dim), nn.ReLU(), nn.Linear(hidden_dim, 400)
+ )
+
+ # self.jpeg_head = nn.Linear(hidden_dim, 400)
+
+ # decoder_depth = 2 # hardcode
+ # self.decoder_blocks = nn.ModuleList([
+ # Block(hidden_dim, 16, 4, qkv_bias=True, qk_scale=None, norm_layer=nn.LayerNorm)
+ # for i in range(decoder_depth)])
+
+ # self.decoder_norm = nn.LayerNorm(hidden_dim)
+ # self.decoder_pred = nn.Linear(hidden_dim, 400, bias=True) # decoder to patch
+
+ # decoder_pos_embed = get_2d_sincos_pos_embed(self.decoder_pos_embed.shape[-1], (self.image_size//self.patch_size), cls_token=False)
+ # decoder_pos_embed = get_2d_sincos_pos_embed_v2(self.decoder_pos_embed.shape[-1], (self.img_h//self.patch_size, self.img_w//self.patch_size))
+ # self.decoder_pos_embed.data.copy_(torch.from_numpy(decoder_pos_embed).float().unsqueeze(0).repeat(1,self.k,1))
+
+ def forward(self, qpos, image, env_state, actions=None, is_pad=None):
+ """
+ qpos: batch, qpos_dim
+ image: batch, num_cam, channel, height, width
+ env_state: None
+ actions: batch, seq, action_dim
+ """
+ is_training = actions is not None # train or val
+ bs, _ = qpos.shape
+ ### Obtain latent z from action sequence
+ if is_training:
+ # project action sequence to embedding dim, and concat with a CLS token
+ action_embed = self.encoder_action_proj(actions) # (bs, seq, hidden_dim)
+ qpos_embed = self.encoder_joint_proj(qpos) # (bs, hidden_dim)
+ qpos_embed = torch.unsqueeze(qpos_embed, axis=1) # (bs, 1, hidden_dim)
+ cls_embed = self.cls_embed.weight # (1, hidden_dim)
+ cls_embed = torch.unsqueeze(cls_embed, axis=0).repeat(
+ bs, 1, 1
+ ) # (bs, 1, hidden_dim)
+ encoder_input = torch.cat(
+ [cls_embed, qpos_embed, action_embed], axis=1
+ ) # (bs, seq+1, hidden_dim)
+ encoder_input = encoder_input.permute(1, 0, 2) # (seq+1, bs, hidden_dim)
+ # do not mask cls token
+ cls_joint_is_pad = torch.full((bs, 2), False).to(
+ qpos.device
+ ) # False: not a padding
+ is_pad = torch.cat([cls_joint_is_pad, is_pad], axis=1) # (bs, seq+1)
+ # obtain position embedding
+ pos_embed = self.pos_table.clone().detach()
+ pos_embed = pos_embed.permute(1, 0, 2) # (seq+1, 1, hidden_dim)
+ # query model
+ encoder_output = self.encoder(
+ encoder_input, pos=pos_embed, src_key_padding_mask=is_pad
+ )
+ encoder_output = encoder_output[0] # take cls output only
+ latent_info = self.latent_proj(encoder_output)
+ mu = latent_info[:, : self.latent_dim]
+ logvar = latent_info[:, self.latent_dim :]
+ latent_sample = reparametrize(mu, logvar)
+ latent_input = self.latent_out_proj(latent_sample)
+ else:
+ mu = logvar = None
+ latent_sample = torch.zeros([bs, self.latent_dim], dtype=torch.float32).to(
+ qpos.device
+ )
+ latent_input = self.latent_out_proj(latent_sample)
+
+ if self.backbones is not None:
+ # Image observation features and position embeddings
+ all_cam_features = []
+ all_cam_pos = []
+ for cam_id, cam_name in enumerate(self.camera_names):
+ features, pos = self.backbones[0](image[:, cam_id]) # HARDCODED
+ features = features[0] # take the last layer feature
+ pos = pos[0]
+ all_cam_features.append(self.input_proj(features))
+ all_cam_pos.append(pos)
+ # proprioception features
+ proprio_input = self.input_proj_robot_state(qpos)
+ # fold camera dimension into width dimension
+ src = torch.cat(all_cam_features, axis=3)
+ pos = torch.cat(all_cam_pos, axis=3)
+ hs = self.transformer(
+ src,
+ None,
+ self.query_embed.weight,
+ pos,
+ latent_input,
+ proprio_input,
+ self.additional_pos_embed.weight,
+ )[0]
+ else:
+ qpos = self.input_proj_robot_state(qpos)
+ env_state = self.input_proj_env_state(env_state)
+ transformer_input = torch.cat([qpos, env_state], axis=1) # seq length = 2
+ hs = self.transformer(
+ transformer_input, None, self.query_embed.weight, self.pos.weight
+ )[0]
+ a_hat = self.action_head(hs[:, : self.num_queries])
+ is_pad_hat = self.is_pad_head(hs[:, : self.num_queries])
+
+ # for blk in self.decoder_blocks:
+ # jpeg_token = blk(hs[:,self.num_queries:])
+ # jpeg_token = self.decoder_norm(jpeg_token)
+ # jpeg_token_hat = self.decoder_pred(jpeg_token)
+ jpeg_token_hat = self.jpeg_head(hs[:, self.num_queries :])
+
+ return a_hat, is_pad_hat, [mu, logvar], jpeg_token_hat
+
+
+class DETRVAEjpeg_diffusion(nn.Module):
+ # add timestep
+ """This is the DETR module that performs object detection"""
+
+ def __init__(
+ self,
+ backbones,
+ transformer,
+ encoder,
+ state_dim,
+ num_queries,
+ camera_names,
+ disable_vae_latent=False,
+ predict_frame=20,
+ jpeg_token_num=4,
+ jpeg_dim=400,
+ ):
+ """Initializes the model.
+ Parameters:
+ backbones: torch module of the backbone to be used. See backbone.py
+ transformer: torch module of the transformer architecture. See transformer.py
+ state_dim: robot state dimension of the environment
+ num_queries: number of object queries, ie detection slot. This is the maximal number of objects
+ DETR can detect in a single image. For COCO, we recommend 100 queries.
+ aux_loss: True if auxiliary decoding losses (loss at each decoder layer) are to be used.
+ """
+ super().__init__()
+ self.num_queries = num_queries
+ self.camera_names = camera_names
+ self.transformer = transformer
+ self.encoder = encoder
+ hidden_dim = transformer.d_model
+ self.disable_vae_latent = disable_vae_latent
+ self.action_head = nn.Linear(hidden_dim, state_dim)
+ self.is_pad_head = nn.Linear(hidden_dim, 1)
+ self.predict_frame = predict_frame
+ self.jpeg_token_num = jpeg_token_num
+ self.jpeg_dim = jpeg_dim
+ self.num_jpeg = (
+ predict_frame * jpeg_token_num
+ ) # 80 tokens for the 20 next frame jpegs, 4 token will represent 1 jpeg TODO tune
+ self.query_embed = nn.Embedding(num_queries + self.num_jpeg, hidden_dim)
+ if backbones is not None:
+ self.input_proj = nn.Conv2d(
+ backbones[0].num_channels, hidden_dim, kernel_size=1
+ )
+ self.backbones = nn.ModuleList(backbones)
+ self.input_proj_robot_state = nn.Linear(14, hidden_dim)
+ else:
+ # input_dim = 14 + 7 # robot_state + env_state
+ self.input_proj_robot_state = nn.Linear(14, hidden_dim)
+ self.input_proj_env_state = nn.Linear(7, hidden_dim)
+ self.pos = torch.nn.Embedding(2, hidden_dim)
+ self.backbones = None
+
+ # encoder extra parameters
+ self.latent_dim = 32 # final size of latent z # TODO tune
+ self.cls_embed = nn.Embedding(1, hidden_dim) # extra cls token embedding
+ self.encoder_action_proj = nn.Linear(
+ 14, hidden_dim
+ ) # project action to embedding
+ self.encoder_joint_proj = nn.Linear(14, hidden_dim) # project qpos to embedding
+ self.latent_proj = nn.Linear(
+ hidden_dim, self.latent_dim * 2
+ ) # project hidden state to latent std, var
+ self.register_buffer(
+ "pos_table", get_sinusoid_encoding_table(1 + 1 + num_queries, hidden_dim)
+ ) # [CLS], qpos, a_seq
+
+ # decoder extra parameters
+ self.latent_out_proj = nn.Linear(
+ self.latent_dim, hidden_dim
+ ) # project latent sample to embedding
+ self.additional_pos_embed = nn.Embedding(
+ 2, hidden_dim
+ ) # learned position embedding for proprio and latent
+
+ self.jpeg_head = nn.Sequential(
+ nn.Linear(hidden_dim, hidden_dim),
+ nn.ReLU(),
+ nn.Linear(hidden_dim, jpeg_dim),
+ )
+
+ def forward(
+ self,
+ qpos,
+ image,
+ env_state,
+ actions=None,
+ is_pad=None,
+ jpegs=None,
+ is_jpeg_pad=None,
+ denoise_steps=None,
+ is_training=True,
+ ):
+ """
+ qpos: batch, qpos_dim
+ image: batch, num_cam, channel, height, width
+ env_state: None
+ actions: batch, seq, action_dim
+ """
+ # is_training = actions is not None # train or val
+ bs, _ = qpos.shape
+ ### Obtain latent z from action sequence
+ if is_training and self.disable_vae_latent == False:
+ # project action sequence to embedding dim, and concat with a CLS token
+ action_embed = self.encoder_action_proj(actions) # (bs, seq, hidden_dim)
+ qpos_embed = self.encoder_joint_proj(qpos) # (bs, hidden_dim)
+ qpos_embed = torch.unsqueeze(qpos_embed, axis=1) # (bs, 1, hidden_dim)
+ cls_embed = self.cls_embed.weight # (1, hidden_dim)
+ cls_embed = torch.unsqueeze(cls_embed, axis=0).repeat(
+ bs, 1, 1
+ ) # (bs, 1, hidden_dim)
+ encoder_input = torch.cat(
+ [cls_embed, qpos_embed, action_embed], axis=1
+ ) # (bs, seq+1, hidden_dim)
+ encoder_input = encoder_input.permute(1, 0, 2) # (seq+1, bs, hidden_dim)
+ # do not mask cls token
+ cls_joint_is_pad = torch.full((bs, 2), False).to(
+ qpos.device
+ ) # False: not a padding
+ is_pad = torch.cat([cls_joint_is_pad, is_pad], axis=1) # (bs, seq+1)
+ # obtain position embedding
+ pos_embed = self.pos_table.clone().detach()
+ pos_embed = pos_embed.permute(1, 0, 2) # (seq+1, 1, hidden_dim)
+ # query model
+ encoder_output = self.encoder(
+ encoder_input, pos=pos_embed, src_key_padding_mask=is_pad
+ )
+ encoder_output = encoder_output[0] # take cls output only
+ latent_info = self.latent_proj(encoder_output)
+ mu = latent_info[:, : self.latent_dim]
+ logvar = latent_info[:, self.latent_dim :]
+ latent_sample = reparametrize(mu, logvar)
+ latent_input = self.latent_out_proj(latent_sample)
+ else:
+ mu = logvar = None
+ latent_sample = torch.zeros([bs, self.latent_dim], dtype=torch.float32).to(
+ qpos.device
+ )
+ latent_input = self.latent_out_proj(latent_sample)
+
+ if self.backbones is not None:
+ # Image observation features and position embeddings
+ all_cam_features = []
+ all_cam_pos = []
+ for cam_id, cam_name in enumerate(self.camera_names):
+ features, pos = self.backbones[0](image[:, cam_id]) # HARDCODED
+ features = features[0] # take the last layer feature
+ pos = pos[0]
+ all_cam_features.append(self.input_proj(features))
+ all_cam_pos.append(pos)
+ # proprioception features
+ proprio_input = self.input_proj_robot_state(qpos)
+ # fold camera dimension into width dimension
+ src = torch.cat(all_cam_features, axis=3)
+ # need timestep?
+ pos = torch.cat(all_cam_pos, axis=3)
+ hs = self.transformer(
+ src,
+ None,
+ self.query_embed.weight,
+ pos,
+ latent_input,
+ proprio_input,
+ self.additional_pos_embed.weight,
+ actions,
+ jpegs,
+ denoise_steps,
+ )[0]
+ else:
+ qpos = self.input_proj_robot_state(qpos)
+ env_state = self.input_proj_env_state(env_state)
+ transformer_input = torch.cat([qpos, env_state], axis=1) # seq length = 2
+ hs = self.transformer(
+ transformer_input, None, self.query_embed.weight, self.pos.weight
+ )[0]
+ a_hat = self.action_head(hs[:, : self.num_queries])
+ # is_pad_hat = self.is_pad_head(hs[:, :self.num_queries])
+ is_pad_hat = self.is_pad_head(hs)
+ # for blk in self.decoder_blocks:
+ # jpeg_token = blk(hs[:,self.num_queries:])
+ # jpeg_token = self.decoder_norm(jpeg_token)
+ # jpeg_token_hat = self.decoder_pred(jpeg_token)
+ jpeg_token_hat = self.jpeg_head(hs[:, self.num_queries :])
+
+ return a_hat, jpeg_token_hat, is_pad_hat, [mu, logvar]
+
+
+class DETRVAEjpeg_diffusion_seperate(nn.Module):
+ # add timestep
+ """This is the DETR module that performs object detection"""
+
+ def __init__(
+ self,
+ backbones,
+ transformer,
+ encoder,
+ state_dim,
+ num_queries,
+ camera_names,
+ disable_vae_latent=False,
+ predict_frame=20,
+ jpeg_token_num=4,
+ jpeg_dim=400,
+ ):
+ """Initializes the model.
+ Parameters:
+ backbones: torch module of the backbone to be used. See backbone.py
+ transformer: torch module of the transformer architecture. See transformer.py
+ state_dim: robot state dimension of the environment
+ num_queries: number of object queries, ie detection slot. This is the maximal number of objects
+ DETR can detect in a single image. For COCO, we recommend 100 queries.
+ aux_loss: True if auxiliary decoding losses (loss at each decoder layer) are to be used.
+ """
+ super().__init__()
+ self.num_queries = num_queries
+ self.camera_names = camera_names
+ self.transformer = transformer
+ self.encoder = encoder
+ hidden_dim = transformer.d_model
+ self.disable_vae_latent = disable_vae_latent
+ self.action_head = nn.Linear(hidden_dim, state_dim)
+ self.is_pad_head = nn.Linear(hidden_dim, 1)
+ self.predict_frame = predict_frame
+ self.jpeg_token_num = jpeg_token_num
+ self.jpeg_dim = jpeg_dim
+ self.num_jpeg = (
+ predict_frame * jpeg_token_num
+ ) # 80 tokens for the 20 next frame jpegs, 4 token will represent 1 jpeg TODO tune
+ self.query_embed_action = nn.Embedding(num_queries, hidden_dim)
+ self.query_embed_jpeg = nn.Embedding(self.num_jpeg, hidden_dim)
+ # self.query_embed = nn.Embedding(num_queries+self.num_jpeg, hidden_dim)
+ if backbones is not None:
+ self.input_proj = nn.Conv2d(
+ backbones[0].num_channels, hidden_dim, kernel_size=1
+ )
+ self.backbones = nn.ModuleList(backbones)
+ self.input_proj_robot_state = nn.Linear(14, hidden_dim)
+ else:
+ # input_dim = 14 + 7 # robot_state + env_state
+ self.input_proj_robot_state = nn.Linear(14, hidden_dim)
+ self.input_proj_env_state = nn.Linear(7, hidden_dim)
+ self.pos = torch.nn.Embedding(2, hidden_dim)
+ self.backbones = None
+
+ # encoder extra parameters
+ self.latent_dim = 32 # final size of latent z # TODO tune
+ self.cls_embed = nn.Embedding(1, hidden_dim) # extra cls token embedding
+ self.encoder_action_proj = nn.Linear(
+ 14, hidden_dim
+ ) # project action to embedding
+ self.encoder_joint_proj = nn.Linear(14, hidden_dim) # project qpos to embedding
+ self.latent_proj = nn.Linear(
+ hidden_dim, self.latent_dim * 2
+ ) # project hidden state to latent std, var
+ self.register_buffer(
+ "pos_table", get_sinusoid_encoding_table(1 + 1 + num_queries, hidden_dim)
+ ) # [CLS], qpos, a_seq
+
+ # decoder extra parameters
+ self.latent_out_proj = nn.Linear(
+ self.latent_dim, hidden_dim
+ ) # project latent sample to embedding
+ self.additional_pos_embed = nn.Embedding(
+ 2, hidden_dim
+ ) # learned position embedding for proprio and latent
+
+ self.jpeg_head = nn.Sequential(
+ nn.Linear(hidden_dim, hidden_dim),
+ nn.ReLU(),
+ nn.Linear(hidden_dim, jpeg_dim), # todo
+ )
+
+ def forward(
+ self,
+ qpos,
+ image,
+ env_state,
+ actions=None,
+ is_pad=None,
+ denoise_steps=None,
+ is_training=True,
+ ):
+ """
+ qpos: batch, qpos_dim
+ image: batch, num_cam, channel, height, width
+ env_state: None
+ actions: batch, seq, action_dim
+ """
+ # is_training = actions is not None # train or val
+ bs, _ = qpos.shape
+ ### Obtain latent z from action sequence
+ if is_training and self.disable_vae_latent == False:
+ # project action sequence to embedding dim, and concat with a CLS token
+ action_embed = self.encoder_action_proj(actions) # (bs, seq, hidden_dim)
+ qpos_embed = self.encoder_joint_proj(qpos) # (bs, hidden_dim)
+ qpos_embed = torch.unsqueeze(qpos_embed, axis=1) # (bs, 1, hidden_dim)
+ cls_embed = self.cls_embed.weight # (1, hidden_dim)
+ cls_embed = torch.unsqueeze(cls_embed, axis=0).repeat(
+ bs, 1, 1
+ ) # (bs, 1, hidden_dim)
+ encoder_input = torch.cat(
+ [cls_embed, qpos_embed, action_embed], axis=1
+ ) # (bs, seq+1, hidden_dim)
+ encoder_input = encoder_input.permute(1, 0, 2) # (seq+1, bs, hidden_dim)
+ # do not mask cls token
+ cls_joint_is_pad = torch.full((bs, 2), False).to(
+ qpos.device
+ ) # False: not a padding
+ is_pad = torch.cat([cls_joint_is_pad, is_pad], axis=1) # (bs, seq+1)
+ # obtain position embedding
+ pos_embed = self.pos_table.clone().detach()
+ pos_embed = pos_embed.permute(1, 0, 2) # (seq+1, 1, hidden_dim)
+ # query model
+ encoder_output = self.encoder(
+ encoder_input, pos=pos_embed, src_key_padding_mask=is_pad
+ )
+ encoder_output = encoder_output[0] # take cls output only
+ latent_info = self.latent_proj(encoder_output)
+ mu = latent_info[:, : self.latent_dim]
+ logvar = latent_info[:, self.latent_dim :]
+ latent_sample = reparametrize(mu, logvar)
+ latent_input = self.latent_out_proj(latent_sample)
+ else:
+ mu = logvar = None
+ latent_sample = torch.zeros([bs, self.latent_dim], dtype=torch.float32).to(
+ qpos.device
+ )
+ latent_input = self.latent_out_proj(latent_sample)
+
+ if self.backbones is not None:
+ # Image observation features and position embeddings
+ all_cam_features = []
+ all_cam_pos = []
+ for cam_id, cam_name in enumerate(self.camera_names):
+ features, pos = self.backbones[0](image[:, cam_id]) # HARDCODED
+ features = features[0] # take the last layer feature
+ pos = pos[0]
+ all_cam_features.append(self.input_proj(features))
+ all_cam_pos.append(pos)
+ # proprioception features
+ proprio_input = self.input_proj_robot_state(qpos)
+ # fold camera dimension into width dimension
+ src = torch.cat(all_cam_features, axis=3)
+ # need timestep?
+ pos = torch.cat(all_cam_pos, axis=3)
+ # hs = self.transformer(src, None, self.query_embed.weight, pos, latent_input, proprio_input, self.additional_pos_embed.weight, actions, jpegs, denoise_steps,self.predict_frame, self.jpeg_token_num)[0]
+ hs_action, hs_jpeg = self.transformer(
+ src,
+ None,
+ self.query_embed_action.weight,
+ self.query_embed_jpeg.weight,
+ pos,
+ latent_input,
+ proprio_input,
+ self.additional_pos_embed.weight,
+ actions,
+ denoise_steps,
+ )
+ else:
+ qpos = self.input_proj_robot_state(qpos)
+ env_state = self.input_proj_env_state(env_state)
+ transformer_input = torch.cat([qpos, env_state], axis=1) # seq length = 2
+ # print(hs_action.shape,hs_jpeg.shape)
+ a_hat = self.action_head(hs_action)
+ is_pad_action_hat = self.is_pad_head(hs_action)
+ is_pad_jpeg_hat = self.is_pad_head(hs_jpeg)
+ # print(a_hat.shape)
+ # print(is_pad_action_hat.shape,is_pad_jpeg_hat.shape)
+ is_pad_hat = torch.cat([is_pad_action_hat, is_pad_jpeg_hat], axis=1)
+ jpeg_token_hat = self.jpeg_head(hs_jpeg)
+
+ return a_hat, jpeg_token_hat, is_pad_hat, [mu, logvar]
+
+
+class DETRVAE_nf_diffusion(nn.Module):
+ """This is the DETR module that performs object detection"""
+
+ def __init__(
+ self,
+ backbones,
+ transformer,
+ encoder,
+ state_dim,
+ num_queries,
+ camera_names,
+ predict_frame,
+ disable_vae_latent=False,
+ ):
+ """Initializes the model.
+ Parameters:
+ backbones: torch module of the backbone to be used. See backbone.py
+ transformer: torch module of the transformer architecture. See transformer.py
+ state_dim: robot state dimension of the environment
+ num_queries: number of object queries, ie detection slot. This is the maximal number of objects
+ DETR can detect in a single image. For COCO, we recommend 100 queries.
+ aux_loss: True if auxiliary decoding losses (loss at each decoder layer) are to be used.
+ """
+ super().__init__()
+ self.num_queries = num_queries
+ self.camera_names = camera_names
+ self.transformer = transformer
+ self.encoder = encoder
+ hidden_dim = transformer.d_model
+ self.action_head = nn.Linear(hidden_dim, state_dim)
+ self.is_pad_head = nn.Linear(hidden_dim, 1)
+ self.query_embed = nn.Embedding(num_queries, hidden_dim)
+ if backbones is not None:
+ self.input_proj = nn.Conv2d(
+ backbones[0].num_channels, hidden_dim, kernel_size=1
+ )
+ self.backbones = nn.ModuleList(backbones)
+ self.input_proj_robot_state = nn.Linear(14, hidden_dim)
+ else:
+ # input_dim = 14 + 7 # robot_state + env_state
+ self.input_proj_robot_state = nn.Linear(14, hidden_dim)
+ self.input_proj_env_state = nn.Linear(7, hidden_dim)
+ self.pos = torch.nn.Embedding(2, hidden_dim)
+ self.backbones = None
+
+ # encoder extra parameters
+ self.latent_dim = 32 # final size of latent z # TODO tune
+ self.cls_embed = nn.Embedding(1, hidden_dim) # extra cls token embedding
+ self.encoder_action_proj = nn.Linear(
+ 14, hidden_dim
+ ) # project action to embedding
+ self.encoder_joint_proj = nn.Linear(14, hidden_dim) # project qpos to embedding
+ self.latent_proj = nn.Linear(
+ hidden_dim, self.latent_dim * 2
+ ) # project hidden state to latent std, var
+ self.register_buffer(
+ "pos_table", get_sinusoid_encoding_table(1 + 1 + num_queries, hidden_dim)
+ ) # [CLS], qpos, a_seq
+
+ # decoder extra parameters
+ self.latent_out_proj = nn.Linear(
+ self.latent_dim, hidden_dim
+ ) # project latent sample to embedding
+ self.additional_pos_embed = nn.Embedding(
+ 2, hidden_dim
+ ) # learned position embedding for proprio and latent
+
+ # settings for next frame prediction
+ self.patch_size = 16
+ self.img_h, self.img_w = 224, 224
+ self.mask_token = nn.Parameter(torch.zeros(1, 1, hidden_dim))
+ # self.n_patch = (self.image_size//self.patch_size)**2
+ self.k = predict_frame # number of next frames
+ self.n_patch = (
+ (self.img_h // self.patch_size) * (self.img_w // self.patch_size) * (self.k)
+ )
+ self.decoder_pos_embed = nn.Parameter(
+ torch.zeros(1, self.n_patch, hidden_dim), requires_grad=False
+ ) # (1, n_patch, h)
+ self.patch_embed = nn.Embedding(self.n_patch, hidden_dim)
+ self.decoder_embed = nn.Linear(hidden_dim, hidden_dim, bias=True)
+
+ decoder_depth = 2 # hardcode
+ self.decoder_blocks = nn.ModuleList(
+ [
+ Block(
+ hidden_dim,
+ 16,
+ 4,
+ qkv_bias=True,
+ qk_scale=None,
+ norm_layer=nn.LayerNorm,
+ )
+ for i in range(decoder_depth)
+ ]
+ )
+
+ self.decoder_norm = nn.LayerNorm(hidden_dim)
+ self.decoder_pred = nn.Linear(
+ hidden_dim, self.patch_size**2 * 3, bias=True
+ ) # decoder to patch
+
+ # decoder_pos_embed = get_2d_sincos_pos_embed(self.decoder_pos_embed.shape[-1], (self.image_size//self.patch_size), cls_token=False)
+ decoder_pos_embed = get_2d_sincos_pos_embed_v2(
+ self.decoder_pos_embed.shape[-1],
+ (self.img_h // self.patch_size, self.img_w // self.patch_size),
+ )
+ self.decoder_pos_embed.data.copy_(
+ torch.from_numpy(decoder_pos_embed)
+ .float()
+ .unsqueeze(0)
+ .repeat(1, self.k, 1)
+ )
+ self.disable_vae_latent = disable_vae_latent
+ # fwd_params = sum(p.numel() for p in self.decoder_blocks.parameters() if p.requires_grad)
+
+ def forward(
+ self,
+ qpos,
+ image,
+ env_state,
+ actions=None,
+ noisy_actions=None,
+ is_pad=None,
+ denoise_steps=None,
+ ):
+ """
+ qpos: batch, qpos_dim
+ image: batch, num_cam, channel, height, width
+ env_state: None
+ actions: batch, seq, action_dim
+ """
+ is_training = actions is not None # train or val
+ # print('is_training',is_training)
+ bs, _ = qpos.shape
+ # ### Obtain latent z from action sequence
+ # print('detr image shape',image.shape)
+ if is_training and not self.disable_vae_latent:
+ # project action sequence to embedding dim, and concat with a CLS token
+ action_embed = self.encoder_action_proj(actions) # (bs, seq, hidden_dim)
+ qpos_embed = self.encoder_joint_proj(qpos) # (bs, hidden_dim)
+ qpos_embed = torch.unsqueeze(qpos_embed, axis=1) # (bs, 1, hidden_dim)
+ cls_embed = self.cls_embed.weight # (1, hidden_dim)
+ cls_embed = torch.unsqueeze(cls_embed, axis=0).repeat(
+ bs, 1, 1
+ ) # (bs, 1, hidden_dim)
+ encoder_input = torch.cat(
+ [cls_embed, qpos_embed, action_embed], axis=1
+ ) # (bs, seq+1, hidden_dim)
+ encoder_input = encoder_input.permute(1, 0, 2) # (seq+1, bs, hidden_dim)
+ # do not mask cls token
+ cls_joint_is_pad = torch.full((bs, 2), False).to(
+ qpos.device
+ ) # False: not a padding
+ is_pad = torch.cat([cls_joint_is_pad, is_pad], axis=1) # (bs, seq+1)
+ # obtain position embedding
+ pos_embed = self.pos_table.clone().detach()
+ pos_embed = pos_embed.permute(1, 0, 2) # (seq+1, 1, hidden_dim)
+ # query model
+ encoder_output = self.encoder(
+ encoder_input, pos=pos_embed, src_key_padding_mask=is_pad
+ )
+ encoder_output = encoder_output[0] # take cls output only
+ latent_info = self.latent_proj(encoder_output)
+ mu = latent_info[:, : self.latent_dim]
+ logvar = latent_info[:, self.latent_dim :]
+ latent_sample = reparametrize(mu, logvar)
+ latent_input = self.latent_out_proj(latent_sample)
+ else:
+ mu = logvar = None
+ latent_sample = torch.zeros([bs, self.latent_dim], dtype=torch.float32).to(
+ qpos.device
+ )
+ latent_input = self.latent_out_proj(latent_sample)
+
+ if self.backbones is not None:
+ # Image observation features and position embeddings
+ all_cam_features = []
+ all_cam_pos = []
+ if is_training:
+ next_frame_images = image[:, 1:] # should resize?
+ image = image[:, :1]
+ for cam_id, cam_name in enumerate(self.camera_names):
+ features, pos = self.backbones[0](image[:, cam_id]) # HARDCODED?
+ features = features[0] # take the last layer feature
+ pos = pos[0]
+ all_cam_features.append(self.input_proj(features))
+ all_cam_pos.append(pos)
+ # proprioception features
+ proprio_input = self.input_proj_robot_state(qpos)
+ # fold camera dimension into width dimension
+ src = torch.cat(all_cam_features, axis=3)
+ pos = torch.cat(all_cam_pos, axis=3)
+ # query_embed = torch.cat([self.query_embed.weight, self.patch_embed.weight], axis=0)
+ # should change
+ # print('src',src.shape)
+ hs_action, hs_patch = self.transformer(
+ src,
+ None,
+ self.query_embed.weight,
+ self.patch_embed.weight,
+ pos,
+ latent_input,
+ proprio_input,
+ self.additional_pos_embed.weight,
+ noisy_actions,
+ denoise_steps,
+ )
+ # hs = self.transformer(src, None, self.query_embed.weight, pos, latent_input, proprio_input, self.additional_pos_embed.weight)[0]
+ else:
+ qpos = self.input_proj_robot_state(qpos)
+ env_state = self.input_proj_env_state(env_state)
+ transformer_input = torch.cat([qpos, env_state], axis=1) # seq length = 2
+ hs = self.transformer(
+ transformer_input, None, self.query_embed.weight, self.pos.weight
+ )[0]
+
+ a_hat = self.action_head(hs_action)
+ is_pad_hat = self.is_pad_head(hs_action)
+
+ # print('hs',hs_action.shape)
+ # print('a_hat',a_hat.shape)
+ # next frame prediction
+ mask_token = self.mask_token
+ mask_tokens = mask_token.repeat(bs, self.n_patch, 1)
+ mask_tokens = mask_tokens + self.decoder_pos_embed.repeat(bs, 1, 1)
+
+ obs_pred = self.decoder_embed(hs_patch)
+ obs_pred_ = torch.cat([obs_pred, mask_tokens], dim=1)
+ # print('obs_pred_',obs_pred_.shape)
+ for blk in self.decoder_blocks:
+ obs_pred_ = blk(obs_pred_)
+ obs_pred_ = self.decoder_norm(obs_pred_)
+ obs_preds = self.decoder_pred(obs_pred_)
+ # print('obs_preds',obs_preds.shape)
+ # print(self.n_patch)
+ obs_preds = obs_preds[:, self.n_patch :]
+ # print('obs_preds',obs_preds.shape)
+ if is_training:
+ # next_frame_images = image[:,1:]
+ # print('next_frame_images',next_frame_images.shape)
+ next_frame_images = nn.functional.interpolate(
+ next_frame_images.reshape(bs, -1, *next_frame_images.shape[-2:]),
+ size=(self.img_h, self.img_w),
+ )
+ # print('next_frame_images',next_frame_images.shape)
+ p = self.patch_size
+ h_p = self.img_h // p
+ w_p = self.img_w // p
+ obs_targets = next_frame_images.reshape(
+ shape=(bs, self.k, 3, h_p, p, w_p, p)
+ )
+ obs_targets = obs_targets.permute(0, 1, 3, 5, 4, 6, 2)
+ obs_targets = obs_targets.reshape(
+ shape=(bs, h_p * w_p * self.k, (p**2) * 3)
+ )
+ else:
+ obs_targets = torch.zeros_like(obs_preds)
+
+ return a_hat, is_pad_hat, [mu, logvar], [obs_preds, obs_targets]
+
+
+class CNNMLP(nn.Module):
+ def __init__(self, backbones, state_dim, camera_names):
+ """Initializes the model.
+ Parameters:
+ backbones: torch module of the backbone to be used. See backbone.py
+ transformer: torch module of the transformer architecture. See transformer.py
+ state_dim: robot state dimension of the environment
+ num_queries: number of object queries, ie detection slot. This is the maximal number of objects
+ DETR can detect in a single image. For COCO, we recommend 100 queries.
+ aux_loss: True if auxiliary decoding losses (loss at each decoder layer) are to be used.
+ """
+ super().__init__()
+ self.camera_names = camera_names
+ self.action_head = nn.Linear(1000, state_dim) # TODO add more
+ if backbones is not None:
+ self.backbones = nn.ModuleList(backbones)
+ backbone_down_projs = []
+ for backbone in backbones:
+ down_proj = nn.Sequential(
+ nn.Conv2d(backbone.num_channels, 128, kernel_size=5),
+ nn.Conv2d(128, 64, kernel_size=5),
+ nn.Conv2d(64, 32, kernel_size=5),
+ )
+ backbone_down_projs.append(down_proj)
+ self.backbone_down_projs = nn.ModuleList(backbone_down_projs)
+
+ mlp_in_dim = 768 * len(backbones) + 14
+ self.mlp = mlp(
+ input_dim=mlp_in_dim, hidden_dim=1024, output_dim=14, hidden_depth=2
+ )
+ else:
+ raise NotImplementedError
+
+ def forward(self, qpos, image, env_state, actions=None):
+ """
+ qpos: batch, qpos_dim
+ image: batch, num_cam, channel, height, width
+ env_state: None
+ actions: batch, seq, action_dim
+ """
+ is_training = actions is not None # train or val
+ bs, _ = qpos.shape
+ # Image observation features and position embeddings
+ all_cam_features = []
+ for cam_id, cam_name in enumerate(self.camera_names):
+ features, pos = self.backbones[cam_id](image[:, cam_id])
+ features = features[0] # take the last layer feature
+ pos = pos[0] # not used
+ all_cam_features.append(self.backbone_down_projs[cam_id](features))
+ # flatten everything
+ flattened_features = []
+ for cam_feature in all_cam_features:
+ flattened_features.append(cam_feature.reshape([bs, -1]))
+ flattened_features = torch.cat(flattened_features, axis=1) # 768 each
+ features = torch.cat([flattened_features, qpos], axis=1) # qpos: 14
+ a_hat = self.mlp(features)
+ return a_hat
+
+
+def mlp(input_dim, hidden_dim, output_dim, hidden_depth):
+ if hidden_depth == 0:
+ mods = [nn.Linear(input_dim, output_dim)]
+ else:
+ mods = [nn.Linear(input_dim, hidden_dim), nn.ReLU(inplace=True)]
+ for i in range(hidden_depth - 1):
+ mods += [nn.Linear(hidden_dim, hidden_dim), nn.ReLU(inplace=True)]
+ mods.append(nn.Linear(hidden_dim, output_dim))
+ trunk = nn.Sequential(*mods)
+ return trunk
+
+
+def build_encoder(args):
+ d_model = args.hidden_dim # 256
+ dropout = args.dropout # 0.1
+ nhead = args.nheads # 8
+ dim_feedforward = args.dim_feedforward # 2048
+ num_encoder_layers = args.enc_layers # 4 # TODO shared with VAE decoder
+ normalize_before = args.pre_norm # False
+ activation = "relu"
+
+ encoder_layer = TransformerEncoderLayer(
+ d_model, nhead, dim_feedforward, dropout, activation, normalize_before
+ )
+ encoder_norm = nn.LayerNorm(d_model) if normalize_before else None
+ encoder = TransformerEncoder(encoder_layer, num_encoder_layers, encoder_norm)
+
+ return encoder
+
+
+def build_tactile_encoder(args):
+ tactile_dim = args.hidden_dim
+ dropout = args.dropout
+ tactile_encoder = Tactile_Encoder(
+ tactile_dim,dropout )
+ return tactile_encoder
+
+def build(args):
+ state_dim = 14 # TODO hardcode
+
+ # From state
+ # backbone = None # from state for now, no need for conv nets
+ # From image
+ backbones = []
+ backbone = build_backbone(args)
+ backbones.append(backbone)
+
+ transformer = build_transformer(args)
+
+ encoder = build_encoder(args)
+
+ model = DETRVAE(
+ backbones,
+ transformer,
+ encoder,
+ state_dim=state_dim,
+ num_queries=args.num_queries,
+ camera_names=args.camera_names,
+ )
+
+ n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad)
+ print("number of parameters: %.2fM" % (n_parameters / 1e6,))
+
+ return model
+
+
+def build_diffusion(args):
+ state_dim = 14 # TODO hardcode
+
+ # From state
+ # backbone = None # from state for now, no need for conv nets
+ # From image
+ backbones = []
+ for camera_id in args.camera_names:
+ if args.use_film:
+ backbone = build_film_backbone(args)
+ else:
+ backbone = build_backbone(args)
+ # backbone = build_backbone(args)
+ backbones.append(backbone)
+ transformer = build_transformer_denoise(
+ args
+ ) # decoder input noisy input & PE & time_embedding
+ encoder = build_encoder(args)
+ model = DETRVAE_Denoise(
+ backbones,
+ transformer,
+ encoder,
+ state_dim=state_dim,
+ num_queries=args.num_queries,
+ camera_names=args.camera_names, # add additional denoise step
+ history_step=args.history_step,
+ disable_vae_latent=args.disable_vae_latent,
+ is_multi_task=args.multitask,
+ use_film=args.use_film,
+ )
+
+ n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad)
+ print("number of parameters: %.2fM" % (n_parameters / 1e6,))
+
+ return model
+
+
+def build_diffusion_tactile(args):
+ state_dim = 14 # TODO hardcode
+
+ # From state
+ # backbone = None # from state for now, no need for conv nets
+ # From image
+ backbones = []
+ for camera_id in args.camera_names:
+ backbone = build_backbone(args)
+ backbones.append(backbone)
+ transformer = build_transformer_denoise_tactile(
+ args
+ ) # decoder input noisy input & PE & time_embedding
+ tactile_encoder = build_tactile_encoder(args)
+ model = DETRVAE_Denoise_Tactile(
+ backbones,
+ transformer,
+ tactile_encoder,
+ state_dim=state_dim,
+ num_queries=args.num_queries,
+ camera_names=args.camera_names, # add additional denoise step
+ )
+
+ n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad)
+ print("number of parameters: %.2fM" % (n_parameters / 1e6,))
+
+ return model
+
+def build_diffusion_tp(args): # token prediction
+ state_dim = 14 # TODO hardcode
+
+ # From state
+ # backbone = None # from state for now, no need for conv nets
+ # From image
+ backbones = []
+ for camera_id in args.camera_names:
+ backbone = build_backbone(args)
+ backbones.append(backbone)
+ transformer = build_transformer_diffusion_prediction(args)
+ encoder = build_encoder(args)
+
+ model = DETRVAE_Denoise_Token_Prediction( # TODO design for token pred
+ backbones,
+ transformer,
+ encoder,
+ state_dim=state_dim,
+ num_queries=args.num_queries,
+ camera_names=args.camera_names, # add additional denoise step
+ history_step=args.history_step,
+ predict_frame=args.predict_frame,
+ image_downsample_rate=args.image_downsample_rate,
+ temporal_downsample_rate=args.temporal_downsample_rate,
+ disable_vae_latent=args.disable_vae_latent,
+ disable_resnet=args.disable_resnet,
+ patch_size=args.patch_size,
+ token_pe_type=args.token_pe_type,
+ image_width=args.image_width,
+ image_height=args.image_height,
+ tokenizer_model_spatial_rate=args.tokenizer_model_spatial_rate,
+ tokenizer_model_temporal_rate=args.tokenizer_model_temporal_rate,
+ resize_rate=args.resize_rate,
+ )
+
+ n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad)
+ print("number of parameters: %.2fM" % (n_parameters / 1e6,))
+
+ return model
+
+
+def build_diffusion_tp_with_dual_visual_token(args): # token prediction
+ state_dim = 14 # TODO hardcode
+
+ # From state
+ # backbone = None # from state for now, no need for conv nets
+ # From image
+ backbones = []
+ for i in range(len(args.camera_names)):
+ backbone = build_backbone(args)
+ backbones.append(backbone) # fixed
+
+ transformer = build_transformer_diffusion_prediction_with_dual_visual_token(
+ args
+ ) # TODO design for token pred decoder input noisy input & PE & time_embedding
+ encoder = build_encoder(args)
+
+ model = DETRVAE_Denoise_Token_Prediction_Dual_Visual_Token( # TODO design for token pred
+ backbones,
+ transformer,
+ encoder,
+ state_dim=state_dim,
+ num_queries=args.num_queries,
+ camera_names=args.camera_names, # add additional denoise step
+ history_step=args.history_step,
+ predict_frame=args.predict_frame,
+ image_downsample_rate=args.image_downsample_rate,
+ temporal_downsample_rate=args.temporal_downsample_rate,
+ disable_vae_latent=args.disable_vae_latent,
+ disable_resnet=args.disable_resnet,
+ patch_size=args.patch_size,
+ token_pe_type=args.token_pe_type,
+ image_width=args.image_width,
+ image_height=args.image_height,
+ tokenizer_model_spatial_rate=args.tokenizer_model_spatial_rate,
+ tokenizer_model_temporal_rate=args.tokenizer_model_temporal_rate,
+ )
+
+ n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad)
+ print("number of parameters: %.2fM" % (n_parameters / 1e6,))
+
+ return model
+
+
+def build_diffusion_pp(args): # pixel prediction
+ state_dim = 14 # TODO hardcode
+
+ # From state
+ # backbone = None # from state for now, no need for conv nets
+ # From image
+ backbones = []
+ backbone = build_backbone(args)
+ if args.disable_resnet: # HARDCODE
+ for param in backbone.parameters():
+ param.requires_grad = False
+ backbones.append(backbone) # fixed
+ # refer to Transformer_diffusion build_transformer_diffusion
+ # TODO unified prediction
+ transformer = build_transformer_diffusion_pixel_prediction(
+ args
+ ) # TODO design for token pred decoder input noisy input & PE & time_embedding
+ encoder = build_encoder(args)
+ # TODO fix
+ model = DETRVAE_Denoise_Pixel_Prediction( # TODO design for token pred
+ backbones,
+ transformer,
+ encoder,
+ state_dim=state_dim,
+ num_queries=args.num_queries,
+ camera_names=args.camera_names, # add additional denoise step
+ history_step=args.history_step,
+ predict_frame=args.predict_frame,
+ image_downsample_rate=args.image_downsample_rate,
+ temporal_downsample_rate=args.temporal_downsample_rate,
+ disable_vae_latent=args.disable_vae_latent,
+ disable_resnet=args.disable_resnet,
+ patch_size=args.patch_size,
+ token_pe_type=args.token_pe_type,
+ image_width=args.image_width,
+ image_height=args.image_height,
+ resize_rate=args.resize_rate,
+ )
+
+ n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad)
+ print("number of parameters: %.2fM" % (n_parameters / 1e6,))
+
+ return model
+
+
+# discard
+def build_seg(args):
+ state_dim = 14 # TODO hardcode
+
+ # From state
+ # backbone = None # from state for now, no need for conv nets
+ # From image
+ # backbones = []
+ # backbone = build_backbone(args)
+ # backbones.append(backbone)
+
+ # transformer = build_transformer(args)
+
+ # encoder = build_encoder(args)
+
+ # model = DETRVAESeg(
+ # backbones,
+ # transformer,
+ # encoder,
+ # state_dim=state_dim,
+ # num_queries=args.num_queries,
+ # camera_names=args.camera_names,
+ # )
+
+ # n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad)
+ # print("number of parameters: %.2fM" % (n_parameters/1e6,))
+
+ # return model
+
+
+def build_dino(args):
+ state_dim = 14 # TODO hardcode
+
+ # From state
+ # backbone = None # from state for now, no need for conv nets
+ # From image
+ backbones = []
+ backbone = build_backbone(args)
+ backbones.append(backbone)
+
+ transformer = build_transformer(args)
+
+ encoder = build_encoder(args)
+
+ model = DETRVAEDino(
+ backbones,
+ transformer,
+ encoder,
+ state_dim=state_dim,
+ num_queries=args.num_queries,
+ camera_names=args.camera_names,
+ )
+
+ n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad)
+ print("number of parameters: %.2fM" % (n_parameters / 1e6,))
+
+ return model
+
+
+def build_jpeg(args):
+ state_dim = 14 # TODO hardcode
+
+ # From state
+ # backbone = None # from state for now, no need for conv nets
+ # From image
+ backbones = []
+ backbone = build_backbone(args)
+ backbones.append(backbone)
+
+ transformer = build_transformer(args)
+
+ encoder = build_encoder(args)
+
+ model = DETRVAEjpeg(
+ backbones,
+ transformer,
+ encoder,
+ state_dim=state_dim,
+ num_queries=args.num_queries,
+ camera_names=args.camera_names,
+ )
+
+ n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad)
+ print("number of parameters: %.2fM" % (n_parameters / 1e6,))
+
+ return model
+
+
+def build_jpeg_diffusion(args):
+ state_dim = 14 # TODO hardcode
+
+ # From state
+ # backbone = None # from state for now, no need for conv nets
+ # From image
+ backbones = []
+ backbone = build_backbone(args)
+ backbones.append(backbone)
+
+ transformer = build_transformer_diffusion(args)
+
+ encoder = build_encoder(args)
+
+ model = DETRVAEjpeg_diffusion(
+ backbones,
+ transformer,
+ encoder,
+ state_dim=state_dim,
+ num_queries=args.num_queries,
+ camera_names=args.camera_names,
+ disable_vae_latent=args.disable_vae_latent,
+ predict_frame=args.predict_frame,
+ )
+
+ n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad)
+ print("number of parameters: %.2fM" % (n_parameters / 1e6,))
+
+ return model
+
+
+def build_jpeg_diffusion_seperate(args):
+ state_dim = 14 # TODO hardcode
+
+ # From state
+ # backbone = None # from state for now, no need for conv nets
+ # From image
+ print("predict_frame", args.predict_frame)
+ print("jpeg_token_num", args.jpeg_token_num)
+ print("jpeg_dim", args.jpeg_dim)
+ backbones = []
+ backbone = build_backbone(args)
+ backbones.append(backbone)
+
+ transformer = build_transformer_diffusion_seperate(args)
+
+ encoder = build_encoder(args)
+
+ model = DETRVAEjpeg_diffusion_seperate(
+ backbones,
+ transformer,
+ encoder,
+ state_dim=state_dim,
+ num_queries=args.num_queries,
+ camera_names=args.camera_names,
+ disable_vae_latent=args.disable_vae_latent,
+ predict_frame=args.predict_frame,
+ jpeg_token_num=args.jpeg_token_num,
+ jpeg_dim=args.jpeg_dim,
+ )
+
+ n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad)
+ print("number of parameters: %.2fM" % (n_parameters / 1e6,))
+
+ return model
+
+
+def build_nf_diffusion_seperate(args):
+ state_dim = 14 # TODO hardcode
+
+ # From state
+ # backbone = None # from state for now, no need for conv nets
+ # From image
+ print("predict_frame", args.predict_frame)
+ backbones = []
+ backbone = build_backbone(args)
+ backbones.append(backbone)
+
+ transformer = build_transformer_diffusion_seperate(args)
+
+ encoder = build_encoder(args)
+
+ model = DETRVAE_nf_diffusion(
+ backbones,
+ transformer,
+ encoder,
+ state_dim=state_dim,
+ num_queries=args.num_queries,
+ camera_names=args.camera_names,
+ disable_vae_latent=args.disable_vae_latent,
+ predict_frame=args.predict_frame,
+ )
+
+ n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad)
+ print("number of parameters: %.2fM" % (n_parameters / 1e6,))
+
+ return model
+
+
+def build_cnnmlp(args):
+ state_dim = 14 # TODO hardcode
+
+ # From state
+ # backbone = None # from state for now, no need for conv nets
+ # From image
+ backbones = []
+ for _ in args.camera_names:
+ backbone = build_backbone(args)
+ backbones.append(backbone)
+
+ model = CNNMLP(
+ backbones,
+ state_dim=state_dim,
+ camera_names=args.camera_names,
+ )
+
+ n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad)
+ print("number of parameters: %.2fM" % (n_parameters / 1e6,))
+
+ return model
diff --git a/ACT_DP_multitask/detr/models/detr_vae_nfp.py b/ACT_DP_multitask/detr/models/detr_vae_nfp.py
new file mode 100644
index 0000000000000000000000000000000000000000..4557cf2da298f11eb57b736d7a816aeeb5902b64
--- /dev/null
+++ b/ACT_DP_multitask/detr/models/detr_vae_nfp.py
@@ -0,0 +1,523 @@
+# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
+"""
+DETR model and criterion classes.
+"""
+import torch
+from torch import nn
+from torch.autograd import Variable
+from .backbone import build_backbone
+from .transformer import build_transformer, TransformerEncoder, TransformerEncoderLayer
+from .vision_transformer import Block, get_2d_sincos_pos_embed, get_2d_sincos_pos_embed_v2
+from .mr_mg.policy.model.vision_transformer import vit_base
+
+import numpy as np
+
+import IPython
+e = IPython.embed
+
+
+def reparametrize(mu, logvar):
+ std = logvar.div(2).exp()
+ eps = Variable(std.data.new(std.size()).normal_())
+ return mu + std * eps
+
+
+def get_sinusoid_encoding_table(n_position, d_hid):
+ def get_position_angle_vec(position):
+ return [position / np.power(10000, 2 * (hid_j // 2) / d_hid) for hid_j in range(d_hid)]
+
+ sinusoid_table = np.array([get_position_angle_vec(pos_i) for pos_i in range(n_position)])
+ sinusoid_table[:, 0::2] = np.sin(sinusoid_table[:, 0::2]) # dim 2i
+ sinusoid_table[:, 1::2] = np.cos(sinusoid_table[:, 1::2]) # dim 2i+1
+
+ return torch.FloatTensor(sinusoid_table).unsqueeze(0)
+
+
+class DETRVAE(nn.Module):
+ """ This is the DETR module that performs object detection """
+ def __init__(self, backbones, transformer, encoder, state_dim, num_queries, camera_names):
+ """ Initializes the model.
+ Parameters:
+ backbones: torch module of the backbone to be used. See backbone.py
+ transformer: torch module of the transformer architecture. See transformer.py
+ state_dim: robot state dimension of the environment
+ num_queries: number of object queries, ie detection slot. This is the maximal number of objects
+ DETR can detect in a single image. For COCO, we recommend 100 queries.
+ aux_loss: True if auxiliary decoding losses (loss at each decoder layer) are to be used.
+ """
+ super().__init__()
+ self.num_queries = num_queries
+ self.camera_names = camera_names
+ self.transformer = transformer
+ self.encoder = encoder
+ hidden_dim = transformer.d_model
+ self.action_head = nn.Linear(hidden_dim, state_dim)
+ self.is_pad_head = nn.Linear(hidden_dim, 1)
+ self.query_embed = nn.Embedding(num_queries, hidden_dim)
+ if backbones is not None:
+ self.input_proj = nn.Conv2d(backbones[0].num_channels, hidden_dim, kernel_size=1)
+ self.backbones = nn.ModuleList(backbones)
+ self.input_proj_robot_state = nn.Linear(14, hidden_dim)
+ else:
+ # input_dim = 14 + 7 # robot_state + env_state
+ self.input_proj_robot_state = nn.Linear(14, hidden_dim)
+ self.input_proj_env_state = nn.Linear(7, hidden_dim)
+ self.pos = torch.nn.Embedding(2, hidden_dim)
+ self.backbones = None
+
+ # encoder extra parameters
+ self.latent_dim = 32 # final size of latent z # TODO tune
+ self.cls_embed = nn.Embedding(1, hidden_dim) # extra cls token embedding
+ self.encoder_action_proj = nn.Linear(14, hidden_dim) # project action to embedding
+ self.encoder_joint_proj = nn.Linear(14, hidden_dim) # project qpos to embedding
+ self.latent_proj = nn.Linear(hidden_dim, self.latent_dim*2) # project hidden state to latent std, var
+ self.register_buffer('pos_table', get_sinusoid_encoding_table(1+1+num_queries, hidden_dim)) # [CLS], qpos, a_seq
+
+ # decoder extra parameters
+ self.latent_out_proj = nn.Linear(self.latent_dim, hidden_dim) # project latent sample to embedding
+ self.additional_pos_embed = nn.Embedding(2, hidden_dim) # learned position embedding for proprio and latent
+
+ # settings for next frame prediction
+ self.patch_size = 16
+ # self.image_size = 224
+ # self.img_h, self.img_w = 128, 160
+ self.img_h, self.img_w = 224, 224
+ self.mask_token = nn.Parameter(torch.zeros(1, 1, hidden_dim))
+ # self.n_patch = (self.image_size//self.patch_size)**2
+ self.k = 1 # number of next frames
+ self.n_patch = (self.img_h//self.patch_size)*(self.img_w//self.patch_size)*(self.k)
+ self.decoder_pos_embed = nn.Parameter(torch.zeros(1, self.n_patch, hidden_dim), requires_grad=False) # (1, n_patch, h)
+ self.patch_embed = nn.Embedding(self.n_patch, hidden_dim)
+ self.decoder_embed = nn.Linear(hidden_dim, hidden_dim, bias=True)
+
+ decoder_depth = 2 # hardcode
+ self.decoder_blocks = nn.ModuleList([
+ Block(hidden_dim, 16, 4, qkv_bias=True, qk_scale=None, norm_layer=nn.LayerNorm)
+ for i in range(decoder_depth)])
+
+ self.decoder_norm = nn.LayerNorm(hidden_dim)
+ self.decoder_pred = nn.Linear(hidden_dim, self.patch_size**2 * 3, bias=True) # decoder to patch
+
+ # decoder_pos_embed = get_2d_sincos_pos_embed(self.decoder_pos_embed.shape[-1], (self.image_size//self.patch_size), cls_token=False)
+ decoder_pos_embed = get_2d_sincos_pos_embed_v2(self.decoder_pos_embed.shape[-1], (self.img_h//self.patch_size, self.img_w//self.patch_size))
+ self.decoder_pos_embed.data.copy_(torch.from_numpy(decoder_pos_embed).float().unsqueeze(0).repeat(1,self.k,1))
+
+ # fwd_params = sum(p.numel() for p in self.decoder_blocks.parameters() if p.requires_grad)
+
+
+ def forward(self, qpos, image, env_state, actions=None, is_pad=None):
+ """
+ qpos: batch, qpos_dim
+ image: batch, num_cam, channel, height, width
+ env_state: None
+ actions: batch, seq, action_dim
+ """
+ is_training = actions is not None # train or val
+ bs, _ = qpos.shape
+ ### Obtain latent z from action sequence
+ if is_training:
+ # project action sequence to embedding dim, and concat with a CLS token
+ action_embed = self.encoder_action_proj(actions) # (bs, seq, hidden_dim)
+ qpos_embed = self.encoder_joint_proj(qpos) # (bs, hidden_dim)
+ qpos_embed = torch.unsqueeze(qpos_embed, axis=1) # (bs, 1, hidden_dim)
+ cls_embed = self.cls_embed.weight # (1, hidden_dim)
+ cls_embed = torch.unsqueeze(cls_embed, axis=0).repeat(bs, 1, 1) # (bs, 1, hidden_dim)
+ encoder_input = torch.cat([cls_embed, qpos_embed, action_embed], axis=1) # (bs, seq+1, hidden_dim)
+ encoder_input = encoder_input.permute(1, 0, 2) # (seq+1, bs, hidden_dim)
+ # do not mask cls token
+ cls_joint_is_pad = torch.full((bs, 2), False).to(qpos.device) # False: not a padding
+ is_pad = torch.cat([cls_joint_is_pad, is_pad], axis=1) # (bs, seq+1)
+ # obtain position embedding
+ pos_embed = self.pos_table.clone().detach()
+ pos_embed = pos_embed.permute(1, 0, 2) # (seq+1, 1, hidden_dim)
+ # query model
+ encoder_output = self.encoder(encoder_input, pos=pos_embed, src_key_padding_mask=is_pad)
+ encoder_output = encoder_output[0] # take cls output only
+ latent_info = self.latent_proj(encoder_output)
+ mu = latent_info[:, :self.latent_dim]
+ logvar = latent_info[:, self.latent_dim:]
+ latent_sample = reparametrize(mu, logvar)
+ latent_input = self.latent_out_proj(latent_sample)
+ else:
+ mu = logvar = None
+ latent_sample = torch.zeros([bs, self.latent_dim], dtype=torch.float32).to(qpos.device)
+ latent_input = self.latent_out_proj(latent_sample)
+
+ if self.backbones is not None:
+ # Image observation features and position embeddings
+ all_cam_features = []
+ all_cam_pos = []
+ if is_training:
+ next_frame_images = image[:,1:]
+ image = image[:,:1]
+ for cam_id, cam_name in enumerate(self.camera_names):
+ features, pos = self.backbones[0](image[:, cam_id]) # HARDCODED?
+ features = features[0] # take the last layer feature
+ pos = pos[0]
+ all_cam_features.append(self.input_proj(features))
+ all_cam_pos.append(pos)
+ # proprioception features
+ proprio_input = self.input_proj_robot_state(qpos)
+ # fold camera dimension into width dimension
+ src = torch.cat(all_cam_features, axis=3)
+ pos = torch.cat(all_cam_pos, axis=3)
+ query_embed = torch.cat([self.query_embed.weight, self.patch_embed.weight], axis=0)
+ hs = self.transformer(src, None, query_embed, pos, latent_input, proprio_input, self.additional_pos_embed.weight)[0]
+ # hs = self.transformer(src, None, self.query_embed.weight, pos, latent_input, proprio_input, self.additional_pos_embed.weight)[0]
+ else:
+ qpos = self.input_proj_robot_state(qpos)
+ env_state = self.input_proj_env_state(env_state)
+ transformer_input = torch.cat([qpos, env_state], axis=1) # seq length = 2
+ hs = self.transformer(transformer_input, None, self.query_embed.weight, self.pos.weight)[0]
+ a_hat = self.action_head(hs[:,:self.num_queries])
+ is_pad_hat = self.is_pad_head(hs[:,:self.num_queries])
+
+ # next frame prediction
+ mask_token = self.mask_token
+ mask_tokens = mask_token.repeat(bs, self.n_patch, 1)
+ mask_tokens = mask_tokens + self.decoder_pos_embed.repeat(bs, 1, 1)
+
+ obs_pred = self.decoder_embed(hs[:,self.num_queries:])
+ obs_pred_ = torch.cat([obs_pred, mask_tokens], dim=1)
+ for blk in self.decoder_blocks:
+ obs_pred_ = blk(obs_pred_)
+ obs_pred_ = self.decoder_norm(obs_pred_)
+ obs_preds = self.decoder_pred(obs_pred_)
+ obs_preds = obs_preds[:,self.n_patch:]
+
+ if is_training:
+ # next_frame_images = image[:,1:]
+ next_frame_images = nn.functional.interpolate(next_frame_images.reshape(bs, self.k*3, 224, 224), size=(self.img_h, self.img_w))
+ p = self.patch_size
+ h_p = self.img_h // p
+ w_p = self.img_w // p
+ obs_targets = next_frame_images.reshape(shape=(bs, self.k, 3, h_p, p, w_p, p))
+ obs_targets = obs_targets.permute(0,1,3,5,4,6,2)
+ obs_targets = obs_targets.reshape(shape=(bs, h_p*w_p*self.k, (p**2)*3))
+ else:
+ obs_targets = torch.zeros_like(obs_preds)
+
+ return a_hat, is_pad_hat, [mu, logvar], [obs_preds, obs_targets]
+
+
+class DETRVAE_MAE(nn.Module):
+ """ This is the DETR module that performs object detection """
+ def __init__(self, backbones, transformer, encoder, state_dim, num_queries, camera_names):
+ """ Initializes the model.
+ Parameters:
+ backbones: torch module of the backbone to be used. See backbone.py
+ transformer: torch module of the transformer architecture. See transformer.py
+ state_dim: robot state dimension of the environment
+ num_queries: number of object queries, ie detection slot. This is the maximal number of objects
+ DETR can detect in a single image. For COCO, we recommend 100 queries.
+ aux_loss: True if auxiliary decoding losses (loss at each decoder layer) are to be used.
+ """
+ super().__init__()
+ self.num_queries = num_queries
+ self.camera_names = camera_names
+ self.transformer = transformer
+ self.encoder = encoder
+ hidden_dim = transformer.d_model
+ self.action_head = nn.Linear(hidden_dim, state_dim)
+ self.is_pad_head = nn.Linear(hidden_dim, 1)
+ self.query_embed = nn.Embedding(num_queries, hidden_dim)
+
+ # self.model_mae = vits.__dict__['vit_base'](patch_size=16, num_classes=0)
+ self.model_mae = vit_base(patch_size=16, num_classes=0)
+ mae_ckpt = 'checkpoints/pretrained/mae_pretrain_vit_base.pth'
+ checkpoint = torch.load(mae_ckpt, map_location='cpu')
+ self.model_mae.load_state_dict(checkpoint['model'], strict=True)
+ print('Load MAE pretrained model')
+ # for name, p in self.model_mae.named_parameters():
+ # p.requires_grad = False
+
+ if backbones is not None:
+ self.input_proj = nn.Conv2d(backbones[0].num_channels, hidden_dim, kernel_size=1)
+ self.backbones = nn.ModuleList(backbones)
+ self.input_proj_robot_state = nn.Linear(14, hidden_dim)
+ else:
+ # input_dim = 14 + 7 # robot_state + env_state
+ self.input_proj_robot_state = nn.Linear(14, hidden_dim)
+ self.input_proj_env_state = nn.Linear(7, hidden_dim)
+ self.pos = torch.nn.Embedding(2, hidden_dim)
+ self.backbones = None
+
+ # encoder extra parameters
+ self.latent_dim = 32 # final size of latent z # TODO tune
+ self.cls_embed = nn.Embedding(1, hidden_dim) # extra cls token embedding
+ self.encoder_action_proj = nn.Linear(14, hidden_dim) # project action to embedding
+ self.encoder_joint_proj = nn.Linear(14, hidden_dim) # project qpos to embedding
+ self.latent_proj = nn.Linear(hidden_dim, self.latent_dim*2) # project hidden state to latent std, var
+ self.register_buffer('pos_table', get_sinusoid_encoding_table(1+1+num_queries, hidden_dim)) # [CLS], qpos, a_seq
+
+ # decoder extra parameters
+ self.latent_out_proj = nn.Linear(self.latent_dim, hidden_dim) # project latent sample to embedding
+ self.additional_pos_embed = nn.Embedding(2, hidden_dim) # learned position embedding for proprio and latent
+
+ # settings for next frame prediction
+ self.patch_size = 16
+ self.img_h, self.img_w = 224, 224
+ self.n_patch = (self.img_h//self.patch_size)*(self.img_w//self.patch_size)
+ self.mask_token = nn.Parameter(torch.zeros(1, 1, hidden_dim))
+ self.decoder_pos_embed = nn.Parameter(torch.zeros(1, self.n_patch, hidden_dim), requires_grad=False) # (1, n_patch, h)
+ self.patch_embed = nn.Embedding(self.n_patch, hidden_dim)
+ self.decoder_embed = nn.Linear(hidden_dim, hidden_dim, bias=True)
+
+ decoder_depth = 2 # hardcode
+ self.decoder_blocks = nn.ModuleList([
+ Block(hidden_dim, 16, 4, qkv_bias=True, qk_scale=None, norm_layer=nn.LayerNorm)
+ for i in range(decoder_depth)])
+
+ self.decoder_norm = nn.LayerNorm(hidden_dim)
+ self.decoder_pred = nn.Linear(hidden_dim, self.patch_size**2 * 3, bias=True) # decoder to patch
+
+ decoder_pos_embed = get_2d_sincos_pos_embed_v2(self.decoder_pos_embed.shape[-1], (self.img_h//self.patch_size, self.img_w//self.patch_size))
+ self.decoder_pos_embed.data.copy_(torch.from_numpy(decoder_pos_embed).float().unsqueeze(0))
+
+
+ def forward(self, qpos, image, env_state, actions=None, is_pad=None):
+ """
+ qpos: batch, qpos_dim
+ image: batch, num_cam, channel, height, width
+ env_state: None
+ actions: batch, seq, action_dim
+ """
+ is_training = actions is not None # train or val
+ bs, _ = qpos.shape
+ ### Obtain latent z from action sequence
+ if is_training:
+ # project action sequence to embedding dim, and concat with a CLS token
+ action_embed = self.encoder_action_proj(actions) # (bs, seq, hidden_dim)
+ qpos_embed = self.encoder_joint_proj(qpos) # (bs, hidden_dim)
+ qpos_embed = torch.unsqueeze(qpos_embed, axis=1) # (bs, 1, hidden_dim)
+ cls_embed = self.cls_embed.weight # (1, hidden_dim)
+ cls_embed = torch.unsqueeze(cls_embed, axis=0).repeat(bs, 1, 1) # (bs, 1, hidden_dim)
+ encoder_input = torch.cat([cls_embed, qpos_embed, action_embed], axis=1) # (bs, seq+1, hidden_dim)
+ encoder_input = encoder_input.permute(1, 0, 2) # (seq+1, bs, hidden_dim)
+ # do not mask cls token
+ cls_joint_is_pad = torch.full((bs, 2), False).to(qpos.device) # False: not a padding
+ is_pad = torch.cat([cls_joint_is_pad, is_pad], axis=1) # (bs, seq+1)
+ # obtain position embedding
+ pos_embed = self.pos_table.clone().detach()
+ pos_embed = pos_embed.permute(1, 0, 2) # (seq+1, 1, hidden_dim)
+ # query model
+ encoder_output = self.encoder(encoder_input, pos=pos_embed, src_key_padding_mask=is_pad)
+ encoder_output = encoder_output[0] # take cls output only
+ latent_info = self.latent_proj(encoder_output)
+ mu = latent_info[:, :self.latent_dim]
+ logvar = latent_info[:, self.latent_dim:]
+ latent_sample = reparametrize(mu, logvar)
+ latent_input = self.latent_out_proj(latent_sample)
+ else:
+ mu = logvar = None
+ latent_sample = torch.zeros([bs, self.latent_dim], dtype=torch.float32).to(qpos.device)
+ latent_input = self.latent_out_proj(latent_sample)
+
+ if self.backbones is not None:
+ # Image observation features and position embeddings
+ all_cam_features = []
+ all_cam_pos = []
+ if is_training:
+ next_frame_images = image[:,1:]
+ image = image[:,:1]
+ for cam_id, cam_name in enumerate(self.camera_names):
+ # features, pos = self.backbones[0](image[:, cam_id]) # HARDCODED
+ # features = features[0] # take the last layer feature
+ # pos = pos[0]
+ # all_cam_features.append(self.input_proj(features))
+ # all_cam_pos.append(pos)
+
+ obs_embedings, patch_embedings, pos_mae = self.model_mae(image[:,cam_id])
+
+ # proprioception features
+ proprio_input = self.input_proj_robot_state(qpos)
+ # fold camera dimension into width dimension
+ # src = torch.cat(all_cam_features, axis=3)
+ # pos = torch.cat(all_cam_pos, axis=3)
+ query_embed = torch.cat([self.query_embed.weight, self.patch_embed.weight], axis=0)
+ hs = self.transformer(patch_embedings, None, query_embed, pos_mae[0,1:], latent_input, proprio_input, self.additional_pos_embed.weight)[0]
+ # hs = self.transformer(src, None, self.query_embed.weight, pos, latent_input, proprio_input, self.additional_pos_embed.weight)[0]
+ else:
+ qpos = self.input_proj_robot_state(qpos)
+ env_state = self.input_proj_env_state(env_state)
+ transformer_input = torch.cat([qpos, env_state], axis=1) # seq length = 2
+ hs = self.transformer(transformer_input, None, self.query_embed.weight, self.pos.weight)[0]
+ a_hat = self.action_head(hs[:,:self.num_queries])
+ is_pad_hat = self.is_pad_head(hs[:,:self.num_queries])
+
+ # next frame prediction
+ mask_token = self.mask_token
+ mask_tokens = mask_token.repeat(bs, self.n_patch, 1)
+ mask_tokens = mask_tokens + self.decoder_pos_embed.repeat(bs, 1, 1)
+
+ obs_pred = self.decoder_embed(hs[:,self.num_queries:])
+ obs_pred_ = torch.cat([obs_pred, mask_tokens], dim=1)
+ for blk in self.decoder_blocks:
+ obs_pred_ = blk(obs_pred_)
+ obs_pred_ = self.decoder_norm(obs_pred_)
+ obs_preds = self.decoder_pred(obs_pred_)
+ obs_preds = obs_preds[:,self.n_patch:]
+
+ if is_training:
+ # next_frame_images = image[:,1:]
+ # next_frame_images = nn.functional.interpolate(next_frame_images[:,0], size=(self.img_h, self.img_w))
+ next_frame_images = next_frame_images[:,0]
+ p = self.patch_size
+ h_p = self.img_h // p
+ w_p = self.img_w // p
+ obs_targets = next_frame_images.reshape(shape=(bs, 3, h_p, p, w_p, p))
+ obs_targets = obs_targets.permute(0,2,4,3,5,1)
+ obs_targets = obs_targets.reshape(shape=(bs, h_p*w_p, (p**2)*3))
+ else:
+ obs_targets = torch.zeros_like(obs_preds)
+
+
+ return a_hat, is_pad_hat, [mu, logvar], [obs_preds, obs_targets]
+
+
+class CNNMLP(nn.Module):
+ def __init__(self, backbones, state_dim, camera_names):
+ """ Initializes the model.
+ Parameters:
+ backbones: torch module of the backbone to be used. See backbone.py
+ transformer: torch module of the transformer architecture. See transformer.py
+ state_dim: robot state dimension of the environment
+ num_queries: number of object queries, ie detection slot. This is the maximal number of objects
+ DETR can detect in a single image. For COCO, we recommend 100 queries.
+ aux_loss: True if auxiliary decoding losses (loss at each decoder layer) are to be used.
+ """
+ super().__init__()
+ self.camera_names = camera_names
+ self.action_head = nn.Linear(1000, state_dim) # TODO add more
+ if backbones is not None:
+ self.backbones = nn.ModuleList(backbones)
+ backbone_down_projs = []
+ for backbone in backbones:
+ down_proj = nn.Sequential(
+ nn.Conv2d(backbone.num_channels, 128, kernel_size=5),
+ nn.Conv2d(128, 64, kernel_size=5),
+ nn.Conv2d(64, 32, kernel_size=5)
+ )
+ backbone_down_projs.append(down_proj)
+ self.backbone_down_projs = nn.ModuleList(backbone_down_projs)
+
+ mlp_in_dim = 768 * len(backbones) + 14
+ self.mlp = mlp(input_dim=mlp_in_dim, hidden_dim=1024, output_dim=14, hidden_depth=2)
+ else:
+ raise NotImplementedError
+
+ def forward(self, qpos, image, env_state, actions=None):
+ """
+ qpos: batch, qpos_dim
+ image: batch, num_cam, channel, height, width
+ env_state: None
+ actions: batch, seq, action_dim
+ """
+ is_training = actions is not None # train or val
+ bs, _ = qpos.shape
+ # Image observation features and position embeddings
+ all_cam_features = []
+ for cam_id, cam_name in enumerate(self.camera_names):
+ features, pos = self.backbones[cam_id](image[:, cam_id])
+ features = features[0] # take the last layer feature
+ pos = pos[0] # not used
+ all_cam_features.append(self.backbone_down_projs[cam_id](features))
+ # flatten everything
+ flattened_features = []
+ for cam_feature in all_cam_features:
+ flattened_features.append(cam_feature.reshape([bs, -1]))
+ flattened_features = torch.cat(flattened_features, axis=1) # 768 each
+ features = torch.cat([flattened_features, qpos], axis=1) # qpos: 14
+ a_hat = self.mlp(features)
+ return a_hat
+
+
+def mlp(input_dim, hidden_dim, output_dim, hidden_depth):
+ if hidden_depth == 0:
+ mods = [nn.Linear(input_dim, output_dim)]
+ else:
+ mods = [nn.Linear(input_dim, hidden_dim), nn.ReLU(inplace=True)]
+ for i in range(hidden_depth - 1):
+ mods += [nn.Linear(hidden_dim, hidden_dim), nn.ReLU(inplace=True)]
+ mods.append(nn.Linear(hidden_dim, output_dim))
+ trunk = nn.Sequential(*mods)
+ return trunk
+
+
+def build_encoder(args):
+ d_model = args.hidden_dim # 256
+ dropout = args.dropout # 0.1
+ nhead = args.nheads # 8
+ dim_feedforward = args.dim_feedforward # 2048
+ num_encoder_layers = args.enc_layers # 4 # TODO shared with VAE decoder
+ normalize_before = args.pre_norm # False
+ activation = "relu"
+
+ encoder_layer = TransformerEncoderLayer(d_model, nhead, dim_feedforward,
+ dropout, activation, normalize_before)
+ encoder_norm = nn.LayerNorm(d_model) if normalize_before else None
+ encoder = TransformerEncoder(encoder_layer, num_encoder_layers, encoder_norm)
+
+ return encoder
+
+
+def build(args):
+ state_dim = 14 # TODO hardcode
+
+ # From state
+ # backbone = None # from state for now, no need for conv nets
+ # From image
+ backbones = []
+ backbone = build_backbone(args)
+ backbones.append(backbone)
+
+ transformer = build_transformer(args)
+
+ encoder = build_encoder(args)
+
+ if not args.mae:
+ model = DETRVAE(
+ backbones,
+ transformer,
+ encoder,
+ state_dim=state_dim,
+ num_queries=args.num_queries,
+ camera_names=args.camera_names,
+ )
+ else:
+ model = DETRVAE_MAE(
+ backbones,
+ transformer,
+ encoder,
+ state_dim=state_dim,
+ num_queries=args.num_queries,
+ camera_names=args.camera_names,
+ )
+
+ n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad)
+ print("number of parameters: %.2fM" % (n_parameters/1e6,))
+
+ return model
+
+def build_cnnmlp(args):
+ state_dim = 14 # TODO hardcode
+
+ # From state
+ # backbone = None # from state for now, no need for conv nets
+ # From image
+ backbones = []
+ for _ in args.camera_names:
+ backbone = build_backbone(args)
+ backbones.append(backbone)
+
+ model = CNNMLP(
+ backbones,
+ state_dim=state_dim,
+ camera_names=args.camera_names,
+ )
+
+ n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad)
+ print("number of parameters: %.2fM" % (n_parameters/1e6,))
+
+ return model
+
diff --git a/ACT_DP_multitask/detr/models/mask_former/__init__.py b/ACT_DP_multitask/detr/models/mask_former/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..0ef4d245cfcf49cf13a5195b60d57d92e0af016e
--- /dev/null
+++ b/ACT_DP_multitask/detr/models/mask_former/__init__.py
@@ -0,0 +1,19 @@
+# Copyright (c) Facebook, Inc. and its affiliates.
+from . import data # register all new datasets
+from . import modeling
+
+# config
+from .config import add_mask_former_config
+
+# dataset loading
+from .data.dataset_mappers.detr_panoptic_dataset_mapper import DETRPanopticDatasetMapper
+from .data.dataset_mappers.mask_former_panoptic_dataset_mapper import (
+ MaskFormerPanopticDatasetMapper,
+)
+from .data.dataset_mappers.mask_former_semantic_dataset_mapper import (
+ MaskFormerSemanticDatasetMapper,
+)
+
+# models
+from .mask_former_model import MaskFormer
+from .test_time_augmentation import SemanticSegmentorWithTTA
diff --git a/ACT_DP_multitask/detr/models/mask_former/__pycache__/__init__.cpython-38.pyc b/ACT_DP_multitask/detr/models/mask_former/__pycache__/__init__.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..dc4c5140e564d8b83ff0b8cd8605afbdf30fa5ca
Binary files /dev/null and b/ACT_DP_multitask/detr/models/mask_former/__pycache__/__init__.cpython-38.pyc differ
diff --git a/ACT_DP_multitask/detr/models/mask_former/config.py b/ACT_DP_multitask/detr/models/mask_former/config.py
new file mode 100644
index 0000000000000000000000000000000000000000..90ba988efd4bad35106b44f2c17de4bf71382a04
--- /dev/null
+++ b/ACT_DP_multitask/detr/models/mask_former/config.py
@@ -0,0 +1,85 @@
+# -*- coding: utf-8 -*-
+# Copyright (c) Facebook, Inc. and its affiliates.
+from detectron2.config import CfgNode as CN
+
+
+def add_mask_former_config(cfg):
+ """
+ Add config for MASK_FORMER.
+ """
+ # data config
+ # select the dataset mapper
+ cfg.INPUT.DATASET_MAPPER_NAME = "mask_former_semantic"
+ # Color augmentation
+ cfg.INPUT.COLOR_AUG_SSD = False
+ # We retry random cropping until no single category in semantic segmentation GT occupies more
+ # than `SINGLE_CATEGORY_MAX_AREA` part of the crop.
+ cfg.INPUT.CROP.SINGLE_CATEGORY_MAX_AREA = 1.0
+ # Pad image and segmentation GT in dataset mapper.
+ cfg.INPUT.SIZE_DIVISIBILITY = -1
+
+ # solver config
+ # weight decay on embedding
+ cfg.SOLVER.WEIGHT_DECAY_EMBED = 0.0
+ # optimizer
+ cfg.SOLVER.OPTIMIZER = "ADAMW"
+ cfg.SOLVER.BACKBONE_MULTIPLIER = 0.1
+
+ # mask_former model config
+ cfg.MODEL.MASK_FORMER = CN()
+
+ # loss
+ cfg.MODEL.MASK_FORMER.DEEP_SUPERVISION = True
+ cfg.MODEL.MASK_FORMER.NO_OBJECT_WEIGHT = 0.1
+ cfg.MODEL.MASK_FORMER.DICE_WEIGHT = 1.0
+ cfg.MODEL.MASK_FORMER.MASK_WEIGHT = 20.0
+
+ # transformer config
+ cfg.MODEL.MASK_FORMER.NHEADS = 8
+ cfg.MODEL.MASK_FORMER.DROPOUT = 0.1
+ cfg.MODEL.MASK_FORMER.DIM_FEEDFORWARD = 2048
+ cfg.MODEL.MASK_FORMER.ENC_LAYERS = 0
+ cfg.MODEL.MASK_FORMER.DEC_LAYERS = 6
+ cfg.MODEL.MASK_FORMER.PRE_NORM = False
+
+ cfg.MODEL.MASK_FORMER.HIDDEN_DIM = 256
+ cfg.MODEL.MASK_FORMER.NUM_OBJECT_QUERIES = 100
+
+ cfg.MODEL.MASK_FORMER.TRANSFORMER_IN_FEATURE = "res5"
+ cfg.MODEL.MASK_FORMER.ENFORCE_INPUT_PROJ = False
+
+ # mask_former inference config
+ cfg.MODEL.MASK_FORMER.TEST = CN()
+ cfg.MODEL.MASK_FORMER.TEST.PANOPTIC_ON = False
+ cfg.MODEL.MASK_FORMER.TEST.OBJECT_MASK_THRESHOLD = 0.0
+ cfg.MODEL.MASK_FORMER.TEST.OVERLAP_THRESHOLD = 0.0
+ cfg.MODEL.MASK_FORMER.TEST.SEM_SEG_POSTPROCESSING_BEFORE_INFERENCE = False
+
+ # Sometimes `backbone.size_divisibility` is set to 0 for some backbone (e.g. ResNet)
+ # you can use this config to override
+ cfg.MODEL.MASK_FORMER.SIZE_DIVISIBILITY = 32
+
+ # pixel decoder config
+ cfg.MODEL.SEM_SEG_HEAD.MASK_DIM = 256
+ # adding transformer in pixel decoder
+ cfg.MODEL.SEM_SEG_HEAD.TRANSFORMER_ENC_LAYERS = 0
+ # pixel decoder
+ cfg.MODEL.SEM_SEG_HEAD.PIXEL_DECODER_NAME = "BasePixelDecoder"
+
+ # swin transformer backbone
+ cfg.MODEL.SWIN = CN()
+ cfg.MODEL.SWIN.PRETRAIN_IMG_SIZE = 224
+ cfg.MODEL.SWIN.PATCH_SIZE = 4
+ cfg.MODEL.SWIN.EMBED_DIM = 96
+ cfg.MODEL.SWIN.DEPTHS = [2, 2, 6, 2]
+ cfg.MODEL.SWIN.NUM_HEADS = [3, 6, 12, 24]
+ cfg.MODEL.SWIN.WINDOW_SIZE = 7
+ cfg.MODEL.SWIN.MLP_RATIO = 4.0
+ cfg.MODEL.SWIN.QKV_BIAS = True
+ cfg.MODEL.SWIN.QK_SCALE = None
+ cfg.MODEL.SWIN.DROP_RATE = 0.0
+ cfg.MODEL.SWIN.ATTN_DROP_RATE = 0.0
+ cfg.MODEL.SWIN.DROP_PATH_RATE = 0.3
+ cfg.MODEL.SWIN.APE = False
+ cfg.MODEL.SWIN.PATCH_NORM = True
+ cfg.MODEL.SWIN.OUT_FEATURES = ["res2", "res3", "res4", "res5"]
diff --git a/ACT_DP_multitask/detr/models/mask_former/mask_former_model.py b/ACT_DP_multitask/detr/models/mask_former/mask_former_model.py
new file mode 100644
index 0000000000000000000000000000000000000000..02a46b97ef9066eb7cf433821c0366cffc9bb57e
--- /dev/null
+++ b/ACT_DP_multitask/detr/models/mask_former/mask_former_model.py
@@ -0,0 +1,304 @@
+# Copyright (c) Facebook, Inc. and its affiliates.
+from typing import Tuple
+
+import torch
+from torch import nn
+from torch.nn import functional as F
+
+from detectron2.config import configurable
+from detectron2.data import MetadataCatalog
+from detectron2.modeling import META_ARCH_REGISTRY, build_backbone, build_sem_seg_head
+from detectron2.modeling.backbone import Backbone
+from detectron2.modeling.postprocessing import sem_seg_postprocess
+from detectron2.structures import ImageList
+
+from .modeling.criterion import SetCriterion
+from .modeling.matcher import HungarianMatcher
+
+
+@META_ARCH_REGISTRY.register()
+class MaskFormer(nn.Module):
+ """
+ Main class for mask classification semantic segmentation architectures.
+ """
+
+ @configurable
+ def __init__(
+ self,
+ *,
+ backbone: Backbone,
+ sem_seg_head: nn.Module,
+ criterion: nn.Module,
+ num_queries: int,
+ panoptic_on: bool,
+ object_mask_threshold: float,
+ overlap_threshold: float,
+ metadata,
+ size_divisibility: int,
+ sem_seg_postprocess_before_inference: bool,
+ pixel_mean: Tuple[float],
+ pixel_std: Tuple[float],
+ ):
+ """
+ Args:
+ backbone: a backbone module, must follow detectron2's backbone interface
+ sem_seg_head: a module that predicts semantic segmentation from backbone features
+ criterion: a module that defines the loss
+ num_queries: int, number of queries
+ panoptic_on: bool, whether to output panoptic segmentation prediction
+ object_mask_threshold: float, threshold to filter query based on classification score
+ for panoptic segmentation inference
+ overlap_threshold: overlap threshold used in general inference for panoptic segmentation
+ metadata: dataset meta, get `thing` and `stuff` category names for panoptic
+ segmentation inference
+ size_divisibility: Some backbones require the input height and width to be divisible by a
+ specific integer. We can use this to override such requirement.
+ sem_seg_postprocess_before_inference: whether to resize the prediction back
+ to original input size before semantic segmentation inference or after.
+ For high-resolution dataset like Mapillary, resizing predictions before
+ inference will cause OOM error.
+ pixel_mean, pixel_std: list or tuple with #channels element, representing
+ the per-channel mean and std to be used to normalize the input image
+ """
+ super().__init__()
+ self.backbone = backbone
+ self.sem_seg_head = sem_seg_head
+ self.criterion = criterion
+ self.num_queries = num_queries
+ self.overlap_threshold = overlap_threshold
+ self.panoptic_on = panoptic_on
+ self.object_mask_threshold = object_mask_threshold
+ self.metadata = metadata
+ if size_divisibility < 0:
+ # use backbone size_divisibility if not set
+ size_divisibility = self.backbone.size_divisibility
+ self.size_divisibility = size_divisibility
+ self.sem_seg_postprocess_before_inference = sem_seg_postprocess_before_inference
+ self.register_buffer("pixel_mean", torch.Tensor(pixel_mean).view(-1, 1, 1), False)
+ self.register_buffer("pixel_std", torch.Tensor(pixel_std).view(-1, 1, 1), False)
+
+ @classmethod
+ def from_config(cls, cfg):
+ backbone = build_backbone(cfg)
+ sem_seg_head = build_sem_seg_head(cfg, backbone.output_shape())
+
+ # Loss parameters:
+ deep_supervision = cfg.MODEL.MASK_FORMER.DEEP_SUPERVISION
+ no_object_weight = cfg.MODEL.MASK_FORMER.NO_OBJECT_WEIGHT
+ dice_weight = cfg.MODEL.MASK_FORMER.DICE_WEIGHT
+ mask_weight = cfg.MODEL.MASK_FORMER.MASK_WEIGHT
+
+ # building criterion
+ matcher = HungarianMatcher(
+ cost_class=1,
+ cost_mask=mask_weight,
+ cost_dice=dice_weight,
+ )
+
+ weight_dict = {"loss_ce": 1, "loss_mask": mask_weight, "loss_dice": dice_weight}
+ if deep_supervision:
+ dec_layers = cfg.MODEL.MASK_FORMER.DEC_LAYERS
+ aux_weight_dict = {}
+ for i in range(dec_layers - 1):
+ aux_weight_dict.update({k + f"_{i}": v for k, v in weight_dict.items()})
+ weight_dict.update(aux_weight_dict)
+
+ losses = ["labels", "masks"]
+
+ criterion = SetCriterion(
+ sem_seg_head.num_classes,
+ matcher=matcher,
+ weight_dict=weight_dict,
+ eos_coef=no_object_weight,
+ losses=losses,
+ )
+
+ return {
+ "backbone": backbone,
+ "sem_seg_head": sem_seg_head,
+ "criterion": criterion,
+ "num_queries": cfg.MODEL.MASK_FORMER.NUM_OBJECT_QUERIES,
+ "panoptic_on": cfg.MODEL.MASK_FORMER.TEST.PANOPTIC_ON,
+ "object_mask_threshold": cfg.MODEL.MASK_FORMER.TEST.OBJECT_MASK_THRESHOLD,
+ "overlap_threshold": cfg.MODEL.MASK_FORMER.TEST.OVERLAP_THRESHOLD,
+ "metadata": MetadataCatalog.get(cfg.DATASETS.TRAIN[0]),
+ "size_divisibility": cfg.MODEL.MASK_FORMER.SIZE_DIVISIBILITY,
+ "sem_seg_postprocess_before_inference": (
+ cfg.MODEL.MASK_FORMER.TEST.SEM_SEG_POSTPROCESSING_BEFORE_INFERENCE
+ or cfg.MODEL.MASK_FORMER.TEST.PANOPTIC_ON
+ ),
+ "pixel_mean": cfg.MODEL.PIXEL_MEAN,
+ "pixel_std": cfg.MODEL.PIXEL_STD,
+ }
+
+ @property
+ def device(self):
+ return self.pixel_mean.device
+
+ def forward(self, batched_inputs):
+ """
+ Args:
+ batched_inputs: a list, batched outputs of :class:`DatasetMapper`.
+ Each item in the list contains the inputs for one image.
+ For now, each item in the list is a dict that contains:
+ * "image": Tensor, image in (C, H, W) format.
+ * "instances": per-region ground truth
+ * Other information that's included in the original dicts, such as:
+ "height", "width" (int): the output resolution of the model (may be different
+ from input resolution), used in inference.
+ Returns:
+ list[dict]:
+ each dict has the results for one image. The dict contains the following keys:
+
+ * "sem_seg":
+ A Tensor that represents the
+ per-pixel segmentation prediced by the head.
+ The prediction has shape KxHxW that represents the logits of
+ each class for each pixel.
+ * "panoptic_seg":
+ A tuple that represent panoptic output
+ panoptic_seg (Tensor): of shape (height, width) where the values are ids for each segment.
+ segments_info (list[dict]): Describe each segment in `panoptic_seg`.
+ Each dict contains keys "id", "category_id", "isthing".
+ """
+ images = [x["image"].to(self.device) for x in batched_inputs]
+ images = [(x - self.pixel_mean) / self.pixel_std for x in images]
+ images = ImageList.from_tensors(images, self.size_divisibility)
+
+ features = self.backbone(images.tensor)
+ outputs = self.sem_seg_head(features)
+
+ if self.training:
+ # mask classification target
+ if "instances" in batched_inputs[0]:
+ gt_instances = [x["instances"].to(self.device) for x in batched_inputs]
+ targets = self.prepare_targets(gt_instances, images)
+ else:
+ targets = None
+
+ # bipartite matching-based loss
+ losses = self.criterion(outputs, targets)
+
+ for k in list(losses.keys()):
+ if k in self.criterion.weight_dict:
+ losses[k] *= self.criterion.weight_dict[k]
+ else:
+ # remove this loss if not specified in `weight_dict`
+ losses.pop(k)
+
+ return losses
+ else:
+ mask_cls_results = outputs["pred_logits"]
+ mask_pred_results = outputs["pred_masks"]
+ # upsample masks
+ mask_pred_results = F.interpolate(
+ mask_pred_results,
+ size=(images.tensor.shape[-2], images.tensor.shape[-1]),
+ mode="bilinear",
+ align_corners=False,
+ )
+
+ processed_results = []
+ for mask_cls_result, mask_pred_result, input_per_image, image_size in zip(
+ mask_cls_results, mask_pred_results, batched_inputs, images.image_sizes
+ ):
+ height = input_per_image.get("height", image_size[0])
+ width = input_per_image.get("width", image_size[1])
+
+ if self.sem_seg_postprocess_before_inference:
+ mask_pred_result = sem_seg_postprocess(
+ mask_pred_result, image_size, height, width
+ )
+
+ # semantic segmentation inference
+ r = self.semantic_inference(mask_cls_result, mask_pred_result)
+ if not self.sem_seg_postprocess_before_inference:
+ r = sem_seg_postprocess(r, image_size, height, width)
+ processed_results.append({"sem_seg": r})
+
+ # panoptic segmentation inference
+ if self.panoptic_on:
+ panoptic_r = self.panoptic_inference(mask_cls_result, mask_pred_result)
+ processed_results[-1]["panoptic_seg"] = panoptic_r
+
+ return processed_results
+
+ def prepare_targets(self, targets, images):
+ h, w = images.tensor.shape[-2:]
+ new_targets = []
+ for targets_per_image in targets:
+ # pad gt
+ gt_masks = targets_per_image.gt_masks
+ padded_masks = torch.zeros((gt_masks.shape[0], h, w), dtype=gt_masks.dtype, device=gt_masks.device)
+ padded_masks[:, : gt_masks.shape[1], : gt_masks.shape[2]] = gt_masks
+ new_targets.append(
+ {
+ "labels": targets_per_image.gt_classes,
+ "masks": padded_masks,
+ }
+ )
+ return new_targets
+
+ def semantic_inference(self, mask_cls, mask_pred):
+ mask_cls = F.softmax(mask_cls, dim=-1)[..., :-1]
+ mask_pred = mask_pred.sigmoid()
+ semseg = torch.einsum("qc,qhw->chw", mask_cls, mask_pred)
+ return semseg
+
+ def panoptic_inference(self, mask_cls, mask_pred):
+ scores, labels = F.softmax(mask_cls, dim=-1).max(-1)
+ mask_pred = mask_pred.sigmoid()
+
+ keep = labels.ne(self.sem_seg_head.num_classes) & (scores > self.object_mask_threshold)
+ cur_scores = scores[keep]
+ cur_classes = labels[keep]
+ cur_masks = mask_pred[keep]
+ cur_mask_cls = mask_cls[keep]
+ cur_mask_cls = cur_mask_cls[:, :-1]
+
+ cur_prob_masks = cur_scores.view(-1, 1, 1) * cur_masks
+
+ h, w = cur_masks.shape[-2:]
+ panoptic_seg = torch.zeros((h, w), dtype=torch.int32, device=cur_masks.device)
+ segments_info = []
+
+ current_segment_id = 0
+
+ if cur_masks.shape[0] == 0:
+ # We didn't detect any mask :(
+ return panoptic_seg, segments_info
+ else:
+ # take argmax
+ cur_mask_ids = cur_prob_masks.argmax(0)
+ stuff_memory_list = {}
+ for k in range(cur_classes.shape[0]):
+ pred_class = cur_classes[k].item()
+ isthing = pred_class in self.metadata.thing_dataset_id_to_contiguous_id.values()
+ mask = cur_mask_ids == k
+ mask_area = mask.sum().item()
+ original_area = (cur_masks[k] >= 0.5).sum().item()
+
+ if mask_area > 0 and original_area > 0:
+ if mask_area / original_area < self.overlap_threshold:
+ continue
+
+ # merge stuff regions
+ if not isthing:
+ if int(pred_class) in stuff_memory_list.keys():
+ panoptic_seg[mask] = stuff_memory_list[int(pred_class)]
+ continue
+ else:
+ stuff_memory_list[int(pred_class)] = current_segment_id + 1
+
+ current_segment_id += 1
+ panoptic_seg[mask] = current_segment_id
+
+ segments_info.append(
+ {
+ "id": current_segment_id,
+ "isthing": bool(isthing),
+ "category_id": int(pred_class),
+ }
+ )
+
+ return panoptic_seg, segments_info
diff --git a/ACT_DP_multitask/detr/models/mask_former/modeling/__init__.py b/ACT_DP_multitask/detr/models/mask_former/modeling/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..4e7eb4843fdcb71329842de6750fcb721a4a4e1f
--- /dev/null
+++ b/ACT_DP_multitask/detr/models/mask_former/modeling/__init__.py
@@ -0,0 +1,5 @@
+# Copyright (c) Facebook, Inc. and its affiliates.
+from .backbone.swin import D2SwinTransformer
+from .heads.mask_former_head import MaskFormerHead
+from .heads.per_pixel_baseline import PerPixelBaselineHead, PerPixelBaselinePlusHead
+from .heads.pixel_decoder import BasePixelDecoder
diff --git a/ACT_DP_multitask/detr/models/mask_former/modeling/backbone/__init__.py b/ACT_DP_multitask/detr/models/mask_former/modeling/backbone/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..9020c2df23e2af280b7bb168b996ae9eaf312eb8
--- /dev/null
+++ b/ACT_DP_multitask/detr/models/mask_former/modeling/backbone/__init__.py
@@ -0,0 +1 @@
+# Copyright (c) Facebook, Inc. and its affiliates.
diff --git a/ACT_DP_multitask/detr/models/mask_former/modeling/backbone/swin.py b/ACT_DP_multitask/detr/models/mask_former/modeling/backbone/swin.py
new file mode 100644
index 0000000000000000000000000000000000000000..86a7762f025e7e702fc13d03d6a219a2e6c9acb8
--- /dev/null
+++ b/ACT_DP_multitask/detr/models/mask_former/modeling/backbone/swin.py
@@ -0,0 +1,768 @@
+# --------------------------------------------------------
+# Swin Transformer
+# Copyright (c) 2021 Microsoft
+# Licensed under The MIT License [see LICENSE for details]
+# Written by Ze Liu, Yutong Lin, Yixuan Wei
+# --------------------------------------------------------
+
+# Copyright (c) Facebook, Inc. and its affiliates.
+# Modified by Bowen Cheng from https://github.com/SwinTransformer/Swin-Transformer-Semantic-Segmentation/blob/main/mmseg/models/backbones/swin_transformer.py
+
+import numpy as np
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import torch.utils.checkpoint as checkpoint
+from timm.models.layers import DropPath, to_2tuple, trunc_normal_
+
+from detectron2.modeling import BACKBONE_REGISTRY, Backbone, ShapeSpec
+
+
+class Mlp(nn.Module):
+ """Multilayer perceptron."""
+
+ def __init__(
+ self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.0
+ ):
+ super().__init__()
+ out_features = out_features or in_features
+ hidden_features = hidden_features or in_features
+ self.fc1 = nn.Linear(in_features, hidden_features)
+ self.act = act_layer()
+ self.fc2 = nn.Linear(hidden_features, out_features)
+ self.drop = nn.Dropout(drop)
+
+ def forward(self, x):
+ x = self.fc1(x)
+ x = self.act(x)
+ x = self.drop(x)
+ x = self.fc2(x)
+ x = self.drop(x)
+ return x
+
+
+def window_partition(x, window_size):
+ """
+ Args:
+ x: (B, H, W, C)
+ window_size (int): window size
+ Returns:
+ windows: (num_windows*B, window_size, window_size, C)
+ """
+ B, H, W, C = x.shape
+ x = x.view(B, H // window_size, window_size, W // window_size, window_size, C)
+ windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
+ return windows
+
+
+def window_reverse(windows, window_size, H, W):
+ """
+ Args:
+ windows: (num_windows*B, window_size, window_size, C)
+ window_size (int): Window size
+ H (int): Height of image
+ W (int): Width of image
+ Returns:
+ x: (B, H, W, C)
+ """
+ B = int(windows.shape[0] / (H * W / window_size / window_size))
+ x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1)
+ x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
+ return x
+
+
+class WindowAttention(nn.Module):
+ """Window based multi-head self attention (W-MSA) module with relative position bias.
+ It supports both of shifted and non-shifted window.
+ Args:
+ dim (int): Number of input channels.
+ window_size (tuple[int]): The height and width of the window.
+ num_heads (int): Number of attention heads.
+ qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
+ qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set
+ attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0
+ proj_drop (float, optional): Dropout ratio of output. Default: 0.0
+ """
+
+ def __init__(
+ self,
+ dim,
+ window_size,
+ num_heads,
+ qkv_bias=True,
+ qk_scale=None,
+ attn_drop=0.0,
+ proj_drop=0.0,
+ ):
+
+ super().__init__()
+ self.dim = dim
+ self.window_size = window_size # Wh, Ww
+ self.num_heads = num_heads
+ head_dim = dim // num_heads
+ self.scale = qk_scale or head_dim ** -0.5
+
+ # define a parameter table of relative position bias
+ self.relative_position_bias_table = nn.Parameter(
+ torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)
+ ) # 2*Wh-1 * 2*Ww-1, nH
+
+ # get pair-wise relative position index for each token inside the window
+ coords_h = torch.arange(self.window_size[0])
+ coords_w = torch.arange(self.window_size[1])
+ coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
+ coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
+ relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww
+ relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
+ relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0
+ relative_coords[:, :, 1] += self.window_size[1] - 1
+ relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
+ relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
+ self.register_buffer("relative_position_index", relative_position_index)
+
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
+ self.attn_drop = nn.Dropout(attn_drop)
+ self.proj = nn.Linear(dim, dim)
+ self.proj_drop = nn.Dropout(proj_drop)
+
+ trunc_normal_(self.relative_position_bias_table, std=0.02)
+ self.softmax = nn.Softmax(dim=-1)
+
+ def forward(self, x, mask=None):
+ """Forward function.
+ Args:
+ x: input features with shape of (num_windows*B, N, C)
+ mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None
+ """
+ B_, N, C = x.shape
+ qkv = (
+ self.qkv(x)
+ .reshape(B_, N, 3, self.num_heads, C // self.num_heads)
+ .permute(2, 0, 3, 1, 4)
+ )
+ q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
+
+ q = q * self.scale
+ attn = q @ k.transpose(-2, -1)
+
+ relative_position_bias = self.relative_position_bias_table[
+ self.relative_position_index.view(-1)
+ ].view(
+ self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1
+ ) # Wh*Ww,Wh*Ww,nH
+ relative_position_bias = relative_position_bias.permute(
+ 2, 0, 1
+ ).contiguous() # nH, Wh*Ww, Wh*Ww
+ attn = attn + relative_position_bias.unsqueeze(0)
+
+ if mask is not None:
+ nW = mask.shape[0]
+ attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0)
+ attn = attn.view(-1, self.num_heads, N, N)
+ attn = self.softmax(attn)
+ else:
+ attn = self.softmax(attn)
+
+ attn = self.attn_drop(attn)
+
+ x = (attn @ v).transpose(1, 2).reshape(B_, N, C)
+ x = self.proj(x)
+ x = self.proj_drop(x)
+ return x
+
+
+class SwinTransformerBlock(nn.Module):
+ """Swin Transformer Block.
+ Args:
+ dim (int): Number of input channels.
+ num_heads (int): Number of attention heads.
+ window_size (int): Window size.
+ shift_size (int): Shift size for SW-MSA.
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
+ qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
+ qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
+ drop (float, optional): Dropout rate. Default: 0.0
+ attn_drop (float, optional): Attention dropout rate. Default: 0.0
+ drop_path (float, optional): Stochastic depth rate. Default: 0.0
+ act_layer (nn.Module, optional): Activation layer. Default: nn.GELU
+ norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
+ """
+
+ def __init__(
+ self,
+ dim,
+ num_heads,
+ window_size=7,
+ shift_size=0,
+ mlp_ratio=4.0,
+ qkv_bias=True,
+ qk_scale=None,
+ drop=0.0,
+ attn_drop=0.0,
+ drop_path=0.0,
+ act_layer=nn.GELU,
+ norm_layer=nn.LayerNorm,
+ ):
+ super().__init__()
+ self.dim = dim
+ self.num_heads = num_heads
+ self.window_size = window_size
+ self.shift_size = shift_size
+ self.mlp_ratio = mlp_ratio
+ assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size"
+
+ self.norm1 = norm_layer(dim)
+ self.attn = WindowAttention(
+ dim,
+ window_size=to_2tuple(self.window_size),
+ num_heads=num_heads,
+ qkv_bias=qkv_bias,
+ qk_scale=qk_scale,
+ attn_drop=attn_drop,
+ proj_drop=drop,
+ )
+
+ self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
+ self.norm2 = norm_layer(dim)
+ mlp_hidden_dim = int(dim * mlp_ratio)
+ self.mlp = Mlp(
+ in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop
+ )
+
+ self.H = None
+ self.W = None
+
+ def forward(self, x, mask_matrix):
+ """Forward function.
+ Args:
+ x: Input feature, tensor size (B, H*W, C).
+ H, W: Spatial resolution of the input feature.
+ mask_matrix: Attention mask for cyclic shift.
+ """
+ B, L, C = x.shape
+ H, W = self.H, self.W
+ assert L == H * W, "input feature has wrong size"
+
+ shortcut = x
+ x = self.norm1(x)
+ x = x.view(B, H, W, C)
+
+ # pad feature maps to multiples of window size
+ pad_l = pad_t = 0
+ pad_r = (self.window_size - W % self.window_size) % self.window_size
+ pad_b = (self.window_size - H % self.window_size) % self.window_size
+ x = F.pad(x, (0, 0, pad_l, pad_r, pad_t, pad_b))
+ _, Hp, Wp, _ = x.shape
+
+ # cyclic shift
+ if self.shift_size > 0:
+ shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))
+ attn_mask = mask_matrix
+ else:
+ shifted_x = x
+ attn_mask = None
+
+ # partition windows
+ x_windows = window_partition(
+ shifted_x, self.window_size
+ ) # nW*B, window_size, window_size, C
+ x_windows = x_windows.view(
+ -1, self.window_size * self.window_size, C
+ ) # nW*B, window_size*window_size, C
+
+ # W-MSA/SW-MSA
+ attn_windows = self.attn(x_windows, mask=attn_mask) # nW*B, window_size*window_size, C
+
+ # merge windows
+ attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C)
+ shifted_x = window_reverse(attn_windows, self.window_size, Hp, Wp) # B H' W' C
+
+ # reverse cyclic shift
+ if self.shift_size > 0:
+ x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2))
+ else:
+ x = shifted_x
+
+ if pad_r > 0 or pad_b > 0:
+ x = x[:, :H, :W, :].contiguous()
+
+ x = x.view(B, H * W, C)
+
+ # FFN
+ x = shortcut + self.drop_path(x)
+ x = x + self.drop_path(self.mlp(self.norm2(x)))
+
+ return x
+
+
+class PatchMerging(nn.Module):
+ """Patch Merging Layer
+ Args:
+ dim (int): Number of input channels.
+ norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
+ """
+
+ def __init__(self, dim, norm_layer=nn.LayerNorm):
+ super().__init__()
+ self.dim = dim
+ self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)
+ self.norm = norm_layer(4 * dim)
+
+ def forward(self, x, H, W):
+ """Forward function.
+ Args:
+ x: Input feature, tensor size (B, H*W, C).
+ H, W: Spatial resolution of the input feature.
+ """
+ B, L, C = x.shape
+ assert L == H * W, "input feature has wrong size"
+
+ x = x.view(B, H, W, C)
+
+ # padding
+ pad_input = (H % 2 == 1) or (W % 2 == 1)
+ if pad_input:
+ x = F.pad(x, (0, 0, 0, W % 2, 0, H % 2))
+
+ x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C
+ x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C
+ x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C
+ x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C
+ x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C
+ x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C
+
+ x = self.norm(x)
+ x = self.reduction(x)
+
+ return x
+
+
+class BasicLayer(nn.Module):
+ """A basic Swin Transformer layer for one stage.
+ Args:
+ dim (int): Number of feature channels
+ depth (int): Depths of this stage.
+ num_heads (int): Number of attention head.
+ window_size (int): Local window size. Default: 7.
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4.
+ qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
+ qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
+ drop (float, optional): Dropout rate. Default: 0.0
+ attn_drop (float, optional): Attention dropout rate. Default: 0.0
+ drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
+ norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
+ downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None
+ use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
+ """
+
+ def __init__(
+ self,
+ dim,
+ depth,
+ num_heads,
+ window_size=7,
+ mlp_ratio=4.0,
+ qkv_bias=True,
+ qk_scale=None,
+ drop=0.0,
+ attn_drop=0.0,
+ drop_path=0.0,
+ norm_layer=nn.LayerNorm,
+ downsample=None,
+ use_checkpoint=False,
+ ):
+ super().__init__()
+ self.window_size = window_size
+ self.shift_size = window_size // 2
+ self.depth = depth
+ self.use_checkpoint = use_checkpoint
+
+ # build blocks
+ self.blocks = nn.ModuleList(
+ [
+ SwinTransformerBlock(
+ dim=dim,
+ num_heads=num_heads,
+ window_size=window_size,
+ shift_size=0 if (i % 2 == 0) else window_size // 2,
+ mlp_ratio=mlp_ratio,
+ qkv_bias=qkv_bias,
+ qk_scale=qk_scale,
+ drop=drop,
+ attn_drop=attn_drop,
+ drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,
+ norm_layer=norm_layer,
+ )
+ for i in range(depth)
+ ]
+ )
+
+ # patch merging layer
+ if downsample is not None:
+ self.downsample = downsample(dim=dim, norm_layer=norm_layer)
+ else:
+ self.downsample = None
+
+ def forward(self, x, H, W):
+ """Forward function.
+ Args:
+ x: Input feature, tensor size (B, H*W, C).
+ H, W: Spatial resolution of the input feature.
+ """
+
+ # calculate attention mask for SW-MSA
+ Hp = int(np.ceil(H / self.window_size)) * self.window_size
+ Wp = int(np.ceil(W / self.window_size)) * self.window_size
+ img_mask = torch.zeros((1, Hp, Wp, 1), device=x.device) # 1 Hp Wp 1
+ h_slices = (
+ slice(0, -self.window_size),
+ slice(-self.window_size, -self.shift_size),
+ slice(-self.shift_size, None),
+ )
+ w_slices = (
+ slice(0, -self.window_size),
+ slice(-self.window_size, -self.shift_size),
+ slice(-self.shift_size, None),
+ )
+ cnt = 0
+ for h in h_slices:
+ for w in w_slices:
+ img_mask[:, h, w, :] = cnt
+ cnt += 1
+
+ mask_windows = window_partition(
+ img_mask, self.window_size
+ ) # nW, window_size, window_size, 1
+ mask_windows = mask_windows.view(-1, self.window_size * self.window_size)
+ attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
+ attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(
+ attn_mask == 0, float(0.0)
+ )
+
+ for blk in self.blocks:
+ blk.H, blk.W = H, W
+ if self.use_checkpoint:
+ x = checkpoint.checkpoint(blk, x, attn_mask)
+ else:
+ x = blk(x, attn_mask)
+ if self.downsample is not None:
+ x_down = self.downsample(x, H, W)
+ Wh, Ww = (H + 1) // 2, (W + 1) // 2
+ return x, H, W, x_down, Wh, Ww
+ else:
+ return x, H, W, x, H, W
+
+
+class PatchEmbed(nn.Module):
+ """Image to Patch Embedding
+ Args:
+ patch_size (int): Patch token size. Default: 4.
+ in_chans (int): Number of input image channels. Default: 3.
+ embed_dim (int): Number of linear projection output channels. Default: 96.
+ norm_layer (nn.Module, optional): Normalization layer. Default: None
+ """
+
+ def __init__(self, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None):
+ super().__init__()
+ patch_size = to_2tuple(patch_size)
+ self.patch_size = patch_size
+
+ self.in_chans = in_chans
+ self.embed_dim = embed_dim
+
+ self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
+ if norm_layer is not None:
+ self.norm = norm_layer(embed_dim)
+ else:
+ self.norm = None
+
+ def forward(self, x):
+ """Forward function."""
+ # padding
+ _, _, H, W = x.size()
+ if W % self.patch_size[1] != 0:
+ x = F.pad(x, (0, self.patch_size[1] - W % self.patch_size[1]))
+ if H % self.patch_size[0] != 0:
+ x = F.pad(x, (0, 0, 0, self.patch_size[0] - H % self.patch_size[0]))
+
+ x = self.proj(x) # B C Wh Ww
+ if self.norm is not None:
+ Wh, Ww = x.size(2), x.size(3)
+ x = x.flatten(2).transpose(1, 2)
+ x = self.norm(x)
+ x = x.transpose(1, 2).view(-1, self.embed_dim, Wh, Ww)
+
+ return x
+
+
+class SwinTransformer(nn.Module):
+ """Swin Transformer backbone.
+ A PyTorch impl of : `Swin Transformer: Hierarchical Vision Transformer using Shifted Windows` -
+ https://arxiv.org/pdf/2103.14030
+ Args:
+ pretrain_img_size (int): Input image size for training the pretrained model,
+ used in absolute postion embedding. Default 224.
+ patch_size (int | tuple(int)): Patch size. Default: 4.
+ in_chans (int): Number of input image channels. Default: 3.
+ embed_dim (int): Number of linear projection output channels. Default: 96.
+ depths (tuple[int]): Depths of each Swin Transformer stage.
+ num_heads (tuple[int]): Number of attention head of each stage.
+ window_size (int): Window size. Default: 7.
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4.
+ qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True
+ qk_scale (float): Override default qk scale of head_dim ** -0.5 if set.
+ drop_rate (float): Dropout rate.
+ attn_drop_rate (float): Attention dropout rate. Default: 0.
+ drop_path_rate (float): Stochastic depth rate. Default: 0.2.
+ norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm.
+ ape (bool): If True, add absolute position embedding to the patch embedding. Default: False.
+ patch_norm (bool): If True, add normalization after patch embedding. Default: True.
+ out_indices (Sequence[int]): Output from which stages.
+ frozen_stages (int): Stages to be frozen (stop grad and set eval mode).
+ -1 means not freezing any parameters.
+ use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
+ """
+
+ def __init__(
+ self,
+ pretrain_img_size=224,
+ patch_size=4,
+ in_chans=3,
+ embed_dim=96,
+ depths=[2, 2, 6, 2],
+ num_heads=[3, 6, 12, 24],
+ window_size=7,
+ mlp_ratio=4.0,
+ qkv_bias=True,
+ qk_scale=None,
+ drop_rate=0.0,
+ attn_drop_rate=0.0,
+ drop_path_rate=0.2,
+ norm_layer=nn.LayerNorm,
+ ape=False,
+ patch_norm=True,
+ out_indices=(0, 1, 2, 3),
+ frozen_stages=-1,
+ use_checkpoint=False,
+ ):
+ super().__init__()
+
+ self.pretrain_img_size = pretrain_img_size
+ self.num_layers = len(depths)
+ self.embed_dim = embed_dim
+ self.ape = ape
+ self.patch_norm = patch_norm
+ self.out_indices = out_indices
+ self.frozen_stages = frozen_stages
+
+ # split image into non-overlapping patches
+ self.patch_embed = PatchEmbed(
+ patch_size=patch_size,
+ in_chans=in_chans,
+ embed_dim=embed_dim,
+ norm_layer=norm_layer if self.patch_norm else None,
+ )
+
+ # absolute position embedding
+ if self.ape:
+ pretrain_img_size = to_2tuple(pretrain_img_size)
+ patch_size = to_2tuple(patch_size)
+ patches_resolution = [
+ pretrain_img_size[0] // patch_size[0],
+ pretrain_img_size[1] // patch_size[1],
+ ]
+
+ self.absolute_pos_embed = nn.Parameter(
+ torch.zeros(1, embed_dim, patches_resolution[0], patches_resolution[1])
+ )
+ trunc_normal_(self.absolute_pos_embed, std=0.02)
+
+ self.pos_drop = nn.Dropout(p=drop_rate)
+
+ # stochastic depth
+ dpr = [
+ x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))
+ ] # stochastic depth decay rule
+
+ # build layers
+ self.layers = nn.ModuleList()
+ for i_layer in range(self.num_layers):
+ layer = BasicLayer(
+ dim=int(embed_dim * 2 ** i_layer),
+ depth=depths[i_layer],
+ num_heads=num_heads[i_layer],
+ window_size=window_size,
+ mlp_ratio=mlp_ratio,
+ qkv_bias=qkv_bias,
+ qk_scale=qk_scale,
+ drop=drop_rate,
+ attn_drop=attn_drop_rate,
+ drop_path=dpr[sum(depths[:i_layer]) : sum(depths[: i_layer + 1])],
+ norm_layer=norm_layer,
+ downsample=PatchMerging if (i_layer < self.num_layers - 1) else None,
+ use_checkpoint=use_checkpoint,
+ )
+ self.layers.append(layer)
+
+ num_features = [int(embed_dim * 2 ** i) for i in range(self.num_layers)]
+ self.num_features = num_features
+
+ # add a norm layer for each output
+ for i_layer in out_indices:
+ layer = norm_layer(num_features[i_layer])
+ layer_name = f"norm{i_layer}"
+ self.add_module(layer_name, layer)
+
+ self._freeze_stages()
+
+ def _freeze_stages(self):
+ if self.frozen_stages >= 0:
+ self.patch_embed.eval()
+ for param in self.patch_embed.parameters():
+ param.requires_grad = False
+
+ if self.frozen_stages >= 1 and self.ape:
+ self.absolute_pos_embed.requires_grad = False
+
+ if self.frozen_stages >= 2:
+ self.pos_drop.eval()
+ for i in range(0, self.frozen_stages - 1):
+ m = self.layers[i]
+ m.eval()
+ for param in m.parameters():
+ param.requires_grad = False
+
+ def init_weights(self, pretrained=None):
+ """Initialize the weights in backbone.
+ Args:
+ pretrained (str, optional): Path to pre-trained weights.
+ Defaults to None.
+ """
+
+ def _init_weights(m):
+ if isinstance(m, nn.Linear):
+ trunc_normal_(m.weight, std=0.02)
+ if isinstance(m, nn.Linear) and m.bias is not None:
+ nn.init.constant_(m.bias, 0)
+ elif isinstance(m, nn.LayerNorm):
+ nn.init.constant_(m.bias, 0)
+ nn.init.constant_(m.weight, 1.0)
+
+ def forward(self, x):
+ """Forward function."""
+ x = self.patch_embed(x)
+
+ Wh, Ww = x.size(2), x.size(3)
+ if self.ape:
+ # interpolate the position embedding to the corresponding size
+ absolute_pos_embed = F.interpolate(
+ self.absolute_pos_embed, size=(Wh, Ww), mode="bicubic"
+ )
+ x = (x + absolute_pos_embed).flatten(2).transpose(1, 2) # B Wh*Ww C
+ else:
+ x = x.flatten(2).transpose(1, 2)
+ x = self.pos_drop(x)
+
+ outs = {}
+ for i in range(self.num_layers):
+ layer = self.layers[i]
+ x_out, H, W, x, Wh, Ww = layer(x, Wh, Ww)
+
+ if i in self.out_indices:
+ norm_layer = getattr(self, f"norm{i}")
+ x_out = norm_layer(x_out)
+
+ out = x_out.view(-1, H, W, self.num_features[i]).permute(0, 3, 1, 2).contiguous()
+ outs["res{}".format(i + 2)] = out
+
+ return outs
+
+ def train(self, mode=True):
+ """Convert the model into training mode while keep layers freezed."""
+ super(SwinTransformer, self).train(mode)
+ self._freeze_stages()
+
+
+@BACKBONE_REGISTRY.register()
+class D2SwinTransformer(SwinTransformer, Backbone):
+ def __init__(self, cfg, input_shape):
+
+ pretrain_img_size = cfg.MODEL.SWIN.PRETRAIN_IMG_SIZE
+ patch_size = cfg.MODEL.SWIN.PATCH_SIZE
+ in_chans = 3
+ embed_dim = cfg.MODEL.SWIN.EMBED_DIM
+ depths = cfg.MODEL.SWIN.DEPTHS
+ num_heads = cfg.MODEL.SWIN.NUM_HEADS
+ window_size = cfg.MODEL.SWIN.WINDOW_SIZE
+ mlp_ratio = cfg.MODEL.SWIN.MLP_RATIO
+ qkv_bias = cfg.MODEL.SWIN.QKV_BIAS
+ qk_scale = cfg.MODEL.SWIN.QK_SCALE
+ drop_rate = cfg.MODEL.SWIN.DROP_RATE
+ attn_drop_rate = cfg.MODEL.SWIN.ATTN_DROP_RATE
+ drop_path_rate = cfg.MODEL.SWIN.DROP_PATH_RATE
+ norm_layer = nn.LayerNorm
+ ape = cfg.MODEL.SWIN.APE
+ patch_norm = cfg.MODEL.SWIN.PATCH_NORM
+
+ super().__init__(
+ pretrain_img_size,
+ patch_size,
+ in_chans,
+ embed_dim,
+ depths,
+ num_heads,
+ window_size,
+ mlp_ratio,
+ qkv_bias,
+ qk_scale,
+ drop_rate,
+ attn_drop_rate,
+ drop_path_rate,
+ norm_layer,
+ ape,
+ patch_norm,
+ )
+
+ self._out_features = cfg.MODEL.SWIN.OUT_FEATURES
+
+ self._out_feature_strides = {
+ "res2": 4,
+ "res3": 8,
+ "res4": 16,
+ "res5": 32,
+ }
+ self._out_feature_channels = {
+ "res2": self.num_features[0],
+ "res3": self.num_features[1],
+ "res4": self.num_features[2],
+ "res5": self.num_features[3],
+ }
+
+ def forward(self, x):
+ """
+ Args:
+ x: Tensor of shape (N,C,H,W). H, W must be a multiple of ``self.size_divisibility``.
+ Returns:
+ dict[str->Tensor]: names and the corresponding features
+ """
+ assert (
+ x.dim() == 4
+ ), f"SwinTransformer takes an input of shape (N, C, H, W). Got {x.shape} instead!"
+ outputs = {}
+ y = super().forward(x)
+ for k in y.keys():
+ if k in self._out_features:
+ outputs[k] = y[k]
+ return outputs
+
+ def output_shape(self):
+ return {
+ name: ShapeSpec(
+ channels=self._out_feature_channels[name], stride=self._out_feature_strides[name]
+ )
+ for name in self._out_features
+ }
+
+ @property
+ def size_divisibility(self):
+ return 32
diff --git a/ACT_DP_multitask/detr/models/mask_former/modeling/criterion.py b/ACT_DP_multitask/detr/models/mask_former/modeling/criterion.py
new file mode 100644
index 0000000000000000000000000000000000000000..7631ee3766bc0920f3a0e8ca648782e6c5c2e0bf
--- /dev/null
+++ b/ACT_DP_multitask/detr/models/mask_former/modeling/criterion.py
@@ -0,0 +1,187 @@
+# Copyright (c) Facebook, Inc. and its affiliates.
+# Modified by Bowen Cheng from https://github.com/facebookresearch/detr/blob/master/models/detr.py
+"""
+MaskFormer criterion.
+"""
+import torch
+import torch.nn.functional as F
+from torch import nn
+
+from detectron2.utils.comm import get_world_size
+
+from ..utils.misc import is_dist_avail_and_initialized, nested_tensor_from_tensor_list
+
+
+def dice_loss(inputs, targets, num_masks):
+ """
+ Compute the DICE loss, similar to generalized IOU for masks
+ Args:
+ inputs: A float tensor of arbitrary shape.
+ The predictions for each example.
+ targets: A float tensor with the same shape as inputs. Stores the binary
+ classification label for each element in inputs
+ (0 for the negative class and 1 for the positive class).
+ """
+ inputs = inputs.sigmoid()
+ inputs = inputs.flatten(1)
+ numerator = 2 * (inputs * targets).sum(-1)
+ denominator = inputs.sum(-1) + targets.sum(-1)
+ loss = 1 - (numerator + 1) / (denominator + 1)
+ return loss.sum() / num_masks
+
+
+def sigmoid_focal_loss(inputs, targets, num_masks, alpha: float = 0.25, gamma: float = 2):
+ """
+ Loss used in RetinaNet for dense detection: https://arxiv.org/abs/1708.02002.
+ Args:
+ inputs: A float tensor of arbitrary shape.
+ The predictions for each example.
+ targets: A float tensor with the same shape as inputs. Stores the binary
+ classification label for each element in inputs
+ (0 for the negative class and 1 for the positive class).
+ alpha: (optional) Weighting factor in range (0,1) to balance
+ positive vs negative examples. Default = -1 (no weighting).
+ gamma: Exponent of the modulating factor (1 - p_t) to
+ balance easy vs hard examples.
+ Returns:
+ Loss tensor
+ """
+ prob = inputs.sigmoid()
+ ce_loss = F.binary_cross_entropy_with_logits(inputs, targets, reduction="none")
+ p_t = prob * targets + (1 - prob) * (1 - targets)
+ loss = ce_loss * ((1 - p_t) ** gamma)
+
+ if alpha >= 0:
+ alpha_t = alpha * targets + (1 - alpha) * (1 - targets)
+ loss = alpha_t * loss
+
+ return loss.mean(1).sum() / num_masks
+
+
+class SetCriterion(nn.Module):
+ """This class computes the loss for DETR.
+ The process happens in two steps:
+ 1) we compute hungarian assignment between ground truth boxes and the outputs of the model
+ 2) we supervise each pair of matched ground-truth / prediction (supervise class and box)
+ """
+
+ def __init__(self, num_classes, matcher, weight_dict, eos_coef, losses):
+ """Create the criterion.
+ Parameters:
+ num_classes: number of object categories, omitting the special no-object category
+ matcher: module able to compute a matching between targets and proposals
+ weight_dict: dict containing as key the names of the losses and as values their relative weight.
+ eos_coef: relative classification weight applied to the no-object category
+ losses: list of all the losses to be applied. See get_loss for list of available losses.
+ """
+ super().__init__()
+ self.num_classes = num_classes
+ self.matcher = matcher
+ self.weight_dict = weight_dict
+ self.eos_coef = eos_coef
+ self.losses = losses
+ empty_weight = torch.ones(self.num_classes + 1)
+ empty_weight[-1] = self.eos_coef
+ self.register_buffer("empty_weight", empty_weight)
+
+ def loss_labels(self, outputs, targets, indices, num_masks):
+ """Classification loss (NLL)
+ targets dicts must contain the key "labels" containing a tensor of dim [nb_target_boxes]
+ """
+ assert "pred_logits" in outputs
+ src_logits = outputs["pred_logits"]
+
+ idx = self._get_src_permutation_idx(indices)
+ target_classes_o = torch.cat([t["labels"][J] for t, (_, J) in zip(targets, indices)])
+ target_classes = torch.full(
+ src_logits.shape[:2], self.num_classes, dtype=torch.int64, device=src_logits.device
+ )
+ target_classes[idx] = target_classes_o
+
+ loss_ce = F.cross_entropy(src_logits.transpose(1, 2), target_classes, self.empty_weight)
+ losses = {"loss_ce": loss_ce}
+ return losses
+
+ def loss_masks(self, outputs, targets, indices, num_masks):
+ """Compute the losses related to the masks: the focal loss and the dice loss.
+ targets dicts must contain the key "masks" containing a tensor of dim [nb_target_boxes, h, w]
+ """
+ assert "pred_masks" in outputs
+
+ src_idx = self._get_src_permutation_idx(indices)
+ tgt_idx = self._get_tgt_permutation_idx(indices)
+ src_masks = outputs["pred_masks"]
+ src_masks = src_masks[src_idx]
+ masks = [t["masks"] for t in targets]
+ # TODO use valid to mask invalid areas due to padding in loss
+ target_masks, valid = nested_tensor_from_tensor_list(masks).decompose()
+ target_masks = target_masks.to(src_masks)
+ target_masks = target_masks[tgt_idx]
+
+ # upsample predictions to the target size
+ src_masks = F.interpolate(
+ src_masks[:, None], size=target_masks.shape[-2:], mode="bilinear", align_corners=False
+ )
+ src_masks = src_masks[:, 0].flatten(1)
+
+ target_masks = target_masks.flatten(1)
+ target_masks = target_masks.view(src_masks.shape)
+ losses = {
+ "loss_mask": sigmoid_focal_loss(src_masks, target_masks, num_masks),
+ "loss_dice": dice_loss(src_masks, target_masks, num_masks),
+ }
+ return losses
+
+ def _get_src_permutation_idx(self, indices):
+ # permute predictions following indices
+ batch_idx = torch.cat([torch.full_like(src, i) for i, (src, _) in enumerate(indices)])
+ src_idx = torch.cat([src for (src, _) in indices])
+ return batch_idx, src_idx
+
+ def _get_tgt_permutation_idx(self, indices):
+ # permute targets following indices
+ batch_idx = torch.cat([torch.full_like(tgt, i) for i, (_, tgt) in enumerate(indices)])
+ tgt_idx = torch.cat([tgt for (_, tgt) in indices])
+ return batch_idx, tgt_idx
+
+ def get_loss(self, loss, outputs, targets, indices, num_masks):
+ loss_map = {"labels": self.loss_labels, "masks": self.loss_masks}
+ assert loss in loss_map, f"do you really want to compute {loss} loss?"
+ return loss_map[loss](outputs, targets, indices, num_masks)
+
+ def forward(self, outputs, targets):
+ """This performs the loss computation.
+ Parameters:
+ outputs: dict of tensors, see the output specification of the model for the format
+ targets: list of dicts, such that len(targets) == batch_size.
+ The expected keys in each dict depends on the losses applied, see each loss' doc
+ """
+ outputs_without_aux = {k: v for k, v in outputs.items() if k != "aux_outputs"}
+
+ # Retrieve the matching between the outputs of the last layer and the targets
+ indices = self.matcher(outputs_without_aux, targets)
+
+ # Compute the average number of target boxes accross all nodes, for normalization purposes
+ num_masks = sum(len(t["labels"]) for t in targets)
+ num_masks = torch.as_tensor(
+ [num_masks], dtype=torch.float, device=next(iter(outputs.values())).device
+ )
+ if is_dist_avail_and_initialized():
+ torch.distributed.all_reduce(num_masks)
+ num_masks = torch.clamp(num_masks / get_world_size(), min=1).item()
+
+ # Compute all the requested losses
+ losses = {}
+ for loss in self.losses:
+ losses.update(self.get_loss(loss, outputs, targets, indices, num_masks))
+
+ # In case of auxiliary losses, we repeat this process with the output of each intermediate layer.
+ if "aux_outputs" in outputs:
+ for i, aux_outputs in enumerate(outputs["aux_outputs"]):
+ indices = self.matcher(aux_outputs, targets)
+ for loss in self.losses:
+ l_dict = self.get_loss(loss, aux_outputs, targets, indices, num_masks)
+ l_dict = {k + f"_{i}": v for k, v in l_dict.items()}
+ losses.update(l_dict)
+
+ return losses
diff --git a/ACT_DP_multitask/detr/models/mask_former/modeling/heads/__init__.py b/ACT_DP_multitask/detr/models/mask_former/modeling/heads/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..9020c2df23e2af280b7bb168b996ae9eaf312eb8
--- /dev/null
+++ b/ACT_DP_multitask/detr/models/mask_former/modeling/heads/__init__.py
@@ -0,0 +1 @@
+# Copyright (c) Facebook, Inc. and its affiliates.
diff --git a/ACT_DP_multitask/detr/models/mask_former/modeling/heads/mask_former_head.py b/ACT_DP_multitask/detr/models/mask_former/modeling/heads/mask_former_head.py
new file mode 100644
index 0000000000000000000000000000000000000000..16f9d990953b33cecf6fd1e9a25131e2c9127ffe
--- /dev/null
+++ b/ACT_DP_multitask/detr/models/mask_former/modeling/heads/mask_former_head.py
@@ -0,0 +1,119 @@
+# Copyright (c) Facebook, Inc. and its affiliates.
+import logging
+from copy import deepcopy
+from typing import Callable, Dict, List, Optional, Tuple, Union
+
+import fvcore.nn.weight_init as weight_init
+from torch import nn
+from torch.nn import functional as F
+
+from detectron2.config import configurable
+from detectron2.layers import Conv2d, ShapeSpec, get_norm
+from detectron2.modeling import SEM_SEG_HEADS_REGISTRY
+
+from ..transformer.transformer_predictor import TransformerPredictor
+from .pixel_decoder import build_pixel_decoder
+
+
+@SEM_SEG_HEADS_REGISTRY.register()
+class MaskFormerHead(nn.Module):
+
+ _version = 2
+
+ def _load_from_state_dict(
+ self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs
+ ):
+ version = local_metadata.get("version", None)
+ if version is None or version < 2:
+ # Do not warn if train from scratch
+ scratch = True
+ logger = logging.getLogger(__name__)
+ for k in list(state_dict.keys()):
+ newk = k
+ if "sem_seg_head" in k and not k.startswith(prefix + "predictor"):
+ newk = k.replace(prefix, prefix + "pixel_decoder.")
+ # logger.debug(f"{k} ==> {newk}")
+ if newk != k:
+ state_dict[newk] = state_dict[k]
+ del state_dict[k]
+ scratch = False
+
+ if not scratch:
+ logger.warning(
+ f"Weight format of {self.__class__.__name__} have changed! "
+ "Please upgrade your models. Applying automatic conversion now ..."
+ )
+
+ @configurable
+ def __init__(
+ self,
+ input_shape: Dict[str, ShapeSpec],
+ *,
+ num_classes: int,
+ pixel_decoder: nn.Module,
+ loss_weight: float = 1.0,
+ ignore_value: int = -1,
+ # extra parameters
+ transformer_predictor: nn.Module,
+ transformer_in_feature: str,
+ ):
+ """
+ NOTE: this interface is experimental.
+ Args:
+ input_shape: shapes (channels and stride) of the input features
+ num_classes: number of classes to predict
+ pixel_decoder: the pixel decoder module
+ loss_weight: loss weight
+ ignore_value: category id to be ignored during training.
+ transformer_predictor: the transformer decoder that makes prediction
+ transformer_in_feature: input feature name to the transformer_predictor
+ """
+ super().__init__()
+ input_shape = sorted(input_shape.items(), key=lambda x: x[1].stride)
+ self.in_features = [k for k, v in input_shape]
+ feature_strides = [v.stride for k, v in input_shape]
+ feature_channels = [v.channels for k, v in input_shape]
+
+ self.ignore_value = ignore_value
+ self.common_stride = 4
+ self.loss_weight = loss_weight
+
+ self.pixel_decoder = pixel_decoder
+ self.predictor = transformer_predictor
+ self.transformer_in_feature = transformer_in_feature
+
+ self.num_classes = num_classes
+
+ @classmethod
+ def from_config(cls, cfg, input_shape: Dict[str, ShapeSpec]):
+ return {
+ "input_shape": {
+ k: v for k, v in input_shape.items() if k in cfg.MODEL.SEM_SEG_HEAD.IN_FEATURES
+ },
+ "ignore_value": cfg.MODEL.SEM_SEG_HEAD.IGNORE_VALUE,
+ "num_classes": cfg.MODEL.SEM_SEG_HEAD.NUM_CLASSES,
+ "pixel_decoder": build_pixel_decoder(cfg, input_shape),
+ "loss_weight": cfg.MODEL.SEM_SEG_HEAD.LOSS_WEIGHT,
+ "transformer_in_feature": cfg.MODEL.MASK_FORMER.TRANSFORMER_IN_FEATURE,
+ "transformer_predictor": TransformerPredictor(
+ cfg,
+ cfg.MODEL.SEM_SEG_HEAD.CONVS_DIM
+ if cfg.MODEL.MASK_FORMER.TRANSFORMER_IN_FEATURE == "transformer_encoder"
+ else input_shape[cfg.MODEL.MASK_FORMER.TRANSFORMER_IN_FEATURE].channels,
+ mask_classification=True,
+ ),
+ }
+
+ def forward(self, features):
+ return self.layers(features)
+
+ def layers(self, features):
+ mask_features, transformer_encoder_features = self.pixel_decoder.forward_features(features)
+ if self.transformer_in_feature == "transformer_encoder":
+ assert (
+ transformer_encoder_features is not None
+ ), "Please use the TransformerEncoderPixelDecoder."
+ predictions = self.predictor(transformer_encoder_features, mask_features)
+ else:
+ predictions = self.predictor(features[self.transformer_in_feature], mask_features)
+ return predictions
diff --git a/ACT_DP_multitask/detr/models/mask_former/modeling/heads/per_pixel_baseline.py b/ACT_DP_multitask/detr/models/mask_former/modeling/heads/per_pixel_baseline.py
new file mode 100644
index 0000000000000000000000000000000000000000..a99f508e7b4a87ada0af6f10209f10edefa7e412
--- /dev/null
+++ b/ACT_DP_multitask/detr/models/mask_former/modeling/heads/per_pixel_baseline.py
@@ -0,0 +1,243 @@
+# Copyright (c) Facebook, Inc. and its affiliates.
+import logging
+from typing import Callable, Dict, List, Optional, Tuple, Union
+
+import fvcore.nn.weight_init as weight_init
+from torch import nn
+from torch.nn import functional as F
+
+from detectron2.config import configurable
+from detectron2.layers import Conv2d, ShapeSpec, get_norm
+from detectron2.modeling import SEM_SEG_HEADS_REGISTRY
+
+from ..transformer.transformer_predictor import TransformerPredictor
+from .pixel_decoder import build_pixel_decoder
+
+
+@SEM_SEG_HEADS_REGISTRY.register()
+class PerPixelBaselineHead(nn.Module):
+
+ _version = 2
+
+ def _load_from_state_dict(
+ self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs
+ ):
+ version = local_metadata.get("version", None)
+ if version is None or version < 2:
+ logger = logging.getLogger(__name__)
+ # Do not warn if train from scratch
+ scratch = True
+ logger = logging.getLogger(__name__)
+ for k in list(state_dict.keys()):
+ newk = k
+ if "sem_seg_head" in k and not k.startswith(prefix + "predictor"):
+ newk = k.replace(prefix, prefix + "pixel_decoder.")
+ # logger.warning(f"{k} ==> {newk}")
+ if newk != k:
+ state_dict[newk] = state_dict[k]
+ del state_dict[k]
+ scratch = False
+
+ if not scratch:
+ logger.warning(
+ f"Weight format of {self.__class__.__name__} have changed! "
+ "Please upgrade your models. Applying automatic conversion now ..."
+ )
+
+ @configurable
+ def __init__(
+ self,
+ input_shape: Dict[str, ShapeSpec],
+ *,
+ num_classes: int,
+ pixel_decoder: nn.Module,
+ loss_weight: float = 1.0,
+ ignore_value: int = -1,
+ ):
+ """
+ NOTE: this interface is experimental.
+ Args:
+ input_shape: shapes (channels and stride) of the input features
+ num_classes: number of classes to predict
+ pixel_decoder: the pixel decoder module
+ loss_weight: loss weight
+ ignore_value: category id to be ignored during training.
+ """
+ super().__init__()
+ input_shape = sorted(input_shape.items(), key=lambda x: x[1].stride)
+ self.in_features = [k for k, v in input_shape]
+ feature_strides = [v.stride for k, v in input_shape]
+ feature_channels = [v.channels for k, v in input_shape]
+
+ self.ignore_value = ignore_value
+ self.common_stride = 4
+ self.loss_weight = loss_weight
+
+ self.pixel_decoder = pixel_decoder
+ self.predictor = Conv2d(
+ self.pixel_decoder.mask_dim, num_classes, kernel_size=1, stride=1, padding=0
+ )
+ weight_init.c2_msra_fill(self.predictor)
+
+ @classmethod
+ def from_config(cls, cfg, input_shape: Dict[str, ShapeSpec]):
+ return {
+ "input_shape": {
+ k: v for k, v in input_shape.items() if k in cfg.MODEL.SEM_SEG_HEAD.IN_FEATURES
+ },
+ "ignore_value": cfg.MODEL.SEM_SEG_HEAD.IGNORE_VALUE,
+ "num_classes": cfg.MODEL.SEM_SEG_HEAD.NUM_CLASSES,
+ "pixel_decoder": build_pixel_decoder(cfg, input_shape),
+ "loss_weight": cfg.MODEL.SEM_SEG_HEAD.LOSS_WEIGHT,
+ }
+
+ def forward(self, features, targets=None):
+ """
+ Returns:
+ In training, returns (None, dict of losses)
+ In inference, returns (CxHxW logits, {})
+ """
+ x = self.layers(features)
+ if self.training:
+ return None, self.losses(x, targets)
+ else:
+ x = F.interpolate(
+ x, scale_factor=self.common_stride, mode="bilinear", align_corners=False
+ )
+ return x, {}
+
+ def layers(self, features):
+ x, _ = self.pixel_decoder.forward_features(features)
+ x = self.predictor(x)
+ return x
+
+ def losses(self, predictions, targets):
+ predictions = predictions.float() # https://github.com/pytorch/pytorch/issues/48163
+ predictions = F.interpolate(
+ predictions, scale_factor=self.common_stride, mode="bilinear", align_corners=False
+ )
+ loss = F.cross_entropy(
+ predictions, targets, reduction="mean", ignore_index=self.ignore_value
+ )
+ losses = {"loss_sem_seg": loss * self.loss_weight}
+ return losses
+
+
+@SEM_SEG_HEADS_REGISTRY.register()
+class PerPixelBaselinePlusHead(PerPixelBaselineHead):
+ def _load_from_state_dict(
+ self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs
+ ):
+ version = local_metadata.get("version", None)
+ if version is None or version < 2:
+ # Do not warn if train from scratch
+ scratch = True
+ logger = logging.getLogger(__name__)
+ for k in list(state_dict.keys()):
+ newk = k
+ if "sem_seg_head" in k and not k.startswith(prefix + "predictor"):
+ newk = k.replace(prefix, prefix + "pixel_decoder.")
+ logger.debug(f"{k} ==> {newk}")
+ if newk != k:
+ state_dict[newk] = state_dict[k]
+ del state_dict[k]
+ scratch = False
+
+ if not scratch:
+ logger.warning(
+ f"Weight format of {self.__class__.__name__} have changed! "
+ "Please upgrade your models. Applying automatic conversion now ..."
+ )
+
+ @configurable
+ def __init__(
+ self,
+ input_shape: Dict[str, ShapeSpec],
+ *,
+ # extra parameters
+ transformer_predictor: nn.Module,
+ transformer_in_feature: str,
+ deep_supervision: bool,
+ # inherit parameters
+ num_classes: int,
+ pixel_decoder: nn.Module,
+ loss_weight: float = 1.0,
+ ignore_value: int = -1,
+ ):
+ """
+ NOTE: this interface is experimental.
+ Args:
+ input_shape: shapes (channels and stride) of the input features
+ transformer_predictor: the transformer decoder that makes prediction
+ transformer_in_feature: input feature name to the transformer_predictor
+ deep_supervision: whether or not to add supervision to the output of
+ every transformer decoder layer
+ num_classes: number of classes to predict
+ pixel_decoder: the pixel decoder module
+ loss_weight: loss weight
+ ignore_value: category id to be ignored during training.
+ """
+ super().__init__(
+ input_shape,
+ num_classes=num_classes,
+ pixel_decoder=pixel_decoder,
+ loss_weight=loss_weight,
+ ignore_value=ignore_value,
+ )
+
+ del self.predictor
+
+ self.predictor = transformer_predictor
+ self.transformer_in_feature = transformer_in_feature
+ self.deep_supervision = deep_supervision
+
+ @classmethod
+ def from_config(cls, cfg, input_shape: Dict[str, ShapeSpec]):
+ ret = super().from_config(cfg, input_shape)
+ ret["transformer_in_feature"] = cfg.MODEL.MASK_FORMER.TRANSFORMER_IN_FEATURE
+ if cfg.MODEL.MASK_FORMER.TRANSFORMER_IN_FEATURE == "transformer_encoder":
+ in_channels = cfg.MODEL.SEM_SEG_HEAD.CONVS_DIM
+ else:
+ in_channels = input_shape[ret["transformer_in_feature"]].channels
+ ret["transformer_predictor"] = TransformerPredictor(
+ cfg, in_channels, mask_classification=False
+ )
+ ret["deep_supervision"] = cfg.MODEL.MASK_FORMER.DEEP_SUPERVISION
+ return ret
+
+ def forward(self, features, targets=None):
+ """
+ Returns:
+ In training, returns (None, dict of losses)
+ In inference, returns (CxHxW logits, {})
+ """
+ x, aux_outputs = self.layers(features)
+ if self.training:
+ if self.deep_supervision:
+ losses = self.losses(x, targets)
+ for i, aux_output in enumerate(aux_outputs):
+ losses["loss_sem_seg" + f"_{i}"] = self.losses(
+ aux_output["pred_masks"], targets
+ )["loss_sem_seg"]
+ return None, losses
+ else:
+ return None, self.losses(x, targets)
+ else:
+ x = F.interpolate(
+ x, scale_factor=self.common_stride, mode="bilinear", align_corners=False
+ )
+ return x, {}
+
+ def layers(self, features):
+ mask_features, transformer_encoder_features = self.pixel_decoder.forward_features(features)
+ if self.transformer_in_feature == "transformer_encoder":
+ assert (
+ transformer_encoder_features is not None
+ ), "Please use the TransformerEncoderPixelDecoder."
+ predictions = self.predictor(transformer_encoder_features, mask_features)
+ else:
+ predictions = self.predictor(features[self.transformer_in_feature], mask_features)
+ if self.deep_supervision:
+ return predictions["pred_masks"], predictions["aux_outputs"]
+ else:
+ return predictions["pred_masks"], None
diff --git a/ACT_DP_multitask/detr/models/mask_former/modeling/heads/pixel_decoder.py b/ACT_DP_multitask/detr/models/mask_former/modeling/heads/pixel_decoder.py
new file mode 100644
index 0000000000000000000000000000000000000000..8ae74c8ce6000aa976d19c7f9223519f46ce7e6a
--- /dev/null
+++ b/ACT_DP_multitask/detr/models/mask_former/modeling/heads/pixel_decoder.py
@@ -0,0 +1,294 @@
+# Copyright (c) Facebook, Inc. and its affiliates.
+import logging
+from typing import Callable, Dict, List, Optional, Tuple, Union
+
+import fvcore.nn.weight_init as weight_init
+from torch import nn
+from torch.nn import functional as F
+
+from detectron2.config import configurable
+from detectron2.layers import Conv2d, ShapeSpec, get_norm
+from detectron2.modeling import SEM_SEG_HEADS_REGISTRY
+
+from ..transformer.position_encoding import PositionEmbeddingSine
+from ..transformer.transformer import TransformerEncoder, TransformerEncoderLayer
+
+
+def build_pixel_decoder(cfg, input_shape):
+ """
+ Build a pixel decoder from `cfg.MODEL.MASK_FORMER.PIXEL_DECODER_NAME`.
+ """
+ name = cfg.MODEL.SEM_SEG_HEAD.PIXEL_DECODER_NAME
+ model = SEM_SEG_HEADS_REGISTRY.get(name)(cfg, input_shape)
+ forward_features = getattr(model, "forward_features", None)
+ if not callable(forward_features):
+ raise ValueError(
+ "Only SEM_SEG_HEADS with forward_features method can be used as pixel decoder. "
+ f"Please implement forward_features for {name} to only return mask features."
+ )
+ return model
+
+
+@SEM_SEG_HEADS_REGISTRY.register()
+class BasePixelDecoder(nn.Module):
+ # @configurable
+ def __init__(
+ self,
+ input_shape: Dict[str, ShapeSpec],
+ # *,
+ conv_dim: int,
+ mask_dim: int,
+ norm: Optional[Union[str, Callable]] = None,
+ ):
+ """
+ NOTE: this interface is experimental.
+ Args:
+ input_shape: shapes (channels and stride) of the input features
+ conv_dims: number of output channels for the intermediate conv layers.
+ mask_dim: number of output channels for the final conv layer.
+ norm (str or callable): normalization for all conv layers
+ """
+ super().__init__()
+
+ input_shape = sorted(input_shape.items(), key=lambda x: x[1].stride)
+ self.in_features = [k for k, v in input_shape] # starting from "res2" to "res5"
+ feature_channels = [v.channels for k, v in input_shape]
+
+ lateral_convs = []
+ output_convs = []
+
+ use_bias = norm == ""
+ for idx, in_channels in enumerate(feature_channels):
+ if idx == len(self.in_features) - 1:
+ output_norm = get_norm(norm, conv_dim)
+ output_conv = Conv2d(
+ in_channels,
+ conv_dim,
+ kernel_size=3,
+ stride=1,
+ padding=1,
+ bias=use_bias,
+ norm=output_norm,
+ activation=F.relu,
+ )
+ weight_init.c2_xavier_fill(output_conv)
+ self.add_module("layer_{}".format(idx + 1), output_conv)
+
+ lateral_convs.append(None)
+ output_convs.append(output_conv)
+ else:
+ lateral_norm = get_norm(norm, conv_dim)
+ output_norm = get_norm(norm, conv_dim)
+
+ lateral_conv = Conv2d(
+ in_channels, conv_dim, kernel_size=1, bias=use_bias, norm=lateral_norm
+ )
+ output_conv = Conv2d(
+ conv_dim,
+ conv_dim,
+ kernel_size=3,
+ stride=1,
+ padding=1,
+ bias=use_bias,
+ norm=output_norm,
+ activation=F.relu,
+ )
+ weight_init.c2_xavier_fill(lateral_conv)
+ weight_init.c2_xavier_fill(output_conv)
+ self.add_module("adapter_{}".format(idx + 1), lateral_conv)
+ self.add_module("layer_{}".format(idx + 1), output_conv)
+
+ lateral_convs.append(lateral_conv)
+ output_convs.append(output_conv)
+ # Place convs into top-down order (from low to high resolution)
+ # to make the top-down computation in forward clearer.
+ self.lateral_convs = lateral_convs[::-1]
+ self.output_convs = output_convs[::-1]
+
+ self.mask_dim = mask_dim
+ self.mask_features = Conv2d(
+ conv_dim,
+ mask_dim,
+ kernel_size=3,
+ stride=1,
+ padding=1,
+ )
+ weight_init.c2_xavier_fill(self.mask_features)
+
+ # @classmethod
+ # def from_config(cls, cfg, input_shape: Dict[str, ShapeSpec]):
+ # ret = {}
+ # ret["input_shape"] = {
+ # k: v for k, v in input_shape.items() if k in cfg.MODEL.SEM_SEG_HEAD.IN_FEATURES
+ # }
+ # ret["conv_dim"] = cfg.MODEL.SEM_SEG_HEAD.CONVS_DIM
+ # ret["mask_dim"] = cfg.MODEL.SEM_SEG_HEAD.MASK_DIM
+ # ret["norm"] = cfg.MODEL.SEM_SEG_HEAD.NORM
+ # return ret
+
+ def forward_features(self, features):
+ # Reverse feature maps into top-down order (from low to high resolution)
+ for idx, f in enumerate(self.in_features[::-1]):
+ x = features[f]
+ lateral_conv = self.lateral_convs[idx]
+ output_conv = self.output_convs[idx]
+ if lateral_conv is None:
+ y = output_conv(x)
+ else:
+ cur_fpn = lateral_conv(x)
+ # Following FPN implementation, we use nearest upsampling here
+ y = cur_fpn + F.interpolate(y, size=cur_fpn.shape[-2:], mode="nearest")
+ y = output_conv(y)
+ return self.mask_features(y), None
+
+ def forward(self, features, targets=None):
+ logger = logging.getLogger(__name__)
+ logger.warning("Calling forward() may cause unpredicted behavior of PixelDecoder module.")
+ return self.forward_features(features)
+
+
+class TransformerEncoderOnly(nn.Module):
+ def __init__(
+ self,
+ d_model=512,
+ nhead=8,
+ num_encoder_layers=6,
+ dim_feedforward=2048,
+ dropout=0.1,
+ activation="relu",
+ normalize_before=False,
+ ):
+ super().__init__()
+
+ encoder_layer = TransformerEncoderLayer(
+ d_model, nhead, dim_feedforward, dropout, activation, normalize_before
+ )
+ encoder_norm = nn.LayerNorm(d_model) if normalize_before else None
+ self.encoder = TransformerEncoder(encoder_layer, num_encoder_layers, encoder_norm)
+
+ self._reset_parameters()
+
+ self.d_model = d_model
+ self.nhead = nhead
+
+ def _reset_parameters(self):
+ for p in self.parameters():
+ if p.dim() > 1:
+ nn.init.xavier_uniform_(p)
+
+ def forward(self, src, mask, pos_embed):
+ # flatten NxCxHxW to HWxNxC
+ bs, c, h, w = src.shape
+ src = src.flatten(2).permute(2, 0, 1)
+ pos_embed = pos_embed.flatten(2).permute(2, 0, 1)
+ if mask is not None:
+ mask = mask.flatten(1)
+
+ memory = self.encoder(src, src_key_padding_mask=mask, pos=pos_embed)
+ return memory.permute(1, 2, 0).view(bs, c, h, w)
+
+
+@SEM_SEG_HEADS_REGISTRY.register()
+class TransformerEncoderPixelDecoder(BasePixelDecoder):
+ @configurable
+ def __init__(
+ self,
+ input_shape: Dict[str, ShapeSpec],
+ *,
+ transformer_dropout: float,
+ transformer_nheads: int,
+ transformer_dim_feedforward: int,
+ transformer_enc_layers: int,
+ transformer_pre_norm: bool,
+ conv_dim: int,
+ mask_dim: int,
+ norm: Optional[Union[str, Callable]] = None,
+ ):
+ """
+ NOTE: this interface is experimental.
+ Args:
+ input_shape: shapes (channels and stride) of the input features
+ transformer_dropout: dropout probability in transformer
+ transformer_nheads: number of heads in transformer
+ transformer_dim_feedforward: dimension of feedforward network
+ transformer_enc_layers: number of transformer encoder layers
+ transformer_pre_norm: whether to use pre-layernorm or not
+ conv_dims: number of output channels for the intermediate conv layers.
+ mask_dim: number of output channels for the final conv layer.
+ norm (str or callable): normalization for all conv layers
+ """
+ super().__init__(input_shape, conv_dim=conv_dim, mask_dim=mask_dim, norm=norm)
+
+ input_shape = sorted(input_shape.items(), key=lambda x: x[1].stride)
+ self.in_features = [k for k, v in input_shape] # starting from "res2" to "res5"
+ feature_strides = [v.stride for k, v in input_shape]
+ feature_channels = [v.channels for k, v in input_shape]
+
+ in_channels = feature_channels[len(self.in_features) - 1]
+ self.input_proj = Conv2d(in_channels, conv_dim, kernel_size=1)
+ weight_init.c2_xavier_fill(self.input_proj)
+ self.transformer = TransformerEncoderOnly(
+ d_model=conv_dim,
+ dropout=transformer_dropout,
+ nhead=transformer_nheads,
+ dim_feedforward=transformer_dim_feedforward,
+ num_encoder_layers=transformer_enc_layers,
+ normalize_before=transformer_pre_norm,
+ )
+ N_steps = conv_dim // 2
+ self.pe_layer = PositionEmbeddingSine(N_steps, normalize=True)
+
+ # update layer
+ use_bias = norm == ""
+ output_norm = get_norm(norm, conv_dim)
+ output_conv = Conv2d(
+ conv_dim,
+ conv_dim,
+ kernel_size=3,
+ stride=1,
+ padding=1,
+ bias=use_bias,
+ norm=output_norm,
+ activation=F.relu,
+ )
+ weight_init.c2_xavier_fill(output_conv)
+ delattr(self, "layer_{}".format(len(self.in_features)))
+ self.add_module("layer_{}".format(len(self.in_features)), output_conv)
+ self.output_convs[0] = output_conv
+
+ @classmethod
+ def from_config(cls, cfg, input_shape: Dict[str, ShapeSpec]):
+ ret = super().from_config(cfg, input_shape)
+ ret["transformer_dropout"] = cfg.MODEL.MASK_FORMER.DROPOUT
+ ret["transformer_nheads"] = cfg.MODEL.MASK_FORMER.NHEADS
+ ret["transformer_dim_feedforward"] = cfg.MODEL.MASK_FORMER.DIM_FEEDFORWARD
+ ret[
+ "transformer_enc_layers"
+ ] = cfg.MODEL.SEM_SEG_HEAD.TRANSFORMER_ENC_LAYERS # a separate config
+ ret["transformer_pre_norm"] = cfg.MODEL.MASK_FORMER.PRE_NORM
+ return ret
+
+ def forward_features(self, features):
+ # Reverse feature maps into top-down order (from low to high resolution)
+ for idx, f in enumerate(self.in_features[::-1]):
+ x = features[f]
+ lateral_conv = self.lateral_convs[idx]
+ output_conv = self.output_convs[idx]
+ if lateral_conv is None:
+ transformer = self.input_proj(x)
+ pos = self.pe_layer(x)
+ transformer = self.transformer(transformer, None, pos)
+ y = output_conv(transformer)
+ # save intermediate feature as input to Transformer decoder
+ transformer_encoder_features = transformer
+ else:
+ cur_fpn = lateral_conv(x)
+ # Following FPN implementation, we use nearest upsampling here
+ y = cur_fpn + F.interpolate(y, size=cur_fpn.shape[-2:], mode="nearest")
+ y = output_conv(y)
+ return self.mask_features(y), transformer_encoder_features
+
+ def forward(self, features, targets=None):
+ logger = logging.getLogger(__name__)
+ logger.warning("Calling forward() may cause unpredicted behavior of PixelDecoder module.")
+ return self.forward_features(features)
diff --git a/ACT_DP_multitask/detr/models/mask_former/modeling/matcher.py b/ACT_DP_multitask/detr/models/mask_former/modeling/matcher.py
new file mode 100644
index 0000000000000000000000000000000000000000..a4d11706577bf353bb0df13fe57032da3c66e8f5
--- /dev/null
+++ b/ACT_DP_multitask/detr/models/mask_former/modeling/matcher.py
@@ -0,0 +1,174 @@
+# Copyright (c) Facebook, Inc. and its affiliates.
+# Modified by Bowen Cheng from https://github.com/facebookresearch/detr/blob/master/models/matcher.py
+"""
+Modules to compute the matching cost and solve the corresponding LSAP.
+"""
+import torch
+import torch.nn.functional as F
+from scipy.optimize import linear_sum_assignment
+from torch import nn
+
+
+def batch_dice_loss(inputs, targets):
+ """
+ Compute the DICE loss, similar to generalized IOU for masks
+ Args:
+ inputs: A float tensor of arbitrary shape.
+ The predictions for each example.
+ targets: A float tensor with the same shape as inputs. Stores the binary
+ classification label for each element in inputs
+ (0 for the negative class and 1 for the positive class).
+ """
+ inputs = inputs.sigmoid()
+ inputs = inputs.flatten(1)
+ numerator = 2 * torch.einsum("nc,mc->nm", inputs, targets)
+ denominator = inputs.sum(-1)[:, None] + targets.sum(-1)[None, :]
+ loss = 1 - (numerator + 1) / (denominator + 1)
+ return loss
+
+
+def batch_sigmoid_focal_loss(inputs, targets, alpha: float = 0.25, gamma: float = 2):
+ """
+ Loss used in RetinaNet for dense detection: https://arxiv.org/abs/1708.02002.
+ Args:
+ inputs: A float tensor of arbitrary shape.
+ The predictions for each example.
+ targets: A float tensor with the same shape as inputs. Stores the binary
+ classification label for each element in inputs
+ (0 for the negative class and 1 for the positive class).
+ alpha: (optional) Weighting factor in range (0,1) to balance
+ positive vs negative examples. Default = -1 (no weighting).
+ gamma: Exponent of the modulating factor (1 - p_t) to
+ balance easy vs hard examples.
+ Returns:
+ Loss tensor
+ """
+ hw = inputs.shape[1]
+
+ prob = inputs.sigmoid()
+ focal_pos = ((1 - prob) ** gamma) * F.binary_cross_entropy_with_logits(
+ inputs, torch.ones_like(inputs), reduction="none"
+ )
+ focal_neg = (prob ** gamma) * F.binary_cross_entropy_with_logits(
+ inputs, torch.zeros_like(inputs), reduction="none"
+ )
+ if alpha >= 0:
+ focal_pos = focal_pos * alpha
+ focal_neg = focal_neg * (1 - alpha)
+
+ loss = torch.einsum("nc,mc->nm", focal_pos, targets) + torch.einsum(
+ "nc,mc->nm", focal_neg, (1 - targets)
+ )
+
+ return loss / hw
+
+
+class HungarianMatcher(nn.Module):
+ """This class computes an assignment between the targets and the predictions of the network
+
+ For efficiency reasons, the targets don't include the no_object. Because of this, in general,
+ there are more predictions than targets. In this case, we do a 1-to-1 matching of the best predictions,
+ while the others are un-matched (and thus treated as non-objects).
+ """
+
+ def __init__(self, cost_class: float = 1, cost_mask: float = 1, cost_dice: float = 1):
+ """Creates the matcher
+
+ Params:
+ cost_class: This is the relative weight of the classification error in the matching cost
+ cost_mask: This is the relative weight of the focal loss of the binary mask in the matching cost
+ cost_dice: This is the relative weight of the dice loss of the binary mask in the matching cost
+ """
+ super().__init__()
+ self.cost_class = cost_class
+ self.cost_mask = cost_mask
+ self.cost_dice = cost_dice
+ assert cost_class != 0 or cost_mask != 0 or cost_dice != 0, "all costs cant be 0"
+
+ @torch.no_grad()
+ def memory_efficient_forward(self, outputs, targets):
+ """More memory-friendly matching"""
+ bs, num_queries = outputs["pred_logits"].shape[:2]
+
+ # Work out the mask padding size
+ masks = [v["masks"] for v in targets]
+ h_max = max([m.shape[1] for m in masks])
+ w_max = max([m.shape[2] for m in masks])
+
+ indices = []
+
+ # Iterate through batch size
+ for b in range(bs):
+
+ out_prob = outputs["pred_logits"][b].softmax(-1) # [num_queries, num_classes]
+ out_mask = outputs["pred_masks"][b] # [num_queries, H_pred, W_pred]
+
+ tgt_ids = targets[b]["labels"]
+ # gt masks are already padded when preparing target
+ tgt_mask = targets[b]["masks"].to(out_mask)
+
+ # Compute the classification cost. Contrary to the loss, we don't use the NLL,
+ # but approximate it in 1 - proba[target class].
+ # The 1 is a constant that doesn't change the matching, it can be ommitted.
+ cost_class = -out_prob[:, tgt_ids]
+
+ # Downsample gt masks to save memory
+ tgt_mask = F.interpolate(tgt_mask[:, None], size=out_mask.shape[-2:], mode="nearest")
+
+ # Flatten spatial dimension
+ out_mask = out_mask.flatten(1) # [batch_size * num_queries, H*W]
+ tgt_mask = tgt_mask[:, 0].flatten(1) # [num_total_targets, H*W]
+
+ # Compute the focal loss between masks
+ cost_mask = batch_sigmoid_focal_loss(out_mask, tgt_mask)
+
+ # Compute the dice loss betwen masks
+ cost_dice = batch_dice_loss(out_mask, tgt_mask)
+
+ # Final cost matrix
+ C = (
+ self.cost_mask * cost_mask
+ + self.cost_class * cost_class
+ + self.cost_dice * cost_dice
+ )
+ C = C.reshape(num_queries, -1).cpu()
+
+ indices.append(linear_sum_assignment(C))
+ return [
+ (torch.as_tensor(i, dtype=torch.int64), torch.as_tensor(j, dtype=torch.int64))
+ for i, j in indices
+ ]
+
+ @torch.no_grad()
+ def forward(self, outputs, targets):
+ """Performs the matching
+
+ Params:
+ outputs: This is a dict that contains at least these entries:
+ "pred_logits": Tensor of dim [batch_size, num_queries, num_classes] with the classification logits
+ "pred_masks": Tensor of dim [batch_size, num_queries, H_pred, W_pred] with the predicted masks
+
+ targets: This is a list of targets (len(targets) = batch_size), where each target is a dict containing:
+ "labels": Tensor of dim [num_target_boxes] (where num_target_boxes is the number of ground-truth
+ objects in the target) containing the class labels
+ "masks": Tensor of dim [num_target_boxes, H_gt, W_gt] containing the target masks
+
+ Returns:
+ A list of size batch_size, containing tuples of (index_i, index_j) where:
+ - index_i is the indices of the selected predictions (in order)
+ - index_j is the indices of the corresponding selected targets (in order)
+ For each batch element, it holds:
+ len(index_i) = len(index_j) = min(num_queries, num_target_boxes)
+ """
+ return self.memory_efficient_forward(outputs, targets)
+
+ def __repr__(self):
+ head = "Matcher " + self.__class__.__name__
+ body = [
+ "cost_class: {}".format(self.cost_class),
+ "cost_mask: {}".format(self.cost_mask),
+ "cost_dice: {}".format(self.cost_dice),
+ ]
+ _repr_indent = 4
+ lines = [head] + [" " * _repr_indent + line for line in body]
+ return "\n".join(lines)
diff --git a/ACT_DP_multitask/detr/models/mask_former/modeling/transformer/__init__.py b/ACT_DP_multitask/detr/models/mask_former/modeling/transformer/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..9020c2df23e2af280b7bb168b996ae9eaf312eb8
--- /dev/null
+++ b/ACT_DP_multitask/detr/models/mask_former/modeling/transformer/__init__.py
@@ -0,0 +1 @@
+# Copyright (c) Facebook, Inc. and its affiliates.
diff --git a/ACT_DP_multitask/detr/models/mask_former/modeling/transformer/position_encoding.py b/ACT_DP_multitask/detr/models/mask_former/modeling/transformer/position_encoding.py
new file mode 100644
index 0000000000000000000000000000000000000000..a189ce20d14f44b6b61b209acd2f62fc1d70814a
--- /dev/null
+++ b/ACT_DP_multitask/detr/models/mask_former/modeling/transformer/position_encoding.py
@@ -0,0 +1,52 @@
+# Copyright (c) Facebook, Inc. and its affiliates.
+# # Modified by Bowen Cheng from: https://github.com/facebookresearch/detr/blob/master/models/position_encoding.py
+"""
+Various positional encodings for the transformer.
+"""
+import math
+
+import torch
+from torch import nn
+
+
+class PositionEmbeddingSine(nn.Module):
+ """
+ This is a more standard version of the position embedding, very similar to the one
+ used by the Attention is all you need paper, generalized to work on images.
+ """
+
+ def __init__(self, num_pos_feats=64, temperature=10000, normalize=False, scale=None):
+ super().__init__()
+ self.num_pos_feats = num_pos_feats
+ self.temperature = temperature
+ self.normalize = normalize
+ if scale is not None and normalize is False:
+ raise ValueError("normalize should be True if scale is passed")
+ if scale is None:
+ scale = 2 * math.pi
+ self.scale = scale
+
+ def forward(self, x, mask=None):
+ if mask is None:
+ mask = torch.zeros((x.size(0), x.size(2), x.size(3)), device=x.device, dtype=torch.bool)
+ not_mask = ~mask
+ y_embed = not_mask.cumsum(1, dtype=torch.float32)
+ x_embed = not_mask.cumsum(2, dtype=torch.float32)
+ if self.normalize:
+ eps = 1e-6
+ y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale
+ x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale
+
+ dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device)
+ dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats)
+
+ pos_x = x_embed[:, :, :, None] / dim_t
+ pos_y = y_embed[:, :, :, None] / dim_t
+ pos_x = torch.stack(
+ (pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4
+ ).flatten(3)
+ pos_y = torch.stack(
+ (pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4
+ ).flatten(3)
+ pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2)
+ return pos
diff --git a/ACT_DP_multitask/detr/models/mask_former/modeling/transformer/transformer.py b/ACT_DP_multitask/detr/models/mask_former/modeling/transformer/transformer.py
new file mode 100644
index 0000000000000000000000000000000000000000..ea8caa0108f5e136a9739320ab69a3e1b6f40298
--- /dev/null
+++ b/ACT_DP_multitask/detr/models/mask_former/modeling/transformer/transformer.py
@@ -0,0 +1,369 @@
+# Copyright (c) Facebook, Inc. and its affiliates.
+# Modified by Bowen Cheng from: https://github.com/facebookresearch/detr/blob/master/models/transformer.py
+"""
+Transformer class.
+
+Copy-paste from torch.nn.Transformer with modifications:
+ * positional encodings are passed in MHattention
+ * extra LN at the end of encoder is removed
+ * decoder returns a stack of activations from all decoding layers
+"""
+import copy
+from typing import List, Optional
+
+import torch
+import torch.nn.functional as F
+from torch import Tensor, nn
+
+
+class Transformer(nn.Module):
+ def __init__(
+ self,
+ d_model=512,
+ nhead=8,
+ num_encoder_layers=6,
+ num_decoder_layers=6,
+ dim_feedforward=2048,
+ dropout=0.1,
+ activation="relu",
+ normalize_before=False,
+ return_intermediate_dec=False,
+ ):
+ super().__init__()
+
+ encoder_layer = TransformerEncoderLayer(
+ d_model, nhead, dim_feedforward, dropout, activation, normalize_before
+ )
+ encoder_norm = nn.LayerNorm(d_model) if normalize_before else None
+ self.encoder = TransformerEncoder(encoder_layer, num_encoder_layers, encoder_norm)
+
+ decoder_layer = TransformerDecoderLayer(
+ d_model, nhead, dim_feedforward, dropout, activation, normalize_before
+ )
+ decoder_norm = nn.LayerNorm(d_model)
+ self.decoder = TransformerDecoder(
+ decoder_layer,
+ num_decoder_layers,
+ decoder_norm,
+ return_intermediate=return_intermediate_dec,
+ )
+
+ self._reset_parameters()
+
+ self.d_model = d_model
+ self.nhead = nhead
+
+ def _reset_parameters(self):
+ for p in self.parameters():
+ if p.dim() > 1:
+ nn.init.xavier_uniform_(p)
+
+ def forward(self, src, mask, query_embed, pos_embed):
+ # flatten NxCxHxW to HWxNxC
+ bs, c, h, w = src.shape
+ src = src.flatten(2).permute(2, 0, 1)
+ pos_embed = pos_embed.flatten(2).permute(2, 0, 1)
+ query_embed = query_embed.unsqueeze(1).repeat(1, bs, 1)
+ if mask is not None:
+ mask = mask.flatten(1)
+
+ tgt = torch.zeros_like(query_embed)
+ memory = self.encoder(src, src_key_padding_mask=mask, pos=pos_embed)
+ hs = self.decoder(
+ tgt, memory, memory_key_padding_mask=mask, pos=pos_embed, query_pos=query_embed
+ )
+ return hs.transpose(1, 2), memory.permute(1, 2, 0).view(bs, c, h, w)
+
+
+class TransformerEncoder(nn.Module):
+ def __init__(self, encoder_layer, num_layers, norm=None):
+ super().__init__()
+ self.layers = _get_clones(encoder_layer, num_layers)
+ self.num_layers = num_layers
+ self.norm = norm
+
+ def forward(
+ self,
+ src,
+ mask: Optional[Tensor] = None,
+ src_key_padding_mask: Optional[Tensor] = None,
+ pos: Optional[Tensor] = None,
+ ):
+ output = src
+
+ for layer in self.layers:
+ output = layer(
+ output, src_mask=mask, src_key_padding_mask=src_key_padding_mask, pos=pos
+ )
+
+ if self.norm is not None:
+ output = self.norm(output)
+
+ return output
+
+
+class TransformerDecoder(nn.Module):
+ def __init__(self, decoder_layer, num_layers, norm=None, return_intermediate=False):
+ super().__init__()
+ self.layers = _get_clones(decoder_layer, num_layers)
+ self.num_layers = num_layers
+ self.norm = norm
+ self.return_intermediate = return_intermediate
+
+ def forward(
+ self,
+ tgt,
+ memory,
+ tgt_mask: Optional[Tensor] = None,
+ memory_mask: Optional[Tensor] = None,
+ tgt_key_padding_mask: Optional[Tensor] = None,
+ memory_key_padding_mask: Optional[Tensor] = None,
+ pos: Optional[Tensor] = None,
+ query_pos: Optional[Tensor] = None,
+ ):
+ output = tgt
+
+ intermediate = []
+
+ for layer in self.layers:
+ output = layer(
+ output,
+ memory,
+ tgt_mask=tgt_mask,
+ memory_mask=memory_mask,
+ tgt_key_padding_mask=tgt_key_padding_mask,
+ memory_key_padding_mask=memory_key_padding_mask,
+ pos=pos,
+ query_pos=query_pos,
+ )
+ if self.return_intermediate:
+ intermediate.append(self.norm(output))
+
+ if self.norm is not None:
+ output = self.norm(output)
+ if self.return_intermediate:
+ intermediate.pop()
+ intermediate.append(output)
+
+ if self.return_intermediate:
+ return torch.stack(intermediate)
+
+ return output.unsqueeze(0)
+
+
+class TransformerEncoderLayer(nn.Module):
+ def __init__(
+ self,
+ d_model,
+ nhead,
+ dim_feedforward=2048,
+ dropout=0.1,
+ activation="relu",
+ normalize_before=False,
+ ):
+ super().__init__()
+ self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
+ # Implementation of Feedforward model
+ self.linear1 = nn.Linear(d_model, dim_feedforward)
+ self.dropout = nn.Dropout(dropout)
+ self.linear2 = nn.Linear(dim_feedforward, d_model)
+
+ self.norm1 = nn.LayerNorm(d_model)
+ self.norm2 = nn.LayerNorm(d_model)
+ self.dropout1 = nn.Dropout(dropout)
+ self.dropout2 = nn.Dropout(dropout)
+
+ self.activation = _get_activation_fn(activation)
+ self.normalize_before = normalize_before
+
+ def with_pos_embed(self, tensor, pos: Optional[Tensor]):
+ return tensor if pos is None else tensor + pos
+
+ def forward_post(
+ self,
+ src,
+ src_mask: Optional[Tensor] = None,
+ src_key_padding_mask: Optional[Tensor] = None,
+ pos: Optional[Tensor] = None,
+ ):
+ q = k = self.with_pos_embed(src, pos)
+ src2 = self.self_attn(
+ q, k, value=src, attn_mask=src_mask, key_padding_mask=src_key_padding_mask
+ )[0]
+ src = src + self.dropout1(src2)
+ src = self.norm1(src)
+ src2 = self.linear2(self.dropout(self.activation(self.linear1(src))))
+ src = src + self.dropout2(src2)
+ src = self.norm2(src)
+ return src
+
+ def forward_pre(
+ self,
+ src,
+ src_mask: Optional[Tensor] = None,
+ src_key_padding_mask: Optional[Tensor] = None,
+ pos: Optional[Tensor] = None,
+ ):
+ src2 = self.norm1(src)
+ q = k = self.with_pos_embed(src2, pos)
+ src2 = self.self_attn(
+ q, k, value=src2, attn_mask=src_mask, key_padding_mask=src_key_padding_mask
+ )[0]
+ src = src + self.dropout1(src2)
+ src2 = self.norm2(src)
+ src2 = self.linear2(self.dropout(self.activation(self.linear1(src2))))
+ src = src + self.dropout2(src2)
+ return src
+
+ def forward(
+ self,
+ src,
+ src_mask: Optional[Tensor] = None,
+ src_key_padding_mask: Optional[Tensor] = None,
+ pos: Optional[Tensor] = None,
+ ):
+ if self.normalize_before:
+ return self.forward_pre(src, src_mask, src_key_padding_mask, pos)
+ return self.forward_post(src, src_mask, src_key_padding_mask, pos)
+
+
+class TransformerDecoderLayer(nn.Module):
+ def __init__(
+ self,
+ d_model,
+ nhead,
+ dim_feedforward=2048,
+ dropout=0.1,
+ activation="relu",
+ normalize_before=False,
+ ):
+ super().__init__()
+ self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
+ self.multihead_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
+ # Implementation of Feedforward model
+ self.linear1 = nn.Linear(d_model, dim_feedforward)
+ self.dropout = nn.Dropout(dropout)
+ self.linear2 = nn.Linear(dim_feedforward, d_model)
+
+ self.norm1 = nn.LayerNorm(d_model)
+ self.norm2 = nn.LayerNorm(d_model)
+ self.norm3 = nn.LayerNorm(d_model)
+ self.dropout1 = nn.Dropout(dropout)
+ self.dropout2 = nn.Dropout(dropout)
+ self.dropout3 = nn.Dropout(dropout)
+
+ self.activation = _get_activation_fn(activation)
+ self.normalize_before = normalize_before
+
+ def with_pos_embed(self, tensor, pos: Optional[Tensor]):
+ return tensor if pos is None else tensor + pos
+
+ def forward_post(
+ self,
+ tgt,
+ memory,
+ tgt_mask: Optional[Tensor] = None,
+ memory_mask: Optional[Tensor] = None,
+ tgt_key_padding_mask: Optional[Tensor] = None,
+ memory_key_padding_mask: Optional[Tensor] = None,
+ pos: Optional[Tensor] = None,
+ query_pos: Optional[Tensor] = None,
+ ):
+ q = k = self.with_pos_embed(tgt, query_pos)
+ tgt2 = self.self_attn(
+ q, k, value=tgt, attn_mask=tgt_mask, key_padding_mask=tgt_key_padding_mask
+ )[0]
+ tgt = tgt + self.dropout1(tgt2)
+ tgt = self.norm1(tgt)
+ tgt2 = self.multihead_attn(
+ query=self.with_pos_embed(tgt, query_pos),
+ key=self.with_pos_embed(memory, pos),
+ value=memory,
+ attn_mask=memory_mask,
+ key_padding_mask=memory_key_padding_mask,
+ )[0]
+ tgt = tgt + self.dropout2(tgt2)
+ tgt = self.norm2(tgt)
+ tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt))))
+ tgt = tgt + self.dropout3(tgt2)
+ tgt = self.norm3(tgt)
+ return tgt
+
+ def forward_pre(
+ self,
+ tgt,
+ memory,
+ tgt_mask: Optional[Tensor] = None,
+ memory_mask: Optional[Tensor] = None,
+ tgt_key_padding_mask: Optional[Tensor] = None,
+ memory_key_padding_mask: Optional[Tensor] = None,
+ pos: Optional[Tensor] = None,
+ query_pos: Optional[Tensor] = None,
+ ):
+ tgt2 = self.norm1(tgt)
+ q = k = self.with_pos_embed(tgt2, query_pos)
+ tgt2 = self.self_attn(
+ q, k, value=tgt2, attn_mask=tgt_mask, key_padding_mask=tgt_key_padding_mask
+ )[0]
+ tgt = tgt + self.dropout1(tgt2)
+ tgt2 = self.norm2(tgt)
+ tgt2 = self.multihead_attn(
+ query=self.with_pos_embed(tgt2, query_pos),
+ key=self.with_pos_embed(memory, pos),
+ value=memory,
+ attn_mask=memory_mask,
+ key_padding_mask=memory_key_padding_mask,
+ )[0]
+ tgt = tgt + self.dropout2(tgt2)
+ tgt2 = self.norm3(tgt)
+ tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt2))))
+ tgt = tgt + self.dropout3(tgt2)
+ return tgt
+
+ def forward(
+ self,
+ tgt,
+ memory,
+ tgt_mask: Optional[Tensor] = None,
+ memory_mask: Optional[Tensor] = None,
+ tgt_key_padding_mask: Optional[Tensor] = None,
+ memory_key_padding_mask: Optional[Tensor] = None,
+ pos: Optional[Tensor] = None,
+ query_pos: Optional[Tensor] = None,
+ ):
+ if self.normalize_before:
+ return self.forward_pre(
+ tgt,
+ memory,
+ tgt_mask,
+ memory_mask,
+ tgt_key_padding_mask,
+ memory_key_padding_mask,
+ pos,
+ query_pos,
+ )
+ return self.forward_post(
+ tgt,
+ memory,
+ tgt_mask,
+ memory_mask,
+ tgt_key_padding_mask,
+ memory_key_padding_mask,
+ pos,
+ query_pos,
+ )
+
+
+def _get_clones(module, N):
+ return nn.ModuleList([copy.deepcopy(module) for i in range(N)])
+
+
+def _get_activation_fn(activation):
+ """Return an activation function given a string"""
+ if activation == "relu":
+ return F.relu
+ if activation == "gelu":
+ return F.gelu
+ if activation == "glu":
+ return F.glu
+ raise RuntimeError(f"activation should be relu/gelu, not {activation}.")
diff --git a/ACT_DP_multitask/detr/models/mask_former/modeling/transformer/transformer_predictor.py b/ACT_DP_multitask/detr/models/mask_former/modeling/transformer/transformer_predictor.py
new file mode 100644
index 0000000000000000000000000000000000000000..45d04451444dc5532d60dd7a6b6072ba5b87357f
--- /dev/null
+++ b/ACT_DP_multitask/detr/models/mask_former/modeling/transformer/transformer_predictor.py
@@ -0,0 +1,171 @@
+# Copyright (c) Facebook, Inc. and its affiliates.
+# Modified by Bowen Cheng from: https://github.com/facebookresearch/detr/blob/master/models/detr.py
+import fvcore.nn.weight_init as weight_init
+import torch
+from torch import nn
+from torch.nn import functional as F
+
+from detectron2.config import configurable
+from detectron2.layers import Conv2d
+
+from .position_encoding import PositionEmbeddingSine
+from .transformer import Transformer
+
+
+class TransformerPredictor(nn.Module):
+ @configurable
+ def __init__(
+ self,
+ in_channels,
+ mask_classification=True,
+ *,
+ num_classes: int,
+ hidden_dim: int,
+ num_queries: int,
+ nheads: int,
+ dropout: float,
+ dim_feedforward: int,
+ enc_layers: int,
+ dec_layers: int,
+ pre_norm: bool,
+ deep_supervision: bool,
+ mask_dim: int,
+ enforce_input_project: bool,
+ ):
+ """
+ NOTE: this interface is experimental.
+ Args:
+ in_channels: channels of the input features
+ mask_classification: whether to add mask classifier or not
+ num_classes: number of classes
+ hidden_dim: Transformer feature dimension
+ num_queries: number of queries
+ nheads: number of heads
+ dropout: dropout in Transformer
+ dim_feedforward: feature dimension in feedforward network
+ enc_layers: number of Transformer encoder layers
+ dec_layers: number of Transformer decoder layers
+ pre_norm: whether to use pre-LayerNorm or not
+ deep_supervision: whether to add supervision to every decoder layers
+ mask_dim: mask feature dimension
+ enforce_input_project: add input project 1x1 conv even if input
+ channels and hidden dim is identical
+ """
+ super().__init__()
+
+ self.mask_classification = mask_classification
+
+ # positional encoding
+ N_steps = hidden_dim // 2
+ self.pe_layer = PositionEmbeddingSine(N_steps, normalize=True)
+
+ transformer = Transformer(
+ d_model=hidden_dim,
+ dropout=dropout,
+ nhead=nheads,
+ dim_feedforward=dim_feedforward,
+ num_encoder_layers=enc_layers,
+ num_decoder_layers=dec_layers,
+ normalize_before=pre_norm,
+ return_intermediate_dec=deep_supervision,
+ )
+
+ self.num_queries = num_queries
+ self.transformer = transformer
+ hidden_dim = transformer.d_model
+
+ self.query_embed = nn.Embedding(num_queries, hidden_dim)
+
+ if in_channels != hidden_dim or enforce_input_project:
+ self.input_proj = Conv2d(in_channels, hidden_dim, kernel_size=1)
+ weight_init.c2_xavier_fill(self.input_proj)
+ else:
+ self.input_proj = nn.Sequential()
+ self.aux_loss = deep_supervision
+
+ # output FFNs
+ if self.mask_classification:
+ self.class_embed = nn.Linear(hidden_dim, num_classes + 1)
+ self.mask_embed = MLP(hidden_dim, hidden_dim, mask_dim, 3)
+
+ @classmethod
+ def from_config(cls, cfg, in_channels, mask_classification):
+ ret = {}
+ ret["in_channels"] = in_channels
+ ret["mask_classification"] = mask_classification
+
+ ret["num_classes"] = cfg.MODEL.SEM_SEG_HEAD.NUM_CLASSES
+ ret["hidden_dim"] = cfg.MODEL.MASK_FORMER.HIDDEN_DIM
+ ret["num_queries"] = cfg.MODEL.MASK_FORMER.NUM_OBJECT_QUERIES
+ # Transformer parameters:
+ ret["nheads"] = cfg.MODEL.MASK_FORMER.NHEADS
+ ret["dropout"] = cfg.MODEL.MASK_FORMER.DROPOUT
+ ret["dim_feedforward"] = cfg.MODEL.MASK_FORMER.DIM_FEEDFORWARD
+ ret["enc_layers"] = cfg.MODEL.MASK_FORMER.ENC_LAYERS
+ ret["dec_layers"] = cfg.MODEL.MASK_FORMER.DEC_LAYERS
+ ret["pre_norm"] = cfg.MODEL.MASK_FORMER.PRE_NORM
+ ret["deep_supervision"] = cfg.MODEL.MASK_FORMER.DEEP_SUPERVISION
+ ret["enforce_input_project"] = cfg.MODEL.MASK_FORMER.ENFORCE_INPUT_PROJ
+
+ ret["mask_dim"] = cfg.MODEL.SEM_SEG_HEAD.MASK_DIM
+
+ return ret
+
+ def forward(self, x, mask_features):
+ pos = self.pe_layer(x)
+
+ src = x
+ mask = None
+ hs, memory = self.transformer(self.input_proj(src), mask, self.query_embed.weight, pos)
+
+ if self.mask_classification:
+ outputs_class = self.class_embed(hs)
+ out = {"pred_logits": outputs_class[-1]}
+ else:
+ out = {}
+
+ if self.aux_loss:
+ # [l, bs, queries, embed]
+ mask_embed = self.mask_embed(hs)
+ outputs_seg_masks = torch.einsum("lbqc,bchw->lbqhw", mask_embed, mask_features)
+ out["pred_masks"] = outputs_seg_masks[-1]
+ out["aux_outputs"] = self._set_aux_loss(
+ outputs_class if self.mask_classification else None, outputs_seg_masks
+ )
+ else:
+ # FIXME h_boxes takes the last one computed, keep this in mind
+ # [bs, queries, embed]
+ mask_embed = self.mask_embed(hs[-1])
+ outputs_seg_masks = torch.einsum("bqc,bchw->bqhw", mask_embed, mask_features)
+ out["pred_masks"] = outputs_seg_masks
+ return out
+
+ @torch.jit.unused
+ def _set_aux_loss(self, outputs_class, outputs_seg_masks):
+ # this is a workaround to make torchscript happy, as torchscript
+ # doesn't support dictionary with non-homogeneous values, such
+ # as a dict having both a Tensor and a list.
+ if self.mask_classification:
+ return [
+ {"pred_logits": a, "pred_masks": b}
+ for a, b in zip(outputs_class[:-1], outputs_seg_masks[:-1])
+ ]
+ else:
+ return [{"pred_masks": b} for b in outputs_seg_masks[:-1]]
+
+
+class MLP(nn.Module):
+ """Very simple multi-layer perceptron (also called FFN)"""
+
+ def __init__(self, input_dim, hidden_dim, output_dim, num_layers):
+ super().__init__()
+ self.num_layers = num_layers
+ h = [hidden_dim] * (num_layers - 1)
+ self.layers = nn.ModuleList(
+ nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim])
+ )
+
+ def forward(self, x):
+ for i, layer in enumerate(self.layers):
+ x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x)
+ return x
diff --git a/ACT_DP_multitask/detr/models/mask_former/test_time_augmentation.py b/ACT_DP_multitask/detr/models/mask_former/test_time_augmentation.py
new file mode 100644
index 0000000000000000000000000000000000000000..8d250b6bb7792b54ddeaaab62cc6c170d74d3bb9
--- /dev/null
+++ b/ACT_DP_multitask/detr/models/mask_former/test_time_augmentation.py
@@ -0,0 +1,113 @@
+# Copyright (c) Facebook, Inc. and its affiliates.
+import copy
+from itertools import count
+
+import numpy as np
+import torch
+from fvcore.transforms import HFlipTransform
+from torch import nn
+from torch.nn.parallel import DistributedDataParallel
+
+from detectron2.data.detection_utils import read_image
+from detectron2.modeling import DatasetMapperTTA
+
+__all__ = [
+ "SemanticSegmentorWithTTA",
+]
+
+
+class SemanticSegmentorWithTTA(nn.Module):
+ """
+ A SemanticSegmentor with test-time augmentation enabled.
+ Its :meth:`__call__` method has the same interface as :meth:`SemanticSegmentor.forward`.
+ """
+
+ def __init__(self, cfg, model, tta_mapper=None, batch_size=1):
+ """
+ Args:
+ cfg (CfgNode):
+ model (SemanticSegmentor): a SemanticSegmentor to apply TTA on.
+ tta_mapper (callable): takes a dataset dict and returns a list of
+ augmented versions of the dataset dict. Defaults to
+ `DatasetMapperTTA(cfg)`.
+ batch_size (int): batch the augmented images into this batch size for inference.
+ """
+ super().__init__()
+ if isinstance(model, DistributedDataParallel):
+ model = model.module
+ self.cfg = cfg.clone()
+
+ self.model = model
+
+ if tta_mapper is None:
+ tta_mapper = DatasetMapperTTA(cfg)
+ self.tta_mapper = tta_mapper
+ self.batch_size = batch_size
+
+ def _batch_inference(self, batched_inputs):
+ """
+ Execute inference on a list of inputs,
+ using batch size = self.batch_size, instead of the length of the list.
+ Inputs & outputs have the same format as :meth:`SemanticSegmentor.forward`
+ """
+ outputs = []
+ inputs = []
+ for idx, input in zip(count(), batched_inputs):
+ inputs.append(input)
+ if len(inputs) == self.batch_size or idx == len(batched_inputs) - 1:
+ with torch.no_grad():
+ outputs.extend(self.model(inputs))
+ inputs = []
+ return outputs
+
+ def __call__(self, batched_inputs):
+ """
+ Same input/output format as :meth:`SemanticSegmentor.forward`
+ """
+
+ def _maybe_read_image(dataset_dict):
+ ret = copy.copy(dataset_dict)
+ if "image" not in ret:
+ image = read_image(ret.pop("file_name"), self.model.input_format)
+ image = torch.from_numpy(np.ascontiguousarray(image.transpose(2, 0, 1))) # CHW
+ ret["image"] = image
+ if "height" not in ret and "width" not in ret:
+ ret["height"] = image.shape[1]
+ ret["width"] = image.shape[2]
+ return ret
+
+ return [self._inference_one_image(_maybe_read_image(x)) for x in batched_inputs]
+
+ def _inference_one_image(self, input):
+ """
+ Args:
+ input (dict): one dataset dict with "image" field being a CHW tensor
+ Returns:
+ dict: one output dict
+ """
+ augmented_inputs, tfms = self._get_augmented_inputs(input)
+ # 1: forward with all augmented images
+ outputs = self._batch_inference(augmented_inputs)
+ # Delete now useless variables to avoid being out of memory
+ del augmented_inputs
+ # 2: merge the results
+ # handle flip specially
+ new_outputs = []
+ for output, tfm in zip(outputs, tfms):
+ if any(isinstance(t, HFlipTransform) for t in tfm.transforms):
+ new_outputs.append(output.pop("sem_seg").flip(dims=[2]))
+ else:
+ new_outputs.append(output.pop("sem_seg"))
+ del outputs
+ # to avoid OOM with torch.stack
+ final_predictions = new_outputs[0]
+ for i in range(1, len(new_outputs)):
+ final_predictions += new_outputs[i]
+ final_predictions = final_predictions / len(new_outputs)
+ del new_outputs
+ return {"sem_seg": final_predictions}
+
+ def _get_augmented_inputs(self, input):
+ augmented_inputs = self.tta_mapper(input)
+ tfms = [x.pop("transforms") for x in augmented_inputs]
+ return augmented_inputs, tfms
diff --git a/ACT_DP_multitask/detr/models/mask_former/utils/__init__.py b/ACT_DP_multitask/detr/models/mask_former/utils/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..9020c2df23e2af280b7bb168b996ae9eaf312eb8
--- /dev/null
+++ b/ACT_DP_multitask/detr/models/mask_former/utils/__init__.py
@@ -0,0 +1 @@
+# Copyright (c) Facebook, Inc. and its affiliates.
diff --git a/ACT_DP_multitask/detr/models/mask_former/utils/misc.py b/ACT_DP_multitask/detr/models/mask_former/utils/misc.py
new file mode 100644
index 0000000000000000000000000000000000000000..874d9805b482f52bbffc1be620e36e0cffc07c46
--- /dev/null
+++ b/ACT_DP_multitask/detr/models/mask_former/utils/misc.py
@@ -0,0 +1,111 @@
+# Copyright (c) Facebook, Inc. and its affiliates.
+# Modified by Bowen Cheng from https://github.com/facebookresearch/detr/blob/master/util/misc.py
+"""
+Misc functions, including distributed helpers.
+
+Mostly copy-paste from torchvision references.
+"""
+from typing import List, Optional
+
+import torch
+import torch.distributed as dist
+import torchvision
+from torch import Tensor
+
+
+def _max_by_axis(the_list):
+ # type: (List[List[int]]) -> List[int]
+ maxes = the_list[0]
+ for sublist in the_list[1:]:
+ for index, item in enumerate(sublist):
+ maxes[index] = max(maxes[index], item)
+ return maxes
+
+
+class NestedTensor(object):
+ def __init__(self, tensors, mask: Optional[Tensor]):
+ self.tensors = tensors
+ self.mask = mask
+
+ def to(self, device):
+ # type: (Device) -> NestedTensor # noqa
+ cast_tensor = self.tensors.to(device)
+ mask = self.mask
+ if mask is not None:
+ assert mask is not None
+ cast_mask = mask.to(device)
+ else:
+ cast_mask = None
+ return NestedTensor(cast_tensor, cast_mask)
+
+ def decompose(self):
+ return self.tensors, self.mask
+
+ def __repr__(self):
+ return str(self.tensors)
+
+
+def nested_tensor_from_tensor_list(tensor_list: List[Tensor]):
+ # TODO make this more general
+ if tensor_list[0].ndim == 3:
+ if torchvision._is_tracing():
+ # nested_tensor_from_tensor_list() does not export well to ONNX
+ # call _onnx_nested_tensor_from_tensor_list() instead
+ return _onnx_nested_tensor_from_tensor_list(tensor_list)
+
+ # TODO make it support different-sized images
+ max_size = _max_by_axis([list(img.shape) for img in tensor_list])
+ # min_size = tuple(min(s) for s in zip(*[img.shape for img in tensor_list]))
+ batch_shape = [len(tensor_list)] + max_size
+ b, c, h, w = batch_shape
+ dtype = tensor_list[0].dtype
+ device = tensor_list[0].device
+ tensor = torch.zeros(batch_shape, dtype=dtype, device=device)
+ mask = torch.ones((b, h, w), dtype=torch.bool, device=device)
+ for img, pad_img, m in zip(tensor_list, tensor, mask):
+ pad_img[: img.shape[0], : img.shape[1], : img.shape[2]].copy_(img)
+ m[: img.shape[1], : img.shape[2]] = False
+ else:
+ raise ValueError("not supported")
+ return NestedTensor(tensor, mask)
+
+
+# _onnx_nested_tensor_from_tensor_list() is an implementation of
+# nested_tensor_from_tensor_list() that is supported by ONNX tracing.
+@torch.jit.unused
+def _onnx_nested_tensor_from_tensor_list(tensor_list: List[Tensor]) -> NestedTensor:
+ max_size = []
+ for i in range(tensor_list[0].dim()):
+ max_size_i = torch.max(
+ torch.stack([img.shape[i] for img in tensor_list]).to(torch.float32)
+ ).to(torch.int64)
+ max_size.append(max_size_i)
+ max_size = tuple(max_size)
+
+ # work around for
+ # pad_img[: img.shape[0], : img.shape[1], : img.shape[2]].copy_(img)
+ # m[: img.shape[1], :img.shape[2]] = False
+ # which is not yet supported in onnx
+ padded_imgs = []
+ padded_masks = []
+ for img in tensor_list:
+ padding = [(s1 - s2) for s1, s2 in zip(max_size, tuple(img.shape))]
+ padded_img = torch.nn.functional.pad(img, (0, padding[2], 0, padding[1], 0, padding[0]))
+ padded_imgs.append(padded_img)
+
+ m = torch.zeros_like(img[0], dtype=torch.int, device=img.device)
+ padded_mask = torch.nn.functional.pad(m, (0, padding[2], 0, padding[1]), "constant", 1)
+ padded_masks.append(padded_mask.to(torch.bool))
+
+ tensor = torch.stack(padded_imgs)
+ mask = torch.stack(padded_masks)
+
+ return NestedTensor(tensor, mask=mask)
+
+
+def is_dist_avail_and_initialized():
+ if not dist.is_available():
+ return False
+ if not dist.is_initialized():
+ return False
+ return True
diff --git a/ACT_DP_multitask/detr/models/mr_mg/LICENSE b/ACT_DP_multitask/detr/models/mr_mg/LICENSE
new file mode 100644
index 0000000000000000000000000000000000000000..8c4d49b8628acece7ddcf25a6cc8ea24745a827c
--- /dev/null
+++ b/ACT_DP_multitask/detr/models/mr_mg/LICENSE
@@ -0,0 +1,226 @@
+ Apache License
+ Version 2.0, January 2004
+ http://www.apache.org/licenses/
+
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
+
+ 1. Definitions.
+
+ "License" shall mean the terms and conditions for use, reproduction,
+ and distribution as defined by Sections 1 through 9 of this document.
+
+ "Licensor" shall mean the copyright owner or entity authorized by
+ the copyright owner that is granting the License.
+
+ "Legal Entity" shall mean the union of the acting entity and all
+ other entities that control, are controlled by, or are under common
+ control with that entity. For the purposes of this definition,
+ "control" means (i) the power, direct or indirect, to cause the
+ direction or management of such entity, whether by contract or
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
+ outstanding shares, or (iii) beneficial ownership of such entity.
+
+ "You" (or "Your") shall mean an individual or Legal Entity
+ exercising permissions granted by this License.
+
+ "Source" form shall mean the preferred form for making modifications,
+ including but not limited to software source code, documentation
+ source, and configuration files.
+
+ "Object" form shall mean any form resulting from mechanical
+ transformation or translation of a Source form, including but
+ not limited to compiled object code, generated documentation,
+ and conversions to other media types.
+
+ "Work" shall mean the work of authorship, whether in Source or
+ Object form, made available under the License, as indicated by a
+ copyright notice that is included in or attached to the work
+ (an example is provided in the Appendix below).
+
+ "Derivative Works" shall mean any work, whether in Source or Object
+ form, that is based on (or derived from) the Work and for which the
+ editorial revisions, annotations, elaborations, or other modifications
+ represent, as a whole, an original work of authorship. For the purposes
+ of this License, Derivative Works shall not include works that remain
+ separable from, or merely link (or bind by name) to the interfaces of,
+ the Work and Derivative Works thereof.
+
+ "Contribution" shall mean any work of authorship, including
+ the original version of the Work and any modifications or additions
+ to that Work or Derivative Works thereof, that is intentionally
+ submitted to Licensor for inclusion in the Work by the copyright owner
+ or by an individual or Legal Entity authorized to submit on behalf of
+ the copyright owner. For the purposes of this definition, "submitted"
+ means any form of electronic, verbal, or written communication sent
+ to the Licensor or its representatives, including but not limited to
+ communication on electronic mailing lists, source code control systems,
+ and issue tracking systems that are managed by, or on behalf of, the
+ Licensor for the purpose of discussing and improving the Work, but
+ excluding communication that is conspicuously marked or otherwise
+ designated in writing by the copyright owner as "Not a Contribution."
+
+ "Contributor" shall mean Licensor and any individual or Legal Entity
+ on behalf of whom a Contribution has been received by Licensor and
+ subsequently incorporated within the Work.
+
+ 2. Grant of Copyright License. Subject to the terms and conditions of
+ this License, each Contributor hereby grants to You a perpetual,
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
+ copyright license to reproduce, prepare Derivative Works of,
+ publicly display, publicly perform, sublicense, and distribute the
+ Work and such Derivative Works in Source or Object form.
+
+ 3. Grant of Patent License. Subject to the terms and conditions of
+ this License, each Contributor hereby grants to You a perpetual,
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
+ (except as stated in this section) patent license to make, have made,
+ use, offer to sell, sell, import, and otherwise transfer the Work,
+ where such license applies only to those patent claims licensable
+ by such Contributor that are necessarily infringed by their
+ Contribution(s) alone or by combination of their Contribution(s)
+ with the Work to which such Contribution(s) was submitted. If You
+ institute patent litigation against any entity (including a
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
+ or a Contribution incorporated within the Work constitutes direct
+ or contributory patent infringement, then any patent licenses
+ granted to You under this License for that Work shall terminate
+ as of the date such litigation is filed.
+
+ 4. Redistribution. You may reproduce and distribute copies of the
+ Work or Derivative Works thereof in any medium, with or without
+ modifications, and in Source or Object form, provided that You
+ meet the following conditions:
+
+ (a) You must give any other recipients of the Work or
+ Derivative Works a copy of this License; and
+
+ (b) You must cause any modified files to carry prominent notices
+ stating that You changed the files; and
+
+ (c) You must retain, in the Source form of any Derivative Works
+ that You distribute, all copyright, patent, trademark, and
+ attribution notices from the Source form of the Work,
+ excluding those notices that do not pertain to any part of
+ the Derivative Works; and
+
+ (d) If the Work includes a "NOTICE" text file as part of its
+ distribution, then any Derivative Works that You distribute must
+ include a readable copy of the attribution notices contained
+ within such NOTICE file, excluding those notices that do not
+ pertain to any part of the Derivative Works, in at least one
+ of the following places: within a NOTICE text file distributed
+ as part of the Derivative Works; within the Source form or
+ documentation, if provided along with the Derivative Works; or,
+ within a display generated by the Derivative Works, if and
+ wherever such third-party notices normally appear. The contents
+ of the NOTICE file are for informational purposes only and
+ do not modify the License. You may add Your own attribution
+ notices within Derivative Works that You distribute, alongside
+ or as an addendum to the NOTICE text from the Work, provided
+ that such additional attribution notices cannot be construed
+ as modifying the License.
+
+ You may add Your own copyright statement to Your modifications and
+ may provide additional or different license terms and conditions
+ for use, reproduction, or distribution of Your modifications, or
+ for any such Derivative Works as a whole, provided Your use,
+ reproduction, and distribution of the Work otherwise complies with
+ the conditions stated in this License.
+
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
+ any Contribution intentionally submitted for inclusion in the Work
+ by You to the Licensor shall be under the terms and conditions of
+ this License, without any additional terms or conditions.
+ Notwithstanding the above, nothing herein shall supersede or modify
+ the terms of any separate license agreement you may have executed
+ with Licensor regarding such Contributions.
+
+ 6. Trademarks. This License does not grant permission to use the trade
+ names, trademarks, service marks, or product names of the Licensor,
+ except as required for reasonable and customary use in describing the
+ origin of the Work and reproducing the content of the NOTICE file.
+
+ 7. Disclaimer of Warranty. Unless required by applicable law or
+ agreed to in writing, Licensor provides the Work (and each
+ Contributor provides its Contributions) on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
+ implied, including, without limitation, any warranties or conditions
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
+ PARTICULAR PURPOSE. You are solely responsible for determining the
+ appropriateness of using or redistributing the Work and assume any
+ risks associated with Your exercise of permissions under this License.
+
+ 8. Limitation of Liability. In no event and under no legal theory,
+ whether in tort (including negligence), contract, or otherwise,
+ unless required by applicable law (such as deliberate and grossly
+ negligent acts) or agreed to in writing, shall any Contributor be
+ liable to You for damages, including any direct, indirect, special,
+ incidental, or consequential damages of any character arising as a
+ result of this License or out of the use or inability to use the
+ Work (including but not limited to damages for loss of goodwill,
+ work stoppage, computer failure or malfunction, or any and all
+ other commercial damages or losses), even if such Contributor
+ has been advised of the possibility of such damages.
+
+ 9. Accepting Warranty or Additional Liability. While redistributing
+ the Work or Derivative Works thereof, You may choose to offer,
+ and charge a fee for, acceptance of support, warranty, indemnity,
+ or other liability obligations and/or rights consistent with this
+ License. However, in accepting such obligations, You may act only
+ on Your own behalf and on Your sole responsibility, not on behalf
+ of any other Contributor, and only if You agree to indemnify,
+ defend, and hold each Contributor harmless for any liability
+ incurred by, or claims asserted against, such Contributor by reason
+ of your accepting any such warranty or additional liability.
+
+ END OF TERMS AND CONDITIONS
+
+ APPENDIX: How to apply the Apache License to your work.
+
+ To apply the Apache License to your work, attach the following
+ boilerplate notice, with the fields enclosed by brackets "{}"
+ replaced with your own identifying information. (Don't include
+ the brackets!) The text should be enclosed in the appropriate
+ comment syntax for the file format. We also recommend that a
+ file or class name and description of purpose be included on the
+ same "printed page" as the copyright notice for easier
+ identification within third-party archives.
+
+ Copyright {yyyy} {name of copyright owner}
+
+ Licensed under the Apache License, Version 2.0 (the "License");
+ you may not use this file except in compliance with the License.
+ You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
+
+=======================================================================
+Apache SkyWalking Subcomponents:
+
+The Apache SkyWalking project contains subcomponents with separate copyright
+notices and license terms. Your use of the source code for the these
+subcomponents is subject to the terms and conditions of the following
+licenses.
+
+========================================================================
+Apache 2.0 licenses
+========================================================================
+
+The following components are provided under the Apache License. See project link for details.
+The text of each license is the standard Apache 2.0 license.
+
+ proto files from cncf/udpa: https://github.com/cncf/udpa Apache 2.0
+ proto files from envoyproxy/data-plane-api: https://github.com/envoyproxy/data-plane-api Apache 2.0
+ proto files from prometheus/client_model: https://github.com/prometheus/client_model Apache 2.0
+ proto files from opentelemetry: https://github.com/open-telemetry/opentelemetry-proto/tree/main/opentelemetry/proto Apache 2.0
+ proto files from opentelemetry: https://github.com/open-telemetry/opentelemetry-proto/tree/v0.7.0 Apache 2.0
+ flatbuffers files from istio/proxy: https://github.com/istio/proxy Apache 2.0
+ mvnw files from https://github.com/apache/maven-wrapper Apache 2.0
+ svg files from skywalking-ui/src/assets/icons: https://github.com/google/material-design-icons Apache 2.0
+ ZipkinQueryHandler.java reference from zipkin2.server.internal.ZipkinQueryApiV2: https://github.com/openzipkin/zipkin Apache 2.0
\ No newline at end of file
diff --git a/ACT_DP_multitask/detr/models/mr_mg/README.md b/ACT_DP_multitask/detr/models/mr_mg/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..f5109c25686204f03e545d0655aeecee62f228c1
--- /dev/null
+++ b/ACT_DP_multitask/detr/models/mr_mg/README.md
@@ -0,0 +1,98 @@
+
GR-MG
+
+This repo contains code for the paper:
+### Leveraging Partially Annotated Data via Multi-Modal Goal Conditioned Policy
+
+[Peiyan Li](https://github.com/LPY1219), [Hongtao Wu\*‡](https://scholar.google.com/citations?hl=zh-CN&user=7u0TYgIAAAAJ&view_op=list_works&sortby=pubdate), [Yan Huang\*](https://yanrockhuang.github.io/), [Chilam Cheang](https://github.com/bytedance/GR-MG/tree/main), [Liang Wang](https://scholar.google.com/citations?hl=zh-CN&user=8kzzUboAAAAJ&view_op=list_works&sortby=pubdate), [Tao Kong](https://www.taokong.org/)
+
+*Corresponding author ‡ Project lead
+
+### [🌠Project Website](https://gr-mg.github.io/) | [📄 Paper](https://arxiv.org/abs/2408.14368)
+
+
+
+
+
+
+
+## News
+- (🔥 New) **(2024.08.27)** We have released the code and checkpoints of GR-MG !
+## Preparation
+**Note:** We only test GR-MG with CUDA 12.1 and python 3.9
+
+```bash
+# clone this repository
+git clone https://github.com/bytedance/GR-MG.git
+cd GR_MG
+# install dependencies for goal image generation model
+bash ./goal_gen/install.sh
+# install dependencies for multi-modal goal conditioned policy
+bash ./policy/install.sh
+```
+Download the pretrained [InstructPix2Pix](https://huggingface.co/timbrooks/instruct-pix2pix) weights from Huggingface and save them in `resources/IP2P/`.
+Download the pretrained MAE encoder [mae_pretrain_vit_base.pth ](https://dl.fbaipublicfiles.com/mae/pretrain/mae_pretrain_vit_base.pth) and save it in `resources/MAE/`.
+Download and unzip the [CALVIN](https://github.com/mees/calvin) dataset.
+
+
+## Checkpoints
+- [Multi-modal Goal Conditioned Policy](https://lf-robot-opensource.bytetos.com/obj/lab-robot-public/gr_mg_release/epoch=47-step=83712.ckpt)
+- [Goal Image Generation Model](https://lf-robot-opensource.bytetos.com/obj/lab-robot-public/gr_mg_release/goal_gen.ckpt)
+
+
+## Training
+
+### 1. Train Goal Image Generation Model
+```bash
+# modify the variables in the script before you execute the following instruction
+bash ./goal_gen/train_ip2p.sh ./goal_gen/config/train.json
+```
+### 2. Pretrain Multi-modal Goal Conditioned Policy
+We use the method described in [GR-1](https://arxiv.org/abs/2312.13139) and pretrain our policy with Ego4D videos. You can download the pretrained model checkpoint [here](https://lf-robot-opensource.bytetos.com/obj/lab-robot-public/gr_mg_release/pretrained.pt). You can also pretrain the policy yourself using the scripts we provide. Before doing this, you'll need to download the [Ego4D](https://ego4d-data.org/) dataset.
+
+```bash
+# pretrain multi-modal goal conditioned policy
+bash ./policy/main.sh ./policy/config/pretrain.json
+```
+### 3. Train Multi-modal Goal Conditioned Policy
+After pretraining, modify the pretrained_model_path in `/policy/config/train.json` and execute the following instruction to train the policy.
+```bash
+# train multi-modal goal conditioned policy
+bash ./policy/main.sh ./policy/config/train.json
+```
+
+
+## Evaluation
+To evaluate our model on CALVIN, you can execute the following instruction:
+```bash
+# Evaluate GR-MG on CALVIN
+bash ./evaluate/eval.sh ./policy/config/train.json
+```
+In the `eval.sh` script, you can specify which goal image generation model and policy to use. Additionally, we provide multi-GPU evaluation code, allowing you to evaluate different training epochs of the policy simultaneously.
+
+
+## Contact
+If you have any questions about the project, please contact peiyan.li@cripac.ia.ac.cn.
+
+
+## Acknowledgements
+
+We thank the authors of the following projects for making their code and dataset open source:
+
+- [CALVIN](https://github.com/mees/calvin)
+- [InstructPix2Pix](https://github.com/timothybrooks/instruct-pix2pix)
+- [T5](https://github.com/google-research/text-to-text-transfer-transformer)
+- [GR-1](https://github.com/bytedance/GR-1)
+- [CLIP](https://github.com/openai/CLIP)
+- [MAE](https://github.com/facebookresearch/mae)
+
+## Citation
+
+If you find this project useful, please star the repository and cite our paper:
+```
+@article{li2024gr,
+ title={GR-MG: Leveraging Partially Annotated Data via Multi-Modal Goal Conditioned Policy},
+ author={Li, Peiyan and Wu, Hongtao and Huang, Yan and Cheang, Chilam and Wang, Liang and Kong, Tao},
+ journal={arXiv preprint arXiv:2408.14368},
+ year={2024}
+}
+```
diff --git a/ACT_DP_multitask/detr/models/mr_mg/evaluate/eval.py b/ACT_DP_multitask/detr/models/mr_mg/evaluate/eval.py
new file mode 100644
index 0000000000000000000000000000000000000000..8f2c458dfd3f817f03b60e4e5ce8ce8391fc3570
--- /dev/null
+++ b/ACT_DP_multitask/detr/models/mr_mg/evaluate/eval.py
@@ -0,0 +1,220 @@
+# MIT License
+
+# Copyright (c) 2021 Oier Mees
+# Copyright (c) 2024 Bytedance Ltd. and/or its affiliates
+
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+
+# The above copyright notice and this permission notice shall be included in all
+# copies or substantial portions of the Software.
+
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+# SOFTWARE.
+import argparse
+import json
+import logging
+import os
+from pathlib import Path
+import sys
+import time
+import re
+import copy
+from copy import deepcopy
+import os
+# This is for using the locally installed repo clone when using slurm
+import matplotlib.pyplot as plt
+sys.path.insert(0, Path(__file__).absolute().parents[2].as_posix())
+from calvin_agent.evaluation.multistep_sequences import get_sequences
+from calvin_agent.evaluation.utils import (
+ count_success,
+ get_env_state_for_initial_condition
+)
+import hydra
+import numpy as np
+from omegaconf import OmegaConf
+from pytorch_lightning import seed_everything
+from termcolor import colored
+import torch
+from tqdm.auto import tqdm
+from utils.utils import print_and_save
+from wrapper.model_wrapper import CustomModel
+from goal_gen.evaluate import IP2PEvaluation
+
+logger = logging.getLogger(__name__)
+EP_LEN = 360
+NUM_SEQUENCES = 1000
+SAVE_DIR = None
+FAIL_COUNTER=0
+
+def make_env(dataset_path, observation_space, device_id):
+ val_folder = Path(dataset_path) / "validation"
+ # insert your own env wrapper
+ from wrapper.calvin_env_wrapper_raw import CalvinEnvWrapperRaw
+ device = torch.device('cuda', device_id)
+ env = CalvinEnvWrapperRaw(val_folder, observation_space, device)
+ return env
+
+
+def evaluate_policy(model, env, eval_sr_path, eval_result_path, ip2p_model):
+ """Run this function to evaluate a model on the CALVIN challenge."""
+ conf_dir = Path("./calvin/calvin_models/conf")
+ task_cfg = OmegaConf.load(conf_dir / "callbacks/rollout/tasks/new_playtable_tasks.yaml")
+ task_oracle = hydra.utils.instantiate(task_cfg)
+ val_annotations = OmegaConf.load(conf_dir / "annotations/new_playtable_validation.yaml")
+ eval_sequences = get_sequences(NUM_SEQUENCES)
+ results = []
+ sequence_i = 0
+ for index,(initial_state, eval_sequence) in enumerate(eval_sequences):
+ result= evaluate_sequence(env, model, task_oracle, initial_state, eval_sequence, val_annotations, sequence_i,ip2p_model)
+ results.append(result)
+ success_list = count_success(results)
+ with open(eval_sr_path, 'a') as f:
+ line =f"{sequence_i}/{NUM_SEQUENCES}: "
+ for sr in success_list:
+ line += f"{sr:.3f} | "
+ sequence_i += 1
+ line += "\n"
+ f.write(line)
+
+ if index%100==0 and index!=0: #save every 100 sequences
+ print_and_save(results, eval_sequences[:index+1], eval_result_path[:-5]+f"_{index+1}"+".json", None)
+ print_and_save(results, eval_sequences, eval_result_path, None)
+ return results
+
+
+def evaluate_sequence(env, model, task_checker, initial_state, eval_sequence, val_annotations, sequence_i,ip2p_model):
+ """Evaluates a sequence of language instructions."""
+ robot_obs, scene_obs = get_env_state_for_initial_condition(initial_state)
+ env.reset(robot_obs=robot_obs, scene_obs=scene_obs)
+ success_counter = 0
+ for subtask_i, subtask in enumerate(eval_sequence):
+ success = rollout(env, model, task_checker, subtask, val_annotations, subtask_i, sequence_i,ip2p_model)
+ if success:
+ success_counter += 1
+ else:
+ return success_counter
+ return success_counter
+
+def rollout(env, model, task_oracle, subtask, val_annotations, subtask_i, sequence_i,ip2p_model):
+ """Run the actual rollout on one subtask."""
+ obs = env.get_obs()
+ # get lang annotation for subtask
+ lang_annotation = val_annotations[subtask][0]
+ model.reset()
+ start_info = env.get_info()
+ debug_image=[]
+ progress=0
+ for i in range(EP_LEN):
+ if i % 20 == 0: # hardcode
+ static_rgb = obs['rgb_obs']['rgb_static'] # (200, 200, 3)
+ hand_rgb = obs['rgb_obs']['rgb_gripper']
+ image_patch=[static_rgb]
+ text_patch=[lang_annotation + f".And {progress}% of the instruction has been finished."]
+ print(text_patch)
+ goal_image=ip2p_model.inference(image_patch,text_patch)
+ temp_image=[static_rgb,goal_image[0],hand_rgb]
+ debug_image.append(temp_image)
+
+ action,progress = model.step(obs,deepcopy(goal_image),[lang_annotation])
+ obs, _, _, current_info = env.step(action)
+
+ # check if current step solves a task
+ current_task_info = task_oracle.get_task_info_for_set(start_info, current_info, {subtask})
+ if len(current_task_info) > 0:
+ print("success!")
+ return True
+ print("fail!")
+
+ global FAIL_COUNTER
+ FAIL_COUNTER+=1
+ if FAIL_COUNTER % 30 ==0: # save every 30 failure cases
+ length=len(debug_image)
+ fig, ax = plt.subplots(length, 2,figsize=(5.5, 46.58))
+ for ax_ in ax.flat:
+ # ax_.plot([1, 2, 3], [4, 5, 6])
+ ax_.axis('off') # éšè—æ¯ä¸ªå图的刻度和边框
+ for i in range(length):
+ ax[i][0].imshow(debug_image[i][0])
+ ax[i][1].imshow(debug_image[i][1])
+ # ax[i][2].imshow(debug_image[i][2])
+ plt.tight_layout()
+ plt.axis('off')
+ plt.savefig(os.path.join(SAVE_DIR, f"{sequence_i}-{subtask_i}-{subtask}.png"),dpi=100)
+ plt.close()
+ return False
+
+
+def main():
+ seed_everything(0, workers=True) # type:ignore
+ parser = argparse.ArgumentParser(description="Evaluate a trained model on multistep sequences with language goals.")
+ parser.add_argument("--dataset_path", default='/mnt/bn/robotics/manipulation_data/calvin_data/task_ABC_D',
+ type=str, help="Path to the dataset root directory.") # modify it before opensource
+ # evaluation
+ parser.add_argument('--config_path', type=str, default="", help='path to the policy config file')
+ parser.add_argument('--ckpt_dir', type=str, default="",help="path to the policy ckpt file")
+ parser.add_argument('--epoch', type=int,default=41, help="epoch index for evaluating")
+ parser.add_argument('--device_id', default=0, type=int, help="CUDA device")
+ parser.add_argument('--ip2p_ckpt_path', default="", type=str, help="ip2p ckpt path")
+ args = parser.parse_args()
+ config_path = args.config_path
+ ckpt_dir = args.ckpt_dir
+ epoch = args.epoch
+ device_id = args.device_id
+ ip2p_ckpt_path=args.ip2p_ckpt_path
+ assert config_path != None
+ # Load config file
+ with open(config_path, 'r') as f:
+ configs = json.load(f)
+
+ # Get checkpoint path
+ ckpt_path = None
+ ckpt_files = os.listdir(ckpt_dir)
+ for ckpt_file in ckpt_files:
+ match = re.search(r'epoch=(\d+)', ckpt_file)
+ if match:
+ temp_epoch = int(match.group(1))
+ if temp_epoch == epoch:
+ ckpt_path = os.path.join(ckpt_dir, ckpt_file)
+ break
+
+ device = torch.device('cuda', device_id)
+ model = CustomModel(
+ ckpt_path=ckpt_path,
+ configs=configs,
+ device=device)
+ observation_space = {
+ 'rgb_obs': ['rgb_static', 'rgb_gripper'],
+ 'depth_obs': [],
+ 'state_obs': ['robot_obs'],
+ 'actions': ['rel_actions'],
+ 'language': ['language']}
+ env = make_env(args.dataset_path, observation_space, device_id)
+ # Success rate and result files
+ flag="opensourcesd"
+ sub_dir=f"{flag}_{epoch}_epoch"
+ # set a global variable
+ global SAVE_DIR
+ SAVE_DIR=os.path.join(ckpt_dir,sub_dir)
+ if not os.path.exists(SAVE_DIR):
+ os.makedirs(SAVE_DIR)
+ sr_path = os.path.join(SAVE_DIR, f"success_rate.txt")
+ result_path = os.path.join(SAVE_DIR, f"results.json")
+ ip2p_model=IP2PEvaluation(ip2p_ckpt_path)
+ evaluate_policy(
+ model,
+ env,
+ eval_sr_path=sr_path,
+ eval_result_path=result_path,
+ ip2p_model=ip2p_model)
+if __name__ == "__main__":
+ main()
\ No newline at end of file
diff --git a/ACT_DP_multitask/detr/models/mr_mg/evaluate/eval.sh b/ACT_DP_multitask/detr/models/mr_mg/evaluate/eval.sh
new file mode 100644
index 0000000000000000000000000000000000000000..2249b0810d6b69defbe96199be44e15f9892ccd4
--- /dev/null
+++ b/ACT_DP_multitask/detr/models/mr_mg/evaluate/eval.sh
@@ -0,0 +1,21 @@
+cd /opt/tiger/GR_MG
+export EPOCHS=(47)
+export CKPT_DIR="PATH_TO_POLICY_DIR"
+export SD_CKPT="/PATH_TO_GOAL_GEN_MODEL_CKPT/epoch=49-step=51450.ckpt"
+export MESA_GL_VERSION_OVERRIDE=3.3
+echo $EPOCHS
+echo $CKPT_DIR
+sudo chmod 777 -R ${CKPT_DIR}
+
+export COUNTER=-1
+# Use a for loop to iterate through a list
+for epoch in "${EPOCHS[@]}"; do
+ export COUNTER=$((${COUNTER} + 1))
+ export CUDA_VISIBLE_DEVICES=${COUNTER}
+ python3 evaluate/eval.py \
+ --ckpt_dir ${CKPT_DIR} \
+ --epoch ${epoch} \
+ --ip2p_ckpt_path ${SD_CKPT} \
+ --config_path ${@:1} &
+done
+wait
\ No newline at end of file
diff --git a/ACT_DP_multitask/detr/models/mr_mg/evaluate/install.sh b/ACT_DP_multitask/detr/models/mr_mg/evaluate/install.sh
new file mode 100644
index 0000000000000000000000000000000000000000..1d61fe29916418dda3fb66332bd229a278dee79d
--- /dev/null
+++ b/ACT_DP_multitask/detr/models/mr_mg/evaluate/install.sh
@@ -0,0 +1,22 @@
+#!/bin/bash
+# You should first install the packages needed by policy ang goal image generation model
+# install calvin
+cd /opt/tiger/GR_MG
+git clone --recurse-submodules https://github.com/mees/calvin.git
+export CALVIN_ROOT=$(pwd)/calvin
+cd calvin
+cd calvin_env; git checkout main
+cd ..
+
+cd $CALVIN_ROOT
+pip3 install setuptools==57.5.0
+sh install.sh
+cd /opt/tiger/GR_MG
+export EVALUTION_ROOT=$(pwd)
+
+# Install dependency for calvin
+sudo apt-get -y install libegl1-mesa libegl1
+sudo apt-get -y install libgl1
+sudo apt-get -y install libosmesa6-dev
+sudo apt install ffmpeg
+sudo apt-get -y install patchelf
\ No newline at end of file
diff --git a/ACT_DP_multitask/detr/models/mr_mg/evaluate/utils/utils.py b/ACT_DP_multitask/detr/models/mr_mg/evaluate/utils/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..36d6e3e5518eadddb39d96ddf611091446521028
--- /dev/null
+++ b/ACT_DP_multitask/detr/models/mr_mg/evaluate/utils/utils.py
@@ -0,0 +1,158 @@
+# MIT License
+
+# Copyright (c) 2021 Oier Mees
+# Copyright (c) 2024 Bytedance Ltd. and/or its affiliates
+
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+
+# The above copyright notice and this permission notice shall be included in all
+# copies or substantial portions of the Software.
+
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+# SOFTWARE.
+
+from collections import Counter
+import json
+import math
+
+import numpy as np
+
+def count_success(results):
+ count = Counter(results)
+ step_success = []
+ for i in range(1, 6):
+ n_success = sum(count[j] for j in reversed(range(i, 6)))
+ sr = n_success / len(results)
+ step_success.append(sr)
+ return step_success
+
+def print_and_save(results, sequences, eval_result_path, epoch=None):
+ current_data = {}
+ print(f"Results for Epoch {epoch}:")
+ avg_seq_len = np.mean(results)
+ chain_sr = {i + 1: sr for i, sr in enumerate(count_success(results))}
+ print(f"Average successful sequence length: {avg_seq_len}")
+ print("Success rates for i instructions in a row:")
+ for i, sr in chain_sr.items():
+ print(f"{i}: {sr * 100:.1f}%")
+
+ cnt_success = Counter()
+ cnt_fail = Counter()
+
+ for result, (_, sequence) in zip(results, sequences):
+ for successful_tasks in sequence[:result]:
+ cnt_success[successful_tasks] += 1
+ if result < len(sequence):
+ failed_task = sequence[result]
+ cnt_fail[failed_task] += 1
+
+ total = cnt_success + cnt_fail
+ task_info = {}
+ for task in total:
+ task_info[task] = {"success": cnt_success[task], "total": total[task]}
+ print(f"{task}: {cnt_success[task]} / {total[task]} | SR: {cnt_success[task] / total[task] * 100:.1f}%")
+
+ data = {"number of seq": len(results),"avg_seq_len": avg_seq_len, "chain_sr": chain_sr, "task_info": task_info}
+
+ current_data[epoch] = data
+
+ print()
+ previous_data = {}
+ json_data = {**previous_data, **current_data}
+ with open(eval_result_path, "w") as file:
+ json.dump(json_data, file)
+ print(
+ f"Best model: epoch {max(json_data, key=lambda x: json_data[x]['avg_seq_len'])} "
+ f"with average sequences length of {max(map(lambda x: x['avg_seq_len'], json_data.values()))}"
+ )
+
+def alpha2rotm(a):
+ """Alpha euler angle to rotation matrix."""
+ rotm = np.array([
+ [1, 0, 0],
+ [0, np.cos(a), -np.sin(a)],
+ [0, np.sin(a), np.cos(a)]
+ ])
+ return rotm
+
+def beta2rotm(b):
+ """Beta euler angle to rotation matrix."""
+ rotm = np.array([
+ [np.cos(b), 0, np.sin(b)],
+ [0, 1, 0],
+ [-np.sin(b), 0, np.cos(b)]
+ ])
+ return rotm
+
+def gamma2rotm(c):
+ """Gamma euler angle to rotation matrix."""
+ rotm = np.array([
+ [np.cos(c), -np.sin(c), 0],
+ [np.sin(c), np.cos(c), 0],
+ [0, 0, 1]
+ ])
+ return rotm
+
+def euler2rotm(euler_angles):
+ """Euler angle (ZYX) to rotation matrix."""
+ alpha = euler_angles[0]
+ beta = euler_angles[1]
+ gamma = euler_angles[2]
+
+ rotm_a = alpha2rotm(alpha)
+ rotm_b = beta2rotm(beta)
+ rotm_c = gamma2rotm(gamma)
+
+ rotm = rotm_c @ rotm_b @ rotm_a
+
+ return rotm
+
+def isRotm(R):
+ # Checks if a matrix is a valid rotation matrix.
+ # Forked from Andy Zeng
+ Rt = np.transpose(R)
+ shouldBeIdentity = np.dot(Rt, R)
+ I = np.identity(3, dtype=R.dtype)
+ n = np.linalg.norm(I - shouldBeIdentity)
+ return n < 1e-6
+
+def rotm2euler(R):
+ # Forked from: https://learnopencv.com/rotation-matrix-to-euler-angles/
+ # R = Rz * Ry * Rx
+ assert(isRotm(R))
+ sy = math.sqrt(R[0,0] * R[0,0] + R[1,0] * R[1,0])
+ singular = sy < 1e-6
+
+ if not singular :
+ x = math.atan2(R[2,1] , R[2,2])
+ y = math.atan2(-R[2,0], sy)
+ z = math.atan2(R[1,0], R[0,0])
+ else :
+ x = math.atan2(-R[1,2], R[1,1])
+ y = math.atan2(-R[2,0], sy)
+ z = 0
+
+ # (-pi , pi]
+ while x > np.pi:
+ x -= (2 * np.pi)
+ while x <= -np.pi:
+ x += (2 * np.pi)
+ while y > np.pi:
+ y -= (2 * np.pi)
+ while y <= -np.pi:
+ y += (2 * np.pi)
+ while z > np.pi:
+ z -= (2 * np.pi)
+ while z <= -np.pi:
+ z += (2 * np.pi)
+ return np.array([x, y, z])
diff --git a/ACT_DP_multitask/detr/models/mr_mg/evaluate/wrapper/calvin_env_wrapper_raw.py b/ACT_DP_multitask/detr/models/mr_mg/evaluate/wrapper/calvin_env_wrapper_raw.py
new file mode 100644
index 0000000000000000000000000000000000000000..31d7093f874d18a8da774aed9033688fd90e1d3e
--- /dev/null
+++ b/ACT_DP_multitask/detr/models/mr_mg/evaluate/wrapper/calvin_env_wrapper_raw.py
@@ -0,0 +1,117 @@
+# MIT License
+
+# Copyright (c) 2021 Oier Mees
+# Copyright (c) 2024 Bytedance Ltd. and/or its affiliates
+
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+
+# The above copyright notice and this permission notice shall be included in all
+# copies or substantial portions of the Software.
+
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+# SOFTWARE.
+import logging
+import os
+from typing import Any, Dict, Tuple, Union
+from calvin_agent.datasets.utils.episode_utils import process_depth, process_rgb, process_state
+import gym
+import numpy as np
+import torch
+from calvin_env.envs.play_table_env import get_env
+from calvin_env.utils.utils import EglDeviceNotFoundError, get_egl_device_id
+logger = logging.getLogger(__name__)
+class CalvinEnvWrapperRaw(gym.Wrapper):
+ def __init__(self, abs_datasets_dir, observation_space, device, show_gui=False, **kwargs):
+ """Environment wrapper which returns raw observations.
+
+ Args:
+ abs_datasets_dir: absolute datset directory
+ observation_sapce: {'rgb_obs': ['rgb_static', 'rgb_gripper'], 'depth_obs': [], 'state_obs': ['robot_obs'], 'actions': ['rel_actions'], 'language': ['language']}
+ """
+ # self.set_egl_device(device)
+ env = get_env(
+ abs_datasets_dir, show_gui=show_gui, obs_space=observation_space, **kwargs
+ )
+ super(CalvinEnvWrapperRaw, self).__init__(env)
+ self.observation_space_keys = observation_space
+ self.device = device
+ self.relative_actions = "rel_actions" in self.observation_space_keys["actions"]
+ logger.info(f"Initialized PlayTableEnv for device {self.device}")
+
+ @staticmethod
+ def set_egl_device(device):
+ if "EGL_VISIBLE_DEVICES" in os.environ:
+ logger.warning("Environment variable EGL_VISIBLE_DEVICES is already set. Is this intended?")
+ cuda_id = device.index if device.type == "cuda" else 0
+ try:
+ egl_id = get_egl_device_id(cuda_id)
+ except EglDeviceNotFoundError:
+ logger.warning(
+ "Couldn't find correct EGL device. Setting EGL_VISIBLE_DEVICE=0. "
+ "When using DDP with many GPUs this can lead to OOM errors. "
+ "Did you install PyBullet correctly? Please refer to calvin env README"
+ )
+ egl_id = 0
+ os.environ["EGL_VISIBLE_DEVICES"] = str(egl_id)
+ logger.info(f"EGL_DEVICE_ID {egl_id} <==> CUDA_DEVICE_ID {cuda_id}")
+
+ def step(
+ self, action_tensor: torch.Tensor
+ ) -> Tuple[Dict[str, Union[torch.Tensor, Dict[str, torch.Tensor]]], int, bool, Dict]:
+ if self.relative_actions:
+ action = action_tensor.squeeze().cpu().detach().numpy()
+ assert len(action) == 7
+ else:
+ if action_tensor.shape[-1] == 7:
+ slice_ids = [3, 6]
+ elif action_tensor.shape[-1] == 8:
+ slice_ids = [3, 7]
+ else:
+ logger.error("actions are required to have length 8 (for euler angles) or 9 (for quaternions)")
+ raise NotImplementedError
+ action = np.split(action_tensor.squeeze().cpu().detach().numpy(), slice_ids)
+ action[-1] = 1 if action[-1] > 0 else -1
+ o, r, d, i = self.env.step(action)
+
+ # obs = self.transform_observation(o)
+ obs = o # use raw observation
+ return obs, r, d, i
+
+ def reset(
+ self,
+ reset_info: Dict[str, Any] = None,
+ batch_idx: int = 0,
+ seq_idx: int = 0,
+ scene_obs: Any = None,
+ robot_obs: Any = None,
+ ) -> Dict[str, Union[torch.Tensor, Dict[str, torch.Tensor]]]:
+ if reset_info is not None:
+ obs = self.env.reset(
+ robot_obs=reset_info["robot_obs"][batch_idx, seq_idx],
+ scene_obs=reset_info["scene_obs"][batch_idx, seq_idx],
+ )
+ elif scene_obs is not None or robot_obs is not None:
+ obs = self.env.reset(scene_obs=scene_obs, robot_obs=robot_obs)
+ else:
+ obs = self.env.reset()
+
+ # return self.transform_observation(obs)
+ return obs # use raw observation
+
+ def get_info(self):
+ return self.env.get_info()
+
+ def get_obs(self):
+ obs = self.env.get_obs()
+ # return self.transform_observation(obs)
+ return obs # use raw observation
diff --git a/ACT_DP_multitask/detr/models/mr_mg/evaluate/wrapper/model_wrapper.py b/ACT_DP_multitask/detr/models/mr_mg/evaluate/wrapper/model_wrapper.py
new file mode 100644
index 0000000000000000000000000000000000000000..6b90fac7ed464ebdc7787444fb6ed3a15eea450c
--- /dev/null
+++ b/ACT_DP_multitask/detr/models/mr_mg/evaluate/wrapper/model_wrapper.py
@@ -0,0 +1,265 @@
+# Copyright (2024) Bytedance Ltd. and/or its affiliates
+
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+
+# http://www.apache.org/licenses/LICENSE-2.0
+
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from omegaconf import OmegaConf
+import numpy as np
+import torch
+import torchvision.transforms as T
+from PIL import Image
+import torch.nn.functional as F
+import policy.model.vision_transformer as vits
+from utils.utils import euler2rotm, rotm2euler
+from copy import deepcopy
+from calvin_agent.models.calvin_base_model import CalvinBaseModel
+from policy.model.model import GR_MG
+import time
+import clip
+GRIPPER_OPEN = 1
+GRIPPER_CLOSE = 0
+class CustomModel(CalvinBaseModel):
+ def __init__(self,
+ ckpt_path,
+ configs,
+ device):
+ self.device = device
+ # model config
+ self.configs = configs
+ self.seq_len = configs["policy"]['seq_len']
+ self.act_len = configs["policy"]["act_len"]
+ self.device = device
+ input_size = (224, 224)
+ clip_mean = (0.485, 0.456, 0.406)
+ clip_std = (0.229, 0.224, 0.225)
+ self.preprocess = T.Compose([
+ T.Resize(input_size, interpolation=Image.BICUBIC),
+ T.Normalize(clip_mean, clip_std)])
+ model_mae = vits.__dict__['vit_base'](patch_size=16, num_classes=0)
+ model_mae.to(self.device)
+ training_target = []
+ if configs["trainer"]['act_pred']:
+ training_target.append('act_pred')
+ if configs["trainer"]['fwd_pred']:
+ training_target.append('fwd_pred')
+ if configs["trainer"]['fwd_pred_hand']:
+ training_target.append('fwd_pred_hand')
+ if configs["trainer"]["progress_pred"]:
+ training_target.append('progress_pred')
+ print(f"training target: {training_target}")
+
+ #modify before release
+ #clip model
+ clip_name = "ViT-B/32"
+ clip_model, clip_preprocess = clip.load(clip_name)
+ # freeze clip
+ for _, param in clip_model.named_parameters():
+ param.requires_grad = False
+ policy_config = self.configs['policy']
+ input_config=self.configs["input"]
+ trainer_config=self.configs["trainer"]
+ self.policy = GR_MG(
+ state_dim=input_config['state_dim'],
+ act_dim=input_config['act_dim'],
+ act_len=policy_config['act_len'],
+ act_latent_dim=policy_config['act_latent_dim'],
+ act_encoder_dim=policy_config['act_encoder_dim'],
+ act_decoder_dim=policy_config['act_decoder_dim'],
+ progress_decoder_dim=self.configs["policy"]["progress_decoder_dim"],
+ hidden_size=policy_config['embed_dim'],
+ model_mae=model_mae,
+ clip_model=clip_model,
+ img_feat_dim=policy_config["img_feat_dim"],
+ lang_feat_dim = policy_config["lang_feat_dim"],
+ patch_feat_dim=policy_config["patch_feat_dim"],
+ resampler_params=policy_config['resampler_params'],
+ max_length=policy_config['seq_len'],
+ training_target=training_target,
+ without_norm_pix_loss=trainer_config['without_norm_pix_loss'],
+ use_hand_rgb=input_config['use_hand_rgb'],
+ use_state=input_config['use_state'],
+ use_resampler=policy_config['use_resampler'],
+ n_layer=policy_config['n_layer'],
+ n_head=policy_config['n_head'],
+ n_inner=4*policy_config['embed_dim'],
+ activation_function=policy_config['activation_function'],
+ n_positions=1024,
+ resid_pdrop=policy_config['dropout'],
+ attn_pdrop=policy_config['dropout'],
+ device=self.device)
+
+
+ # Set up the model
+ payload = torch.load(ckpt_path)
+ epoch = payload['epoch']
+ state_dict = payload['state_dict']
+ print(f"loading state dict from epoch {epoch}...")
+
+ del payload
+ # Remove the prefix "model." from pl models
+ pure_state_dict = dict()
+ for k, v in state_dict.items():
+ if "model." in k:
+ new_k = k[6:]
+ pure_state_dict[new_k] = v
+ msg = self.policy.load_state_dict(pure_state_dict, strict=True)
+ print(msg)
+ self.policy.to(self.device)
+ self.policy.eval()
+
+ def reset(self):
+ """Reset function."""
+ self.rgb_list = []
+ self.hand_rgb_list = []
+ self.state_list = []
+ self.rollout_step_counter = 0
+
+ @staticmethod
+ def compute_rel_state(states):
+ first_xyz = states[0][0]
+ first_rotm = states[0][1]
+ first_gripper = states[0][2]
+ seq_len = len(states)
+ arm_states = np.zeros((seq_len, 6))
+ gripper_states = np.zeros(seq_len)
+ gripper_states[0] = first_gripper
+ for i in range(1, seq_len):
+ curr_xyz = states[i][0]
+ curr_rotm = states[i][1]
+ curr_gripper = states[i][2]
+ rel_rotm = first_rotm.T @ curr_rotm
+ rel_xyz = np.dot(first_rotm.T, curr_xyz - first_xyz)
+ arm_states[i, 0:3] = rel_xyz
+ arm_states[i, 3:6] = rotm2euler(rel_rotm)
+ gripper_states[i] = curr_gripper
+ return arm_states, gripper_states
+
+ def step(self, obs, goal, text):
+ """Step function."""
+ goal_rgb = goal[0]
+
+ rgb = obs['rgb_obs']['rgb_static'] # (200, 200, 3)
+ hand_rgb = obs['rgb_obs']['rgb_gripper']
+
+ goal_rgb = Image.fromarray(goal_rgb)
+ goal_rgb = T.ToTensor()(goal_rgb.convert("RGB"))
+ goal_rgb = self.preprocess(goal_rgb) # (3, 224, 224)
+
+ rgb = Image.fromarray(rgb)
+ rgb = T.ToTensor()(rgb.convert("RGB"))
+ rgb = self.preprocess(rgb) # (3, 224, 224)
+ self.rgb_list.append(rgb)
+
+ hand_rgb = Image.fromarray(hand_rgb)
+ hand_rgb = T.ToTensor()(hand_rgb.convert("RGB"))
+ hand_rgb = self.preprocess(hand_rgb)
+ self.hand_rgb_list.append(hand_rgb)
+
+ state = obs['robot_obs'] # (15,)
+ xyz_state = state[:3]
+ rpy_state = state[3:6]
+ rotm_state = euler2rotm(rpy_state)
+ gripper_state = state[-1]
+ state = (xyz_state, rotm_state, gripper_state)
+ self.state_list.append(state)
+
+ buffer_len = len(self.rgb_list)
+ if buffer_len > self.seq_len:
+ self.rgb_list.pop(0)
+ self.hand_rgb_list.pop(0)
+ self.state_list.pop(0)
+ assert len(self.rgb_list) == self.seq_len
+ assert len(self.hand_rgb_list) == self.seq_len
+ assert len(self.state_list) == self.seq_len
+ buffer_len = len(self.rgb_list)
+
+
+
+ # Static RGB
+ c, h, w = rgb.shape
+ c2,h2,w2=goal_rgb.shape
+ assert c==c2 and h==h2 and w==w2
+ rgb_data = torch.zeros((1, self.seq_len, c, h, w))
+ rgb_tensor = torch.stack(self.rgb_list, dim=0) # (len, c, h, w)
+ rgb_data[0, :buffer_len] = rgb_tensor
+ goal_rgb_data=torch.zeros((1, c, h, w))
+ goal_rgb_data[0]=goal_rgb
+
+ # Hand RGB
+ c, h, w = hand_rgb.shape
+ hand_rgb_data = torch.zeros((1, self.seq_len, c, h, w))
+ hand_rgb_tensor = torch.stack(self.hand_rgb_list, dim=0) # (len, c, h, w)
+ hand_rgb_data[0, :buffer_len] = hand_rgb_tensor
+
+ # State
+ arm_state, gripper_state = CustomModel.compute_rel_state(self.state_list)
+ arm_state_data = torch.zeros((1, self.seq_len, 6))
+ arm_state_tensor = torch.from_numpy(arm_state)
+ arm_state_data[0, :buffer_len] = arm_state_tensor
+ gripper_state_tensor = torch.from_numpy(gripper_state)
+ gripper_state_tensor = (gripper_state_tensor + 1.0) / 2.0
+ gripper_state_tensor = gripper_state_tensor.long()
+ gripper_state_data = torch.zeros((1, self.seq_len)).long()
+ gripper_state_data[0, :buffer_len] = gripper_state_tensor
+ gripper_state_data = F.one_hot(gripper_state_data, num_classes=2).type_as(arm_state_data)
+
+ # Attention mask
+ attention_mask = torch.zeros(1, self.seq_len).long()
+ attention_mask[0, :buffer_len] = 1
+
+ # Action placeholder
+ arm_action_data = torch.zeros((1, self.seq_len, self.configs["policy"]['act_len'], 6))
+ gripper_action_data = torch.zeros(1, self.seq_len, self.configs["policy"]['act_len'])
+
+ #progress_placeholder
+ progress_data=torch.zeros(1, self.seq_len)
+
+ input_dict = dict()
+ input_dict['rgb'] = rgb_data.to(self.device)
+ input_dict['hand_rgb'] = hand_rgb_data.to(self.device)
+ input_dict["goal_rgb"]=goal_rgb_data.to(self.device)
+ input_dict['arm_state'] = arm_state_data.to(self.device)
+ input_dict['gripper_state'] = gripper_state_data.to(self.device)
+ input_dict['arm_action'] = arm_action_data.to(self.device)
+ input_dict['gripper_action'] = gripper_action_data.to(self.device)
+ input_dict['attention_mask'] = attention_mask.to(self.device)
+ input_dict["text"]=[text]
+ input_dict["progress"]=progress_data
+ # Forward pass
+ with torch.no_grad():
+ # action,action_traj = self.policy.evaluate(input_dict)
+ action,progress = self.policy.evaluate(input_dict,return_progress=True)
+ progress=int(int(progress * 10) *10)
+ action=action.numpy()
+
+
+
+ # Action mode: ee_rel_pose_local
+ state = obs['robot_obs'] # (15,)
+ xyz_state = state[:3]
+ rpy_state = state[3:6]
+ rotm_state = euler2rotm(rpy_state)
+ rel_action = action
+ xyz_action = rel_action[:3] / 50 # scale down by 50
+ rpy_action = rel_action[3:6] / 20 # scale down by 20
+ gripper_action = rel_action[6]
+ rotm_action = euler2rotm(rpy_action)
+ xyz_next_state = xyz_state + rotm_state @ xyz_action
+ rotm_next_state = rotm_state @ rotm_action
+ rpy_next_state = rotm2euler(rotm_next_state)
+ action = np.zeros(7)
+ action[:3] = (xyz_next_state - xyz_state) * 50
+ action[3:6] = (rpy_next_state - rpy_state) * 20
+ action[-1] = gripper_action
+ action = torch.from_numpy(action)
+ self.rollout_step_counter += 1
+
+ return action,progress
diff --git a/ACT_DP_multitask/detr/models/mr_mg/goal_gen/config/train.json b/ACT_DP_multitask/detr/models/mr_mg/goal_gen/config/train.json
new file mode 100644
index 0000000000000000000000000000000000000000..b4f0f4536f7dab49c851007eca2e934154febe99
--- /dev/null
+++ b/ACT_DP_multitask/detr/models/mr_mg/goal_gen/config/train.json
@@ -0,0 +1,37 @@
+{
+ "exp_name": "train_goal_gen",
+ "seed": 123,
+ "batch_size": 64,
+ "learning_rate": 8e-5,
+ "min_lr_scale": 1.0,
+ "warmup_steps": 800,
+ "device": "cuda",
+ "num_workers": 10,
+ "save_epoch": 1,
+ "pretrained_model_dir": "PATH_TO/resources/IP2P/instruct-pix2pix",
+ "ckpt_root": "SAVE_PATH/goal_gen/checkpoints/",
+ "log_root": "LOG_PATH/goal_gen/logs/",
+ "resume": null,
+ "color_aug": false,
+
+
+ "conditioning_dropout_prob": 0.05,
+ "use_ema": true,
+ "gradient_checkpointing":false,
+
+ "adam_beta1": 0.95,
+ "adam_beta2": 0.999,
+ "adam_weight_decay": 1e-2,
+ "adam_epsilon": 1e-08,
+
+ "trainer": {
+ "accelerator": "gpu",
+ "strategy": "ddp",
+ "precision": "bf16",
+ "logger": ["tensorboard"],
+ "use_distributed_sampler": true,
+ "gradient_clip_val": 0.7,
+ "log_every_n_steps": 50,
+ "max_epochs": 50
+ }
+}
diff --git a/ACT_DP_multitask/detr/models/mr_mg/goal_gen/evaluate.py b/ACT_DP_multitask/detr/models/mr_mg/goal_gen/evaluate.py
new file mode 100644
index 0000000000000000000000000000000000000000..880c551169a2856c6218f5e09efd5949c8ade489
--- /dev/null
+++ b/ACT_DP_multitask/detr/models/mr_mg/goal_gen/evaluate.py
@@ -0,0 +1,142 @@
+# Copyright (2024) Bytedance Ltd. and/or its affiliates
+
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+
+# http://www.apache.org/licenses/LICENSE-2.0
+
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import os
+import yaml
+import numpy as np
+import matplotlib.pyplot as plt
+from PIL import Image
+import torch
+import torchvision.transforms as transforms
+from transformers import T5Tokenizer, T5EncoderModel
+from diffusers import AutoencoderKL, UNet2DConditionModel
+from goal_gen.utils.pipeline import Pipeline
+from goal_gen.data.calvindataset import CalvinDataset_Goalgen
+class IP2PEvaluation(object):
+ def __init__(self,
+ ckpt_path,
+ res=256):
+ # Init models
+ pretrained_model_dir = "/mnt/bn/lpy-lq/stable_diffusion/instruct-pix2pix"
+
+ self.tokenizer = T5Tokenizer.from_pretrained("t5-base")
+ self.text_encoder = T5EncoderModel.from_pretrained("t5-base")
+ self.vae = AutoencoderKL.from_pretrained(
+ pretrained_model_dir, subfolder="vae")
+ self.unet = UNet2DConditionModel.from_pretrained(
+ pretrained_model_dir, subfolder="unet")
+
+ # Load weight for unet
+ payload = torch.load(ckpt_path)
+ state_dict = payload['state_dict']
+ del payload
+ msg = self.unet.load_state_dict(state_dict['unet_ema'], strict=True)
+ print(msg)
+
+ self.pipe = Pipeline.from_pretrained(
+ pretrained_model_dir,
+ text_encoder=self.text_encoder,
+ tokenizer=self.tokenizer,
+ vae=self.vae,
+ unet=self.unet,
+ revision=None,
+ variant=None,
+ torch_dtype=torch.bfloat16
+ ).to("cuda")
+
+ self.pipe.safety_checker = None
+ self.pipe.requires_safety_checker = False
+ self.generator = torch.Generator("cuda").manual_seed(42)
+
+ # Diffusion hyparams
+ self.num_inference_steps = 50
+ self.image_guidance_scale = 2.5
+ self.guidance_scale = 7.5
+
+ # Image transform
+ self.res = res
+ self.transform = transforms.Resize((res, res))
+
+ def evaluate(self, eval_result_dir, eval_data_dir,is_training):
+ os.makedirs(eval_result_dir,exist_ok=True)
+ save_dir=os.path.join(eval_result_dir,"debug.png")
+ dataset = CalvinDataset_Goalgen(
+ eval_data_dir,
+ resolution=256,
+ resolution_before_crop=288,
+ center_crop=True,
+ forward_n_min_max=(20, 22),
+ is_training=is_training,
+ use_full=True,
+ color_aug=False
+ )
+ for i in range(0, len(dataset), 100):
+ example = dataset[i]
+ text=example['input_text']
+ original_pixel_values = example['original_pixel_values']
+ edited_pixel_values = example['edited_pixel_values']
+ progress=example["progress"]
+
+ progress=progress*10
+ text[0]=text[0]+f".And {progress}% of the instruction has been finished."
+ print(text[0])
+ input_image_batch=[original_pixel_values]
+ predict_image = self.inference(input_image_batch, text)
+
+ fig, ax = plt.subplots(1,3)
+ for k in range(3):
+ original_image = original_pixel_values.permute(1, 2, 0).numpy()
+ original_image = (original_image + 1) / 2 * 255
+ original_image = np.clip(original_image, 0, 255)
+ original_image = original_image.astype(np.uint8)
+ ax[0].imshow(original_image)
+
+
+ edited_image = edited_pixel_values.permute(1, 2, 0).numpy()
+ edited_image = (edited_image + 1) / 2 * 255
+ edited_image = np.clip(edited_image, 0, 255)
+ edited_image = edited_image.astype(np.uint8)
+ ax[1].imshow(edited_image)
+
+ ax[2].imshow(predict_image[0])
+
+ plt.savefig( save_dir,dpi=300)
+ plt.close()
+
+ def inference(self, image_batch, text_batch):
+ """Inference function."""
+ input_images = []
+ for image in image_batch:
+ if isinstance(image, np.ndarray):
+ image=Image.fromarray(image)
+ input_image = self.transform(image)
+ input_images.append(input_image)
+ edited_images = self.pipe(
+ prompt=text_batch,
+ image=input_images,
+ num_inference_steps=self.num_inference_steps,
+ image_guidance_scale=self.image_guidance_scale,
+ guidance_scale=self.guidance_scale,
+ generator=self.generator,
+ safety_checker=None,
+ requires_safety_checker=False).images
+ edited_images=[ np.array(image) for image in edited_images]
+
+ return edited_images
+
+if __name__ == "__main__":
+ ckpt_path="PATH_TO_IP2P_CKPT/epoch=49-step=102900.ckpt"
+ eval = IP2PEvaluation(ckpt_path)
+ eval_data_dir = "PATH_TO_CALVIN/calvin/task_ABC_D/"
+ eval_result_dir = "SAVE_DIR"
+ eval.evaluate(eval_result_dir, eval_data_dir,is_training=False)
\ No newline at end of file
diff --git a/ACT_DP_multitask/detr/models/mr_mg/goal_gen/install.sh b/ACT_DP_multitask/detr/models/mr_mg/goal_gen/install.sh
new file mode 100644
index 0000000000000000000000000000000000000000..5026f369d5f9f7d3db5a50dfd4f6fbef8c43d344
--- /dev/null
+++ b/ACT_DP_multitask/detr/models/mr_mg/goal_gen/install.sh
@@ -0,0 +1,18 @@
+pip3 install moviepy
+pip3 install matplotlib
+pip3 install einops
+pip3 install diffusers
+pip3 install timm
+pip3 install matplotlib
+pip3 install numpy
+pip3 install sentencepiece
+pip3 install accelerate
+pip3 install transformers
+pip3 install datasets
+pip3 install ftfy
+pip3 install tensorboard
+pip3 install flamingo_pytorch
+pip install git+https://github.com/openai/CLIP.git
+pip3 install torch==2.1.0 torchvision==0.16.0 torchaudio==2.1.0 --index-url https://download.pytorch.org/whl/cu118
+pip3 install lightning==2.1.0
+pip3 install pytorch-lightning==2.1.0
\ No newline at end of file
diff --git a/ACT_DP_multitask/detr/models/mr_mg/goal_gen/model/model.py b/ACT_DP_multitask/detr/models/mr_mg/goal_gen/model/model.py
new file mode 100644
index 0000000000000000000000000000000000000000..977ed5ef7678b88deb4531dd4feee921d628aeed
--- /dev/null
+++ b/ACT_DP_multitask/detr/models/mr_mg/goal_gen/model/model.py
@@ -0,0 +1,128 @@
+# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#!/usr/bin/env python
+# coding=utf-8
+import torch
+import torch.nn as nn
+from diffusers import AutoencoderKL, DDPMScheduler, UNet2DConditionModel
+from diffusers.training_utils import EMAModel
+from transformers import T5Tokenizer, T5EncoderModel
+import os
+# modified from https://github.com/huggingface/diffusers/blob/main/examples/instruct_pix2pix/train_instruct_pix2pix.py
+class IP2P(nn.Module):
+ """InstructPix2Pix model."""
+ def __init__(self,
+ pretrained_model_dir,
+ device,
+ seed=123,
+ conditioning_dropout_prob=None,
+ gradient_checkpointing=False):
+ super().__init__()
+ self.device = device
+ self.noise_scheduler = DDPMScheduler.from_pretrained(
+ pretrained_model_dir, subfolder="scheduler")
+ text_encoder_name = "t5-base"
+ self.tokenizer = T5Tokenizer.from_pretrained(text_encoder_name)
+ self.text_encoder = T5EncoderModel.from_pretrained(text_encoder_name)
+ self.vae = AutoencoderKL.from_pretrained(
+ pretrained_model_dir, subfolder="vae")
+ self.unet = UNet2DConditionModel.from_pretrained(
+ pretrained_model_dir, subfolder="unet")
+
+ # InstructPix2Pix uses an additional image for conditioning. To accommodate that,
+ # it uses 8 channels (instead of 4) in the first (conv) layer of the UNet. This UNet is
+ # then fine-tuned on the custom InstructPix2Pix dataset. This modified UNet is initialized
+ # from the pre-trained checkpoints. For the extra channels added to the first layer, they are
+ # initialized to zero.
+ self.in_channels = 8
+ self.unet.register_to_config(in_channels=self.in_channels)
+ # Freeze vae and text_encoder
+ self.vae.requires_grad_(False)
+ self.text_encoder.requires_grad_(False)
+
+
+ # Conditioning dropout probability used for classifier free guidance
+ self.conditioning_dropout_prob = conditioning_dropout_prob
+
+ if gradient_checkpointing:
+ self.unet.enable_gradient_checkpointing() # it will reduce GPU memory but add computing burden
+
+ self.generator = torch.Generator(device=self.device).manual_seed(seed)
+
+ def tokenize_texts(self, texts):
+ inputs = self.tokenizer(texts, return_tensors="pt", padding='max_length', truncation=True, max_length=77)
+ return inputs
+
+ def forward(self, input_dict):
+ original_pixel_values = input_dict['original_pixel_values']
+ edited_pixel_values = input_dict['edited_pixel_values']
+ input_text = input_dict['input_text'][0]
+ progress=input_dict["progress"]*10#(b,)
+ input_text=[ text+f".And {curr_progress}% of the instruction has been finished." for (text,curr_progress) in zip(input_text,progress)]
+ input_ids=self.tokenize_texts(input_text)
+
+ # We want to learn the denoising process w.r.t the edited images which
+ # are conditioned on the original image (which was edited) and the edit instruction.
+ # So, first, convert images to latent space.
+ latents = self.vae.encode(edited_pixel_values).latent_dist.sample()
+ latents = latents * self.vae.config.scaling_factor
+
+ # Sample noise that we'll add to the latents
+ noise = torch.randn_like(latents)
+ bsz = latents.shape[0]
+ # Sample a random timestep for each image
+ timesteps = torch.randint(0, self.noise_scheduler.config.num_train_timesteps, (bsz,)).to(latents.device)
+ timesteps = timesteps.long()
+
+ # Add noise to the latents according to the noise magnitude at each timestep
+ # (this is the forward diffusion process)
+ noisy_latents = self.noise_scheduler.add_noise(latents, noise, timesteps)
+
+ # Get the text embedding for conditioning
+ encoder_hidden_states = self.text_encoder(**(input_ids.to(self.device))).last_hidden_state
+
+ # Get the additional image embedding for conditioning.
+ # Instead of getting a diagonal Gaussian here, we simply take the mode.
+ original_image_embeds = self.vae.encode(original_pixel_values).latent_dist.mode()
+
+ # Conditioning dropout to support classifier-free guidance during inference.
+ if self.conditioning_dropout_prob is not None:
+ random_p = torch.rand(bsz, device=latents.device, generator=self.generator)
+ # Sample masks for the edit prompts.
+ prompt_mask = random_p < 2 * self.conditioning_dropout_prob
+ prompt_mask = prompt_mask.reshape(bsz, 1, 1)
+ # Final text conditioning.
+ null_conditioning = self.text_encoder(**(self.tokenize_texts([""]).to(self.device))).last_hidden_state
+ encoder_hidden_states = torch.where(prompt_mask, null_conditioning, encoder_hidden_states)
+
+ # Sample masks for the original images.
+ image_mask_dtype = original_image_embeds.dtype
+ image_mask = 1 - (
+ (random_p >= self.conditioning_dropout_prob).to(image_mask_dtype)
+ * (random_p < 3 * self.conditioning_dropout_prob).to(image_mask_dtype)
+ )
+ image_mask = image_mask.reshape(bsz, 1, 1, 1)
+ # Final image conditioning.
+ original_image_embeds = image_mask * original_image_embeds
+
+ # Concatenate the `original_image_embeds` with the `noisy_latents`.
+ concatenated_noisy_latents = torch.cat([noisy_latents, original_image_embeds], dim=1)
+
+ # Get the target for loss depending on the prediction type
+ target = noise
+
+ # Predict the noise residual and compute loss
+ prediction = self.unet(concatenated_noisy_latents, timesteps, encoder_hidden_states, return_dict=False)[0]
+
+ return prediction, target
diff --git a/ACT_DP_multitask/detr/models/mr_mg/goal_gen/train.py b/ACT_DP_multitask/detr/models/mr_mg/goal_gen/train.py
new file mode 100644
index 0000000000000000000000000000000000000000..63ff713af47f7066cdfb7a528c444e505a55f702
--- /dev/null
+++ b/ACT_DP_multitask/detr/models/mr_mg/goal_gen/train.py
@@ -0,0 +1,212 @@
+# Copyright (2024) Bytedance Ltd. and/or its affiliates
+
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+
+# http://www.apache.org/licenses/LICENSE-2.0
+
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import os
+import argparse
+import json
+from pathlib import Path
+import importlib
+import copy
+import numpy as np
+from lightning.pytorch.callbacks import LearningRateMonitor, ModelCheckpoint
+from lightning.pytorch.trainer import Trainer
+from lightning.pytorch.loggers import TensorBoardLogger
+from lightning.pytorch.strategies import DDPStrategy
+from lightning import seed_everything
+import torch
+import random
+from utils.utils import SetupCallback
+import datetime
+from data.calvindataset import CalvinDataset_Goalgen
+from goal_gen.training.trainer import Goalgen_Trainer
+from torch.utils.data import DataLoader
+def get_date_str():
+ return str(datetime.date.today())
+
+
+def init_setup_callback(config):
+ return SetupCallback(
+ now=str(datetime.datetime.now()).replace(' ', '_'),
+ logdir=config['log_dir'],
+ ckptdir=config['ckpt_dir'],
+ )
+
+def init_trainer_config(configs):
+ trainer_config = copy.deepcopy(configs['trainer'])
+ trainer_config['devices'] = configs.get('gpus', 'auto')
+ trainer_config['num_nodes'] = configs.get('num_nodes', 1)
+ if 'strategy' not in trainer_config or trainer_config['strategy'] == 'ddp':
+ trainer_config['strategy'] = DDPStrategy(find_unused_parameters=False)
+ exp_name = configs['exp_name']
+
+ # init loggers
+ loggers = None
+ log_dir = os.path.join(get_date_str(), exp_name)
+ log_dir = os.path.join(configs['log_root'], log_dir)
+ configs['log_dir'] = log_dir
+ Path(configs['log_dir']).mkdir(parents=True, exist_ok=True)
+ loggers = []
+ loggers.append(TensorBoardLogger(log_dir, name=exp_name)) # you can also add other loggers
+ trainer_config['logger'] = loggers
+ ckpt_dir = os.path.join(get_date_str(), exp_name)
+ ckpt_dir = os.path.join(configs['ckpt_root'], ckpt_dir)
+ configs['ckpt_dir'] = ckpt_dir
+ Path(configs['ckpt_dir']).mkdir(parents=True, exist_ok=True)
+ trainer_config['callbacks'] = [
+ init_setup_callback(configs),
+ LearningRateMonitor(logging_interval='step'),
+ ModelCheckpoint(dirpath=ckpt_dir, save_top_k=-1,every_n_epochs=1)
+ ]
+ return trainer_config
+
+
+def set_seed(seed=0):
+ random.seed(seed)
+ np.random.seed(seed)
+ torch.manual_seed(seed)
+ torch.cuda.manual_seed_all(seed)
+ seed_everything(seed)
+
+
+def experiment(variant):
+ set_seed(variant['seed'])
+ trainer_config = init_trainer_config(variant)
+
+ trainer = Trainer(**trainer_config)
+ variant['gpus'] = trainer.num_devices
+
+ model = Goalgen_Trainer(variant)
+
+ # dataset
+ train_data= CalvinDataset_Goalgen(
+ data_dir="/PATH_TO_CALVIN/calvin/task_ABC_D/",
+ resolution=256,
+ resolution_before_crop=288,
+ center_crop=False,
+ forward_n_min_max=[20, 22],
+ use_full=True,
+ is_training=True,
+ color_aug=True)
+ val_data= CalvinDataset_Goalgen(
+ data_dir="/PATH_TO_CALVIN/calvin/task_ABC_D/",
+ resolution=256,
+ resolution_before_crop=288,
+ center_crop=False,
+ forward_n_min_max=[20, 22],
+ use_full = True,
+ is_training=False,
+ color_aug=False)
+ train_dataloader= DataLoader(train_data,
+ batch_size=variant["batch_size"],
+ num_workers=variant["num_workers"])
+ val_dataloader= DataLoader(val_data,
+ batch_size=variant["batch_size"],
+ num_workers=variant["num_workers"])
+
+ _kwargs = {
+ 'model': model,
+ 'train_dataloaders':train_dataloader,
+ 'val_dataloaders':val_dataloader,
+ 'ckpt_path': variant['resume']
+ }
+ if _kwargs['ckpt_path'] is not None:
+ print(f"Resuming from {variant['resume']}...")
+ trainer.fit(**_kwargs)
+
+def deep_update(d1, d2):
+ # use d2 to update d1
+ for k, v in d2.items():
+ if isinstance(v, dict) and k in d1:
+ assert isinstance(d1[k], dict)
+ deep_update(d1[k], d2[k])
+ else:
+ d1[k] = d2[k]
+ return d1
+
+def load_config(config_file):
+ _config = json.load(open(config_file))
+ config = {}
+ if _config.get('parent', None):
+ deep_update(config, load_config(_config['parent']))
+ deep_update(config, _config)
+ return config
+
+def update_configs(configs, args):
+ for (k, v) in args.items():
+ if k not in configs:
+ print(f"{k} not in config. The value is {v}.")
+ configs[k] = v
+ if isinstance(v, dict):
+ for (sub_k, sub_v) in v.items():
+ assert sub_k in configs[k], f"{sub_k} - {configs[k]}"
+ if sub_v != None:
+ configs[k][sub_k] = sub_v
+ else:
+ if v != None:
+ configs[k] = v
+ return configs
+
+def parse_args():
+ parser = argparse.ArgumentParser()
+ parser.add_argument('--config', type=str,default="")
+ parser.add_argument('--gpus', default=1, type=int)
+ parser.add_argument('--num_nodes', default=1, type=int)
+ parser.add_argument('--seed', default=None, type=int)
+ parser.add_argument('--log_root', default=None, type=str)
+ parser.add_argument('--ckpt_root', default=None, type=str)
+ parser.add_argument('--exp_name', default=None, type=str)
+ parser.add_argument('--resume', default=None, type=str)
+
+ # Training
+ parser.add_argument('--batch_size', default=None, type=int)
+ parser.add_argument('--learning_rate', default=None, type=float)
+ parser.add_argument('--min_lr_scale', default=None, type=float)
+ parser.add_argument('--warmup_steps', default=None, type=int)
+ parser.add_argument('--adam_weight_decay', default=None, type=float)
+ parser.add_argument('--adam_beta1', default=None, type=float)
+ parser.add_argument('--adam_beta2', default=None, type=float)
+ parser.add_argument('--adam_epsilon', default=None, type=float)
+
+ # Diffusion
+ parser.add_argument("--conditioning_dropout_prob", default=None, type=float)
+ global_names = set(vars(parser.parse_known_args()[0]).keys())
+
+ # Trainer
+ trainer_parser = parser.add_argument_group('trainer')
+ trainer_parser.add_argument('--strategy', default=None, type=str)
+ trainer_parser.add_argument('--precision', default=None, type=str)
+ trainer_parser.add_argument('--gradient_clip_val', default=None, type=float)
+ trainer_parser.add_argument('--max_epochs', default=None, type=int)
+ trainer_names = set(vars(parser.parse_known_args()[0]).keys()) - global_names
+
+ args = {}
+ trainer_args = {}
+ temp_args = vars(parser.parse_args())
+ for (k, v) in temp_args.items():
+ if k in global_names:
+ args[k] = v
+ elif k in trainer_names:
+ trainer_args[k] = v
+
+ args['trainer'] = trainer_args
+
+ return args
+
+if __name__ == '__main__':
+ args=parse_args()
+ configs = load_config(args.pop('config'))
+ configs = update_configs(configs, args)
+ os.system(f"sudo chmod 777 -R {configs['ckpt_root']}")
+ os.system(f"sudo chmod 777 -R {configs['log_root']}")
+ experiment(variant=configs)
diff --git a/ACT_DP_multitask/detr/models/mr_mg/goal_gen/train_ip2p.sh b/ACT_DP_multitask/detr/models/mr_mg/goal_gen/train_ip2p.sh
new file mode 100644
index 0000000000000000000000000000000000000000..2eb484e7bb320ba2b71bfeb3815fd548fb7e4be5
--- /dev/null
+++ b/ACT_DP_multitask/detr/models/mr_mg/goal_gen/train_ip2p.sh
@@ -0,0 +1,17 @@
+#!/usr/bin/env bash
+set -x
+GPUS_PER_NODE=8 # number of gpus per machine
+MASTER_ADDR={master_address}":"{port} # modify it with your own address and port
+NNODES=1 # number of machines
+JOB_ID=107
+torchrun \
+ --nproc_per_node $GPUS_PER_NODE \
+ --nnodes $NNODES \
+ --node_rank 0 \
+ --rdzv_endpoint $MASTER_ADDR \
+ --rdzv_id $JOB_ID \
+ --rdzv_backend c10d \
+ goal_gen/train.py \
+ --config ${@:1} \
+ --gpus $GPUS_PER_NODE \
+ --num_nodes $NNODES
diff --git a/ACT_DP_multitask/detr/models/mr_mg/goal_gen/training/trainer.py b/ACT_DP_multitask/detr/models/mr_mg/goal_gen/training/trainer.py
new file mode 100644
index 0000000000000000000000000000000000000000..50e86ea80dc234e447a96a84cd1a95630fbdc7a1
--- /dev/null
+++ b/ACT_DP_multitask/detr/models/mr_mg/goal_gen/training/trainer.py
@@ -0,0 +1,129 @@
+# Copyright (2024) Bytedance Ltd. and/or its affiliates
+
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+
+# http://www.apache.org/licenses/LICENSE-2.0
+
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import torch
+import torch.nn.functional as F
+import lightning.pytorch as pl
+from utils.ema import requires_grad,update_ema
+from utils.dist_train import get_rank
+from model.model import IP2P
+class Goalgen_Trainer(pl.LightningModule):
+ def __init__(self, configs):
+ super().__init__()
+ self._main_rank_print('--------------- model configs ---------------')
+ self._main_rank_print(configs)
+ self.configs = configs
+ self._initialize()
+ self.save_hyperparameters()
+ self.dataset_name = "calvin"
+
+ @staticmethod
+ def _main_rank_print(*args, **kwargs):
+ if get_rank() == 0:
+ print(*args, **kwargs)
+
+ @property
+ def num_gpus(self):
+ return self.trainer.num_devices * self.trainer.num_nodes
+
+ def _initialize(self):
+
+ self.model = IP2P(
+ pretrained_model_dir=self.configs['pretrained_model_dir'], # load the pretrained instructpix2pix weight
+ device=self.configs['device'],
+ seed=self.configs['seed'],
+ conditioning_dropout_prob=self.configs['conditioning_dropout_prob'],
+ gradient_checkpointing=self.configs['gradient_checkpointing']
+ )
+
+ self.use_ema=self.configs["use_ema"]
+ if self.use_ema:
+ self.ema_model = IP2P(
+ pretrained_model_dir=self.configs['pretrained_model_dir'],
+ device=self.configs['device'],
+ seed=self.configs['seed'],
+ conditioning_dropout_prob=self.configs['conditioning_dropout_prob'],
+ gradient_checkpointing=self.configs['gradient_checkpointing']
+ )
+ requires_grad(self.ema_model, False) # ema model will not be trained. It will only be updated.
+ self.ema_model.eval()
+
+
+ @classmethod
+ def from_checkpoint(cls, ckpt_dir=None, configs=None):
+ if ckpt_dir is None:
+ assert configs is not None, "ckpt_dir and configs are both None for initialization."
+ return cls(configs)
+
+ def configure_optimizers(self):
+ lr = self.configs['learning_rate']
+ eff_bsz = self.configs['batch_size'] * self.num_gpus
+ self._main_rank_print('-' * 40)
+ self._main_rank_print(f"learning rate: {lr}, effective batch size: {eff_bsz}")
+
+ optimizer_params = [
+ {'params': self.model.unet.parameters(), 'lr': lr},
+ ] # only unet will be trained
+
+ optimizer = torch.optim.AdamW(
+ optimizer_params,
+ betas=(self.configs['adam_beta1'], self.configs['adam_beta2']),
+ weight_decay=self.configs['adam_weight_decay'],
+ eps=self.configs['adam_epsilon']
+ )
+
+ return {
+ 'optimizer': optimizer
+ }
+
+ def _log_output(self, output, phase, dataset=None, **kwargs):
+ for k, v in output.items():
+ log_name = f"{phase}_{k}"
+ if dataset is not None:
+ log_name = f"{dataset}_{log_name}"
+ self.log(log_name, v, prog_bar=True, **kwargs)
+
+ def validation_step(self, batch, batch_idx, dataloader_idx=0):
+ with torch.no_grad():
+ if self.use_ema:
+ prediction, target = self.ema_model.forward(batch)
+ else:
+ prediction, target = self.model.forward(batch)
+ loss = F.mse_loss(prediction.float(), target.float(), reduction="mean")
+ output = {'loss': loss}
+
+ self._log_output(output, phase="val", sync_dist=True,
+ on_epoch=True, on_step=False, dataset=self.dataset_name)
+
+ def training_step(self, batch, batch_idx):
+ prediction, target = self.model.forward(batch)
+ loss = F.mse_loss(prediction.float(), target.float(), reduction="mean")
+ output = {'loss': loss}
+ self._log_output(output, phase="train", on_epoch=False, on_step=True,dataset=self.dataset_name)
+ if self.configs['use_ema']:
+ update_ema(self.ema_model, self.model, decay=0.999)
+ return output['loss']
+
+
+ def on_save_checkpoint(self, checkpoint):
+ if not self.use_ema:
+ checkpoint['state_dict'] = {'unet': self.model.unet.state_dict()}
+ else:
+ checkpoint['state_dict'] = {'unet_ema': self.ema_model.unet.state_dict(),'unet': self.model.unet.state_dict()}
+ def on_load_checkpoint(self, checkpoint):
+ # ä»…åŠ è½½unet模å—çš„å‚æ•°
+ if not self.use_ema:
+ self.model.unet.load_state_dict(checkpoint['state_dict']['unet'])
+ else:
+ self.model.unet.load_state_dict(checkpoint['state_dict']['unet_ema'])
+ self.ema_model.unet.load_state_dict(checkpoint['state_dict']['unet_ema'])
\ No newline at end of file
diff --git a/ACT_DP_multitask/detr/models/mr_mg/goal_gen/utils/dist_train.py b/ACT_DP_multitask/detr/models/mr_mg/goal_gen/utils/dist_train.py
new file mode 100644
index 0000000000000000000000000000000000000000..12842dccc023cde8ded10582700d62343111b87e
--- /dev/null
+++ b/ACT_DP_multitask/detr/models/mr_mg/goal_gen/utils/dist_train.py
@@ -0,0 +1,77 @@
+# Copyright (2024) Bytedance Ltd. and/or its affiliates
+
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+
+# http://www.apache.org/licenses/LICENSE-2.0
+
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import io
+import os
+
+import torch
+import torch.distributed as dist
+
+
+_print = print
+
+def get_world_size(): return int(os.getenv('WORLD_SIZE', 1))
+def get_rank(): return int(os.getenv('RANK', 0))
+def get_local_rank(): return int(os.getenv('LOCAL_RANK', 0))
+
+
+def is_dist():
+ return dist.is_available() and dist.is_initialized() and get_world_size() > 1
+
+def print(*argc, all=False, **kwargs):
+ if not is_dist():
+ _print(*argc, **kwargs)
+ return
+
+ if not all and get_local_rank() != 0:
+ return
+
+ output = io.StringIO()
+ kwargs['end'] = ''
+ kwargs['file'] = output
+ kwargs['flush'] = True
+ _print(*argc, **kwargs)
+
+ s = output.getvalue()
+ output.close()
+
+ s = '[rank {}] {}'.format(dist.get_rank(), s)
+ _print(s)
+
+def reduce_mean(tensor, nprocs=None):
+ if not is_dist():
+ return tensor
+ if not isinstance(tensor, torch.Tensor):
+ device = torch.cuda.current_device()
+ rt = torch.tensor(tensor, device=device)
+ else:
+ rt = tensor.clone()
+ dist.all_reduce(rt, op=dist.ReduceOp.SUM)
+ nprocs = nprocs if nprocs else dist.get_world_size()
+ rt = rt / nprocs
+ if not isinstance(tensor, torch.Tensor):
+ rt = rt.item()
+ return rt
+
+def reduce_sum(tensor):
+ if not is_dist():
+ return tensor
+ if not isinstance(tensor, torch.Tensor):
+ device = torch.cuda.current_device()
+ rt = torch.tensor(tensor, device=device)
+ else:
+ rt = tensor.clone()
+ dist.all_reduce(rt, op=dist.ReduceOp.SUM)
+ if not isinstance(tensor, torch.Tensor):
+ rt = rt.item()
+ return rt
diff --git a/ACT_DP_multitask/detr/models/mr_mg/goal_gen/utils/ema.py b/ACT_DP_multitask/detr/models/mr_mg/goal_gen/utils/ema.py
new file mode 100644
index 0000000000000000000000000000000000000000..01c95b1c95c3a7d2d569aedd11fefd7d03789a13
--- /dev/null
+++ b/ACT_DP_multitask/detr/models/mr_mg/goal_gen/utils/ema.py
@@ -0,0 +1,35 @@
+# Copyright (2024) Bytedance Ltd. and/or its affiliates
+
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+
+# http://www.apache.org/licenses/LICENSE-2.0
+
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import torch
+from torch import inf
+from collections import OrderedDict
+@torch.no_grad()
+def update_ema(ema_model, model, decay=0.999):
+ """
+ Step the EMA model towards the current model.
+ """
+ ema_params = OrderedDict(ema_model.unet.named_parameters())
+ model_params = OrderedDict(model.unet.named_parameters())
+
+ for name, param in model_params.items():
+ # TODO: Consider applying only to params that require_grad to avoid small numerical changes of pos_embed
+ ema_params[name].mul_(decay).add_(param.data, alpha=1 - decay)
+
+def requires_grad(model, flag=True):
+ """
+ Set requires_grad flag for all parameters in a model.
+ """
+ for p in model.parameters():
+ p.requires_grad = flag
+
diff --git a/ACT_DP_multitask/detr/models/mr_mg/goal_gen/utils/format_calvin_data.py b/ACT_DP_multitask/detr/models/mr_mg/goal_gen/utils/format_calvin_data.py
new file mode 100644
index 0000000000000000000000000000000000000000..cb34d9dd6c9d48990f80c9c8f66ec36cda62192c
--- /dev/null
+++ b/ACT_DP_multitask/detr/models/mr_mg/goal_gen/utils/format_calvin_data.py
@@ -0,0 +1,79 @@
+# Copyright (2024) Bytedance Ltd. and/or its affiliates
+
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+
+# http://www.apache.org/licenses/LICENSE-2.0
+
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# Format calvin data for training ip2p
+import os
+import argparse
+import json
+from tqdm import tqdm
+
+import numpy as np
+import cv2
+
+
+SPLITS = ['training']
+# SPLITS = ['validation']
+def main(data_dir, target_dir, num_trajs_per_task=10000):
+ for split in SPLITS:
+ meta = dict()
+ split_dir = os.path.join(target_dir, split)
+ os.mkdir(split_dir)
+ dataset_dir = os.path.join(data_dir, split)
+ anns = np.load(
+ os.path.join(dataset_dir, "lang_annotations", "auto_lang_ann.npy"),
+ allow_pickle=True).item()
+ n_trajs = len(anns['info']['indx'])
+ task_dict = {}
+ for traj_idx in tqdm(range(n_trajs)):
+ if split == 'training':
+ # sample trajectories based on num_trajs_per_task
+ traj_task = anns['language']['task'][traj_idx]
+ if traj_task not in task_dict:
+ task_dict[traj_task] = 1
+ else:
+ task_dict[traj_task] = task_dict[traj_task] + 1
+ if task_dict[traj_task] > num_trajs_per_task:
+ continue
+
+ traj_dir = os.path.join(split_dir, f"{traj_idx}")
+ os.mkdir(traj_dir)
+ traj_st, traj_ed = anns['info']['indx'][traj_idx]
+ traj_text = anns['language']['ann'][traj_idx]
+ for i in range(traj_st, traj_ed + 1):
+ frame = np.load(os.path.join(dataset_dir, f"episode_{i:07d}.npz"))
+ static_rgb = frame['rgb_static']
+ hand_rgb = frame['rgb_gripper']
+ cv2.imwrite(os.path.join(traj_dir, f"{i - traj_st}_static.png"), cv2.cvtColor(static_rgb, cv2.COLOR_BGR2RGB))
+ cv2.imwrite(os.path.join(traj_dir, f"{i - traj_st}_hand.png"), cv2.cvtColor(hand_rgb, cv2.COLOR_BGR2RGB))
+ meta[traj_idx] = {"text": traj_text, "num_frames": int(traj_ed - traj_st + 1)}
+ with open(os.path.join(split_dir, "meta.json"), "w") as f:
+ json.dump(meta, f)
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--data_dir",
+ type=str,
+ default="",
+ help="data directory")
+ parser.add_argument("--target_dir",
+ type=str,
+ default="",
+ help="target data directory")
+ parser.add_argument("--num_trajs_per_task",
+ type=int,
+ default=10000, # when you want to do few-shot experiments, change this number
+ help="number of trajectories per task")
+ args = parser.parse_args()
+
+ main(args.data_dir, args.target_dir, args.num_trajs_per_task)
\ No newline at end of file
diff --git a/ACT_DP_multitask/detr/models/mr_mg/goal_gen/utils/pipeline.py b/ACT_DP_multitask/detr/models/mr_mg/goal_gen/utils/pipeline.py
new file mode 100644
index 0000000000000000000000000000000000000000..d098b8c5f13f75f8d746191f09d67067ca471c0c
--- /dev/null
+++ b/ACT_DP_multitask/detr/models/mr_mg/goal_gen/utils/pipeline.py
@@ -0,0 +1,445 @@
+# Copyright 2024 The InstructPix2Pix Authors and The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# modified from https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py
+from diffusers import StableDiffusionInstructPix2PixPipeline
+from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
+from typing import Callable, Dict, List, Optional, Union
+import numpy as np
+import torch
+
+
+def retrieve_latents(
+ encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
+):
+ if hasattr(encoder_output, "latent_dist") and sample_mode == "sample":
+ return encoder_output.latent_dist.sample(generator)
+ elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax":
+ return encoder_output.latent_dist.mode()
+ elif hasattr(encoder_output, "latents"):
+ return encoder_output.latents
+ else:
+ raise AttributeError("Could not access latents of provided encoder_output")
+
+class Pipeline(StableDiffusionInstructPix2PixPipeline):
+ model_cpu_offload_seq = "text_encoder->unet->vae"
+ _optional_components = ["safety_checker", "feature_extractor", "image_encoder"]
+ _exclude_from_cpu_offload = ["safety_checker"]
+ _callback_tensor_inputs = ["latents", "prompt_embeds", "image_latents"]
+ def __init__(
+ self,
+ vae,
+ text_encoder,
+ tokenizer,
+ unet,
+ scheduler,
+ safety_checker,
+ feature_extractor,
+ image_encoder=None,
+ requires_safety_checker=True,
+ ):
+ super().__init__(
+ vae,
+ text_encoder,
+ tokenizer,
+ unet,
+ scheduler,
+ safety_checker,
+ feature_extractor,
+ )
+
+ @torch.no_grad()
+ def _encode_prompt(
+ self,
+ prompt,
+ device,
+ num_images_per_prompt,
+ do_classifier_free_guidance,
+ negative_prompt=None,
+ prompt_embeds: Optional[torch.FloatTensor] = None,
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
+ ):
+ r"""
+ Encodes the prompt into text encoder hidden states.
+
+ Args:
+ prompt (`str` or `List[str]`, *optional*):
+ prompt to be encoded
+ device: (`torch.device`):
+ torch device
+ num_images_per_prompt (`int`):
+ number of images that should be generated per prompt
+ do_classifier_free_guidance (`bool`):
+ whether to use classifier free guidance or not
+ negative_ prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
+ less than `1`).
+ prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
+ provided, text embeddings will be generated from `prompt` input argument.
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
+ argument.
+ """
+ if prompt is not None and isinstance(prompt, str):
+ batch_size = 1
+ elif prompt is not None and isinstance(prompt, list):
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+
+ if prompt_embeds is None:
+ # textual inversion: procecss multi-vector tokens if necessary
+ # if isinstance(self, TextualInversionLoaderMixin):
+ # prompt = self.maybe_convert_prompt(prompt, self.tokenizer)
+
+ text_inputs = self.tokenizer(
+ prompt,
+ padding="max_length",
+ max_length=77,
+ truncation=True,
+ return_tensors="pt",
+ )
+
+
+
+ text_input_ids = text_inputs.input_ids
+ untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
+
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
+ text_input_ids, untruncated_ids
+ ):
+ removed_text = self.tokenizer.batch_decode(
+ untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]
+ )
+ logger.warning(
+ "The following part of your input was truncated because CLIP can only handle sequences up to"
+ f" {self.tokenizer.model_max_length} tokens: {removed_text}"
+ )
+
+ attention_mask = text_inputs.attention_mask.to(device)
+
+
+ prompt_embeds = self.text_encoder(
+ text_input_ids.to(device),
+ attention_mask=attention_mask,
+ )
+ prompt_embeds = prompt_embeds[0]
+
+ prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device)
+
+ bs_embed, seq_len, _ = prompt_embeds.shape
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
+ prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
+
+ # get unconditional embeddings for classifier free guidance
+ if do_classifier_free_guidance and negative_prompt_embeds is None:
+ uncond_tokens: List[str]
+ if negative_prompt is None:
+ uncond_tokens = [""] * batch_size
+ elif type(prompt) is not type(negative_prompt):
+ raise TypeError(
+ f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
+ f" {type(prompt)}."
+ )
+ elif isinstance(negative_prompt, str):
+ uncond_tokens = [negative_prompt]
+ elif batch_size != len(negative_prompt):
+ raise ValueError(
+ f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
+ f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
+ " the batch size of `prompt`."
+ )
+ else:
+ uncond_tokens = negative_prompt
+
+ # textual inversion: procecss multi-vector tokens if necessary
+ # if isinstance(self, TextualInversionLoaderMixin):
+ # uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer)
+
+ max_length = prompt_embeds.shape[1]
+ uncond_input = self.tokenizer(
+ uncond_tokens,
+ padding="max_length",
+ max_length=max_length,
+ truncation=True,
+ return_tensors="pt",
+ )
+
+ attention_mask = uncond_input.attention_mask.to(device)
+
+
+ negative_prompt_embeds = self.text_encoder(
+ uncond_input.input_ids.to(device),
+ attention_mask=attention_mask,
+ )
+ negative_prompt_embeds = negative_prompt_embeds[0]
+
+ if do_classifier_free_guidance:
+ # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
+ seq_len = negative_prompt_embeds.shape[1]
+
+ negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder.dtype, device=device)
+
+ negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
+ negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
+
+ # For classifier free guidance, we need to do two forward passes.
+ # Here we concatenate the unconditional and text embeddings into a single batch
+ # to avoid doing two forward passes
+ # pix2pix has two negative embeddings, and unlike in other pipelines latents are ordered [prompt_embeds, negative_prompt_embeds, negative_prompt_embeds]
+ prompt_embeds = torch.cat([prompt_embeds, negative_prompt_embeds, negative_prompt_embeds])
+
+ return prompt_embeds
+
+
+ @torch.no_grad()
+ def __call__(
+ self,
+ prompt: Union[str, List[str]] = None,
+ image =None,
+ num_inference_steps: int = 100,
+ guidance_scale: float = 7.5,
+ image_guidance_scale: float = 1.5,
+ negative_prompt: Optional[Union[str, List[str]]] = None,
+ num_images_per_prompt: Optional[int] = 1,
+ eta: float = 0.0,
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
+ latents: Optional[torch.FloatTensor] = None,
+ prompt_embeds: Optional[torch.FloatTensor] = None,
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
+ ip_adapter_image= None,
+ output_type: Optional[str] = "pil",
+ return_dict: bool = True,
+ callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
+ **kwargs,
+ ):
+
+ callback = kwargs.pop("callback", None)
+ callback_steps = kwargs.pop("callback_steps", None)
+
+ if callback is not None:
+ deprecate(
+ "callback",
+ "1.0.0",
+ "Passing `callback` as an input argument to `__call__` is deprecated, consider use `callback_on_step_end`",
+ )
+ if callback_steps is not None:
+ deprecate(
+ "callback_steps",
+ "1.0.0",
+ "Passing `callback_steps` as an input argument to `__call__` is deprecated, consider use `callback_on_step_end`",
+ )
+
+ # 0. Check inputs
+ self.check_inputs(
+ prompt,
+ callback_steps,
+ negative_prompt,
+ prompt_embeds,
+ negative_prompt_embeds,
+ callback_on_step_end_tensor_inputs,
+ )
+ self._guidance_scale = guidance_scale
+ self._image_guidance_scale = image_guidance_scale
+
+ device = self._execution_device
+
+ if ip_adapter_image is not None:
+ output_hidden_state = False if isinstance(self.unet.encoder_hid_proj, ImageProjection) else True
+ image_embeds, negative_image_embeds = self.encode_image(
+ ip_adapter_image, device, num_images_per_prompt, output_hidden_state
+ )
+ if self.do_classifier_free_guidance:
+ image_embeds = torch.cat([image_embeds, negative_image_embeds, negative_image_embeds])
+
+ if image is None:
+ raise ValueError("`image` input cannot be undefined.")
+
+ # 1. Define call parameters
+
+ if prompt is not None and isinstance(prompt, str):
+ batch_size = 1
+ elif prompt is not None and isinstance(prompt, list):
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+
+ device = self._execution_device
+ # check if scheduler is in sigmas space
+ scheduler_is_in_sigma_space = hasattr(self.scheduler, "sigmas")
+
+ # 2. Encode input prompt
+ prompt_embeds = self._encode_prompt(
+ prompt,
+ device,
+ num_images_per_prompt,
+ self.do_classifier_free_guidance,
+ negative_prompt,
+ prompt_embeds=prompt_embeds,
+ negative_prompt_embeds=negative_prompt_embeds,
+ )
+
+
+
+ length=prompt_embeds.shape[0]
+
+ # 3. Preprocess image
+ image = self.image_processor.preprocess(image)
+
+ # 4. set timesteps
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
+ timesteps = self.scheduler.timesteps
+
+ # 5. Prepare Image latents
+ image_latents = self.prepare_image_latents(
+ image,
+ batch_size,
+ num_images_per_prompt,
+ prompt_embeds.dtype,
+ device,
+ self.do_classifier_free_guidance,
+ )
+
+ height, width = image_latents.shape[-2:]
+ height = height * self.vae_scale_factor
+ width = width * self.vae_scale_factor
+
+ # 6. Prepare latent variables
+ num_channels_latents = self.vae.config.latent_channels
+ latents = self.prepare_latents(
+ batch_size * num_images_per_prompt,
+ num_channels_latents,
+ height,
+ width,
+ prompt_embeds.dtype,
+ device,
+ generator,
+ latents,
+ )
+
+ # 7. Check that shapes of latents and image match the UNet channels
+ num_channels_image = image_latents.shape[1]
+ if num_channels_latents + num_channels_image != self.unet.config.in_channels:
+ raise ValueError(
+ f"Incorrect configuration settings! The config of `pipeline.unet`: {self.unet.config} expects"
+ f" {self.unet.config.in_channels} but received `num_channels_latents`: {num_channels_latents} +"
+ f" `num_channels_image`: {num_channels_image} "
+ f" = {num_channels_latents+num_channels_image}. Please verify the config of"
+ " `pipeline.unet` or your `image` input."
+ )
+
+ # 8. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
+
+ # 8.1 Add image embeds for IP-Adapter
+ added_cond_kwargs = {"image_embeds": image_embeds} if ip_adapter_image is not None else None
+
+ # 9. Denoising loop
+ num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
+ self._num_timesteps = len(timesteps)
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
+ for i, t in enumerate(timesteps):
+ # Expand the latents if we are doing classifier free guidance.
+ # The latents are expanded 3 times because for pix2pix the guidance\
+ # is applied for both the text and the input image.
+ latent_model_input = torch.cat([latents] * 3) if self.do_classifier_free_guidance else latents
+
+ # concat latents, image_latents in the channel dimension
+ scaled_latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
+ scaled_latent_model_input = torch.cat([scaled_latent_model_input, image_latents], dim=1)
+
+ # predict the noise residual
+ noise_pred = self.unet(
+ scaled_latent_model_input,
+ t,
+ encoder_hidden_states=prompt_embeds,
+ added_cond_kwargs=added_cond_kwargs,
+ return_dict=False,
+ )[0]
+
+ # Hack:
+ # For karras style schedulers the model does classifer free guidance using the
+ # predicted_original_sample instead of the noise_pred. So we need to compute the
+ # predicted_original_sample here if we are using a karras style scheduler.
+ if scheduler_is_in_sigma_space:
+ step_index = (self.scheduler.timesteps == t).nonzero()[0].item()
+ sigma = self.scheduler.sigmas[step_index]
+ noise_pred = latent_model_input - sigma * noise_pred
+
+ # perform guidance
+ if self.do_classifier_free_guidance:
+ noise_pred_text, noise_pred_image, noise_pred_uncond = noise_pred.chunk(3)
+ noise_pred = (
+ noise_pred_uncond
+ + self.guidance_scale * (noise_pred_text - noise_pred_image)
+ + self.image_guidance_scale * (noise_pred_image - noise_pred_uncond)
+ )
+
+ # Hack:
+ # For karras style schedulers the model does classifer free guidance using the
+ # predicted_original_sample instead of the noise_pred. But the scheduler.step function
+ # expects the noise_pred and computes the predicted_original_sample internally. So we
+ # need to overwrite the noise_pred here such that the value of the computed
+ # predicted_original_sample is correct.
+ if scheduler_is_in_sigma_space:
+ noise_pred = (noise_pred - latents) / (-sigma)
+
+ # compute the previous noisy sample x_t -> x_t-1
+ latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
+
+ if callback_on_step_end is not None:
+ callback_kwargs = {}
+ for k in callback_on_step_end_tensor_inputs:
+ callback_kwargs[k] = locals()[k]
+ callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
+
+ latents = callback_outputs.pop("latents", latents)
+ prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
+ negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
+ image_latents = callback_outputs.pop("image_latents", image_latents)
+
+ # call the callback, if provided
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
+ progress_bar.update()
+ if callback is not None and i % callback_steps == 0:
+ step_idx = i // getattr(self.scheduler, "order", 1)
+ callback(step_idx, t, latents)
+
+ if not output_type == "latent":
+ image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
+ image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype)
+ else:
+ image = latents
+ has_nsfw_concept = None
+
+ if has_nsfw_concept is None:
+ do_denormalize = [True] * image.shape[0]
+ else:
+ do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept]
+
+ image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize)
+
+ # Offload all models
+ self.maybe_free_model_hooks()
+
+ if not return_dict:
+ return (image, has_nsfw_concept)
+
+ return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)
+
diff --git a/ACT_DP_multitask/detr/models/mr_mg/goal_gen/utils/utils.py b/ACT_DP_multitask/detr/models/mr_mg/goal_gen/utils/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..1989ba9e3859bf9e429fba7d131385f76f28172c
--- /dev/null
+++ b/ACT_DP_multitask/detr/models/mr_mg/goal_gen/utils/utils.py
@@ -0,0 +1,32 @@
+# Copyright (2024) Bytedance Ltd. and/or its affiliates
+
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+
+# http://www.apache.org/licenses/LICENSE-2.0
+
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import os
+from lightning.pytorch.callbacks import Callback
+import json
+
+
+class SetupCallback(Callback):
+ def __init__(self, now, logdir, ckptdir):
+ super().__init__()
+ self.now = now
+ self.logdir = logdir
+ self.ckptdir = ckptdir
+
+ def on_train_start(self, trainer, model):
+ if trainer.global_rank == 0:
+ # Create logdirs and save configs
+ os.makedirs(self.logdir, exist_ok=True)
+ os.makedirs(self.ckptdir, exist_ok=True)
+
diff --git a/ACT_DP_multitask/detr/models/mr_mg/media/model.gif b/ACT_DP_multitask/detr/models/mr_mg/media/model.gif
new file mode 100644
index 0000000000000000000000000000000000000000..9eae2e1691c3364f60b96023e68904a5e9c3b5b5
--- /dev/null
+++ b/ACT_DP_multitask/detr/models/mr_mg/media/model.gif
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:3d96bce6dc9112fcfce234a15d48f8d3c1ee22354c3c62281a7bf10db25b2610
+size 1054873
diff --git a/ACT_DP_multitask/detr/models/mr_mg/policy/config/pretrain.json b/ACT_DP_multitask/detr/models/mr_mg/policy/config/pretrain.json
new file mode 100644
index 0000000000000000000000000000000000000000..990ace12eae46a0053244890983bcadea7743734
--- /dev/null
+++ b/ACT_DP_multitask/detr/models/mr_mg/policy/config/pretrain.json
@@ -0,0 +1,77 @@
+{
+ "seed": 123,
+ "ckpt_root": "SAVE_PATH/policy/checkpoints/",
+ "log_root": "LOG_PATH/policy/logs/",
+ "exp_name":"pretrain_policy",
+ "resume":null,
+ "policy": {
+ "resampler_params": {
+ "depth": 3,
+ "dim_head": 128,
+ "heads": 4,
+ "num_latents": 9,
+ "num_media_embeds": 1
+ },
+ "seq_len": 10,
+ "act_len": 5,
+ "act_latent_dim": 32,
+ "act_encoder_dim": 128,
+ "act_decoder_dim": 128,
+ "progress_decoder_dim":128,
+ "patch_feat_dim": 768,
+ "img_feat_dim": 768,
+ "lang_feat_dim":512,
+ "embed_dim": 384,
+ "use_resampler": true,
+ "n_layer": 12,
+ "n_head": 12,
+ "activation_function": "relu",
+ "dropout": 0.1,
+ "forward_n_max": 25
+ },
+ "trainer":{
+ "pl_config":{
+ "accelerator": "gpu",
+ "strategy": "ddp",
+ "precision": "bf16",
+ "logger": ["tensorboard"],
+ "gradient_clip_val": 1.0,
+ "accumulate_grad_batches": 1,
+ "use_distributed_sampler": true,
+ "log_every_n_steps": 50,
+ "max_epochs": 50
+ },
+ "lr_decay":true,
+ "batch_size": 64,
+ "without_norm_pix_loss": false,
+ "start_epoch": 0,
+ "learning_rate": 3.6e-4,
+ "min_learning_rate_scale": 1e-2,
+ "weight_decay": 0.0,
+ "betas_1": 0.9,
+ "betas_2": 0.999,
+ "warmup_epochs": 5,
+ "optimizer":"adam",
+ "num_workers": 8,
+ "save_epoch": 1,
+ "gripper_loss_ratio": 0.01,
+ "fwd_loss_ratio": 1,
+ "kl_loss_ratio": 1.0,
+ "progress_loss_ratio":1.0,
+ "act_pred": false,
+ "fwd_pred": true,
+ "progress_pred":false,
+ "fwd_pred_hand": false,
+ "fwd_pred_next_n": 3,
+ "finetune": false,
+ "use_pretrain": false,
+ "pretrained_model_path": ""
+ },
+
+ "input": {
+ "state_dim": 7,
+ "act_dim": 7,
+ "use_hand_rgb": false,
+ "use_state": false
+ }
+}
\ No newline at end of file
diff --git a/ACT_DP_multitask/detr/models/mr_mg/policy/config/train.json b/ACT_DP_multitask/detr/models/mr_mg/policy/config/train.json
new file mode 100644
index 0000000000000000000000000000000000000000..faf85374edc11cf550c4453e3385d3b75f8dcf21
--- /dev/null
+++ b/ACT_DP_multitask/detr/models/mr_mg/policy/config/train.json
@@ -0,0 +1,77 @@
+{
+ "seed": 123,
+ "ckpt_root": "SAVE_PATH/policy/checkpoints/",
+ "log_root": "LOG_PATH/policy/logs/",
+ "exp_name":"train_policy",
+ "resume":null,
+ "policy": {
+ "resampler_params": {
+ "depth": 3,
+ "dim_head": 128,
+ "heads": 4,
+ "num_latents": 9,
+ "num_media_embeds": 1
+ },
+ "seq_len": 10,
+ "act_len": 5,
+ "act_latent_dim": 32,
+ "act_encoder_dim": 128,
+ "act_decoder_dim": 128,
+ "progress_decoder_dim":128,
+ "patch_feat_dim": 768,
+ "img_feat_dim": 768,
+ "lang_feat_dim":512,
+ "embed_dim": 384,
+ "use_resampler": true,
+ "n_layer": 12,
+ "n_head": 12,
+ "activation_function": "relu",
+ "dropout": 0.1,
+ "forward_n_max": 25
+ },
+ "trainer":{
+ "pl_config":{
+ "accelerator": "gpu",
+ "strategy": "ddp",
+ "precision": "bf16",
+ "logger": ["tensorboard"],
+ "gradient_clip_val": 1.0,
+ "accumulate_grad_batches": 1,
+ "use_distributed_sampler": true,
+ "log_every_n_steps": 50,
+ "max_epochs": 50
+ },
+ "lr_decay":true,
+ "batch_size": 32,
+ "without_norm_pix_loss": false,
+ "start_epoch": 0,
+ "learning_rate": 1e-3,
+ "min_learning_rate_scale": 1e-2,
+ "weight_decay": 0.0,
+ "betas_1": 0.9,
+ "betas_2": 0.999,
+ "warmup_epochs": 1,
+ "optimizer":"adam",
+ "num_workers": 13,
+ "save_epoch": 1,
+ "gripper_loss_ratio": 0.01,
+ "fwd_loss_ratio": 0.1,
+ "kl_loss_ratio": 1.0,
+ "progress_loss_ratio":1.0,
+ "act_pred": true,
+ "fwd_pred": true,
+ "progress_pred":true,
+ "fwd_pred_hand": true,
+ "fwd_pred_next_n": 3,
+ "finetune": true,
+ "use_pretrain": true,
+ "pretrained_model_path": "PATH_TO_PRETRAINED_CHECKPOINT"
+ },
+
+ "input": {
+ "state_dim": 7,
+ "act_dim": 7,
+ "use_hand_rgb": true,
+ "use_state": true
+ }
+}
\ No newline at end of file
diff --git a/ACT_DP_multitask/detr/models/mr_mg/policy/install.sh b/ACT_DP_multitask/detr/models/mr_mg/policy/install.sh
new file mode 100644
index 0000000000000000000000000000000000000000..f5569897c9998d664029298e24598642d9e1eea8
--- /dev/null
+++ b/ACT_DP_multitask/detr/models/mr_mg/policy/install.sh
@@ -0,0 +1,14 @@
+sudo apt-get -y install libosmesa6-dev
+sudo apt-get -y install patchelf
+pip install --upgrade transformers
+pip3 install flamingo_pytorch
+pip3 install tensorboard
+pip install opencv-python-headless
+pip3 install ftfy regex tqdm
+pip3 install matplotlib decord
+pip install git+https://github.com/openai/CLIP.git
+pip install sentencepiece
+pip3 install torch==2.1.0 torchvision==0.16.0 torchaudio==2.1.0 --index-url https://download.pytorch.org/whl/cu118
+pip3 install lightning==2.1.0
+pip3 install pytorch-lightning==2.1.0
+pip install "numpy<2.0.0"
\ No newline at end of file
diff --git a/ACT_DP_multitask/detr/models/mr_mg/policy/main.py b/ACT_DP_multitask/detr/models/mr_mg/policy/main.py
new file mode 100644
index 0000000000000000000000000000000000000000..35dc18ff76a37eaa2fbdf55c14633f1f2d9bdb00
--- /dev/null
+++ b/ACT_DP_multitask/detr/models/mr_mg/policy/main.py
@@ -0,0 +1,181 @@
+# Copyright (2024) Bytedance Ltd. and/or its affiliates
+
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+
+# http://www.apache.org/licenses/LICENSE-2.0
+
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import os
+import argparse
+import json
+from pathlib import Path
+import random
+import numpy as np
+import torch
+from pathlib import Path
+import copy
+import datetime
+from lightning.pytorch.callbacks import Callback
+from lightning.pytorch.callbacks import LearningRateMonitor, ModelCheckpoint
+from lightning.pytorch.trainer import Trainer
+from lightning.pytorch.loggers import TensorBoardLogger
+from lightning.pytorch.strategies import DDPStrategy
+from lightning import seed_everything
+from data.calvin_dataset import CalvinDataset_Policy
+from data.ego4d_dataset import Ego4DDataset_Policy
+from training.trainer import Policy_Trainer
+from torch.utils.data import DataLoader
+def set_seed(seed=0):
+ random.seed(seed)
+ np.random.seed(seed)
+ torch.manual_seed(seed)
+ torch.cuda.manual_seed_all(seed)
+ seed_everything(seed)
+
+
+def get_date_str():
+ return str(datetime.date.today())
+
+class SetupCallback(Callback):
+ def __init__(self, now, logdir, ckptdir):
+ super().__init__()
+ self.now = now
+ self.logdir = logdir
+ self.ckptdir = ckptdir
+
+ def on_train_start(self, trainer, model):
+ if trainer.global_rank == 0:
+ # Create logdirs and save configs
+ os.makedirs(self.logdir, exist_ok=True)
+ os.makedirs(self.ckptdir, exist_ok=True)
+
+def init_setup_callback(config):
+ return SetupCallback(
+ now=str(datetime.datetime.now()).replace(' ', '_'),
+ logdir=config['log_dir'],
+ ckptdir=config['ckpt_dir']
+ )
+
+def init_trainer_config(configs):
+ trainer_config = copy.deepcopy(configs['trainer']["pl_config"])
+ trainer_config['devices'] = configs.get('devices', 'auto')
+ trainer_config['num_nodes'] = configs.get('num_nodes', 1)
+
+ if 'strategy' not in trainer_config or trainer_config['strategy'] == 'ddp':
+ trainer_config['strategy'] = DDPStrategy(find_unused_parameters=False)
+
+ exp_name = configs['exp_name']
+
+ # init loggers
+ log_dir = os.path.join(get_date_str(), exp_name)
+ log_dir = os.path.join(configs['log_root'], log_dir)
+ configs['log_dir'] = log_dir
+ Path(configs['log_dir']).mkdir(parents=True, exist_ok=True)
+ trainer_config['logger'] = [TensorBoardLogger(log_dir, name=exp_name)]
+
+
+ # TODO: make callbacks configurable
+ ckpt_dir = os.path.join(get_date_str(), exp_name)
+ ckpt_dir = os.path.join(configs['ckpt_root'], ckpt_dir)
+ configs['ckpt_dir'] = ckpt_dir
+ Path(configs['ckpt_dir']).mkdir(parents=True, exist_ok=True)
+ trainer_config['callbacks'] = [
+ init_setup_callback(configs),
+ LearningRateMonitor(logging_interval='step'),
+ ModelCheckpoint(dirpath=ckpt_dir, save_top_k=-1, every_n_epochs=configs["trainer"]["save_epoch"]) # if you have only limited space, just set save_top_k=1 to save the best model
+ ]
+
+ return trainer_config
+
+def experiment(variant):
+ set_seed(variant['seed'])
+ trainer_config = init_trainer_config(variant)
+ trainer = Trainer(**trainer_config)
+ model = Policy_Trainer(variant)
+ # dataset
+ if variant["trainer"]["finetune"]:
+ train_data= CalvinDataset_Policy(
+ data_dir="PATH_TO_CALVIN/calvin_data",
+ use_data_augmentation=True,
+ subfolder= "task_ABC_D",
+ mode= "train",
+ forward_n_max=25,
+ use_play=False,
+ use_labeled=True)
+ val_data= CalvinDataset_Policy(
+ data_dir="PATH_TO_CALVIN/calvin_data",
+ use_data_augmentation=False,
+ subfolder= "task_ABC_D",
+ mode= "validate",
+ forward_n_max=25,
+ use_play=False,
+ use_labeled=True)
+ else:
+ train_data= Ego4DDataset_Policy(
+ data_dir="PATH_TO_Ego4d_Videos",
+ preprocess=None,
+ video_sample_rate=2,
+ seq_len=10,
+ annotation_file= "PATH_TO_Ego4d_800k_annotations",
+ use_data_augmentation=True,
+ goal_interval=7)
+ val_data= Ego4DDataset_Policy(
+ data_dir="PATH_TO_Ego4d_Videos",
+ preprocess=None,
+ video_sample_rate=2,
+ seq_len=10,
+ annotation_file= "PATH_TO_Ego4d_800k_annotations",
+ use_data_augmentation=False,
+ goal_interval=7)
+ train_dataloader= DataLoader(train_data,
+ batch_size=variant["trainer"]["batch_size"],
+ num_workers=variant["trainer"]["num_workers"])
+ val_dataloader= DataLoader(val_data,
+ batch_size=variant["trainer"]["batch_size"],
+ num_workers=variant["trainer"]["num_workers"])
+
+ _kwargs = {
+ 'model': model,
+ 'train_dataloaders':train_dataloader,
+ 'val_dataloaders':val_dataloader,
+ 'ckpt_path': variant['resume'] # when you want to restore your training, modify this variant
+ }
+ if _kwargs['ckpt_path'] is not None:
+ print(f"Resuming from {variant['resume']}...")
+ trainer.fit(**_kwargs)
+
+
+def parse_args():
+ parser = argparse.ArgumentParser()
+
+ # Experiment
+ parser.add_argument('--config', type=str,default="")
+ parser.add_argument('--devices', default=1, type=int)
+ parser.add_argument('--num_nodes', default=1, type=int)
+ parser.add_argument('--seed', default=None, type=int)
+ parser.add_argument('--log_root', default=None, type=str)
+ parser.add_argument('--ckpt_root', default=None, type=str)
+ parser.add_argument('--resume', type=str)
+ temp_args = vars(parser.parse_args())
+ config_path=temp_args.pop("config")
+ # load config files
+ configs = json.load(open(config_path))
+ for (k, v) in temp_args.items():
+ if k not in configs:
+ configs[k]=v
+
+ return configs
+
+
+
+if __name__ == '__main__':
+ configs=parse_args()
+ os.system(f"sudo chmod 777 -R {configs['ckpt_root']}")
+ os.system(f"sudo chmod 777 -R {configs['log_root']}")
+ experiment(variant=configs)
diff --git a/ACT_DP_multitask/detr/models/mr_mg/policy/main.sh b/ACT_DP_multitask/detr/models/mr_mg/policy/main.sh
new file mode 100644
index 0000000000000000000000000000000000000000..e8117d6221b034d996a3bfd568b9ed6119ac9a05
--- /dev/null
+++ b/ACT_DP_multitask/detr/models/mr_mg/policy/main.sh
@@ -0,0 +1,17 @@
+#!/usr/bin/env bash
+cd /opt/tiger/GR_MG
+GPUS_PER_NODE=8 # number of gpus per machine
+MASTER_ADDR={master_address}":"{port} # modify it with your own address and port
+NNODES=1 # number of machines
+JOB_ID=107
+torchrun \
+ --nproc_per_node $GPUS_PER_NODE \
+ --nnodes $NNODES \
+ --node_rank 0 \
+ --rdzv_endpoint $MASTER_ADDR \
+ --rdzv_id $JOB_ID \
+ --rdzv_backend c10d \
+ policy/main.py \
+ --config ${@:1} \
+ --devices $GPUS_PER_NODE \
+ --num_nodes $NNODES
\ No newline at end of file
diff --git a/ACT_DP_multitask/detr/models/mr_mg/policy/model/__pycache__/vision_transformer.cpython-310.pyc b/ACT_DP_multitask/detr/models/mr_mg/policy/model/__pycache__/vision_transformer.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..4666a58d37efb7ef0cb52589be3293a73c20a24d
Binary files /dev/null and b/ACT_DP_multitask/detr/models/mr_mg/policy/model/__pycache__/vision_transformer.cpython-310.pyc differ
diff --git a/ACT_DP_multitask/detr/models/mr_mg/policy/model/__pycache__/vision_transformer.cpython-37.pyc b/ACT_DP_multitask/detr/models/mr_mg/policy/model/__pycache__/vision_transformer.cpython-37.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..d10c8d0dee67119ccb59c61d73c6a9bc4911c528
Binary files /dev/null and b/ACT_DP_multitask/detr/models/mr_mg/policy/model/__pycache__/vision_transformer.cpython-37.pyc differ
diff --git a/ACT_DP_multitask/detr/models/mr_mg/policy/model/__pycache__/vision_transformer.cpython-38.pyc b/ACT_DP_multitask/detr/models/mr_mg/policy/model/__pycache__/vision_transformer.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..15e0c6a84ac4140e4fe9e6b5d950f9e1eafaccc5
Binary files /dev/null and b/ACT_DP_multitask/detr/models/mr_mg/policy/model/__pycache__/vision_transformer.cpython-38.pyc differ
diff --git a/ACT_DP_multitask/detr/models/mr_mg/policy/model/gpt2.py b/ACT_DP_multitask/detr/models/mr_mg/policy/model/gpt2.py
new file mode 100644
index 0000000000000000000000000000000000000000..d124e3d49ccf7be6127962b20c5e7650b12560b6
--- /dev/null
+++ b/ACT_DP_multitask/detr/models/mr_mg/policy/model/gpt2.py
@@ -0,0 +1,944 @@
+# coding=utf-8
+# Copyright 2018 The OpenAI Team Authors and HuggingFace Inc. team.
+# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""PyTorch OpenAI GPT-2 model."""
+
+import math
+import os
+import warnings
+from dataclasses import dataclass
+from typing import Optional, Tuple, Union
+
+import torch
+import torch.utils.checkpoint
+from torch import nn
+from torch.cuda.amp import autocast
+from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
+
+from transformers.activations import ACT2FN
+from transformers.modeling_outputs import (
+ BaseModelOutputWithPastAndCrossAttentions,
+ CausalLMOutputWithCrossAttentions,
+ QuestionAnsweringModelOutput,
+ SequenceClassifierOutputWithPast,
+ TokenClassifierOutput,
+)
+from transformers.modeling_utils import PreTrainedModel, SequenceSummary
+from transformers.pytorch_utils import Conv1D, find_pruneable_heads_and_indices, prune_conv1d_layer
+from transformers.utils import (
+ ModelOutput,
+ add_code_sample_docstrings,
+ add_start_docstrings,
+ add_start_docstrings_to_model_forward,
+ logging,
+ replace_return_docstrings,
+)
+from transformers.utils.model_parallel_utils import assert_device_map, get_device_map
+from transformers.models.gpt2.configuration_gpt2 import GPT2Config
+
+
+logger = logging.get_logger(__name__)
+
+_CHECKPOINT_FOR_DOC = "gpt2"
+_CONFIG_FOR_DOC = "GPT2Config"
+
+GPT2_PRETRAINED_MODEL_ARCHIVE_LIST = [
+ "gpt2",
+ "gpt2-medium",
+ "gpt2-large",
+ "gpt2-xl",
+ "distilgpt2",
+ # See all GPT-2 models at https://huggingface.co/models?filter=gpt2
+]
+
+
+def load_tf_weights_in_gpt2(model, config, gpt2_checkpoint_path):
+ """Load tf checkpoints in a pytorch model"""
+ try:
+ import re
+
+ import tensorflow as tf
+ except ImportError:
+ logger.error(
+ "Loading a TensorFlow model in PyTorch, requires TensorFlow to be installed. Please see "
+ "https://www.tensorflow.org/install/ for installation instructions."
+ )
+ raise
+ tf_path = os.path.abspath(gpt2_checkpoint_path)
+ logger.info(f"Converting TensorFlow checkpoint from {tf_path}")
+ # Load weights from TF model
+ init_vars = tf.train.list_variables(tf_path)
+ names = []
+ arrays = []
+ for name, shape in init_vars:
+ logger.info(f"Loading TF weight {name} with shape {shape}")
+ array = tf.train.load_variable(tf_path, name)
+ names.append(name)
+ arrays.append(array.squeeze())
+
+ for name, array in zip(names, arrays):
+ name = name[6:] # skip "model/"
+ name = name.split("/")
+ pointer = model
+ for m_name in name:
+ if re.fullmatch(r"[A-Za-z]+\d+", m_name):
+ scope_names = re.split(r"(\d+)", m_name)
+ else:
+ scope_names = [m_name]
+ if scope_names[0] == "w" or scope_names[0] == "g":
+ pointer = getattr(pointer, "weight")
+ elif scope_names[0] == "b":
+ pointer = getattr(pointer, "bias")
+ elif scope_names[0] == "wpe" or scope_names[0] == "wte":
+ pointer = getattr(pointer, scope_names[0])
+ pointer = getattr(pointer, "weight")
+ else:
+ pointer = getattr(pointer, scope_names[0])
+ if len(scope_names) >= 2:
+ num = int(scope_names[1])
+ pointer = pointer[num]
+ try:
+ if pointer.shape != array.shape:
+ raise ValueError(f"Pointer shape {pointer.shape} and array shape {array.shape} mismatched")
+ except ValueError as e:
+ e.args += (pointer.shape, array.shape)
+ raise
+ logger.info(f"Initialize PyTorch weight {name}")
+ pointer.data = torch.from_numpy(array)
+ return model
+
+
+class GPT2Attention(nn.Module):
+ def __init__(self, config, is_cross_attention=False, layer_idx=None):
+ super().__init__()
+
+ max_positions = config.max_position_embeddings
+ self.register_buffer(
+ "bias",
+ torch.tril(torch.ones((max_positions, max_positions), dtype=torch.bool)).view(
+ 1, 1, max_positions, max_positions
+ ),
+ persistent=False,
+ )
+ self.register_buffer("masked_bias", torch.tensor(-1e4), persistent=False)
+
+ self.embed_dim = config.hidden_size
+ self.num_heads = config.num_attention_heads
+ self.head_dim = self.embed_dim // self.num_heads
+ self.split_size = self.embed_dim
+ if self.head_dim * self.num_heads != self.embed_dim:
+ raise ValueError(
+ f"`embed_dim` must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:"
+ f" {self.num_heads})."
+ )
+
+ self.scale_attn_weights = config.scale_attn_weights
+ self.is_cross_attention = is_cross_attention
+
+ # Layer-wise attention scaling, reordering, and upcasting
+ self.scale_attn_by_inverse_layer_idx = config.scale_attn_by_inverse_layer_idx
+ self.layer_idx = layer_idx
+ self.reorder_and_upcast_attn = config.reorder_and_upcast_attn
+
+ if self.is_cross_attention:
+ self.c_attn = Conv1D(2 * self.embed_dim, self.embed_dim)
+ self.q_attn = Conv1D(self.embed_dim, self.embed_dim)
+ else:
+ self.c_attn = Conv1D(3 * self.embed_dim, self.embed_dim)
+ self.c_proj = Conv1D(self.embed_dim, self.embed_dim)
+
+ self.attn_dropout = nn.Dropout(config.attn_pdrop)
+ self.resid_dropout = nn.Dropout(config.resid_pdrop)
+
+ self.pruned_heads = set()
+
+ def prune_heads(self, heads):
+ if len(heads) == 0:
+ return
+ heads, index = find_pruneable_heads_and_indices(heads, self.num_heads, self.head_dim, self.pruned_heads)
+ index_attn = torch.cat([index, index + self.split_size, index + (2 * self.split_size)])
+
+ # Prune conv1d layers
+ self.c_attn = prune_conv1d_layer(self.c_attn, index_attn, dim=1)
+ self.c_proj = prune_conv1d_layer(self.c_proj, index, dim=0)
+
+ # Update hyper params
+ self.split_size = (self.split_size // self.num_heads) * (self.num_heads - len(heads))
+ self.num_heads = self.num_heads - len(heads)
+ self.pruned_heads = self.pruned_heads.union(heads)
+
+ def _attn(self, query, key, value, attention_mask=None, head_mask=None):
+ attn_weights = torch.matmul(query, key.transpose(-1, -2))
+
+ if self.scale_attn_weights:
+ attn_weights = attn_weights / torch.full(
+ [], value.size(-1) ** 0.5, dtype=attn_weights.dtype, device=attn_weights.device
+ )
+
+ # Layer-wise attention scaling
+ if self.scale_attn_by_inverse_layer_idx:
+ attn_weights = attn_weights / float(self.layer_idx + 1)
+
+ if not self.is_cross_attention:
+ # if only "normal" attention layer implements causal mask
+ query_length, key_length = query.size(-2), key.size(-2)
+ causal_mask = self.bias[:, :, key_length - query_length : key_length, :key_length]
+ mask_value = torch.finfo(attn_weights.dtype).min
+ # Need to be a tensor, otherwise we get error: `RuntimeError: expected scalar type float but found double`.
+ # Need to be on the same device, otherwise `RuntimeError: ..., x and y to be on the same device`
+ mask_value = torch.full([], mask_value, dtype=attn_weights.dtype).to(attn_weights.device)
+ attn_weights = torch.where(causal_mask, attn_weights.to(attn_weights.dtype), mask_value)
+
+ if attention_mask is not None:
+ # Apply the attention mask
+ attn_weights = attn_weights + attention_mask
+
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1)
+
+ # Downcast (if necessary) back to V's dtype (if in mixed-precision) -- No-Op otherwise
+ attn_weights = attn_weights.type(value.dtype)
+ attn_weights = self.attn_dropout(attn_weights)
+
+ # Mask heads if we want to
+ if head_mask is not None:
+ attn_weights = attn_weights * head_mask
+
+ attn_output = torch.matmul(attn_weights, value)
+
+ return attn_output, attn_weights
+
+ def _upcast_and_reordered_attn(self, query, key, value, attention_mask=None, head_mask=None):
+ # Use `torch.baddbmm` (a bit more efficient w/ alpha param for scaling -- from Megatron-LM)
+ bsz, num_heads, q_seq_len, dk = query.size()
+ _, _, k_seq_len, _ = key.size()
+
+ # Preallocate attn_weights for `baddbmm`
+ attn_weights = torch.empty(bsz * num_heads, q_seq_len, k_seq_len, dtype=torch.float32, device=query.device)
+
+ # Compute Scale Factor
+ scale_factor = 1.0
+ if self.scale_attn_weights:
+ scale_factor /= float(value.size(-1)) ** 0.5
+
+ if self.scale_attn_by_inverse_layer_idx:
+ scale_factor /= float(self.layer_idx + 1)
+
+ # Upcast (turn off autocast) and reorder (Scale K by 1 / root(dk))
+ with autocast(enabled=False):
+ q, k = query.reshape(-1, q_seq_len, dk), key.transpose(-1, -2).reshape(-1, dk, k_seq_len)
+ attn_weights = torch.baddbmm(attn_weights, q.float(), k.float(), beta=0, alpha=scale_factor)
+ attn_weights = attn_weights.reshape(bsz, num_heads, q_seq_len, k_seq_len)
+
+ if not self.is_cross_attention:
+ # if only "normal" attention layer implements causal mask
+ query_length, key_length = query.size(-2), key.size(-2)
+ causal_mask = self.bias[:, :, key_length - query_length : key_length, :key_length]
+ mask_value = torch.finfo(attn_weights.dtype).min
+ # Need to be a tensor, otherwise we get error: `RuntimeError: expected scalar type float but found double`.
+ # Need to be on the same device, otherwise `RuntimeError: ..., x and y to be on the same device`
+ mask_value = torch.tensor(mask_value, dtype=attn_weights.dtype).to(attn_weights.device)
+ attn_weights = torch.where(causal_mask, attn_weights, mask_value)
+
+ if attention_mask is not None:
+ # Apply the attention mask
+ attn_weights = attn_weights + attention_mask
+
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1)
+
+ # Downcast (if necessary) back to V's dtype (if in mixed-precision) -- No-Op if otherwise
+ if attn_weights.dtype != torch.float32:
+ raise RuntimeError("Error with upcasting, attn_weights does not have dtype torch.float32")
+ attn_weights = attn_weights.type(value.dtype)
+ attn_weights = self.attn_dropout(attn_weights)
+
+ # Mask heads if we want to
+ if head_mask is not None:
+ attn_weights = attn_weights * head_mask
+
+ attn_output = torch.matmul(attn_weights, value)
+
+ return attn_output, attn_weights
+
+ def _split_heads(self, tensor, num_heads, attn_head_size):
+ """
+ Splits hidden_size dim into attn_head_size and num_heads
+ """
+ new_shape = tensor.size()[:-1] + (num_heads, attn_head_size)
+ tensor = tensor.view(new_shape)
+ return tensor.permute(0, 2, 1, 3) # (batch, head, seq_length, head_features)
+
+ def _merge_heads(self, tensor, num_heads, attn_head_size):
+ """
+ Merges attn_head_size dim and num_attn_heads dim into hidden_size
+ """
+ tensor = tensor.permute(0, 2, 1, 3).contiguous()
+ new_shape = tensor.size()[:-2] + (num_heads * attn_head_size,)
+ return tensor.view(new_shape)
+
+ def forward(
+ self,
+ hidden_states: Optional[Tuple[torch.FloatTensor]],
+ layer_past: Optional[Tuple[torch.Tensor]] = None,
+ attention_mask: Optional[torch.FloatTensor] = None,
+ head_mask: Optional[torch.FloatTensor] = None,
+ encoder_hidden_states: Optional[torch.Tensor] = None,
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
+ use_cache: Optional[bool] = False,
+ output_attentions: Optional[bool] = False,
+ ) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]], ...]:
+ if encoder_hidden_states is not None:
+ if not hasattr(self, "q_attn"):
+ raise ValueError(
+ "If class is used as cross attention, the weights `q_attn` have to be defined. "
+ "Please make sure to instantiate class with `GPT2Attention(..., is_cross_attention=True)`."
+ )
+
+ query = self.q_attn(hidden_states)
+ key, value = self.c_attn(encoder_hidden_states).split(self.split_size, dim=2)
+ attention_mask = encoder_attention_mask
+ else:
+ query, key, value = self.c_attn(hidden_states).split(self.split_size, dim=2)
+
+ query = self._split_heads(query, self.num_heads, self.head_dim)
+ key = self._split_heads(key, self.num_heads, self.head_dim)
+ value = self._split_heads(value, self.num_heads, self.head_dim)
+
+ if layer_past is not None:
+ past_key, past_value = layer_past
+ key = torch.cat((past_key, key), dim=-2)
+ value = torch.cat((past_value, value), dim=-2)
+
+ if use_cache is True:
+ present = (key, value)
+ else:
+ present = None
+
+ if self.reorder_and_upcast_attn:
+ attn_output, attn_weights = self._upcast_and_reordered_attn(query, key, value, attention_mask, head_mask)
+ else:
+ attn_output, attn_weights = self._attn(query, key, value, attention_mask, head_mask)
+
+ attn_output = self._merge_heads(attn_output, self.num_heads, self.head_dim)
+ attn_output = self.c_proj(attn_output)
+ attn_output = self.resid_dropout(attn_output)
+
+ outputs = (attn_output, present)
+ if output_attentions:
+ outputs += (attn_weights,)
+
+ return outputs # a, present, (attentions)
+
+
+class GPT2MLP(nn.Module):
+ def __init__(self, intermediate_size, config):
+ super().__init__()
+ embed_dim = config.hidden_size
+ self.c_fc = Conv1D(intermediate_size, embed_dim)
+ self.c_proj = Conv1D(embed_dim, intermediate_size)
+ self.act = ACT2FN[config.activation_function]
+ self.dropout = nn.Dropout(config.resid_pdrop)
+
+ def forward(self, hidden_states: Optional[Tuple[torch.FloatTensor]]) -> torch.FloatTensor:
+ hidden_states = self.c_fc(hidden_states)
+ hidden_states = self.act(hidden_states)
+ hidden_states = self.c_proj(hidden_states)
+ hidden_states = self.dropout(hidden_states)
+ return hidden_states
+
+
+class GPT2Block(nn.Module):
+ def __init__(self, config, layer_idx=None):
+ super().__init__()
+ hidden_size = config.hidden_size
+ inner_dim = config.n_inner if config.n_inner is not None else 4 * hidden_size
+
+ self.ln_1 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
+ self.attn = GPT2Attention(config, layer_idx=layer_idx)
+ self.ln_2 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
+
+ if config.add_cross_attention:
+ self.crossattention = GPT2Attention(config, is_cross_attention=True, layer_idx=layer_idx)
+ self.ln_cross_attn = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
+
+ self.mlp = GPT2MLP(inner_dim, config)
+
+ def forward(
+ self,
+ hidden_states: Optional[Tuple[torch.FloatTensor]],
+ layer_past: Optional[Tuple[torch.Tensor]] = None,
+ attention_mask: Optional[torch.FloatTensor] = None,
+ head_mask: Optional[torch.FloatTensor] = None,
+ encoder_hidden_states: Optional[torch.Tensor] = None,
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
+ use_cache: Optional[bool] = False,
+ output_attentions: Optional[bool] = False,
+ ) -> Union[Tuple[torch.Tensor], Optional[Tuple[torch.Tensor, Tuple[torch.FloatTensor, ...]]]]:
+ residual = hidden_states
+ hidden_states = self.ln_1(hidden_states)
+ attn_outputs = self.attn(
+ hidden_states,
+ layer_past=layer_past,
+ attention_mask=attention_mask,
+ head_mask=head_mask,
+ use_cache=use_cache,
+ output_attentions=output_attentions,
+ )
+ attn_output = attn_outputs[0] # output_attn: a, present, (attentions)
+ outputs = attn_outputs[1:]
+ # residual connection
+ hidden_states = attn_output + residual
+
+ if encoder_hidden_states is not None:
+ # add one self-attention block for cross-attention
+ if not hasattr(self, "crossattention"):
+ raise ValueError(
+ f"If `encoder_hidden_states` are passed, {self} has to be instantiated with "
+ "cross-attention layers by setting `config.add_cross_attention=True`"
+ )
+ residual = hidden_states
+ hidden_states = self.ln_cross_attn(hidden_states)
+ cross_attn_outputs = self.crossattention(
+ hidden_states,
+ attention_mask=attention_mask,
+ head_mask=head_mask,
+ encoder_hidden_states=encoder_hidden_states,
+ encoder_attention_mask=encoder_attention_mask,
+ output_attentions=output_attentions,
+ )
+ attn_output = cross_attn_outputs[0]
+ # residual connection
+ hidden_states = residual + attn_output
+ outputs = outputs + cross_attn_outputs[2:] # add cross attentions if we output attention weights
+
+ residual = hidden_states
+ hidden_states = self.ln_2(hidden_states)
+ feed_forward_hidden_states = self.mlp(hidden_states)
+ # residual connection
+ hidden_states = residual + feed_forward_hidden_states
+
+ if use_cache:
+ outputs = (hidden_states,) + outputs
+ else:
+ outputs = (hidden_states,) + outputs[1:]
+
+ return outputs # hidden_states, present, (attentions, cross_attentions)
+
+
+class GPT2PreTrainedModel(PreTrainedModel):
+ """
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
+ models.
+ """
+
+ config_class = GPT2Config
+ load_tf_weights = load_tf_weights_in_gpt2
+ base_model_prefix = "transformer"
+ is_parallelizable = True
+ supports_gradient_checkpointing = True
+ _no_split_modules = ["GPT2Block"]
+ _skip_keys_device_placement = "past_key_values"
+
+ def __init__(self, *inputs, **kwargs):
+ super().__init__(*inputs, **kwargs)
+
+ def _init_weights(self, module):
+ """Initialize the weights."""
+ if isinstance(module, (nn.Linear, Conv1D)):
+ # Slightly different from the TF version which uses truncated_normal for initialization
+ # cf https://github.com/pytorch/pytorch/pull/5617
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
+ if module.bias is not None:
+ module.bias.data.zero_()
+ elif isinstance(module, nn.Embedding):
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
+ if module.padding_idx is not None:
+ module.weight.data[module.padding_idx].zero_()
+ elif isinstance(module, nn.LayerNorm):
+ module.bias.data.zero_()
+ module.weight.data.fill_(1.0)
+
+ # Reinitialize selected weights subject to the OpenAI GPT-2 Paper Scheme:
+ # > A modified initialization which accounts for the accumulation on the residual path with model depth. Scale
+ # > the weights of residual layers at initialization by a factor of 1/√N where N is the # of residual layers.
+ # > -- GPT-2 :: https://openai.com/blog/better-language-models/
+ #
+ # Reference (Megatron-LM): https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/model/gpt_model.py
+ for name, p in module.named_parameters():
+ if name == "c_proj.weight":
+ # Special Scaled Initialization --> There are 2 Layer Norms per Transformer Block
+ p.data.normal_(mean=0.0, std=(self.config.initializer_range / math.sqrt(2 * self.config.n_layer)))
+
+ def _set_gradient_checkpointing(self, module, value=False):
+ if isinstance(module, GPT2Model):
+ module.gradient_checkpointing = value
+
+
+@dataclass
+class GPT2DoubleHeadsModelOutput(ModelOutput):
+ """
+ Base class for outputs of models predicting if two sentences are consecutive or not.
+
+ Args:
+ loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
+ Language modeling loss.
+ mc_loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `mc_labels` is provided):
+ Multiple choice classification loss.
+ logits (`torch.FloatTensor` of shape `(batch_size, num_choices, sequence_length, config.vocab_size)`):
+ Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
+ mc_logits (`torch.FloatTensor` of shape `(batch_size, num_choices)`):
+ Prediction scores of the multiple choice classification head (scores for each choice before SoftMax).
+ past_key_values (`Tuple[Tuple[torch.Tensor]]`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
+ Tuple of length `config.n_layers`, containing tuples of tensors of shape `(batch_size, num_heads,
+ sequence_length, embed_size_per_head)`).
+
+ Contains pre-computed hidden-states (key and values in the attention blocks) that can be used (see
+ `past_key_values` input) to speed up sequential decoding.
+ hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
+ Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of
+ shape `(batch_size, sequence_length, hidden_size)`.
+
+ Hidden-states of the model at the output of each layer plus the initial embedding outputs.
+ attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
+ Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
+ sequence_length)`.
+
+ GPT2Attentions weights after the attention softmax, used to compute the weighted average in the
+ self-attention heads.
+ """
+
+ loss: Optional[torch.FloatTensor] = None
+ mc_loss: Optional[torch.FloatTensor] = None
+ logits: torch.FloatTensor = None
+ mc_logits: torch.FloatTensor = None
+ past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
+ hidden_states: Optional[Tuple[torch.FloatTensor]] = None
+ attentions: Optional[Tuple[torch.FloatTensor]] = None
+
+
+GPT2_START_DOCSTRING = r"""
+
+ This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
+ library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
+ etc.)
+
+ This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
+ Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
+ and behavior.
+
+ Parameters:
+ config ([`GPT2Config`]): Model configuration class with all the parameters of the model.
+ Initializing with a config file does not load the weights associated with the model, only the
+ configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
+"""
+
+GPT2_INPUTS_DOCSTRING = r"""
+ Args:
+ input_ids (`torch.LongTensor` of shape `(batch_size, input_ids_length)`):
+ `input_ids_length` = `sequence_length` if `past_key_values` is `None` else
+ `past_key_values[0][0].shape[-2]` (`sequence_length` of input past key value states). Indices of input
+ sequence tokens in the vocabulary.
+
+ If `past_key_values` is used, only `input_ids` that do not have their past calculated should be passed as
+ `input_ids`.
+
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
+ [`PreTrainedTokenizer.__call__`] for details.
+
+ [What are input IDs?](../glossary#input-ids)
+ past_key_values (`Tuple[Tuple[torch.Tensor]]` of length `config.n_layers`):
+ Contains precomputed hidden-states (key and values in the attention blocks) as computed by the model (see
+ `past_key_values` output below). Can be used to speed up sequential decoding. The `input_ids` which have
+ their past given to this model should not be passed as `input_ids` as they have already been computed.
+ attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
+
+ - 1 for tokens that are **not masked**,
+ - 0 for tokens that are **masked**.
+
+ If `past_key_values` is used, `attention_mask` needs to contain the masking strategy that was used for
+ `past_key_values`. In other words, the `attention_mask` always has to have the length:
+ `len(past_key_values) + len(input_ids)`
+
+ [What are attention masks?](../glossary#attention-mask)
+ token_type_ids (`torch.LongTensor` of shape `(batch_size, input_ids_length)`, *optional*):
+ Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0,
+ 1]`:
+
+ - 0 corresponds to a *sentence A* token,
+ - 1 corresponds to a *sentence B* token.
+
+ [What are token type IDs?](../glossary#token-type-ids)
+ position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
+ config.max_position_embeddings - 1]`.
+
+ [What are position IDs?](../glossary#position-ids)
+ head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):
+ Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`:
+
+ - 1 indicates the head is **not masked**,
+ - 0 indicates the head is **masked**.
+
+ inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
+ Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
+ is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
+ model's internal embedding lookup matrix.
+
+ If `past_key_values` is used, optionally only the last `inputs_embeds` have to be input (see
+ `past_key_values`).
+ use_cache (`bool`, *optional*):
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
+ `past_key_values`).
+ output_attentions (`bool`, *optional*):
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
+ tensors for more detail.
+ output_hidden_states (`bool`, *optional*):
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
+ more detail.
+ return_dict (`bool`, *optional*):
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
+"""
+PARALLELIZE_DOCSTRING = r"""
+ This is an experimental feature and is a subject to change at a moment's notice.
+
+ Uses a device map to distribute attention modules of the model across several devices. If no device map is given,
+ it will evenly distribute blocks across all devices.
+
+ Args:
+ device_map (`Dict[int, list]`, optional, defaults to None):
+ A dictionary that maps attention modules to devices. Note that the embedding module and LMHead are always
+ automatically mapped to the first device (for esoteric reasons). That means that the first device should
+ have fewer attention modules mapped to it than other devices. For reference, the gpt2 models have the
+ following number of attention modules:
+
+ - gpt2: 12
+ - gpt2-medium: 24
+ - gpt2-large: 36
+ - gpt2-xl: 48
+
+ Example:
+
+ ```python
+ # Here is an example of a device map on a machine with 4 GPUs using gpt2-xl, which has a total of 48 attention modules:
+ model = GPT2LMHeadModel.from_pretrained("gpt2-xl")
+ device_map = {
+ 0: [0, 1, 2, 3, 4, 5, 6, 7, 8],
+ 1: [9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21],
+ 2: [22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34],
+ 3: [35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47],
+ }
+ model.parallelize(device_map)
+ ```
+"""
+DEPARALLELIZE_DOCSTRING = r"""
+ Moves the model to cpu from a model parallel state.
+
+ Example:
+
+ ```python
+ # On a 4 GPU machine with gpt2-large:
+ model = GPT2LMHeadModel.from_pretrained("gpt2-large")
+ device_map = {
+ 0: [0, 1, 2, 3, 4, 5, 6, 7],
+ 1: [8, 9, 10, 11, 12, 13, 14, 15],
+ 2: [16, 17, 18, 19, 20, 21, 22, 23],
+ 3: [24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35],
+ }
+ model.parallelize(device_map) # Splits the model across several devices
+ model.deparallelize() # Put the model back on cpu and cleans memory by calling torch.cuda.empty_cache()
+ ```
+"""
+
+
+@add_start_docstrings(
+ "The bare GPT2 Model transformer outputting raw hidden-states without any specific head on top.",
+ GPT2_START_DOCSTRING,
+)
+class GPT2Model(GPT2PreTrainedModel):
+ def __init__(self, config):
+ super().__init__(config)
+
+ self.embed_dim = config.hidden_size
+
+ # self.wte = nn.Embedding(config.vocab_size, self.embed_dim) # ä¸ä¼šè¢«ç”¨åˆ°ï¼Œæ‰€ä»¥åŽ»æŽ‰
+ # self.wpe = nn.Embedding(config.max_position_embeddings, self.embed_dim)
+
+ self.drop = nn.Dropout(config.embd_pdrop)
+ self.h = nn.ModuleList([GPT2Block(config, layer_idx=i) for i in range(config.num_hidden_layers)])
+ self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon)
+
+ # Model parallel
+ self.model_parallel = False
+ self.device_map = None
+ self.gradient_checkpointing = False
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ @add_start_docstrings(PARALLELIZE_DOCSTRING)
+ def parallelize(self, device_map=None):
+ # Check validity of device_map
+ warnings.warn(
+ "`GPT2Model.parallelize` is deprecated and will be removed in v5 of Transformers, you should load your"
+ " model with `device_map='balanced'` in the call to `from_pretrained`. You can also provide your own"
+ " `device_map` but it needs to be a dictionary module_name to device, so for instance {'h.0': 0, 'h.1': 1,"
+ " ...}",
+ FutureWarning,
+ )
+ self.device_map = (
+ get_device_map(len(self.h), range(torch.cuda.device_count())) if device_map is None else device_map
+ )
+ assert_device_map(self.device_map, len(self.h))
+ self.model_parallel = True
+ self.first_device = "cpu" if "cpu" in self.device_map.keys() else "cuda:" + str(min(self.device_map.keys()))
+ self.last_device = "cuda:" + str(max(self.device_map.keys()))
+ self.wte = self.wte.to(self.first_device)
+ self.wpe = self.wpe.to(self.first_device)
+ # Load onto devices
+ for k, v in self.device_map.items():
+ for block in v:
+ cuda_device = "cuda:" + str(k)
+ self.h[block] = self.h[block].to(cuda_device)
+ # ln_f to last
+ self.ln_f = self.ln_f.to(self.last_device)
+
+ @add_start_docstrings(DEPARALLELIZE_DOCSTRING)
+ def deparallelize(self):
+ warnings.warn(
+ "Like `parallelize`, `deparallelize` is deprecated and will be removed in v5 of Transformers.",
+ FutureWarning,
+ )
+ self.model_parallel = False
+ self.device_map = None
+ self.first_device = "cpu"
+ self.last_device = "cpu"
+ self.wte = self.wte.to("cpu")
+ self.wpe = self.wpe.to("cpu")
+ for index in range(len(self.h)):
+ self.h[index] = self.h[index].to("cpu")
+ self.ln_f = self.ln_f.to("cpu")
+ torch.cuda.empty_cache()
+
+ def get_input_embeddings(self):
+ return self.wte
+
+ def set_input_embeddings(self, new_embeddings):
+ self.wte = new_embeddings
+
+ def _prune_heads(self, heads_to_prune):
+ """
+ Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer}
+ """
+ for layer, heads in heads_to_prune.items():
+ self.h[layer].attn.prune_heads(heads)
+
+ @add_start_docstrings_to_model_forward(GPT2_INPUTS_DOCSTRING)
+ @add_code_sample_docstrings(
+ checkpoint=_CHECKPOINT_FOR_DOC,
+ output_type=BaseModelOutputWithPastAndCrossAttentions,
+ config_class=_CONFIG_FOR_DOC,
+ )
+ def forward(
+ self,
+ input_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
+ attention_mask: Optional[torch.FloatTensor] = None,
+ token_type_ids: Optional[torch.LongTensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ head_mask: Optional[torch.FloatTensor] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ encoder_hidden_states: Optional[torch.Tensor] = None,
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ) -> Union[Tuple, BaseModelOutputWithPastAndCrossAttentions]:
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ if input_ids is not None and inputs_embeds is not None:
+ raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
+ elif input_ids is not None:
+ self.warn_if_padding_and_no_attention_mask(input_ids, attention_mask)
+ input_shape = input_ids.size()
+ input_ids = input_ids.view(-1, input_shape[-1])
+ batch_size = input_ids.shape[0]
+ elif inputs_embeds is not None:
+ input_shape = inputs_embeds.size()[:-1]
+ batch_size = inputs_embeds.shape[0]
+ else:
+ raise ValueError("You have to specify either input_ids or inputs_embeds")
+
+ device = input_ids.device if input_ids is not None else inputs_embeds.device
+
+ if token_type_ids is not None:
+ token_type_ids = token_type_ids.view(-1, input_shape[-1])
+
+ if past_key_values is None:
+ past_length = 0
+ past_key_values = tuple([None] * len(self.h))
+ else:
+ past_length = past_key_values[0][0].size(-2)
+ if position_ids is None:
+ position_ids = torch.arange(past_length, input_shape[-1] + past_length, dtype=torch.long, device=device)
+ position_ids = position_ids.unsqueeze(0)
+
+ # GPT2Attention mask.
+ if attention_mask is not None:
+ if batch_size <= 0:
+ raise ValueError("batch_size has to be defined and > 0")
+ attention_mask = attention_mask.view(batch_size, -1)
+ # We create a 3D attention mask from a 2D tensor mask.
+ # Sizes are [batch_size, 1, 1, to_seq_length]
+ # So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length]
+ # this attention mask is more simple than the triangular masking of causal attention
+ # used in OpenAI GPT, we just need to prepare the broadcast dimension here.
+ attention_mask = attention_mask[:, None, None, :]
+
+ # Since attention_mask is 1.0 for positions we want to attend and 0.0 for
+ # masked positions, this operation will create a tensor which is 0.0 for
+ # positions we want to attend and the dtype's smallest value for masked positions.
+ # Since we are adding it to the raw scores before the softmax, this is
+ # effectively the same as removing these entirely.
+ attention_mask = attention_mask.to(dtype=self.dtype) # fp16 compatibility
+ attention_mask = (1.0 - attention_mask) * torch.finfo(self.dtype).min
+
+ # If a 2D or 3D attention mask is provided for the cross-attention
+ # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
+ if self.config.add_cross_attention and encoder_hidden_states is not None:
+ encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size()
+ encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
+ if encoder_attention_mask is None:
+ encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device)
+ encoder_attention_mask = self.invert_attention_mask(encoder_attention_mask)
+ else:
+ encoder_attention_mask = None
+
+ # Prepare head mask if needed
+ # 1.0 in head_mask indicate we keep the head
+ # attention_probs has shape bsz x n_heads x N x N
+ # head_mask has shape n_layer x batch x n_heads x N x N
+ head_mask = self.get_head_mask(head_mask, self.config.n_layer)
+
+ if inputs_embeds is None:
+ inputs_embeds = self.wte(input_ids)
+ # position_embeds = self.wpe(position_ids)
+ hidden_states = inputs_embeds # + position_embeds
+
+ if token_type_ids is not None:
+ token_type_embeds = self.wte(token_type_ids)
+ hidden_states = hidden_states + token_type_embeds
+
+ hidden_states = self.drop(hidden_states)
+
+ output_shape = (-1,) + input_shape[1:] + (hidden_states.size(-1),)
+
+ if self.gradient_checkpointing and self.training:
+ if use_cache:
+ logger.warning_once(
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
+ )
+ use_cache = False
+
+ presents = () if use_cache else None
+ all_self_attentions = () if output_attentions else None
+ all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None
+ all_hidden_states = () if output_hidden_states else None
+ for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)):
+ # Model parallel
+ if self.model_parallel:
+ torch.cuda.set_device(hidden_states.device)
+ # Ensure layer_past is on same device as hidden_states (might not be correct)
+ if layer_past is not None:
+ layer_past = tuple(past_state.to(hidden_states.device) for past_state in layer_past)
+ # Ensure that attention_mask is always on the same device as hidden_states
+ if attention_mask is not None:
+ attention_mask = attention_mask.to(hidden_states.device)
+ if isinstance(head_mask, torch.Tensor):
+ head_mask = head_mask.to(hidden_states.device)
+ if output_hidden_states:
+ all_hidden_states = all_hidden_states + (hidden_states,)
+
+ if self.gradient_checkpointing and self.training:
+
+ def create_custom_forward(module):
+ def custom_forward(*inputs):
+ # None for past_key_value
+ return module(*inputs, use_cache, output_attentions)
+
+ return custom_forward
+
+ outputs = torch.utils.checkpoint.checkpoint(
+ create_custom_forward(block),
+ hidden_states,
+ None,
+ attention_mask,
+ head_mask[i],
+ encoder_hidden_states,
+ encoder_attention_mask,
+ )
+ else:
+ outputs = block(
+ hidden_states,
+ layer_past=layer_past,
+ attention_mask=attention_mask,
+ head_mask=head_mask[i],
+ encoder_hidden_states=encoder_hidden_states,
+ encoder_attention_mask=encoder_attention_mask,
+ use_cache=use_cache,
+ output_attentions=output_attentions,
+ )
+
+ hidden_states = outputs[0]
+ if use_cache is True:
+ presents = presents + (outputs[1],)
+
+ if output_attentions:
+ all_self_attentions = all_self_attentions + (outputs[2 if use_cache else 1],)
+ if self.config.add_cross_attention:
+ all_cross_attentions = all_cross_attentions + (outputs[3 if use_cache else 2],)
+
+ # Model Parallel: If it's the last layer for that device, put things on the next device
+ if self.model_parallel:
+ for k, v in self.device_map.items():
+ if i == v[-1] and "cuda:" + str(k) != self.last_device:
+ hidden_states = hidden_states.to("cuda:" + str(k + 1))
+
+ hidden_states = self.ln_f(hidden_states)
+
+ hidden_states = hidden_states.view(output_shape)
+ # Add last hidden state
+ if output_hidden_states:
+ all_hidden_states = all_hidden_states + (hidden_states,)
+
+ if not return_dict:
+ return tuple(
+ v
+ for v in [hidden_states, presents, all_hidden_states, all_self_attentions, all_cross_attentions]
+ if v is not None
+ )
+
+ return BaseModelOutputWithPastAndCrossAttentions(
+ last_hidden_state=hidden_states,
+ past_key_values=presents,
+ hidden_states=all_hidden_states,
+ attentions=all_self_attentions,
+ cross_attentions=all_cross_attentions,
+ )
diff --git a/ACT_DP_multitask/detr/models/mr_mg/policy/model/model.py b/ACT_DP_multitask/detr/models/mr_mg/policy/model/model.py
new file mode 100644
index 0000000000000000000000000000000000000000..c567481df54273008197973adf0f478d765dbe35
--- /dev/null
+++ b/ACT_DP_multitask/detr/models/mr_mg/policy/model/model.py
@@ -0,0 +1,611 @@
+# Copyright (2024) Bytedance Ltd. and/or its affiliates
+
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+
+# http://www.apache.org/licenses/LICENSE-2.0
+
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import torch
+import torch.nn as nn
+from torch.autograd import Variable
+import numpy as np
+
+import transformers
+from flamingo_pytorch import PerceiverResampler
+import clip
+from policy.model.gpt2 import GPT2Model
+from policy.utils.dist_train import print as dis_print
+import os
+from policy.model.vision_transformer import Block,get_2d_sincos_pos_embed
+
+def reparameterize(mu, logvar):
+ std = logvar.div(2).exp()
+ eps = Variable(std.data.new(std.size()).normal_())
+ return mu + std * eps
+
+
+class GR_MG(nn.Module):
+ def __init__(
+ self,
+ state_dim,
+ act_dim,
+ act_len,
+ act_latent_dim,
+ act_encoder_dim,
+ act_decoder_dim,
+ progress_decoder_dim,
+ hidden_size,
+ model_mae,
+ clip_model,
+ img_feat_dim,
+ lang_feat_dim,
+ patch_feat_dim,
+ resampler_params,
+ max_length=None,
+ training_target=['act_pred'],
+ without_norm_pix_loss=False,
+ use_hand_rgb=False,
+ use_state=False,
+ use_resampler=True,
+ **kwargs):
+ super(GR_MG, self).__init__()
+ self.state_dim = state_dim
+ self.act_dim = act_dim
+ self.max_length = max_length
+ self.hidden_size = hidden_size
+ config = transformers.GPT2Config(
+ vocab_size=1, # doesn't matter -- we don't use the vocab
+ n_embd=hidden_size,
+ **kwargs
+ )
+ # note: the difference between this GPT2Model and the default Huggingface version
+ # is that the positional embeddings are removed (since we'll add those ourselves)
+ # and wte is removed ( to set find_unused_parameters=False)
+ self.transformer = GPT2Model(config)
+ transformer_params = sum(p.numel() for p in self.transformer.parameters() if p.requires_grad)
+ dis_print(f"Transformer Parameters: {transformer_params / 1000000:.2f}M")
+
+ self.use_resampler = use_resampler
+ if self.use_resampler:
+ self.n_patch_latents = resampler_params['num_latents']
+ self.perceiver_resampler = PerceiverResampler(
+ dim=patch_feat_dim,
+ depth=resampler_params['depth'],
+ dim_head=resampler_params['dim_head'],
+ heads=resampler_params['heads'],
+ num_latents=self.n_patch_latents,
+ num_media_embeds=resampler_params['num_media_embeds'])
+
+ resampler_params = sum(p.numel() for p in self.perceiver_resampler.parameters() if p.requires_grad)
+ dis_print(f"Perceiver Resampler Parameters: {resampler_params / 1000000:.2f}M")
+
+ self.model_mae = model_mae
+
+ self.act_len = act_len
+ self.act_latent_dim = act_latent_dim
+ self.act_encoder_dim = act_encoder_dim
+ self.act_decoder_dim = act_decoder_dim
+ self.progress_decoder_dim=progress_decoder_dim
+
+ self.use_hand_rgb = use_hand_rgb
+ self.use_state = use_state
+
+ self.text_tokenizer=clip.tokenize
+ self.text_encoder=clip_model
+
+
+ self.lang_feat_dim=lang_feat_dim # hardcode
+ self.img_feat_dim = img_feat_dim
+ self.patch_feat_dim = patch_feat_dim
+ self.n_patches = 49 # TODO: hardcode
+ self.patch_size = 16 # TODO: hardcode
+ self.image_size = 224 # TODO: hardcode
+
+ self.act_pred = False
+ self.fwd_pred = False
+ self.fwd_pred_hand = False
+ self.progress_pred=False
+ if 'act_pred' in training_target:
+ self.act_pred = True
+ if 'fwd_pred' in training_target:
+ self.fwd_pred = True
+ if 'fwd_pred_hand' in training_target:
+ self.fwd_pred_hand = True
+ if 'progress_pred' in training_target:
+ self.progress_pred = True
+
+ self.without_norm_pix_loss = without_norm_pix_loss
+ if self.use_state:
+ # state embedding
+ self.embed_arm_state = torch.nn.Linear(self.state_dim-1, self.hidden_size)
+ self.embed_gripper_state = torch.nn.Linear(2, self.hidden_size) # one-hot gripper state
+ self.embed_state = torch.nn.Linear(2*self.hidden_size, self.hidden_size)
+
+
+ # Embedding function for languages
+ self.embed_lang = torch.nn.Linear(self.lang_feat_dim, self.hidden_size)
+ # relative timestep embedding
+ self.embed_timestep = nn.Embedding(self.max_length, self.hidden_size)
+
+ # image token embedding
+ if self.use_hand_rgb:
+ self.embed_hand_img = torch.nn.Linear(self.img_feat_dim, self.hidden_size)
+ self.embed_img = torch.nn.Linear(self.img_feat_dim, self.hidden_size)
+ self.embed_goal_image = torch.nn.Linear(self.img_feat_dim, self.hidden_size)
+
+ # patch token embedding
+ if self.use_hand_rgb:
+ self.embed_hand_patch = torch.nn.Linear(self.patch_feat_dim, self.hidden_size)
+ self.embed_patch = torch.nn.Linear(self.patch_feat_dim, self.hidden_size)
+ self.embed_goal_patch = torch.nn.Linear(self.patch_feat_dim, self.hidden_size)
+
+ # layer norm
+ self.embed_ln = nn.LayerNorm(self.hidden_size)
+ if self.act_pred:
+ # action query [ACT]
+ self.action_queries = nn.Embedding(1, self.hidden_size) # arm + gripper
+
+ # action encoder (embed action trajectory as style vector)
+ self.embed_arm_action = torch.nn.Linear(self.act_dim - 1, self.act_encoder_dim)
+ self.embed_gripper_action = torch.nn.Embedding(2, self.act_encoder_dim)
+ self.embed_action = nn.Linear(2 * self.act_encoder_dim, self.act_encoder_dim)
+ self.action_encoder_cls_token = torch.nn.Embedding(1, self.act_encoder_dim)
+ action_encoder_depth = 4
+ self.encode_action = nn.ModuleList([
+ Block(self.act_encoder_dim, 8, 4, qkv_bias=True, qk_scale=None, norm_layer=nn.LayerNorm)
+ for i in range(action_encoder_depth)])
+ self.action_encoder_positional_embeddings = nn.Embedding(self.act_len + 1, self.act_encoder_dim)
+ self.pred_style_vector = nn.Linear(self.act_encoder_dim, 2 * self.act_latent_dim)
+ self.embed_style_vector = nn.Linear(self.act_latent_dim, self.act_decoder_dim)
+
+ # action decoder
+ self.proj_action_output_embed = nn.Linear(self.hidden_size, self.act_decoder_dim)
+ action_decoder_depth = 4
+ self.decode_action = nn.ModuleList([
+ Block(self.act_decoder_dim, 8, 4, qkv_bias=True, qk_scale=None, norm_layer=nn.LayerNorm)
+ for i in range(action_decoder_depth)])
+ self.action_mask_token_embedding = nn.Embedding(1, self.act_decoder_dim)
+ self.action_decoder_positional_embeddings = nn.Embedding(self.act_len, self.act_decoder_dim)
+ self.pred_arm_act = nn.Linear(self.act_decoder_dim, self.act_dim - 1) # arm action
+ self.pred_gripper_act = nn.Linear(self.act_decoder_dim, 1) # gripper action (binary)
+
+ # predict future image
+ if self.fwd_pred:
+ # add observation query for fwd prediction
+ self.obs_queries = nn.Embedding(self.n_patch_latents+1, self.hidden_size) # cls+resampler
+ if self.use_hand_rgb:
+ self.obs_hand_queries = nn.Embedding(self.n_patch_latents+1, self.hidden_size) # cls+resampler
+ self.decoder_embed = nn.Linear(self.hidden_size, self.hidden_size, bias=True)
+ self.mask_token = nn.Parameter(torch.zeros(1, 1, 1, self.hidden_size))
+ # fixed sin-cos embedding
+ self.decoder_pos_embed = nn.Parameter(torch.zeros(1, (self.image_size//self.patch_size)**2,
+ self.hidden_size), requires_grad=False) # (1, n_patch, h)
+
+ decoder_depth = 2 # hardcode
+ self.decoder_blocks = nn.ModuleList([
+ Block(self.hidden_size, 16, 4, qkv_bias=True, qk_scale=None, norm_layer=nn.LayerNorm)
+ for i in range(decoder_depth)])
+
+ self.decoder_norm = nn.LayerNorm(self.hidden_size)
+ self.decoder_pred = nn.Linear(self.hidden_size, self.patch_size**2 * 3, bias=True) # decoder to patch
+
+ decoder_pos_embed = get_2d_sincos_pos_embed(self.decoder_pos_embed.shape[-1], (self.image_size//self.patch_size), cls_token=False)
+ self.decoder_pos_embed.data.copy_(torch.from_numpy(decoder_pos_embed).float().unsqueeze(0))
+
+ fwd_params = sum(p.numel() for p in self.decoder_blocks.parameters() if p.requires_grad)
+ dis_print(f"Fwd Decoder Parameters: {fwd_params / 1000000:.2f}M")
+ if self.progress_pred:
+ self.progress_queries=nn.Embedding(1, self.hidden_size) # [PROG]
+ # progress decoder
+ self.proj_progress_output_embed = nn.Linear(self.hidden_size, self.progress_decoder_dim)
+ progress_decoder_depth = 2
+ self.decode_progress = nn.ModuleList([
+ Block(self.progress_decoder_dim, 8, 4, qkv_bias=True, qk_scale=None, norm_layer=nn.LayerNorm)
+ for i in range(progress_decoder_depth)])
+ self.progress_mask_token_embedding = nn.Embedding(1, self.progress_decoder_dim)
+ self.pred_progress = nn.Linear(self.progress_decoder_dim, 1) # pred progress
+ self.sigmoid_progress=nn.Sigmoid()
+
+ def encode_texts(self, texts):
+ inputs = self.text_tokenizer(texts)
+ device = next(self.text_encoder.parameters()).device
+ with torch.no_grad():
+ encoder_hidden_states = self.text_encoder.encode_text(inputs.to(device))
+ return encoder_hidden_states
+
+ def forward(self, input_dict, is_training=True):
+ goal_rgb = input_dict['goal_rgb'] # (b, c, h, w)
+ rgb = input_dict['rgb'] # (b, l, c, h, w)
+ hand_rgb = input_dict['hand_rgb'] # (b, l, c, h, w)
+ attention_mask = input_dict['attention_mask'] # (b, l)
+ text=input_dict["text"][0]
+ progress_targets=input_dict["progress"] #(b,l,)
+
+ obs_preds = None
+ obs_hand_preds = None
+ obs_target = None
+ obs_hand_target = None
+ arm_action_preds = None
+ gripper_action_preds = None
+ action_mu_preds = None
+ action_logvar_preds = None
+ progress_preds=None
+
+ batch_size, seq_length, c, h, w = rgb.shape
+
+
+ if self.use_state:
+ arm_state = input_dict['arm_state'] # (b, l, state_dim - 1)
+ gripper_state = input_dict['gripper_state'] # (b, l, 2)
+ arm_state_embeddings = self.embed_arm_state(arm_state.view(batch_size, seq_length, self.state_dim-1)) # (b, l, h)
+ gripper_state_embeddings = self.embed_gripper_state(gripper_state) # (b, l, h)
+ state_embeddings = torch.cat((arm_state_embeddings, gripper_state_embeddings), dim=2) # (b, l, 2h)
+ state_embeddings = self.embed_state(state_embeddings) # (b, l, h)
+
+ # goal rgb mae feature
+ goal_obs_embeddings, goal_patch_embeddings = self.model_mae(goal_rgb) # (b, img_feat_dim), (b, 196, patch_feat_dim)
+
+ # rgb mae feature
+ obs_embeddings, patch_embeddings = self.model_mae(
+ rgb.view(batch_size*seq_length, c, h, w)) # (b * l, img_feat_dim), (b * l, 196, patch_feat_dim)
+ obs_embeddings = obs_embeddings.view(batch_size, seq_length, -1) # (b, l, img_feat_dim)
+
+ # hand rgb mae feature
+ if self.use_hand_rgb:
+ hand_obs_embeddings, hand_patch_embeddings = self.model_mae(
+ hand_rgb.view(batch_size*seq_length, c, h, w)) # (b * l, img_feat_dim), (b * l, 196, patch_feat_dim)
+ hand_obs_embeddings = hand_obs_embeddings.view(batch_size, seq_length, -1) # (b, l, img_feat_dim)
+
+ # compute obs target
+ if self.fwd_pred:
+ p = self.patch_size
+ h_p = h // p
+ w_p = w // p
+ rgb = rgb.reshape(shape=(batch_size, seq_length, 3, h_p, p, w_p, p)) # b,len,3,14,p,14,p
+ obs_target = rgb.permute(0, 1, 3, 5, 4, 6, 2) # b,len,14,14,p,p,3
+ obs_target = obs_target.reshape(shape=(batch_size, seq_length, h_p * w_p, (p**2) * 3)) # b,len,14x14,p*p*3
+ if not self.without_norm_pix_loss:
+ # norm the target
+ obs_target = (obs_target - obs_target.mean(dim=-1, keepdim=True)
+ ) / (obs_target.var(dim=-1, unbiased=True, keepdim=True).sqrt() + 1e-6)
+
+ if self.fwd_pred_hand:
+ hand_rgb = hand_rgb.reshape(shape=(batch_size, seq_length, 3, h_p, p, w_p, p)) # b,len,3,14,p,14,p
+ obs_hand_target = hand_rgb.permute(0, 1, 3, 5, 4, 6, 2) # b,len,14,14,p,p,3
+ obs_hand_target = obs_hand_target.reshape(shape=(batch_size, seq_length, h_p * w_p, (p**2)*3))
+ if not self.without_norm_pix_loss:
+ # norm the target
+ obs_hand_target = (obs_hand_target - obs_hand_target.mean(dim=-1, keepdim=True)
+ ) / (obs_hand_target.var(dim=-1, unbiased=True, keepdim=True).sqrt() + 1e-6)
+
+ if self.use_resampler:
+ goal_patch_embeddings = goal_patch_embeddings.unsqueeze(1) # (b, 1, 196, patch_feat_dim)
+ goal_patch_embeddings = self.perceiver_resampler(goal_patch_embeddings) # (b, 1, 9, patch_feat_dim)
+ goal_patch_embeddings = goal_patch_embeddings.squeeze(1) # (b, 9, patch_feat_dim)
+
+ patch_embeddings = patch_embeddings.unsqueeze(1) # (b * l, 1, 196, patch_feat_dim)
+ patch_embeddings = self.perceiver_resampler(patch_embeddings) # (b * l, 1, 9, patch_feat_dim)
+ patch_embeddings = patch_embeddings.squeeze(1) # (b * l, 9, patch_feat_dim)
+ patch_embeddings = patch_embeddings.view(batch_size, seq_length, self.n_patch_latents, self.patch_feat_dim) # (b, l, 9, patch_feat_dim)
+
+ if self.use_hand_rgb:
+ hand_patch_embeddings = hand_patch_embeddings.unsqueeze(1) # (b * l, 1, 196, patch_feat_dim)
+ hand_patch_embeddings = self.perceiver_resampler(hand_patch_embeddings) # (b * l, 1, 9, patch_feat_dim)
+ hand_patch_embeddings = hand_patch_embeddings.squeeze(1) # (b * l, 9, patch_feat_dim)
+ hand_patch_embeddings = hand_patch_embeddings.view(batch_size, seq_length, self.n_patch_latents, self.patch_feat_dim) # (b, l, 9, patch_feat_dim)
+ else:
+ raise NotImplementedError
+
+
+ # Embed language
+ lang_embeddings = self.encode_texts(text)
+ lang_embeddings = lang_embeddings / (lang_embeddings.norm(dim=1, keepdim=True) + 1e-6) # normalization
+ lang_embeddings = self.embed_lang(lang_embeddings.float()) # (b, h)
+
+ # embed images and patches
+ goal_obs_embeddings = self.embed_goal_image(goal_obs_embeddings) # (b, h)
+ goal_patch_embeddings = self.embed_goal_patch(goal_patch_embeddings) # (b, 9, h)
+ obs_embeddings = self.embed_img(obs_embeddings) # (b, l, h)
+ patch_embeddings = self.embed_patch(patch_embeddings) # (b, l, 9, h)
+ if self.use_hand_rgb:
+ hand_obs_embeddings = self.embed_hand_img(hand_obs_embeddings) # (b, l, h)
+ hand_patch_embeddings = self.embed_hand_patch(hand_patch_embeddings) # (b, l, 9, h)
+
+ # add timestep embeddings
+ time_embeddings = self.embed_timestep.weight # (l, h)
+
+ lang_embeddings = lang_embeddings.view(batch_size, 1, -1).repeat(1,seq_length,1) + time_embeddings# 注æ„debug
+ patch_embeddings = patch_embeddings + time_embeddings.view(seq_length, 1, self.hidden_size)
+ obs_embeddings = obs_embeddings + time_embeddings
+
+
+ if self.use_hand_rgb:
+ hand_obs_embeddings = hand_obs_embeddings + time_embeddings
+ hand_patch_embeddings = hand_patch_embeddings + time_embeddings.view(seq_length, 1, self.hidden_size)
+ if self.use_state:
+ state_embeddings = state_embeddings + time_embeddings
+
+ # Format sequence: lang, state, patch, obs, hand_patch, hand_obs, [ACT], [OBS], [OBS_HAND],[PROG]
+ lang_embeddings = lang_embeddings.view(batch_size, seq_length, 1, self.hidden_size)
+ obs_embeddings = obs_embeddings.view(batch_size, seq_length, 1, self.hidden_size)
+ if self.use_state:
+ state_embeddings = state_embeddings.view(batch_size, seq_length, 1, self.hidden_size)
+ stacked_inputs = torch.cat((lang_embeddings,state_embeddings, patch_embeddings, obs_embeddings), dim=2) # (b, l, n_tokens, h)
+ else:
+ stacked_inputs = torch.cat((lang_embeddings,patch_embeddings, obs_embeddings), dim=2) # (b, l, n_tokens, h)
+ if self.use_hand_rgb:
+ hand_obs_embeddings = hand_obs_embeddings.view(batch_size, seq_length, 1, self.hidden_size)
+ stacked_inputs = torch.cat((stacked_inputs, hand_patch_embeddings, hand_obs_embeddings), dim=2) # (b, l, n_tokens, h)
+ if self.act_pred:
+ action_queries = self.action_queries.weight # (1, h)
+ action_queries = action_queries.view(1, 1, 1, self.hidden_size).repeat(batch_size, seq_length, 1, 1) # (b, l, 1, h)
+ stacked_inputs = torch.cat((stacked_inputs, action_queries), dim=2) # (b, l, n_tokens, h)
+ if self.fwd_pred:
+ obs_queries = self.obs_queries.weight # (10, h)
+ obs_queries = obs_queries.view(1, 1, self.n_patch_latents + 1, self.hidden_size).repeat(batch_size, seq_length, 1, 1)
+ stacked_inputs = torch.cat((stacked_inputs, obs_queries), dim=2)
+ if self.fwd_pred_hand:
+ obs_hand_queries = self.obs_hand_queries.weight # (10, h)
+ obs_hand_queries = obs_hand_queries.view(1, 1, self.n_patch_latents + 1, self.hidden_size).repeat(batch_size, seq_length, 1, 1)
+ stacked_inputs = torch.cat((stacked_inputs, obs_hand_queries), dim=2) # (b, l, n_tokens, h)
+ if self.progress_pred:
+ progress_queries = self.progress_queries.weight # (1, h)
+ progress_queries = progress_queries.view(1, 1, 1, self.hidden_size).repeat(batch_size, seq_length, 1, 1) # (b, l, 1, h)
+ stacked_inputs = torch.cat((stacked_inputs, progress_queries), dim=2) # (b, l, n_tokens, h)
+ # number of tokens for different modalities
+ n_lang_tokens = 1
+ n_state_tokens = 1
+ n_patch_tokens = self.n_patch_latents
+ n_obs_tokens = 1
+ n_hand_patch_tokens = self.n_patch_latents
+ n_hand_obs_tokens = 1
+ n_act_pred_tokens = 1
+ n_fwd_pred_tokens = n_patch_tokens + n_obs_tokens
+ n_fwd_pred_hand_tokens = n_patch_tokens + n_obs_tokens
+ n_progress_pred_tokens=1
+ # compute number of tokens (does not include the conditioned goal image tokens)
+ n_tokens = n_lang_tokens
+ if self.use_state:
+ n_tokens += n_state_tokens
+ n_tokens += n_patch_tokens
+ n_tokens += n_obs_tokens
+ if self.use_hand_rgb:
+ n_tokens += n_hand_obs_tokens
+ n_tokens += n_hand_patch_tokens
+ if self.act_pred:
+ act_pred_token_i = n_tokens
+ n_tokens += n_act_pred_tokens
+ if self.fwd_pred:
+ obs_pred_token_i = n_tokens
+ n_tokens += n_fwd_pred_tokens
+ if self.fwd_pred_hand:
+ obs_pred_hand_token_i = n_tokens
+ n_tokens += n_fwd_pred_hand_tokens
+ if self.progress_pred:
+ progress_pred_token_i = n_tokens
+ n_tokens += n_progress_pred_tokens
+ # number of condtioned tokens (goal image)
+ n_condtioned_tokens = 1 + self.n_patch_latents
+
+ # add goal image conditions at the front
+ stacked_inputs = stacked_inputs.reshape(batch_size, n_tokens * seq_length, self.hidden_size)
+ goal_obs_embeddings = goal_obs_embeddings.view(batch_size, 1, self.hidden_size)
+ stacked_inputs = torch.cat((goal_patch_embeddings, goal_obs_embeddings, stacked_inputs), dim=1) # (b, l * n_tokens + n_patch_latents + 1, h)
+ assert stacked_inputs.shape == (batch_size, seq_length * n_tokens + n_condtioned_tokens, self.hidden_size)
+
+ # layer norm
+ stacked_inputs = self.embed_ln(stacked_inputs)
+
+ # generate attention mask
+ attn_mask = attention_mask.view(batch_size, 1, seq_length)
+ lang_attn_mask=attn_mask.repeat(1, n_lang_tokens, 1)
+ state_attn_mask = attn_mask.repeat(1, n_state_tokens, 1)
+ patch_attn_mask = attn_mask.repeat(1, n_patch_tokens, 1)
+ obs_attn_mask = attn_mask.repeat(1, n_obs_tokens, 1)
+ hand_patch_attn_mask = attn_mask.repeat(1, n_hand_patch_tokens, 1)
+ hand_obs_attn_mask = attn_mask.repeat(1, n_hand_obs_tokens, 1)
+
+ if self.use_state:
+ stacked_attn_mask = torch.cat((lang_attn_mask,state_attn_mask, patch_attn_mask, obs_attn_mask), dim=1)
+ else:
+ stacked_attn_mask = torch.cat((lang_attn_mask,patch_attn_mask, obs_attn_mask), dim=1)
+ if self.use_hand_rgb:
+ stacked_attn_mask = torch.cat((stacked_attn_mask, hand_patch_attn_mask, hand_obs_attn_mask), dim=1)
+ if self.act_pred:
+ act_pred_attn_mask = torch.zeros((batch_size, n_act_pred_tokens, seq_length), dtype=torch.long).cuda()
+ stacked_attn_mask = torch.cat((stacked_attn_mask, act_pred_attn_mask), dim=1)
+ if self.fwd_pred:
+ fwd_pred_attn_mask = torch.zeros((batch_size, n_fwd_pred_tokens, seq_length), dtype=torch.long).cuda()
+ stacked_attn_mask = torch.cat((stacked_attn_mask, fwd_pred_attn_mask), dim=1)
+ if self.fwd_pred_hand:
+ fwd_pred_hand_attn_mask = torch.zeros((batch_size, n_fwd_pred_hand_tokens, seq_length), dtype=torch.long).cuda()
+ stacked_attn_mask = torch.cat((stacked_attn_mask, fwd_pred_hand_attn_mask), dim=1)
+ if self.progress_pred:
+ progress_pred_attn_mask = torch.zeros((batch_size, n_progress_pred_tokens, seq_length), dtype=torch.long).cuda()
+ stacked_attn_mask = torch.cat((stacked_attn_mask, progress_pred_attn_mask), dim=1)
+
+ stacked_attn_mask = stacked_attn_mask.permute(0, 2, 1) # (b, l, n_tokens)
+ stacked_attn_mask = stacked_attn_mask.reshape(batch_size, n_tokens * seq_length) # (b, l * n_tokens)
+ goal_obs_attn_mask = torch.ones((batch_size, 1), dtype=torch.long).cuda()
+ goal_patch_attn_mask = torch.ones((batch_size, self.n_patch_latents), dtype=torch.long).cuda()
+ stacked_attn_mask = torch.cat((goal_patch_attn_mask, goal_obs_attn_mask, stacked_attn_mask), dim=1) # (b, l * n_tokens + n_patch_latens + 1)
+ assert stacked_attn_mask.shape == (batch_size, seq_length * n_tokens + n_condtioned_tokens)
+
+ # we feed in the input embeddings (not word indices as in NLP) to the model
+ transformer_outputs = self.transformer(
+ inputs_embeds=stacked_inputs,
+ attention_mask=stacked_attn_mask,
+ )
+ x = transformer_outputs['last_hidden_state']
+ x = x[:, n_condtioned_tokens:]
+ x = x.reshape(batch_size, seq_length, n_tokens, self.hidden_size) # (b, l, n_tokens, h)
+
+ # action prediction: predict next action given obs
+ # format sequence
+ if self.act_pred:
+ action_output_embedding = x[:, :, act_pred_token_i] # (b, l, h)
+
+ # encode action
+ arm_action = input_dict['arm_action'] # b,len,act_len,act_dim-1
+ gripper_action = input_dict['gripper_action'].long() # b,len,act_len
+ arm_action_embeddings = self.embed_arm_action(arm_action) # b,len,act_len,act_encoder_dim
+ gripper_action_embeddings = self.embed_gripper_action(gripper_action) # b,len,act_len,act_encoder_dim
+ action_embeddings = torch.cat((arm_action_embeddings, gripper_action_embeddings), dim=-1) # b,len,act_len,2*act_encoder_dim
+ action_embeddings = self.embed_action(action_embeddings) # b,len,act_len,act_encoder_dim
+ cls_token_embeddings = self.action_encoder_cls_token.weight # 1,act_encoder_dim
+ cls_token_embeddings = cls_token_embeddings.unsqueeze(0).unsqueeze(0).repeat(batch_size, seq_length, 1, 1) # b,len,1,act_encoder_dim
+ z = torch.cat((cls_token_embeddings, action_embeddings), dim=2) # b,len,1+act_len,act_encoder_dim
+ action_encoder_positional_embeddings = self.action_encoder_positional_embeddings.weight # 1+act_len,act_encoder_dim
+ z = z + action_encoder_positional_embeddings # b,len,1+act_len,act_encoder_dim
+ z = z.reshape(batch_size * seq_length, self.act_len + 1, self.act_encoder_dim) # b*len,1+1+act_len,act_encoder_dim
+ for blk in self.encode_action:
+ z = blk(z)
+ action_latent_embedding = z[:, 0] # b*len,act_encoder_dim
+ action_latent_embedding = action_latent_embedding.reshape(batch_size, seq_length, self.act_encoder_dim) # b,len,act_encoder_dim
+ action_latent_preds = self.pred_style_vector(action_latent_embedding) # b,len,2*act_latent_dim
+ action_mu_preds = action_latent_preds[:, :, :self.act_latent_dim] # b,len,act_latent_dim
+ action_logvar_preds = action_latent_preds[:, :, self.act_latent_dim:] # b,len,act_latent_dim
+ # sample style vector
+ action_mu_preds = action_mu_preds.view(-1, self.act_latent_dim)
+ action_logvar_preds = action_logvar_preds.view(-1, self.act_latent_dim)
+ action_style_vector = reparameterize(action_mu_preds, action_logvar_preds) # b*len,act_latent_dim
+ action_style_vector = action_style_vector.view(batch_size, seq_length, self.act_latent_dim) # b,len,act_latent_dim
+ if not is_training: # we set the mean=0 and var=1 during inference
+ action_style_vector = torch.zeros([batch_size, seq_length, self.act_latent_dim], dtype=torch.float32).to(rgb.device)
+ action_style_vector = action_style_vector.type_as(arm_action)
+ action_style_embeddings = self.embed_style_vector(action_style_vector) # b,len,act_decoder_dimv
+
+
+ action_output_embedding = self.proj_action_output_embed(action_output_embedding) # (b, l, act_decoder_dim)
+ action_mask_token = self.action_mask_token_embedding.weight # (1, act_decoder_dim)
+ action_mask_token = action_mask_token.unsqueeze(0).unsqueeze(0).repeat(batch_size, seq_length, self.act_len, 1) # (b, l, act_len, act_decoder_dim)
+ action_output_embedding = action_output_embedding.view(batch_size, seq_length, 1, self.act_decoder_dim)
+ action_decoder_positional_embeddings = self.action_decoder_positional_embeddings.weight
+
+ action_style_embeddings = action_style_embeddings.view(batch_size, seq_length, 1, self.act_decoder_dim)
+ y = torch.cat((action_style_embeddings, action_output_embedding, action_mask_token), dim=2) # (b, l, 2 + act_len, act_decoder_dim)
+ y[:, :, 2:] = y[:, :, 2:] + action_decoder_positional_embeddings
+ y = y.reshape(batch_size * seq_length, 2 + self.act_len, self.act_decoder_dim)
+
+
+ # forward transformer
+ for blk in self.decode_action:
+ y = blk(y)
+
+
+ action_decoder_output_embeddings = y[:, 2:] # (b * l, act_len, act_decoder_dim)
+
+
+
+ action_decoder_output_embeddings = action_decoder_output_embeddings.reshape(batch_size, seq_length, self.act_len, self.act_decoder_dim)
+ arm_action_preds = self.pred_arm_act(action_decoder_output_embeddings) # (b, l, act_len, act_dim - 1)
+ gripper_action_preds = self.pred_gripper_act(action_decoder_output_embeddings) # (b, l, act_len, 1)
+
+ # forward prediction: predict next obs
+ if self.fwd_pred:
+ mask_token = self.mask_token # (1, 1, 1, h)
+ mask_tokens = mask_token.repeat(batch_size, seq_length, (self.image_size//self.patch_size)**2, 1) # (b, l, n_patches, h)
+ mask_tokens = mask_tokens + self.decoder_pos_embed.unsqueeze(0).repeat(batch_size, seq_length, 1, 1) # (b, l, n_patch, h)
+
+ obs_pred = self.decoder_embed(x[:, :, obs_pred_token_i:(obs_pred_token_i + n_fwd_pred_tokens)]) # (b, l, n_patch_latents + 1, h)
+ obs_pred_ = torch.cat([obs_pred, mask_tokens], dim=2) # (b, l, n_patch + n_patch_latens + 1, h)
+ obs_pred_ = obs_pred_.reshape(-1, obs_pred_.shape[-2], obs_pred_.shape[-1]) # (b * l, n_patch + n_patch_latens + 1, h)
+ for blk in self.decoder_blocks:
+ obs_pred_ = blk(obs_pred_)
+ obs_pred_ = self.decoder_norm(obs_pred_)
+ obs_preds = self.decoder_pred(obs_pred_) # (b * l, n_patch + n_patch_latens + 1, h)
+ obs_preds = obs_preds.reshape(batch_size, seq_length, -1, obs_preds.shape[-1]) # (b, l, n_patch + n_patch_latens + 1, h)
+ obs_preds = obs_preds[:, :, n_fwd_pred_tokens:] # (b, len, n_patch, h)
+
+ if self.fwd_pred_hand:
+ obs_pred_hand = self.decoder_embed(x[:, :, obs_pred_hand_token_i:(obs_pred_hand_token_i + n_fwd_pred_hand_tokens)]) # (b, l, n_patch_latents + 1, h)
+ obs_pred_hand_ = torch.cat([obs_pred_hand, mask_tokens], dim=2) # (b, l, n_patch + n_patch_latens + 1, h)
+ obs_pred_hand_ = obs_pred_hand_.reshape(-1, obs_pred_hand_.shape[-2], obs_pred_hand_.shape[-1]) # (b * l, n_patch + n_patch_latens + 1, h)
+ for blk in self.decoder_blocks:
+ obs_pred_hand_ = blk(obs_pred_hand_)
+ obs_pred_hand_ = self.decoder_norm(obs_pred_hand_)
+ obs_hand_preds = self.decoder_pred(obs_pred_hand_) # (b * l, n_patch + n_patch_latens + 1, h)
+ obs_hand_preds = obs_hand_preds.reshape(batch_size, seq_length, -1, obs_hand_preds.shape[-1]) # (b, l, n_patch + n_patch_latens + 1, h)
+ obs_hand_preds = obs_hand_preds[:, :, n_fwd_pred_hand_tokens:] # (b, l, n_patch, h)
+
+ # progress prediction
+ # format sequence
+ if self.progress_pred:
+ progress_output_embedding = x[:, :, progress_pred_token_i] # (b, l, h)
+ progress_output_embedding = self.proj_progress_output_embed(progress_output_embedding) # (b, l, progress_decoder_dim)
+ progress_mask_token = self.progress_mask_token_embedding.weight # (1, progress_decoder_dim)
+ progress_mask_token = progress_mask_token.unsqueeze(0).unsqueeze(0).repeat(batch_size, seq_length, 1, 1) # (b, l, 1, progress_decoder_dim)
+ progress_output_embedding = progress_output_embedding.view(batch_size, seq_length, 1, self.progress_decoder_dim)
+ y = torch.cat((progress_output_embedding, progress_mask_token), dim=2) # (b, l, 1 + 1, progress_decoder_dim)
+ y = y.reshape(batch_size * seq_length, 1 + 1, self.act_decoder_dim)
+
+ # forward transformer
+ for blk in self.decode_progress:
+ y = blk(y)
+
+ # get output
+ progress_decoder_output_embeddings = y[:, 1:] # (b * l, 1, progress_decoder_dim)
+ progress_decoder_output_embeddings = progress_decoder_output_embeddings.reshape(batch_size, seq_length, 1, self.progress_decoder_dim)
+ progress_preds = self.pred_progress(progress_decoder_output_embeddings).squeeze() # (b, l)
+ progress_preds=self.sigmoid_progress(progress_preds)
+
+
+
+
+ prediction = {
+ 'obs_preds': obs_preds,
+ 'obs_target': obs_target,
+ 'obs_hand_preds': obs_hand_preds,
+ 'obs_hand_target': obs_hand_target,
+ 'arm_action_preds': arm_action_preds,
+ 'gripper_action_preds': gripper_action_preds,
+ 'action_mu_preds': action_mu_preds,
+ 'action_logvar_preds': action_logvar_preds,
+ 'progress_preds':progress_preds,
+ 'progress_targets':progress_targets,
+ }
+ return prediction
+
+
+ def evaluate(self, input_dict,original_gripper=False,return_progress=False):
+
+ attention_mask = input_dict['attention_mask']
+ prediction = self.forward(input_dict, is_training=False)
+
+ arm_action_preds = prediction['arm_action_preds'] # (1, len, act_len, act_dim-1)
+ gripper_action_preds = prediction['gripper_action_preds'] # (1, len, act_len, 1)
+
+
+
+ arm_action_preds = arm_action_preds.squeeze(0) # (len, act_len, act_dim-1)
+ gripper_action_preds = gripper_action_preds.squeeze() # (len, act_len)
+ arm_action_preds = arm_action_preds[attention_mask.flatten() > 0]
+ gripper_action_preds = gripper_action_preds[attention_mask.flatten() > 0]
+
+
+ # Take the last action
+ arm_action_pred = arm_action_preds[-1].cpu() # (act_len, act_dim-1)
+ arm_action_pred = arm_action_pred[0] # (act_dim-1, )
+ gripper_action_pred = gripper_action_preds[-1:].cpu() # (1, act_len)
+ gripper_action_pred = gripper_action_pred[:, 0] # (1, 1)
+
+ if original_gripper:
+ gripper_action_pred = torch.nn.Sigmoid()(gripper_action_pred)
+ else:
+ gripper_action_pred = torch.nn.Sigmoid()(gripper_action_pred)
+ gripper_action_pred = gripper_action_pred > 0.5
+ gripper_action_pred = gripper_action_pred.int().float()
+ gripper_action_pred = gripper_action_pred * 2.0 - 1.0
+ action_pred = torch.cat((arm_action_pred, gripper_action_pred), dim=0) # (act_dim,)
+ if return_progress:
+ progress=prediction["progress_preds"]
+ progress=progress[-1]
+ return action_pred,progress
+ else:
+ return action_pred
\ No newline at end of file
diff --git a/ACT_DP_multitask/detr/models/mr_mg/policy/model/vision_transformer.py b/ACT_DP_multitask/detr/models/mr_mg/policy/model/vision_transformer.py
new file mode 100644
index 0000000000000000000000000000000000000000..7fdd05befd1ce515fc8d7b468be798189c2e46f7
--- /dev/null
+++ b/ACT_DP_multitask/detr/models/mr_mg/policy/model/vision_transformer.py
@@ -0,0 +1,381 @@
+# Copyright (c) Facebook, Inc. and its affiliates.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+Mostly copy-paste from timm library.
+https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py
+"""
+import math
+from functools import partial
+
+import torch
+import torch.nn as nn
+
+import numpy as np
+
+def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
+ """
+ embed_dim: output dimension for each position
+ pos: a list of positions to be encoded: size (M,)
+ out: (M, D)
+ """
+ assert embed_dim % 2 == 0
+ omega = np.arange(embed_dim // 2, dtype=np.float32)
+ omega /= embed_dim / 2.
+ omega = 1. / 10000**omega # (D/2,)
+
+ pos = pos.reshape(-1) # (M,)
+ out = np.einsum('m,d->md', pos, omega) # (M, D/2), outer product
+
+ emb_sin = np.sin(out) # (M, D/2)
+ emb_cos = np.cos(out) # (M, D/2)
+
+ emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D)
+ return emb
+
+
+def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
+ assert embed_dim % 2 == 0
+
+ # use half of dimensions to encode grid_h
+ emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2)
+ emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2)
+
+ emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D)
+ return emb
+
+
+def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False):
+ """
+ grid_size: int of the grid height and width
+ return:
+ pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
+ """
+ grid_h = np.arange(grid_size, dtype=np.float32)
+ grid_w = np.arange(grid_size, dtype=np.float32)
+ grid = np.meshgrid(grid_w, grid_h) # here w goes first
+ grid = np.stack(grid, axis=0)
+
+ grid = grid.reshape([2, 1, grid_size, grid_size])
+ pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
+ if cls_token:
+ pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0)
+ return pos_embed
+
+# from utils import trunc_normal_
+
+def _no_grad_trunc_normal_(tensor, mean, std, a, b):
+ # Cut & paste from PyTorch official master until it's in a few official releases - RW
+ # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
+ def norm_cdf(x):
+ # Computes standard normal cumulative distribution function
+ return (1. + math.erf(x / math.sqrt(2.))) / 2.
+
+ if (mean < a - 2 * std) or (mean > b + 2 * std):
+ warnings.warn("mean is more than 2 std from [a, b] in nn.init.trunc_normal_. "
+ "The distribution of values may be incorrect.",
+ stacklevel=2)
+
+ with torch.no_grad():
+ # Values are generated by using a truncated uniform distribution and
+ # then using the inverse CDF for the normal distribution.
+ # Get upper and lower cdf values
+ l = norm_cdf((a - mean) / std)
+ u = norm_cdf((b - mean) / std)
+
+ # Uniformly fill tensor with values from [l, u], then translate to
+ # [2l-1, 2u-1].
+ tensor.uniform_(2 * l - 1, 2 * u - 1)
+
+ # Use inverse cdf transform for normal distribution to get truncated
+ # standard normal
+ tensor.erfinv_()
+
+ # Transform to proper mean, std
+ tensor.mul_(std * math.sqrt(2.))
+ tensor.add_(mean)
+
+ # Clamp to ensure it's in the proper range
+ tensor.clamp_(min=a, max=b)
+ return tensor
+
+def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.):
+ # type: (Tensor, float, float, float, float) -> Tensor
+ return _no_grad_trunc_normal_(tensor, mean, std, a, b)
+
+def drop_path(x, drop_prob: float = 0., training: bool = False):
+ if drop_prob == 0. or not training:
+ return x
+ keep_prob = 1 - drop_prob
+ shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
+ random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device)
+ random_tensor.floor_() # binarize
+ output = x.div(keep_prob) * random_tensor
+ return output
+
+
+class DropPath(nn.Module):
+ """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
+ """
+ def __init__(self, drop_prob=None):
+ super(DropPath, self).__init__()
+ self.drop_prob = drop_prob
+
+ def forward(self, x):
+ return drop_path(x, self.drop_prob, self.training)
+
+
+class Mlp(nn.Module):
+ def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
+ super().__init__()
+ out_features = out_features or in_features
+ hidden_features = hidden_features or in_features
+ self.fc1 = nn.Linear(in_features, hidden_features)
+ self.act = act_layer()
+ self.fc2 = nn.Linear(hidden_features, out_features)
+ self.drop = nn.Dropout(drop)
+
+ def forward(self, x):
+ x = self.fc1(x)
+ x = self.act(x)
+ x = self.drop(x)
+ x = self.fc2(x)
+ x = self.drop(x)
+ return x
+
+
+class Attention(nn.Module):
+ def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.):
+ super().__init__()
+ self.num_heads = num_heads
+ head_dim = dim // num_heads
+ self.scale = qk_scale or head_dim ** -0.5
+
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
+ self.attn_drop = nn.Dropout(attn_drop)
+ self.proj = nn.Linear(dim, dim)
+ self.proj_drop = nn.Dropout(proj_drop)
+
+ def forward(self, x):
+ B, N, C = x.shape
+ qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
+ q, k, v = qkv[0], qkv[1], qkv[2]
+
+ attn = (q @ k.transpose(-2, -1)) * self.scale
+ attn = attn.softmax(dim=-1)
+ attn = self.attn_drop(attn)
+
+ x = (attn @ v).transpose(1, 2).reshape(B, N, C)
+ x = self.proj(x)
+ x = self.proj_drop(x)
+ return x, attn
+
+
+class Block(nn.Module):
+ def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
+ drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm):
+ super().__init__()
+ self.norm1 = norm_layer(dim)
+ self.attn = Attention(
+ dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)
+ self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
+ self.norm2 = norm_layer(dim)
+ mlp_hidden_dim = int(dim * mlp_ratio)
+ self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
+
+ def forward(self, x, return_attention=False):
+ y, attn = self.attn(self.norm1(x))
+ if return_attention:
+ return attn
+ x = x + self.drop_path(y)
+ x = x + self.drop_path(self.mlp(self.norm2(x)))
+ return x
+
+
+class PatchEmbed(nn.Module):
+ """ Image to Patch Embedding
+ """
+ def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768):
+ super().__init__()
+ num_patches = (img_size // patch_size) * (img_size // patch_size)
+ self.img_size = img_size
+ self.patch_size = patch_size
+ self.num_patches = num_patches
+
+ self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
+
+ def forward(self, x):
+ B, C, H, W = x.shape
+ x = self.proj(x).flatten(2).transpose(1, 2)
+ return x
+
+
+class VisionTransformer(nn.Module):
+ """ Vision Transformer """
+ def __init__(self, img_size=[224], patch_size=16, in_chans=3, num_classes=0, embed_dim=768, depth=12,
+ num_heads=12, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop_rate=0., attn_drop_rate=0.,
+ drop_path_rate=0., norm_layer=nn.LayerNorm, **kwargs):
+ super().__init__()
+ self.num_features = self.embed_dim = embed_dim
+
+ self.patch_embed = PatchEmbed(
+ img_size=img_size[0], patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim)
+ num_patches = self.patch_embed.num_patches
+
+ self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
+ self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim))
+ self.pos_drop = nn.Dropout(p=drop_rate)
+
+ dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
+ self.blocks = nn.ModuleList([
+ Block(
+ dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
+ drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer)
+ for i in range(depth)])
+ self.norm = norm_layer(embed_dim)
+
+ # Classifier head
+ self.head = nn.Linear(embed_dim, num_classes) if num_classes > 0 else nn.Identity()
+
+ trunc_normal_(self.pos_embed, std=.02)
+ trunc_normal_(self.cls_token, std=.02)
+ self.apply(self._init_weights)
+
+ def _init_weights(self, m):
+ if isinstance(m, nn.Linear):
+ trunc_normal_(m.weight, std=.02)
+ if isinstance(m, nn.Linear) and m.bias is not None:
+ nn.init.constant_(m.bias, 0)
+ elif isinstance(m, nn.LayerNorm):
+ nn.init.constant_(m.bias, 0)
+ nn.init.constant_(m.weight, 1.0)
+
+ def interpolate_pos_encoding(self, x, w, h):
+ npatch = x.shape[1] - 1
+ N = self.pos_embed.shape[1] - 1
+ if npatch == N and w == h:
+ return self.pos_embed
+ class_pos_embed = self.pos_embed[:, 0]
+ patch_pos_embed = self.pos_embed[:, 1:]
+ dim = x.shape[-1]
+ w0 = w // self.patch_embed.patch_size
+ h0 = h // self.patch_embed.patch_size
+ # we add a small number to avoid floating point error in the interpolation
+ # see discussion at https://github.com/facebookresearch/dino/issues/8
+ w0, h0 = w0 + 0.1, h0 + 0.1
+ patch_pos_embed = nn.functional.interpolate(
+ patch_pos_embed.reshape(1, int(math.sqrt(N)), int(math.sqrt(N)), dim).permute(0, 3, 1, 2),
+ scale_factor=(w0 / math.sqrt(N), h0 / math.sqrt(N)),
+ mode='bicubic',
+ )
+ assert int(w0) == patch_pos_embed.shape[-2] and int(h0) == patch_pos_embed.shape[-1]
+ patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
+ return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1)
+
+ def prepare_tokens(self, x):
+ B, nc, w, h = x.shape
+ x = self.patch_embed(x) # patch linear embedding
+
+ # add the [CLS] token to the embed patch tokens
+ cls_tokens = self.cls_token.expand(B, -1, -1)
+ x = torch.cat((cls_tokens, x), dim=1)
+
+ # add positional encoding to each token
+ pos = self.interpolate_pos_encoding(x, w, h)
+ x = x + pos
+
+ return self.pos_drop(x), pos
+
+ def forward(self, x):
+ x, pos = self.prepare_tokens(x)
+ for blk in self.blocks:
+ x = blk(x)
+ x = self.norm(x)
+ return x[:, 0], x[:, 1:], pos
+
+ def get_last_selfattention(self, x):
+ x = self.prepare_tokens(x)
+ for i, blk in enumerate(self.blocks):
+ if i < len(self.blocks) - 1:
+ x = blk(x)
+ else:
+ # return attention of the last block
+ return blk(x, return_attention=True)
+
+ def get_intermediate_layers(self, x, n=1):
+ x = self.prepare_tokens(x)
+ # we return the output tokens from the `n` last blocks
+ output = []
+ for i, blk in enumerate(self.blocks):
+ x = blk(x)
+ if len(self.blocks) - i <= n:
+ output.append(self.norm(x))
+ return output
+
+
+def vit_tiny(patch_size=16, **kwargs):
+ model = VisionTransformer(
+ patch_size=patch_size, embed_dim=192, depth=12, num_heads=3, mlp_ratio=4,
+ qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
+ return model
+
+
+def vit_small(patch_size=16, **kwargs):
+ model = VisionTransformer(
+ patch_size=patch_size, embed_dim=384, depth=12, num_heads=6, mlp_ratio=4,
+ qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
+ return model
+
+
+def vit_base(patch_size=16, **kwargs):
+ model = VisionTransformer(
+ patch_size=patch_size, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4,
+ qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
+ return model
+
+
+class DINOHead(nn.Module):
+ def __init__(self, in_dim, out_dim, use_bn=False, norm_last_layer=True, nlayers=3, hidden_dim=2048, bottleneck_dim=256):
+ super().__init__()
+ nlayers = max(nlayers, 1)
+ if nlayers == 1:
+ self.mlp = nn.Linear(in_dim, bottleneck_dim)
+ else:
+ layers = [nn.Linear(in_dim, hidden_dim)]
+ if use_bn:
+ layers.append(nn.BatchNorm1d(hidden_dim))
+ layers.append(nn.GELU())
+ for _ in range(nlayers - 2):
+ layers.append(nn.Linear(hidden_dim, hidden_dim))
+ if use_bn:
+ layers.append(nn.BatchNorm1d(hidden_dim))
+ layers.append(nn.GELU())
+ layers.append(nn.Linear(hidden_dim, bottleneck_dim))
+ self.mlp = nn.Sequential(*layers)
+ self.apply(self._init_weights)
+ self.last_layer = nn.utils.weight_norm(nn.Linear(bottleneck_dim, out_dim, bias=False))
+ self.last_layer.weight_g.data.fill_(1)
+ if norm_last_layer:
+ self.last_layer.weight_g.requires_grad = False
+
+ def _init_weights(self, m):
+ if isinstance(m, nn.Linear):
+ trunc_normal_(m.weight, std=.02)
+ if isinstance(m, nn.Linear) and m.bias is not None:
+ nn.init.constant_(m.bias, 0)
+
+ def forward(self, x):
+ x = self.mlp(x)
+ x = nn.functional.normalize(x, dim=-1, p=2)
+ x = self.last_layer(x)
+ return x
\ No newline at end of file
diff --git a/ACT_DP_multitask/detr/models/mr_mg/policy/training/trainer.py b/ACT_DP_multitask/detr/models/mr_mg/policy/training/trainer.py
new file mode 100644
index 0000000000000000000000000000000000000000..369cdbec815e8d121e1cfbfb88a10087e27ddaea
--- /dev/null
+++ b/ACT_DP_multitask/detr/models/mr_mg/policy/training/trainer.py
@@ -0,0 +1,506 @@
+# Copyright (2024) Bytedance Ltd. and/or its affiliates
+
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+
+# http://www.apache.org/licenses/LICENSE-2.0
+
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import os
+import torch
+import torch.nn.functional as F
+from functools import partial
+import math
+import lightning.pytorch as pl
+from utils.dist_train import get_rank
+import model.vision_transformer as vits
+import json
+from model.model import GR_MG
+import clip
+def adjust_learning_rate(iter, configs):
+ """Decay the learning rate with half-cycle cosine after warmup"""
+ warmup_iters = configs['warmup_iters']
+ total_iters = configs['iters']
+ min_lr_scale = configs['min_lr_scale']
+
+ if iter < configs['warmup_iters']:
+ lr_scaler = 1.0 * iter / warmup_iters
+ else:
+ lr_scaler = min_lr_scale + (1.0 - min_lr_scale) * 0.5 * \
+ (1.0 + math.cos(math.pi * (iter - warmup_iters) / (total_iters - warmup_iters)))
+ return lr_scaler
+
+def compute_kl_divergence( mu, logvar):
+ if torch.isnan(mu).any() or torch.isnan(logvar).any():
+ raise ValueError("Input contains NaN values")
+ if torch.isinf(mu).any() or torch.isinf(logvar).any():
+ raise ValueError("Input contains infinite values")
+ type=mu.dtype
+ latent_dim = mu.shape[-1]
+ # å°† mu å’Œ logvar 转æ¢ä¸º float32
+ mu = mu.float().view(-1, latent_dim)
+ logvar = logvar.float().view(-1, latent_dim)
+
+ klds = -0.5 * (1 + logvar- mu.pow(2) - logvar.exp())
+ klds = klds.sum(1).mean(0, True).squeeze()
+
+ # transform back to bf16
+ return klds.to(type)
+
+
+class Policy_Trainer(pl.LightningModule):
+ def __init__(self, configs):
+ super().__init__()
+ self._main_rank_print('--------------- model configs ---------------')
+ self._main_rank_print(configs)
+ self.configs = configs
+ self._initialize()
+ self.save_hyperparameters()
+ self.val_set_names = "calvin"
+
+ @staticmethod
+ def _main_rank_print(*args, **kwargs):
+ if get_rank() == 0:
+ print(*args, **kwargs)
+
+ @property
+ def num_gpus(self):
+ return self.trainer.num_devices * self.trainer.num_nodes
+
+ def _initialize(self):
+ training_target = []
+ if self.configs["trainer"]['act_pred']:
+ training_target.append('act_pred')
+ if self.configs["trainer"]['fwd_pred']: # predict future static image
+ training_target.append('fwd_pred')
+ if self.configs["trainer"]['fwd_pred_hand']: # predict future hand image
+ training_target.append('fwd_pred_hand')
+ if self.configs["trainer"]['progress_pred']: # predict progress information
+ training_target.append('progress_pred')
+
+ # mae model
+ model_mae = vits.__dict__['vit_base'](patch_size=16, num_classes=0)
+ model_mae.to(self.device)
+ mae_ckpt = '/PATH_TO/resources/MAE/mae_pretrain_vit_base.pth'
+ checkpoint = torch.load(mae_ckpt, map_location='cpu')
+ model_mae.load_state_dict(checkpoint['model'], strict=True)
+ # freeze mae
+ for name, p in model_mae.named_parameters():
+ p.requires_grad = False
+ #clip model
+ clip_name = "ViT-B/32"
+ clip_model, clip_preprocess = clip.load(clip_name)
+ # freeze clip
+ for _, param in clip_model.named_parameters():
+ param.requires_grad = False
+
+ # resampler parameters
+ resampler_params = dict()
+ resampler_params['depth'] = self.configs["policy"]['resampler_params']["depth"]
+ resampler_params['dim_head'] = self.configs["policy"]['resampler_params']['dim_head']
+ resampler_params['heads'] = self.configs["policy"]['resampler_params']['heads']
+ resampler_params['num_latents'] = self.configs["policy"]['resampler_params']['num_latents']
+ resampler_params['num_media_embeds'] = self.configs["policy"]['resampler_params']['num_media_embeds']
+
+ # main model
+ self.model = GR_MG(
+ state_dim=self.configs["input"]['state_dim'],
+ act_dim=self.configs["input"]['act_dim'],
+ act_len=self.configs["policy"]['act_len'],
+ act_latent_dim=self.configs["policy"]['act_latent_dim'],
+ act_encoder_dim=self.configs["policy"]['act_encoder_dim'],
+ act_decoder_dim=self.configs["policy"]['act_decoder_dim'],
+ progress_decoder_dim=self.configs["policy"]["progress_decoder_dim"],
+ hidden_size=self.configs["policy"]['embed_dim'],
+ model_mae=model_mae,
+ clip_model=clip_model,
+ img_feat_dim=self.configs["policy"]["img_feat_dim"],
+ lang_feat_dim = self.configs["policy"]["lang_feat_dim"],
+ patch_feat_dim=self.configs["policy"]["patch_feat_dim"],
+ resampler_params=resampler_params,
+ max_length=self.configs["policy"]['seq_len'],
+ training_target=training_target,
+ without_norm_pix_loss=self.configs["trainer"]['without_norm_pix_loss'],
+ use_hand_rgb=self.configs["input"]['use_hand_rgb'],
+ use_state=self.configs["input"]['use_state'],
+ use_resampler=self.configs["policy"]['use_resampler'],
+ n_layer=self.configs["policy"]['n_layer'],
+ n_head=self.configs["policy"]['n_head'],
+ n_inner=4*self.configs["policy"]['embed_dim'],
+ activation_function=self.configs["policy"]['activation_function'],
+ n_positions=1024,
+ resid_pdrop=self.configs["policy"]['dropout'],
+ attn_pdrop=self.configs["policy"]['dropout'])
+
+
+ # if finetune, we need to load the pretrained model
+ if self.configs["trainer"]["finetune"]:
+ if self.configs["trainer"]["use_pretrain"]:
+ trainer_config=self.configs["trainer"]
+ self._main_rank_print(f"Loading pretrained model from: {trainer_config['pretrained_model_path']}")
+ checkpoint = torch.load(self.configs["trainer"]['pretrained_model_path'], map_location='cpu')
+ state_dict = dict()
+ # Exclude action and state related weights
+ for key, value in checkpoint['state_dict'].items():
+ if key[:6]=="model.":
+ key = key[6:] # remove "model." from pl checkpoint
+ state_dict[key] = value
+ del checkpoint
+ msg = self.model.load_state_dict(state_dict, strict=False)
+ self._main_rank_print(msg)
+ del state_dict
+
+ # save config
+ if get_rank() == 0:
+ with open(os.path.join(self.configs['ckpt_dir'], 'hyperparameters.json'), 'w') as f:
+ json.dump(self.configs, f)
+
+
+ # these variables are used to indicate what information will be used or predicted
+ self.act_pred = self.model.act_pred
+ self.fwd_pred = self.model.fwd_pred
+ self.fwd_pred_hand = self.model.fwd_pred_hand
+ self.use_state = self.model.use_state
+ self.use_hand_rgb = self.model.use_hand_rgb
+ self.progress_pred=self.model.progress_pred
+
+ # loss
+ self.kl_loss_ratio =self.configs["trainer"]["kl_loss_ratio"]
+ self.gripper_loss_ratio =self.configs["trainer"]["gripper_loss_ratio"]
+ self.fwd_loss_ratio =self.configs["trainer"]["fwd_loss_ratio"]
+ self.progress_loss_ratio=self.configs["trainer"]["progress_loss_ratio"]
+
+
+
+ def configure_optimizers(self):
+ lr = self.configs["trainer"]['learning_rate']
+ eff_bsz = self.configs["trainer"]['batch_size'] * self.num_gpus
+ self._main_rank_print('-' * 40)
+ self._main_rank_print("LR SCHEDULER CONFIGS:")
+ self._main_rank_print(f"learning rate: {lr}, effective batch size: {eff_bsz}")
+
+ optimizer = torch.optim.AdamW(
+ self.model.parameters(),
+ lr=lr,
+ betas=(self.configs["trainer"]['betas_1'], self.configs["trainer"]['betas_2']),
+ weight_decay=self.configs["trainer"]['weight_decay']
+ )
+ assert self.trainer.max_epochs is not None
+ num_training_batches = self.trainer.estimated_stepping_batches
+ iter_per_epoch = num_training_batches / self.trainer.max_epochs
+ self.configs["trainer"]['warmup_steps']=self.configs["trainer"]['warmup_epochs'] * iter_per_epoch
+ lr_scheduler_configs = {
+ 'warmup_iters': self.configs["trainer"]['warmup_steps'],
+ 'iters': self.trainer.max_epochs * iter_per_epoch,
+ 'min_lr_scale': self.configs["trainer"]['min_learning_rate_scale']
+ }
+ lr_lambda = partial(adjust_learning_rate, configs=lr_scheduler_configs)
+ self._main_rank_print(lr_scheduler_configs)
+ scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lr_lambda)
+ return {
+ 'optimizer': optimizer,
+ 'lr_scheduler':
+ {
+ 'scheduler': scheduler,
+ 'interval': 'step',
+ 'frequency': 1
+ }
+ }
+
+
+ def _log_output(self, output, phase, dataset=None, **kwargs):
+ for k, v in output.items():
+ log_name = f"{phase}_{k}"
+ if dataset is not None:
+ log_name = f"{dataset}_{log_name}"
+ self.log(log_name, v, prog_bar=True, **kwargs)
+
+ def validation_step(self, batch, batch_idx):
+ with torch.no_grad():
+ goal_rgb = batch['goal_rgb']
+ rgb = batch['rgb']
+ hand_rgb = batch['hand_rgb']
+ state = batch['rel_state']
+ action = batch['action']
+ action_mask = batch['action_mask']
+ attention_mask = batch['attention_mask']
+ text=batch["text"]
+ progress=batch["progress"]
+
+ # Split arm and gripper action
+ arm_action = action[:, :, :, :6] # (b, l, act_len, act_dim - 1)
+ gripper_action = action[:, :, :, 6] # (b, l, act_len)
+ arm_action_target = torch.clone(arm_action) # (b, l, act_len, act_dim - 1)
+ gripper_action_target = torch.clone(gripper_action) # (b, l, act_len)
+
+ # Split arm and gripper state
+ arm_state = state[:, :, :6] # (b, l, state_dim - 1)
+ gripper_state = state[:, :, 6].long() # (b, l)
+ gripper_state = F.one_hot(gripper_state, num_classes=2).type_as(arm_state) # (b, l, 2)
+
+ seq_len = arm_action.size(1)
+ act_len = arm_action.size(2)
+
+ input_dict = {
+ 'goal_rgb': goal_rgb,
+ 'rgb': rgb,
+ 'hand_rgb': hand_rgb,
+ 'arm_action': arm_action,
+ 'gripper_action': gripper_action,
+ 'arm_state': arm_state,
+ 'gripper_state': gripper_state,
+ 'attention_mask': attention_mask, #(b,l)
+ 'text':text,
+ 'progress':progress
+ }
+
+
+ prediction = self.model(input_dict, is_training=False)
+
+
+ obs_preds = prediction['obs_preds']
+ obs_target = prediction['obs_target']
+ arm_action_preds = prediction['arm_action_preds'] # (b, l, act_len, act_dim - 1)
+ gripper_action_preds = prediction['gripper_action_preds'] # (b, l, act_len, 1)
+ obs_hand_preds = prediction['obs_hand_preds']
+ obs_hand_target = prediction['obs_hand_target']
+ action_mu_preds = prediction['action_mu_preds'] # (b * l, act_latent_dim)
+ action_logvar_preds = prediction['action_logvar_preds'] # (b * l, act_latent_dim)
+ progress_preds=prediction["progress_preds"]
+ progress_targets=prediction["progress_targets"]
+
+
+ loss_act = 0
+ loss_arm_act = 0
+ loss_gripper_act = 0
+ loss_kl_act = 0
+ acc_gripper_act = 0
+ loss_obs = 0
+ loss_hand_obs = 0
+ gripper_cnt = 0
+ loss_progress=0
+
+ # action prediction
+ act_dim = self.model.act_dim
+ if self.act_pred:
+ # kl loss
+ loss_kl_act = compute_kl_divergence(action_mu_preds, action_logvar_preds)
+
+
+ # action smooth l1 loss
+ arm_action_preds = arm_action_preds.view(-1, act_len, act_dim-1)[attention_mask.flatten() > 0] # b,len,act_len,6 -> b*len,act_len,6
+ arm_action_target = arm_action_target.view(-1, act_len, act_dim-1)[attention_mask.flatten() > 0] # b,len,act_len,6 -> b*len,act_len,6
+ action_mask = action_mask.view(-1, act_len)[attention_mask.flatten() > 0] # b,len,act_len -> b*len,act_len
+ arm_action_preds = arm_action_preds.view(-1, act_dim-1)[action_mask.flatten() > 0] # b*len*act_len, 6
+ arm_action_target = arm_action_target.view(-1, act_dim-1)[action_mask.flatten() > 0] # b*len*act_len, 6
+ loss_arm_act = torch.nn.SmoothL1Loss()(arm_action_preds, arm_action_target)
+
+ # gripper bce loss
+ gripper_action_preds = gripper_action_preds.view(-1, act_len)[attention_mask.flatten() > 0] # b,len,act_len -> b*len,act_len
+ gripper_action_target = gripper_action_target.view(-1, act_len)[attention_mask.flatten() > 0] # b,len,act_len -> b*len,act_len
+ gripper_action_preds = gripper_action_preds.flatten()[action_mask.flatten() > 0] # b*len*act_len
+ gripper_action_target = gripper_action_target.flatten()[action_mask.flatten() > 0] # b*len*act_len
+ bce_with_logits_loss = torch.nn.BCEWithLogitsLoss()
+ loss_gripper_act = bce_with_logits_loss(gripper_action_preds, gripper_action_target)
+ loss_act = loss_arm_act + loss_gripper_act * self.gripper_loss_ratio + loss_kl_act * self.kl_loss_ratio
+
+ # Compute gripper action acc
+ gripper_action_preds = torch.nn.Sigmoid()(gripper_action_preds) # Sigmoid function
+ gripper_action_preds = (gripper_action_preds > 0.5).float()
+ acc_gripper_act = torch.eq(gripper_action_preds, gripper_action_target).sum().float()
+ gripper_cnt = gripper_action_preds.shape[0]
+ acc_gripper_act /= gripper_cnt
+
+ # forward prediction
+ if self.fwd_pred:
+ fwd_pred_next_n = self.configs["trainer"]['fwd_pred_next_n']
+ obs_preds = obs_preds[:, :seq_len-fwd_pred_next_n, :, :]
+ obs_target = obs_target[:, fwd_pred_next_n:, :, :]
+ obs_attention_mask = attention_mask[:, fwd_pred_next_n:]
+ loss_obs = (obs_preds - obs_target) ** 2
+ loss_obs = loss_obs.mean(dim=-1).mean(dim=-1)
+ loss_obs = (loss_obs * obs_attention_mask).sum() / obs_attention_mask.sum()
+ if self.fwd_pred_hand:
+ obs_hand_preds = obs_hand_preds[:, :seq_len-fwd_pred_next_n, :, :]
+ obs_hand_target = obs_hand_target[:, fwd_pred_next_n:, :, :]
+ loss_hand_obs = (obs_hand_preds - obs_hand_target) ** 2
+ loss_hand_obs = loss_hand_obs.mean(dim=-1).mean(dim=-1)
+ loss_hand_obs = (loss_hand_obs * obs_attention_mask).sum() / obs_attention_mask.sum()
+ if self.progress_pred:
+ diff = progress_preds - progress_targets
+ masked_diff = diff * attention_mask
+ squared_error = masked_diff ** 2
+ loss_progress = squared_error.sum() / attention_mask.sum()
+
+ # compute loss
+ loss = torch.tensor(0.0).to(self.device)
+ if self.act_pred:
+ loss += loss_act
+ if self.fwd_pred:
+ loss += self.fwd_loss_ratio * loss_obs
+ if self.fwd_pred_hand:
+ loss += self.fwd_loss_ratio * loss_hand_obs
+ if self.progress_pred:
+ loss+=loss_progress*self.progress_loss_ratio
+ output = {
+ 'loss': loss,
+ 'loss_act': loss_act,
+ 'loss_arm_act': loss_arm_act,
+ 'loss_gripper_act': loss_gripper_act,
+ 'loss_kl_act': loss_kl_act,
+ 'acc_gripper_act': acc_gripper_act,
+ 'loss_obs': loss_obs,
+ 'loss_hand_obs': loss_hand_obs,
+ 'loss_progress':loss_progress
+ }
+ self._log_output(output, phase="val", on_epoch=True, on_step=False)
+ return output['loss']
+
+ def training_step(self, batch, batch_idx):
+ goal_rgb = batch['goal_rgb']
+ rgb = batch['rgb']
+ hand_rgb = batch['hand_rgb']
+ state = batch['rel_state']
+ action = batch['action']
+ action_mask = batch['action_mask']
+ attention_mask = batch['attention_mask']
+ text=batch["text"]
+ progress=batch["progress"]
+ # Split arm and gripper action
+ arm_action = action[:, :, :, :6] # (b, l, act_len, act_dim - 1)
+ gripper_action = action[:, :, :, 6] # (b, l, act_len)
+ arm_action_target = torch.clone(arm_action) # (b, l, act_len, act_dim - 1)
+ gripper_action_target = torch.clone(gripper_action) # (b, l, act_len)
+
+ # Split arm and gripper state
+ arm_state = state[:, :, :6] # (b, l, state_dim - 1)
+ gripper_state = state[:, :, 6].long() # (b, l)
+ gripper_state = F.one_hot(gripper_state, num_classes=2).type_as(arm_state) # (b, l, 2)
+
+ seq_len = arm_action.size(1)
+ act_len = arm_action.size(2)
+
+ input_dict = {
+ 'goal_rgb': goal_rgb,
+ 'rgb': rgb,
+ 'hand_rgb': hand_rgb,
+ 'arm_action': arm_action,
+ 'gripper_action': gripper_action,
+ 'arm_state': arm_state,
+ 'gripper_state': gripper_state,
+ 'attention_mask': attention_mask,
+ 'text':text,
+ 'progress':progress
+ }
+
+
+ prediction = self.model(input_dict, is_training=True)
+
+
+ obs_preds = prediction['obs_preds']
+ obs_target = prediction['obs_target']
+ arm_action_preds = prediction['arm_action_preds'] # (b, l, act_len, act_dim - 1)
+ gripper_action_preds = prediction['gripper_action_preds'] # (b, l, act_len, 1)
+ obs_hand_preds = prediction['obs_hand_preds']
+ obs_hand_target = prediction['obs_hand_target']
+ action_mu_preds = prediction['action_mu_preds'] # (b * l, act_latent_dim)
+ action_logvar_preds = prediction['action_logvar_preds'] # (b * l, act_latent_dim)
+ progress_preds=prediction["progress_preds"]
+ progress_targets=prediction["progress_targets"]
+
+
+
+ loss_act = 0
+ loss_arm_act = 0
+ loss_gripper_act = 0
+ loss_kl_act = 0
+ acc_gripper_act = 0
+ loss_obs = 0
+ loss_hand_obs = 0
+ gripper_cnt = 0
+ loss_progress= 0
+ # action prediction
+ act_dim = self.model.act_dim
+ if self.act_pred:
+ # kl loss
+
+ loss_kl_act = compute_kl_divergence(action_mu_preds, action_logvar_preds)
+
+
+ # action smooth l1 loss
+ arm_action_preds = arm_action_preds.view(-1, act_len, act_dim-1)[attention_mask.flatten() > 0] # b,len,act_len,6 -> b*len,act_len,6
+ arm_action_target = arm_action_target.view(-1, act_len, act_dim-1)[attention_mask.flatten() > 0] # b,len,act_len,6 -> b*len,act_len,6
+ action_mask = action_mask.view(-1, act_len)[attention_mask.flatten() > 0] # b,len,act_len -> b*len,act_len
+ arm_action_preds = arm_action_preds.view(-1, act_dim-1)[action_mask.flatten() > 0] # b*len*act_len, 6
+ arm_action_target = arm_action_target.view(-1, act_dim-1)[action_mask.flatten() > 0] # b*len*act_len, 6
+ loss_arm_act = torch.nn.SmoothL1Loss()(arm_action_preds, arm_action_target)
+
+ # gripper bce loss
+ gripper_action_preds = gripper_action_preds.view(-1, act_len)[attention_mask.flatten() > 0] # b,len,act_len -> b*len,act_len
+ gripper_action_target = gripper_action_target.view(-1, act_len)[attention_mask.flatten() > 0] # b,len,act_len -> b*len,act_len
+ gripper_action_preds = gripper_action_preds.flatten()[action_mask.flatten() > 0] # b*len*act_len
+ gripper_action_target = gripper_action_target.flatten()[action_mask.flatten() > 0] # b*len*act_len
+ bce_with_logits_loss = torch.nn.BCEWithLogitsLoss()
+ loss_gripper_act = bce_with_logits_loss(gripper_action_preds, gripper_action_target)
+ loss_act = loss_arm_act + loss_gripper_act * self.gripper_loss_ratio + loss_kl_act * self.kl_loss_ratio
+
+ # Compute gripper action acc
+ gripper_action_preds = torch.nn.Sigmoid()(gripper_action_preds)
+ gripper_action_preds = (gripper_action_preds > 0.5).float()
+ acc_gripper_act = torch.eq(gripper_action_preds, gripper_action_target).sum().float()
+ gripper_cnt = gripper_action_preds.shape[0]
+ acc_gripper_act /= gripper_cnt
+
+ # predict future image
+ if self.fwd_pred:
+ fwd_pred_next_n = self.configs["trainer"]['fwd_pred_next_n']
+ obs_preds = obs_preds[:, :seq_len-fwd_pred_next_n, :, :]
+ obs_target = obs_target[:, fwd_pred_next_n:, :, :]
+ obs_attention_mask = attention_mask[:, fwd_pred_next_n:]
+ loss_obs = (obs_preds - obs_target) ** 2
+ loss_obs = loss_obs.mean(dim=-1).mean(dim=-1)
+ loss_obs = (loss_obs * obs_attention_mask).sum() / obs_attention_mask.sum()
+ if self.fwd_pred_hand:
+ obs_hand_preds = obs_hand_preds[:, :seq_len-fwd_pred_next_n, :, :]
+ obs_hand_target = obs_hand_target[:, fwd_pred_next_n:, :, :]
+ loss_hand_obs = (obs_hand_preds - obs_hand_target) ** 2
+ loss_hand_obs = loss_hand_obs.mean(dim=-1).mean(dim=-1)
+ loss_hand_obs = (loss_hand_obs * obs_attention_mask).sum() / obs_attention_mask.sum()
+
+ if self.progress_pred:
+ diff = progress_preds - progress_targets
+ masked_diff = diff * attention_mask
+ squared_error = masked_diff ** 2
+ loss_progress = squared_error.sum() / attention_mask.sum()
+
+ # compute loss
+ loss = torch.tensor(0.0).to(self.device)
+ if self.act_pred:
+ loss += loss_act
+ if self.fwd_pred:
+ loss += self.fwd_loss_ratio * loss_obs
+ if self.fwd_pred_hand:
+ loss += self.fwd_loss_ratio * loss_hand_obs
+ if self.progress_pred:
+ loss+=loss_progress*self.progress_loss_ratio
+ output = {
+ 'loss': loss,
+ 'loss_act': loss_act,
+ 'loss_arm_act': loss_arm_act,
+ 'loss_gripper_act': loss_gripper_act,
+ 'loss_kl_act': loss_kl_act,
+ 'acc_gripper_act': acc_gripper_act,
+ 'loss_obs': loss_obs,
+ 'loss_hand_obs': loss_hand_obs,
+ 'loss_progress':loss_progress
+ }
+ self._log_output(output, phase="train", on_epoch=False, on_step=True)
+ return output['loss']
+
+
+
diff --git a/ACT_DP_multitask/detr/models/mr_mg/policy/utils/dist_train.py b/ACT_DP_multitask/detr/models/mr_mg/policy/utils/dist_train.py
new file mode 100644
index 0000000000000000000000000000000000000000..1dbcba812ca1941ce53e1e6e8ebd71451159bc4e
--- /dev/null
+++ b/ACT_DP_multitask/detr/models/mr_mg/policy/utils/dist_train.py
@@ -0,0 +1,77 @@
+# Copyright (2024) Bytedance Ltd. and/or its affiliates
+
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+
+# http://www.apache.org/licenses/LICENSE-2.0
+
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import io
+import os
+
+import torch
+import torch.distributed as dist
+
+
+_print = print
+
+def get_world_size(): return int(os.getenv('WORLD_SIZE', 1))
+def get_rank(): return int(os.getenv('RANK', 0))
+def get_local_rank(): return int(os.getenv('LOCAL_RANK', 0))
+
+
+def is_dist():
+ return dist.is_available() and dist.is_initialized() and get_world_size() > 1
+
+def print(*argc, all=False, **kwargs):
+ if not is_dist():
+ _print(*argc, **kwargs)
+ return
+
+ if not all and get_local_rank() != 0:
+ return
+
+ output = io.StringIO()
+ kwargs['end'] = ''
+ kwargs['file'] = output
+ kwargs['flush'] = True
+ _print(*argc, **kwargs)
+
+ s = output.getvalue()
+ output.close()
+
+ s = '[rank {}] {}'.format(dist.get_rank(), s)
+ _print(s)
+
+def reduce_mean(tensor, nprocs=None):
+ if not is_dist():
+ return tensor
+ if not isinstance(tensor, torch.Tensor):
+ device = torch.cuda.current_device()
+ rt = torch.tensor(tensor, device=device)
+ else:
+ rt = tensor.clone()
+ dist.all_reduce(rt, op=dist.ReduceOp.SUM)
+ nprocs = nprocs if nprocs else dist.get_world_size()
+ rt = rt / nprocs
+ if not isinstance(tensor, torch.Tensor):
+ rt = rt.item()
+ return rt
+
+def reduce_sum(tensor):
+ if not is_dist():
+ return tensor
+ if not isinstance(tensor, torch.Tensor):
+ device = torch.cuda.current_device()
+ rt = torch.tensor(tensor, device=device)
+ else:
+ rt = tensor.clone()
+ dist.all_reduce(rt, op=dist.ReduceOp.SUM)
+ if not isinstance(tensor, torch.Tensor):
+ rt = rt.item()
+ return rt
\ No newline at end of file
diff --git a/ACT_DP_multitask/detr/models/position_encoding.py b/ACT_DP_multitask/detr/models/position_encoding.py
new file mode 100644
index 0000000000000000000000000000000000000000..f62b15b802cb19f06b1c7988f9fde5ced357a0ec
--- /dev/null
+++ b/ACT_DP_multitask/detr/models/position_encoding.py
@@ -0,0 +1,111 @@
+# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
+"""
+Various positional encodings for the transformer.
+"""
+import math
+import torch
+from torch import nn
+
+from ..util.misc import NestedTensor
+
+import IPython
+
+e = IPython.embed
+
+
+class PositionEmbeddingSine(nn.Module):
+ """
+ This is a more standard version of the position embedding, very similar to the one
+ used by the Attention is all you need paper, generalized to work on images.
+ """
+
+ def __init__(
+ self, num_pos_feats=64, temperature=10000, normalize=False, scale=None
+ ):
+ super().__init__()
+ self.num_pos_feats = num_pos_feats
+ self.temperature = temperature
+ self.normalize = normalize
+ if scale is not None and normalize is False:
+ raise ValueError("normalize should be True if scale is passed")
+ if scale is None:
+ scale = 2 * math.pi
+ self.scale = scale
+
+ def forward(self, tensor):
+ x = tensor
+ # mask = tensor_list.mask
+ # assert mask is not None
+ # not_mask = ~mask
+
+ not_mask = torch.ones_like(x[0, [0]])
+ y_embed = not_mask.cumsum(1, dtype=torch.float32)
+ x_embed = not_mask.cumsum(2, dtype=torch.float32)
+ if self.normalize:
+ eps = 1e-6
+ y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale
+ x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale
+
+ dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device)
+ dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats)
+
+ pos_x = x_embed[:, :, :, None] / dim_t
+ pos_y = y_embed[:, :, :, None] / dim_t
+ pos_x = torch.stack(
+ (pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4
+ ).flatten(3)
+ pos_y = torch.stack(
+ (pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4
+ ).flatten(3)
+ pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2)
+ return pos # 1 D H W
+
+
+class PositionEmbeddingLearned(nn.Module):
+ """
+ Absolute pos embedding, learned.
+ """
+
+ def __init__(self, num_pos_feats=256):
+ super().__init__()
+ self.row_embed = nn.Embedding(50, num_pos_feats)
+ self.col_embed = nn.Embedding(50, num_pos_feats)
+ self.reset_parameters()
+
+ def reset_parameters(self):
+ nn.init.uniform_(self.row_embed.weight)
+ nn.init.uniform_(self.col_embed.weight)
+
+ def forward(self, tensor_list: NestedTensor):
+ x = tensor_list.tensors
+ h, w = x.shape[-2:]
+ i = torch.arange(w, device=x.device)
+ j = torch.arange(h, device=x.device)
+ x_emb = self.col_embed(i)
+ y_emb = self.row_embed(j)
+ pos = (
+ torch.cat(
+ [
+ x_emb.unsqueeze(0).repeat(h, 1, 1),
+ y_emb.unsqueeze(1).repeat(1, w, 1),
+ ],
+ dim=-1,
+ )
+ .permute(2, 0, 1)
+ .unsqueeze(0)
+ .repeat(x.shape[0], 1, 1, 1)
+ )
+ return pos
+
+
+def build_position_encoding(args):
+ N_steps = args.hidden_dim // 2
+ if args.position_embedding in ("v2", "sine"):
+ # TODO find a better way of exposing other arguments
+ position_embedding = PositionEmbeddingSine(N_steps, normalize=True)
+ elif args.position_embedding in ("v3", "learned"):
+ position_embedding = PositionEmbeddingLearned(N_steps)
+ else:
+ raise ValueError(f"not supported {args.position_embedding}")
+
+ return position_embedding
diff --git a/ACT_DP_multitask/detr/models/resnet_film.py b/ACT_DP_multitask/detr/models/resnet_film.py
new file mode 100644
index 0000000000000000000000000000000000000000..bc58f38f17e94c6cd7eb6afba19e0e1e2d3b151f
--- /dev/null
+++ b/ACT_DP_multitask/detr/models/resnet_film.py
@@ -0,0 +1,450 @@
+from typing import Type, Any, Callable, Union, List, Mapping, Optional
+
+import copy
+import torch
+import torch.nn as nn
+from torch import Tensor
+
+
+def is_torch_version_lower_than_17():
+ major_version = float(torch.__version__.split('.')[0])
+ minor_version = float(torch.__version__.split('.')[1])
+ return major_version == 1 and minor_version < 7
+
+
+if not is_torch_version_lower_than_17():
+ # TODO: Make sure the torchvision version is similarly updated.
+ from torchvision.models import ResNet18_Weights, ResNet34_Weights, ResNet101_Weights
+
+
+def conv3x3(in_planes: int, out_planes: int, stride: int = 1, groups: int = 1, dilation: int = 1) -> nn.Conv2d:
+ """3x3 convolution with padding"""
+ return nn.Conv2d(
+ in_planes,
+ out_planes,
+ kernel_size=3,
+ stride=stride,
+ padding=dilation,
+ groups=groups,
+ bias=False,
+ dilation=dilation,
+ )
+
+
+def conv1x1(in_planes: int, out_planes: int, stride: int = 1) -> nn.Conv2d:
+ """1x1 convolution"""
+ return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
+
+
+class BasicBlock(nn.Module):
+ expansion: int = 1
+
+ def __init__(
+ self,
+ inplanes: int,
+ planes: int,
+ stride: int = 1,
+ downsample: Optional[nn.Module] = None,
+ groups: int = 1,
+ base_width: int = 64,
+ dilation: int = 1,
+ norm_layer: Optional[Callable[..., nn.Module]] = None,
+ ) -> None:
+ super().__init__()
+ if norm_layer is None:
+ norm_layer = nn.BatchNorm2d
+ if groups != 1 or base_width != 64:
+ raise ValueError("BasicBlock only supports groups=1 and base_width=64")
+ if dilation > 1:
+ raise NotImplementedError("Dilation > 1 not supported in BasicBlock")
+ # Both self.conv1 and self.downsample layers downsample the input when stride != 1
+ self.conv1 = conv3x3(inplanes, planes, stride)
+ self.bn1 = norm_layer(planes)
+ self.relu = nn.ReLU(inplace=True)
+ self.conv2 = conv3x3(planes, planes)
+ self.bn2 = norm_layer(planes)
+ self.downsample = downsample
+ self.stride = stride
+
+ def forward(self, x: Tensor, film_features: Optional[Tensor] = None) -> Tensor:
+ identity = x
+
+ out = self.conv1(x)
+ out = self.bn1(out)
+ out = self.relu(out)
+
+ out = self.conv2(out)
+ out = self.bn2(out)
+ # Apply FiLM here
+ if film_features is not None:
+ # gamma, beta will be (B, 1, 1, planes)
+ gamma, beta = torch.split(film_features, 1, dim=1)
+ gamma = gamma.squeeze().view(x.size(0), -1, 1, 1)
+ beta = beta.squeeze().view(x.size(0), -1, 1, 1)
+ out = (1 + gamma) * out + beta
+
+ if self.downsample is not None:
+ identity = self.downsample(x)
+
+ out += identity
+ out = self.relu(out)
+
+ return out
+
+
+class Bottleneck(nn.Module):
+ # Bottleneck in torchvision places the stride for downsampling at 3x3 convolution(self.conv2)
+ # while original implementation places the stride at the first 1x1 convolution(self.conv1)
+ # according to "Deep residual learning for image recognition"https://arxiv.org/abs/1512.03385.
+ # This variant is also known as ResNet V1.5 and improves accuracy according to
+ # https://ngc.nvidia.com/catalog/model-scripts/nvidia:resnet_50_v1_5_for_pytorch.
+
+ expansion: int = 4
+
+ def __init__(
+ self,
+ inplanes: int,
+ planes: int,
+ stride: int = 1,
+ downsample: Optional[nn.Module] = None,
+ groups: int = 1,
+ base_width: int = 64,
+ dilation: int = 1,
+ norm_layer: Optional[Callable[..., nn.Module]] = None,) -> None:
+ super().__init__()
+ if norm_layer is None:
+ norm_layer = nn.BatchNorm2d
+ width = int(planes * (base_width / 64.0)) * groups
+ # Both self.conv2 and self.downsample layers downsample the input when stride != 1
+ self.conv1 = conv1x1(inplanes, width)
+ self.bn1 = norm_layer(width)
+ self.conv2 = conv3x3(width, width, stride, groups, dilation)
+ self.bn2 = norm_layer(width)
+ self.conv3 = conv1x1(width, planes * self.expansion)
+ self.bn3 = norm_layer(planes * self.expansion)
+ self.relu = nn.ReLU(inplace=True)
+ self.downsample = downsample
+ self.stride = stride
+
+ def forward(self, x: Tensor) -> Tensor:
+ identity = x
+
+ out = self.conv1(x)
+ out = self.bn1(out)
+ out = self.relu(out)
+
+ out = self.conv2(out)
+ out = self.bn2(out)
+ out = self.relu(out)
+
+ out = self.conv3(out)
+ out = self.bn3(out)
+
+ if self.downsample is not None:
+ identity = self.downsample(x)
+
+ out += identity
+
+ out = self.relu(out)
+
+ return out
+
+
+class ResNetWithExtraModules(nn.Module):
+ """Update standard ResNet image classification models with FiLM."""
+ def __init__(
+ self,
+ block: Type[Union[BasicBlock, Bottleneck]],
+ layers: List[int],
+ num_classes: int = 1000,
+ zero_init_residual: bool = False,
+ groups: int = 1,
+ width_per_group: int = 64,
+ replace_stride_with_dilation: Optional[List[bool]] = None,
+ norm_layer: Optional[Callable[..., nn.Module]] = None,
+ film_config: Optional[Mapping[str, Any]] = None,) -> None:
+ super().__init__()
+ if norm_layer is None:
+ norm_layer = nn.BatchNorm2d
+ self._norm_layer = norm_layer
+
+ # Save how many blocks in each layer
+ self.layers = layers
+
+ # FiLM only implemented for BasicBlock for now
+ self.use_film = film_config is not None and film_config['use']
+ if self.use_film:
+ self.film_config = film_config
+ self.film_planes = film_config['film_planes']
+
+ self.inplanes = 64
+ self.dilation = 1
+ if replace_stride_with_dilation is None:
+ # each element in the tuple indicates if we should replace
+ # the 2x2 stride with a dilated convolution instead
+ replace_stride_with_dilation = [False, False, False]
+ if len(replace_stride_with_dilation) != 3:
+ raise ValueError(
+ "replace_stride_with_dilation should be None "
+ f"or a 3-element tuple, got {replace_stride_with_dilation}"
+ )
+
+ in_channels_conv1 = 4 if (
+ film_config is not None and
+ film_config.get('append_object_mask', None) is not None) else 3
+
+ self.groups = groups
+ self.base_width = width_per_group
+ self.conv1 = nn.Conv2d(in_channels_conv1, self.inplanes, kernel_size=7, stride=2, padding=3, bias=False)
+ self.bn1 = norm_layer(self.inplanes)
+ self.relu = nn.ReLU(inplace=True)
+ self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
+ self.layer1 = self._make_layer(block, 64, layers[0])
+ self.layer2 = self._make_layer(block, 128, layers[1], stride=2, dilate=replace_stride_with_dilation[0])
+ self.layer3 = self._make_layer(block, 256, layers[2], stride=2, dilate=replace_stride_with_dilation[1])
+ self.layer4 = self._make_layer(block, 512, layers[3], stride=2, dilate=replace_stride_with_dilation[2])
+ self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
+ self.fc = nn.Linear(512 * block.expansion, num_classes)
+
+ for m_name, m in self.named_modules():
+ if isinstance(m, nn.Conv2d):
+ nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
+ elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
+ nn.init.constant_(m.weight, 1)
+ nn.init.constant_(m.bias, 0)
+
+ # Zero-initialize the last BN in each residual branch,
+ # so that the residual branch starts with zeros, and each residual block behaves like an identity.
+ # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677
+ if zero_init_residual:
+ for m in self.modules():
+ if isinstance(m, Bottleneck) and m.bn3.weight is not None:
+ nn.init.constant_(m.bn3.weight, 0) # type: ignore[arg-type]
+ elif isinstance(m, BasicBlock) and m.bn2.weight is not None:
+ nn.init.constant_(m.bn2.weight, 0) # type: ignore[arg-type]
+
+ def _make_layer(
+ self,
+ block: Type[Union[BasicBlock, Bottleneck]],
+ planes: int,
+ blocks: int,
+ stride: int = 1,
+ dilate: bool = False,) -> nn.Sequential:
+ norm_layer = self._norm_layer
+ downsample = None
+ previous_dilation = self.dilation
+ if dilate:
+ self.dilation *= stride
+ stride = 1
+ if stride != 1 or self.inplanes != planes * block.expansion:
+ downsample = nn.Sequential(
+ conv1x1(self.inplanes, planes * block.expansion, stride),
+ norm_layer(planes * block.expansion),
+ )
+
+ layers = [
+ block(self.inplanes, planes, stride, downsample, self.groups, self.base_width, previous_dilation, norm_layer,)
+ ]
+ self.inplanes = planes * block.expansion
+ for _ in range(1, blocks):
+ layers.append(
+ block(
+ self.inplanes,
+ planes,
+ groups=self.groups,
+ base_width=self.base_width,
+ dilation=self.dilation,
+ norm_layer=norm_layer,
+ )
+ )
+
+ if self.use_film:
+ return nn.ModuleList(layers)
+ else:
+ return nn.Sequential(*layers)
+
+ def _forward_impl_film(self, x: Tensor, film_features: List[Optional[Tensor]], flatten: bool = True):
+ assert self.use_film and film_features is not None
+
+ def _extract_film_features_for_layer(film_feat: Optional[Tensor], layer_idx: int):
+ if film_features[layer_idx] is None:
+ return [None] * self.layers[layer_idx]
+
+ num_planes = self.film_planes[layer_idx]
+ num_blocks = self.layers[layer_idx]
+ film_feat = film_feat.view(-1, 2, num_blocks, num_planes)
+ film_feat_per_block = torch.split(film_feat, 1, dim=2)
+ return film_feat_per_block
+
+ for layer_idx, layer in enumerate([self.layer1, self.layer2, self.layer3, self.layer4]):
+ film_feat_per_block = _extract_film_features_for_layer(
+ film_features[layer_idx], layer_idx)
+ for block_idx, block in enumerate(layer):
+ if film_feat_per_block[block_idx] is not None:
+ assert x.shape[0] == film_feat_per_block[block_idx].shape[0], ('FiLM batch size does not match')
+ x = block(x, film_features=film_feat_per_block[block_idx])
+
+ x = self.avgpool(x)
+ if flatten:
+ x = torch.flatten(x, 1)
+ x = self.fc(x)
+ return x
+
+ def _forward_impl(self,
+ x: Tensor,
+ film_features: List[Optional[Tensor]],
+ flatten: bool = True) -> Tensor:
+ # See note [TorchScript super()]
+ x = self.conv1(x)
+ x = self.bn1(x)
+ x = self.relu(x)
+ x = self.maxpool(x)
+
+ if self.use_film:
+ return self._forward_impl_film(x, film_features, flatten=flatten)
+ else:
+ x = self.layer1(x)
+ x = self.layer2(x)
+ x = self.layer3(x)
+ x = self.layer4(x)
+
+ x = self.avgpool(x)
+ if flatten:
+ x = torch.flatten(x, 1)
+ x = self.fc(x)
+
+ return x
+
+ def forward(self,
+ x: Tensor,
+ film_features: List[Optional[Tensor]], **kwargs) -> Tensor:
+ return self._forward_impl(x, film_features, **kwargs)
+
+
+def _resnet(
+ block: Type[Union[BasicBlock, Bottleneck]],
+ layers: List[int],
+ weights,
+ progress: bool,
+ **kwargs: Any,
+) -> ResNetWithExtraModules:
+ model_kwargs = copy.deepcopy(kwargs)
+ if 'pretrained' in model_kwargs:
+ del model_kwargs['pretrained']
+ if 'arch' in model_kwargs:
+ del model_kwargs['arch']
+ model = ResNetWithExtraModules(block, layers, **model_kwargs)
+
+ if weights is not None:
+ model.load_state_dict(weights.get_state_dict(progress=progress))
+ elif kwargs.get('pretrained', False) and kwargs.get('arch') is not None:
+ if float(torch.__version__.split('.')[1]) < 7:
+ # Copied from https://pytorch.org/vision/0.11/_modules/torchvision/models/resnet.html#resnet18
+ model_urls = {
+ 'resnet18': 'https://download.pytorch.org/models/resnet18-f37072fd.pth',
+ 'resnet34': 'https://download.pytorch.org/models/resnet34-b627a593.pth',
+ 'resnet50': 'https://download.pytorch.org/models/resnet50-0676ba61.pth',
+ 'resnet101': 'https://download.pytorch.org/models/resnet101-63fe2227.pth',
+ 'resnet152': 'https://download.pytorch.org/models/resnet152-394f9c45.pth',
+ 'resnext50_32x4d': 'https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth',
+ 'resnext101_32x8d': 'https://download.pytorch.org/models/resnext101_32x8d-8ba56ff5.pth',
+ 'wide_resnet50_2': 'https://download.pytorch.org/models/wide_resnet50_2-95faca4d.pth',
+ 'wide_resnet101_2': 'https://download.pytorch.org/models/wide_resnet101_2-32ee1156.pth',
+ }
+
+ # state_dict = load_state_dict_from_url(model_urls[arch],
+ # progress=progress)
+ state_dict = torch.hub.load_state_dict_from_url(model_urls[kwargs.get('arch')],
+ progress=progress)
+ model.load_state_dict(state_dict)
+
+ return model
+
+
+def resnet18(*, weights = None, progress: bool = True, **kwargs: Any) -> ResNetWithExtraModules:
+ """ResNet-18 from `Deep Residual Learning for Image Recognition `__.
+
+ Args:
+ weights (:class:`~torchvision.models.ResNet18_Weights`, optional): The
+ pretrained weights to use. See
+ :class:`~torchvision.models.ResNet18_Weights` below for
+ more details, and possible values. By default, no pre-trained
+ weights are used.
+ progress (bool, optional): If True, displays a progress bar of the
+ download to stderr. Default is True.
+ **kwargs: parameters passed to the ``torchvision.models.resnet.ResNet``
+ base class. Please refer to the `source code
+ `_
+ for more details about this class.
+
+ .. autoclass:: torchvision.models.ResNet18_Weights
+ :members:
+ """
+ if is_torch_version_lower_than_17():
+ kwargs["arch"] = "resnet18"
+ weights = None
+ else:
+ weights = ResNet18_Weights.verify(weights)
+
+ return _resnet(BasicBlock, [2, 2, 2, 2], weights, progress, **kwargs)
+
+
+def resnet34(*, weights = None, progress: bool = True, **kwargs: Any) -> ResNetWithExtraModules:
+ """ResNet-34 from `Deep Residual Learning for Image Recognition `__.
+
+ Args:
+ weights (:class:`~torchvision.models.ResNet34_Weights`, optional): The
+ pretrained weights to use. See
+ :class:`~torchvision.models.ResNet34_Weights` below for
+ more details, and possible values. By default, no pre-trained
+ weights are used.
+ progress (bool, optional): If True, displays a progress bar of the
+ download to stderr. Default is True.
+ **kwargs: parameters passed to the ``torchvision.models.resnet.ResNet``
+ base class. Please refer to the `source code
+ `_
+ for more details about this class.
+
+ .. autoclass:: torchvision.models.ResNet34_Weights
+ :members:
+ """
+ if is_torch_version_lower_than_17():
+ kwargs["arch"] = "resnet34"
+ weights = None
+ else:
+ weights = ResNet34_Weights.verify(weights)
+
+ return _resnet(BasicBlock, [3, 4, 6, 3], weights, progress, **kwargs)
+
+
+def resnet101(*, weights = None, progress: bool = True, **kwargs: Any) -> ResNetWithExtraModules:
+ """ResNet-101 from `Deep Residual Learning for Image Recognition `__.
+
+ .. note::
+ The bottleneck of TorchVision places the stride for downsampling to the second 3x3
+ convolution while the original paper places it to the first 1x1 convolution.
+ This variant improves the accuracy and is known as `ResNet V1.5
+ `_.
+
+ Args:
+ weights (:class:`~torchvision.models.ResNet101_Weights`, optional): The
+ pretrained weights to use. See
+ :class:`~torchvision.models.ResNet101_Weights` below for
+ more details, and possible values. By default, no pre-trained
+ weights are used.
+ progress (bool, optional): If True, displays a progress bar of the
+ download to stderr. Default is True.
+ **kwargs: parameters passed to the ``torchvision.models.resnet.ResNet``
+ base class. Please refer to the `source code
+ `_
+ for more details about this class.
+
+ .. autoclass:: torchvision.models.ResNet101_Weights
+ :members:
+ """
+ if is_torch_version_lower_than_17():
+ kwargs["arch"] = "resnet101"
+ weights = None
+ else:
+ weights = ResNet101_Weights.verify(weights)
+ return _resnet(Bottleneck, [3, 4, 23, 3], weights, progress, **kwargs)
diff --git a/ACT_DP_multitask/detr/models/transformer.py b/ACT_DP_multitask/detr/models/transformer.py
new file mode 100644
index 0000000000000000000000000000000000000000..3632f6fdf84bd1feb6952189536824f4796ae66f
--- /dev/null
+++ b/ACT_DP_multitask/detr/models/transformer.py
@@ -0,0 +1,2358 @@
+# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
+"""
+DETR Transformer class.
+
+Copy-paste from torch.nn.Transformer with modifications:
+ * positional encodings are passed in MHattention
+ * extra LN at the end of encoder is removed
+ * decoder returns a stack of activations from all decoding layers
+"""
+import copy
+from typing import Optional, List
+import math
+import torch
+import torch.nn.functional as F
+from torch import nn, Tensor
+
+import IPython
+
+e = IPython.embed
+
+
+class Transformer(nn.Module):
+
+ def __init__(
+ self,
+ d_model=512,
+ nhead=8,
+ num_encoder_layers=6,
+ num_decoder_layers=6,
+ dim_feedforward=2048,
+ dropout=0.1,
+ activation="relu",
+ normalize_before=False,
+ return_intermediate_dec=False,
+ ):
+ super().__init__()
+
+ encoder_layer = TransformerEncoderLayer(
+ d_model, nhead, dim_feedforward, dropout, activation, normalize_before
+ )
+ encoder_norm = nn.LayerNorm(d_model) if normalize_before else None
+ self.encoder = TransformerEncoder(
+ encoder_layer, num_encoder_layers, encoder_norm
+ )
+
+ decoder_layer = TransformerDecoderLayer(
+ d_model, nhead, dim_feedforward, dropout, activation, normalize_before
+ )
+ decoder_norm = nn.LayerNorm(d_model)
+ self.decoder = TransformerDecoder(
+ decoder_layer,
+ num_decoder_layers,
+ decoder_norm,
+ return_intermediate=return_intermediate_dec,
+ )
+
+ self._reset_parameters()
+
+ self.d_model = d_model
+ self.nhead = nhead
+
+ def _reset_parameters(self):
+ for p in self.parameters():
+ if p.dim() > 1:
+ nn.init.xavier_uniform_(p)
+
+ def forward(
+ self,
+ src,
+ mask,
+ query_embed,
+ pos_embed,
+ latent_input=None,
+ proprio_input=None,
+ additional_pos_embed=None,
+ ):
+ # TODO flatten only when input has H and W
+ if len(src.shape) == 4: # has H and W
+ # flatten NxCxHxW to HWxNxC
+ bs, c, h, w = src.shape
+ src = src.flatten(2).permute(2, 0, 1)
+ pos_embed = pos_embed.flatten(2).permute(2, 0, 1).repeat(1, bs, 1)
+ query_embed = query_embed.unsqueeze(1).repeat(1, bs, 1)
+ # mask = mask.flatten(1)
+
+ additional_pos_embed = additional_pos_embed.unsqueeze(1).repeat(
+ 1, bs, 1
+ ) # seq, bs, dim
+ pos_embed = torch.cat([additional_pos_embed, pos_embed], axis=0)
+
+ addition_input = torch.stack([latent_input, proprio_input], axis=0)
+ src = torch.cat([addition_input, src], axis=0)
+ else:
+ assert len(src.shape) == 3
+ # flatten NxHWxC to HWxNxC
+ bs, hw, c = src.shape
+ src = src.permute(1, 0, 2)
+ pos_embed = pos_embed.unsqueeze(1).repeat(1, bs, 1)
+ query_embed = query_embed.unsqueeze(1).repeat(1, bs, 1)
+
+ tgt = torch.zeros_like(query_embed)
+ memory = self.encoder(src, src_key_padding_mask=mask, pos=pos_embed)
+ shape_len = tgt.shape[0]
+ tgt_mask = torch.zeros(shape_len, shape_len).to(tgt.device)
+ tgt_mask[:100, 100:] = float(
+ "-inf"
+ )
+ tgt_mask[100:, :100] = float("-inf") #
+ hs = self.decoder(
+ tgt,
+ memory,
+ tgt_mask,
+ memory_key_padding_mask=mask,
+ pos=pos_embed,
+ query_pos=query_embed,
+ )
+ hs = hs.transpose(1, 2)
+ return hs
+
+
+class Transformer_Denoise(nn.Module):
+
+ def __init__(
+ self,
+ d_model=512,
+ nhead=8,
+ num_encoder_layers=6,
+ num_decoder_layers=6,
+ dim_feedforward=2048,
+ dropout=0.1,
+ activation="relu",
+ normalize_before=False,
+ return_intermediate_dec=False,
+ causal_mask=False,
+ ):
+ super().__init__()
+
+ encoder_layer = TransformerEncoderLayer(
+ d_model, nhead, dim_feedforward, dropout, activation, normalize_before
+ )
+ encoder_norm = nn.LayerNorm(d_model) if normalize_before else None
+ self.encoder = TransformerEncoder(
+ encoder_layer, num_encoder_layers, encoder_norm
+ )
+
+ decoder_layer = TransformerDecoderLayer(
+ d_model, nhead, dim_feedforward, dropout, activation, normalize_before
+ )
+ decoder_norm = nn.LayerNorm(d_model)
+ self.decoder = TransformerDecoder(
+ decoder_layer,
+ num_decoder_layers,
+ decoder_norm,
+ return_intermediate=return_intermediate_dec,
+ )
+
+ self.time_embed = nn.Sequential(
+ nn.Linear(d_model, d_model),
+ nn.SiLU(),
+ nn.Linear(d_model, d_model),
+ )
+ self._reset_parameters()
+
+ self.action_embed = nn.Sequential(
+ nn.Linear(14, d_model),
+ nn.SiLU(),
+ nn.Linear(d_model, d_model),
+ )
+
+ self.d_model = d_model
+ self.nhead = nhead
+ self.denoise_step_pos_embed = nn.Embedding(1, d_model)
+ self.causal_mask = causal_mask
+ print("apply causal_mask:", causal_mask)
+
+ def _reset_parameters(self):
+ for p in self.parameters():
+ if p.dim() > 1:
+ nn.init.xavier_uniform_(p)
+ # self.transformer(src, None, self.query_embed.weight, pos, latent_input, proprio_input, self.additional_pos_embed.weight,actions, denoise_steps)
+
+ def forward(
+ self,
+ src,
+ mask,
+ query_embed,
+ pos_embed,
+ latent_input=None,
+ proprio_input=None,
+ additional_pos_embed=None,
+ noisy_actions=None,
+ denoise_steps=None,
+ task_emb=None, # B D
+ ):
+ # TODO flatten only when input has H and W
+ # encoder don't need change
+ # src: image embedding mask: mask
+ # query_embed: decoder PE
+ # pos_embed: encoder PE
+ # latent_input: vae latent or tacile latent
+ # proprio_input: proprio
+ # additional_pos_embed: proprio + proprio
+ # denoise_embed: denoise timestep embedding
+ # noisy_actions: noisy actions
+
+ if len(src.shape) == 4: # has H and W b d h (w n_view)
+ # flatten NxCxHxW to HWxNxC
+ bs, c, h, w = src.shape
+ src = src.flatten(2).permute(2, 0, 1) # h*w, bs, c
+ pos_embed = pos_embed.flatten(2).permute(2, 0, 1).repeat(1, bs, 1)
+ query_embed = query_embed.unsqueeze(1).repeat(1, bs, 1) # N bs dim
+ # mask = mask.flatten(1)
+ # N_add Dim
+ additional_pos_embed = additional_pos_embed.unsqueeze(1).repeat(
+ 1, bs, 1
+ ) # seq, bs, dim
+ pos_embed = torch.cat([additional_pos_embed, pos_embed], axis=0)
+
+ latent_input = latent_input.unsqueeze(1) if len(latent_input.shape) ==2 else latent_input # B 1 D
+ # print('latent_input', latent_input.shape, 'proprio_input', proprio_input.shape)
+ addition_input = torch.cat([latent_input, proprio_input], axis=1).permute(
+ 1, 0, 2
+ ) # B T+1 D -> T+1 B D
+ if task_emb is not None:
+ addition_input = torch.cat([addition_input, task_emb.unsqueeze(0)], axis=0) ## concat task embedding to encoder T+2 B D
+ src = torch.cat([addition_input, src], axis=0)
+ else:
+ assert len(src.shape) == 3
+ # flatten NxHWxC to HWxNxC
+ bs, hw, c = src.shape
+ src = src.permute(1, 0, 2)
+ pos_embed = pos_embed.unsqueeze(1).repeat(1, bs, 1)
+ query_embed = query_embed.unsqueeze(1).repeat(1, bs, 1)
+
+ tgt = self.action_embed(noisy_actions).permute(
+ 1, 0, 2
+ ) # TODO Change to noise tgt B T D -> T B D
+ denoise_embed = get_timestep_embedding(denoise_steps, self.d_model) # B -> B D
+ # print(denoise_embed.shape)
+ denoise_embed = self.time_embed(denoise_embed).unsqueeze(0) # B D -> 1 B D
+ memory = self.encoder(
+ src, src_key_padding_mask=mask, pos=pos_embed
+ ) # cross attention
+ denoise_step_pos_embed = self.denoise_step_pos_embed.weight.unsqueeze(1).repeat(
+ 1, bs, 1
+ ) # 1 D -> 1 B D
+ memory = torch.cat([memory, denoise_embed], axis=0)
+ pos_embed = torch.cat([pos_embed, denoise_step_pos_embed], axis=0)
+ seq_len = tgt.shape[0]
+ if self.causal_mask:
+ tgt_mask = torch.triu(
+ torch.full((seq_len, seq_len), float("-inf")), diagonal=1
+ ).to(tgt.device)
+ else:
+ tgt_mask = torch.zeros(seq_len, seq_len).to(tgt.device)
+ hs = self.decoder(
+ tgt,
+ memory,
+ tgt_mask,
+ memory_key_padding_mask=mask,
+ pos=pos_embed,
+ query_pos=query_embed,
+ ) # TODO
+ hs = hs.transpose(1, 2) # 1 T B D -> 1 B T D
+ return hs
+
+
+class Transformer_Denoise_AdLN(nn.Module):
+
+ def __init__(
+ self,
+ d_model=512,
+ nhead=8,
+ num_encoder_layers=6,
+ num_decoder_layers=6,
+ dim_feedforward=2048,
+ dropout=0.1,
+ activation="relu",
+ normalize_before=False,
+ return_intermediate_dec=False,
+ causal_mask=False,
+ ):
+ super().__init__()
+
+ encoder_layer = TransformerEncoderLayer(
+ d_model, nhead, dim_feedforward, dropout, activation, normalize_before
+ )
+ encoder_norm = nn.LayerNorm(d_model) if normalize_before else None
+ self.encoder = TransformerEncoder(
+ encoder_layer, num_encoder_layers, encoder_norm
+ )
+
+ decoder_layer = TransformerEncoderLayer_AdLN(
+ d_model, nhead, dim_feedforward, dropout, activation, normalize_before
+ )
+ decoder_norm = nn.LayerNorm(d_model)
+ self.decoder = TransformerEncoder_AdLN(
+ decoder_layer,
+ num_decoder_layers,
+ decoder_norm,
+ )
+
+ self.time_embed = nn.Sequential(
+ nn.Linear(d_model, d_model),
+ nn.SiLU(),
+ nn.Linear(d_model, d_model),
+ )
+ self._reset_parameters()
+
+ self.action_embed = nn.Sequential(
+ nn.Linear(14, d_model),
+ nn.SiLU(),
+ nn.Linear(d_model, d_model),
+ )
+
+ self.d_model = d_model
+ self.nhead = nhead
+ self.denoise_step_pos_embed = nn.Embedding(1, d_model)
+
+ self.global_1d_pool = nn.AdaptiveAvgPool1d(1)
+ self.norm_after_pool = nn.LayerNorm(d_model)
+
+ def _reset_parameters(self):
+ for p in self.parameters():
+ if p.dim() > 1:
+ nn.init.xavier_uniform_(p)
+
+ def forward(
+ self,
+ src, # B D H W*num_view
+ mask, # None
+ query_embed, # H D
+ pos_embed, # 1 D H W*num_view
+ latent_input=None, # B 1 D
+ proprio_input=None, # B 1 D
+ additional_pos_embed=None, # 1+1 D
+ noisy_actions=None, # B H D
+ denoise_steps=None, # B
+ ):
+
+ if len(src.shape) == 4: # has H and W b d h (w n_view)
+ # flatten NxCxHxW to HWxNxC
+ bs, c, h, w = src.shape
+ src = src.flatten(2).permute(2, 0, 1) # h*w, bs, c
+ pos_embed = pos_embed.flatten(2).permute(2, 0, 1).repeat(1, bs, 1)
+ query_embed = query_embed.unsqueeze(1).repeat(1, bs, 1) # N bs dim
+ # N_add Dim
+ additional_pos_embed = additional_pos_embed.unsqueeze(1).repeat(
+ 1, bs, 1
+ ) # seq, bs, dim
+ pos_embed = torch.cat([additional_pos_embed, pos_embed], axis=0) # pro. + visual token PE
+
+ latent_input = latent_input.unsqueeze(1) if len(latent_input.shape) ==2 else latent_input # B 1 D
+ addition_input = torch.cat([latent_input, proprio_input], axis=1).permute(
+ 1, 0, 2
+ ) # B T+1 D -> T+1 B D
+ src = torch.cat([addition_input, src], axis=0)
+ else:
+ assert len(src.shape) == 3
+ # flatten NxHWxC to HWxNxC
+ bs, hw, c = src.shape
+ src = src.permute(1, 0, 2)
+ pos_embed = pos_embed.unsqueeze(1).repeat(1, bs, 1)
+ query_embed = query_embed.unsqueeze(1).repeat(1, bs, 1)
+
+ memory = self.encoder(
+ src, src_key_padding_mask=mask, pos=pos_embed
+ ) # cross attention N B D
+ # N B D => B D N => B D 1 => B D
+ memory = self.global_1d_pool(memory.permute(1, 2, 0)).squeeze(-1) # B D
+ memory = self.norm_after_pool(memory)
+ denoise_embed = get_timestep_embedding(denoise_steps, self.d_model) # B -> B D
+ denoise_embed = self.time_embed(denoise_embed) # B -> B D
+
+ condition = memory + denoise_embed # B D as condition for modulation
+
+
+ tgt = self.action_embed(noisy_actions).permute(
+ 1, 0, 2
+ )
+
+ hs = self.decoder(
+ tgt,
+ condition,
+ pos=query_embed
+ ) # Actually need is_pad?
+
+ hs = hs.transpose(1, 2) # 1 T B D -> 1 B T D
+ return hs
+
+class Transformer_Denoise_Tactile(nn.Module):
+ def __init__(
+ self,
+ d_model=512,
+ nhead=8,
+ num_encoder_layers=6,
+ num_decoder_layers=6,
+ dim_feedforward=2048,
+ dropout=0.1,
+ activation="relu",
+ normalize_before=False,
+ return_intermediate_dec=False,
+ causal_mask=False,
+ ):
+ super().__init__()
+ encoder_layer = TransformerEncoderLayer(
+ d_model, nhead, dim_feedforward, dropout, activation, normalize_before
+ )
+ encoder_norm = nn.LayerNorm(d_model) if normalize_before else None
+ self.encoder = TransformerEncoder(
+ encoder_layer, num_encoder_layers, encoder_norm
+ )
+
+ decoder_layer = TransformerDecoderLayer(
+ d_model, nhead, dim_feedforward, dropout, activation, normalize_before
+ )
+ decoder_norm = nn.LayerNorm(d_model)
+ self.decoder = TransformerDecoder_alter(
+ decoder_layer,
+ num_decoder_layers,
+ decoder_norm,
+ return_intermediate=return_intermediate_dec,
+ )
+
+ self.time_embed = nn.Sequential(
+ nn.Linear(d_model, d_model),
+ nn.SiLU(),
+ nn.Linear(d_model, d_model),
+ )
+ self._reset_parameters()
+
+ self.action_embed = nn.Sequential(
+ nn.Linear(14, d_model),
+ nn.SiLU(),
+ nn.Linear(d_model, d_model),
+ )
+
+ self.d_model = d_model
+ self.nhead = nhead
+ self.denoise_step_pos_embed = nn.Embedding(1, d_model)
+ self.causal_mask = causal_mask
+ print("apply causal_mask:", causal_mask)
+
+ def _reset_parameters(self):
+ for p in self.parameters():
+ if p.dim() > 1:
+ nn.init.xavier_uniform_(p)
+ def forward(
+ self,
+ src,
+ mask,
+ query_embed,
+ pos_embed,
+ tactile_input,
+ tactile_pos,
+ proprio_input=None,
+ additional_pos_embed=None,
+ noisy_actions=None,
+ denoise_steps=None,):
+
+ if len(src.shape) == 4: # has H and W b d h (w n_view)
+ # flatten NxCxHxW to HWxNxC
+ bs, c, h, w = src.shape
+ src = src.flatten(2).permute(2, 0, 1) # h*w, bs, c
+ pos_embed = pos_embed.flatten(2).permute(2, 0, 1).repeat(1, bs, 1)
+ query_embed = query_embed.unsqueeze(1).repeat(1, bs, 1) # N bs dim
+
+ tactile_input = tactile_input.flatten(2).permute(2, 0, 1) # h*w*4, bs, c
+ tactile_pos = tactile_pos.flatten(2).permute(2, 0, 1).repeat(1, bs, 1) # h*w, bs, c
+ # N_add Dim
+ additional_pos_embed = additional_pos_embed.unsqueeze(1).repeat(
+ 1, bs, 1
+ ) # seq, bs, dim
+ proprio_input = proprio_input.permute(1, 0, 2) # B 1 D -> 1 B D
+ # vision-state input
+ pos_embed = torch.cat([additional_pos_embed, pos_embed], axis=0)
+ src = torch.cat([proprio_input, src], axis=0)
+ # tactile-state input
+ # print('proprio_input', proprio_input.shape, 'tactile_input', tactile_input.shape)
+ src_tactile = torch.cat([proprio_input,tactile_input], axis=0)
+ src_tactile_pos = torch.cat([additional_pos_embed,tactile_pos], axis=0)
+
+ tgt = self.action_embed(noisy_actions).permute(
+ 1, 0, 2
+ ) # TODO Change to noise tgt B T D -> T B D
+ denoise_embed = get_timestep_embedding(denoise_steps, self.d_model) # B -> B D
+ denoise_embed = self.time_embed(denoise_embed).unsqueeze(0) # B D -> 1 B D
+ denoise_step_pos_embed = self.denoise_step_pos_embed.weight.unsqueeze(1).repeat(
+ 1, bs, 1
+ ) # 1 D -> 1 B D
+
+ # encoder vision-state information
+ memory = self.encoder(
+ src, src_key_padding_mask=mask, pos=pos_embed
+ ) # cross attention
+ memory = torch.cat([memory, denoise_embed], axis=0)
+ pos_embed = torch.cat([pos_embed, denoise_step_pos_embed], axis=0)
+ seq_len = tgt.shape[0]
+
+ # tactile-state information
+ memory_tactile = torch.cat([src_tactile, denoise_embed], axis=0)
+ tactile_pos_embed = torch.cat([src_tactile_pos, denoise_step_pos_embed], axis=0)
+ tgt_mask = torch.zeros(seq_len, seq_len).to(tgt.device)
+ hs = self.decoder(
+ tgt,
+ memory,
+ memory_tactile,
+ tgt_mask,
+ memory_key_padding_mask=mask,
+ pos = pos_embed,
+ pos_alter = tactile_pos_embed,
+ query_pos=query_embed,
+ )
+ hs = hs.transpose(1, 2)
+ return hs
+
+
+class Transformer_diffusion_prediction(nn.Module):
+ def __init__(
+ self,
+ d_model=512,
+ nhead=8,
+ num_encoder_layers=6,
+ num_decoder_layers=6,
+ dim_feedforward=2048,
+ dropout=0.1,
+ activation="relu",
+ normalize_before=False,
+ return_intermediate_dec=False,
+ num_queries=100,
+ share_decoder=True,
+ patch_size=5,
+ diffusion_timestep_type="cat",
+ attention_type="v1",
+ predict_frame=16,
+ predict_only_last=False,
+ token_dim=6,
+ ):
+ super().__init__()
+
+ encoder_layer = TransformerEncoderLayer(
+ d_model, nhead, dim_feedforward, dropout, activation, normalize_before
+ )
+ encoder_norm = nn.LayerNorm(d_model) if normalize_before else None
+ self.encoder = TransformerEncoder(
+ encoder_layer, num_encoder_layers, encoder_norm
+ )
+
+ decoder_layer = TransformerDecoderLayer(
+ d_model, nhead, dim_feedforward, dropout, activation, normalize_before
+ )
+ decoder_norm = nn.LayerNorm(d_model)
+ self.decoder = TransformerDecoder(
+ decoder_layer,
+ num_decoder_layers,
+ decoder_norm,
+ return_intermediate=return_intermediate_dec,
+ )
+ if share_decoder == False:
+ self.decoder_token = copy.deepcopy(self.decoder)
+
+ self.share_decoder = share_decoder
+ self.diffusion_timestep_type = diffusion_timestep_type
+ self.time_embed = nn.Sequential(
+ nn.Linear(d_model, d_model),
+ nn.SiLU(),
+ nn.Linear(d_model, d_model),
+ )
+ self._reset_parameters()
+
+ self.action_embed = nn.Sequential(
+ nn.Linear(14, d_model), # TODO action dim
+ nn.SiLU(),
+ nn.Linear(d_model, d_model),
+ )
+
+ self.token_embed = nn.Sequential(
+ nn.Linear(
+ token_dim * patch_size * patch_size, d_model
+ ), # Hardcode patch size * path size * patch dim
+ nn.SiLU(),
+ nn.Linear(d_model, d_model),
+ )
+
+ self.d_model = d_model
+ self.nhead = nhead
+ self.chunksize = num_queries
+ self.attention_type = attention_type
+ self.predict_frame = predict_frame
+ self.predict_only_last = predict_only_last
+ self.token_dim = token_dim
+
+ def _reset_parameters(self):
+ for p in self.parameters():
+ if p.dim() > 1:
+ nn.init.xavier_uniform_(p)
+
+ # hs_action, hs_token = self.transformer(src, None, self.query_embed.weight, self.query_embed_toekn.weight, pos, latent_input, proprio_input, self.additional_pos_embed.weight, noisy_actions, noise_tokens, denoise_steps, self.denoise_step_pos_embed.weight)
+ def forward(
+ self,
+ src,
+ mask,
+ query_embed,
+ query_embed_token,
+ pos_embed,
+ latent_input=None,
+ proprio_input=None,
+ additional_pos_embed=None,
+ noisy_actions=None,
+ noisy_tokens=None,
+ denoise_steps=None,
+ denoise_step_pos_embed=None,
+ is_pad=None,
+ is_pad_token=None,
+ ):
+ # src: B D H W*num_view
+ # mask:None
+ # query_embed: chunksize D for action query
+ # query_embed_token: Num_tokens D for token query
+ # pos_embed: 1 D H W*num_view for current frame token
+ # latent_input: B D
+ # proprio_input: B T' D
+ # additional_pos_embed: B T'+1 D, include proprio and latent
+ # noisy_actions: B chunksize D
+ # noisy_tokens: B T' N D H' W', T' = predict_frame / temporal_compression_rate
+ # denoise_steps: B
+ # denoise_step_pos_embed:1 D
+ # is_pad: B chunksize
+ # is_pad_token: B T' N H' W'
+ if len(src.shape) == 4: # has H and W
+ bs, c, h, w = src.shape # B D H W*num_view*T
+ src = src.flatten(2).permute(
+ 2, 0, 1
+ ) # H*W*num_view*T, bs, c deal with visual features from resnet
+ pos_embed = (
+ pos_embed.flatten(2).permute(2, 0, 1).repeat(1, bs, 1)
+ ) # H*W*num_view*T, bs, c
+
+ query_embed = query_embed.unsqueeze(1).repeat(1, bs, 1)
+ query_embed_token = query_embed_token.unsqueeze(1).repeat(
+ 1, bs, 1
+ ) # TODO temporal-spatial position embedding FOR predicted frame token
+ # print('transformer query_embed_token', query_embed_token.shape,'query_embed', query_embed.shape)
+ query_embed_all = torch.cat(
+ [query_embed, query_embed_token], axis=0
+ ) # chunksize + num_tokens, bs, c
+ additional_pos_embed = additional_pos_embed.unsqueeze(1).repeat(
+ 1, bs, 1
+ ) # seq, bs, dim
+ pos_embed = torch.cat([additional_pos_embed, pos_embed], axis=0)
+
+ latent_input = latent_input.unsqueeze(1) # B 1 D
+ addition_input = torch.cat([latent_input, proprio_input], axis=1).permute(
+ 1, 0, 2
+ ) # B T+1 D -> T+1 B D
+ src = torch.cat([addition_input, src], axis=0)
+
+ denoise_step_pos_embed = denoise_step_pos_embed.unsqueeze(1).repeat(
+ 1, bs, 1
+ ) # 1 D -> 1 B D
+ else:
+ assert len(src.shape) == 3
+
+ # encoder
+ memory = self.encoder(src, src_key_padding_mask=mask, pos=pos_embed)
+
+ # add denoise timestep embedding for decoder denoise step
+ denoise_embed = self.time_embed(
+ get_timestep_embedding(denoise_steps, self.d_model)
+ ).unsqueeze(
+ 0
+ ) # B D -> 1 B D
+ if (
+ self.diffusion_timestep_type == "cat"
+ ): # TODO add tokenizer visual token & timestep embedding
+ memory = torch.cat([memory, denoise_embed], axis=0)
+ pos_embed = torch.cat([pos_embed, denoise_step_pos_embed], axis=0)
+ # elif self.diffusion_timestep_type == 'vis_cat':
+ # memory = torch.cat([memory, visual_token, denoise_embed], axis=0) # visual token map
+ # pos_embed = torch.cat([pos_embed, pos_embed_visual_token, denoise_step_pos_embed], axis=0) # pos_embed_visual_token
+ else:
+ memory = memory + denoise_embed
+
+ tgt_action = self.action_embed(noisy_actions).permute(1, 0, 2) # H B D
+
+ # if noisy_tokens is None:
+
+ # tgt_key_padding_mask = torch.zeros_like(tgt_token).sum(-1).bool() # T'*N*H'*W' B
+
+ if self.share_decoder and noisy_tokens is not None:
+ noisy_tokens = noisy_tokens.permute(
+ 0, 1, 2, 4, 5, 3
+ ) # B T' N D H' W' -> B T' N H' W' D
+ tgt_token = (
+ self.token_embed(noisy_tokens)
+ .reshape(bs, -1, self.d_model)
+ .permute(1, 0, 2)
+ ) # B T' N H' W' D -> T'*N*H'*W' B D
+ # print('transformer',tgt_action.shape,tgt_token.shape )
+ tgt = torch.cat([tgt_action, tgt_token], axis=0) # H1+ H2 B D
+ seq_len = tgt.shape[0]
+ bs = tgt.shape[1]
+ # tgt_key_padding_mask = torch.cat([is_pad, is_pad_token], axis=1) # B N+M is
+ is_pad_token_zero = torch.zeros_like(is_pad_token).bool() # HARD CODE
+ if is_pad is None:
+ is_pad = torch.zeros(bs, self.chunksize).bool().to(src.device)
+ tgt_key_padding_mask = torch.cat(
+ [is_pad, is_pad_token_zero], axis=1
+ ) # HARD CODE avoid nan when all token is pad
+
+ seq_len = tgt.shape[0]
+ tgt_mask = torch.zeros(seq_len, seq_len).to(
+ tgt.device
+ ) # chunksize + num_pred_token_per_frame*predict_frame
+ # TODO design mask
+ if self.attention_type == "v1":
+ tgt_mask[0 : self.chunksize, self.chunksize :] = float(
+ "-inf"
+ ) # action prediction cannot attend to token prediction
+ tgt_mask[self.chunksize :, self.predict_frame : self.chunksize] = float(
+ "-inf"
+ ) # token prediction cannot attend to action token after token prediction
+ elif self.attention_type == "v2":
+ tgt_mask[0 : self.chunksize, self.chunksize :] = float(
+ "-inf"
+ ) # action prediction cannot attend to token prediction
+ tgt_mask[self.chunksize :, 0 : self.chunksize] = float(
+ "-inf"
+ ) # token prediction cannot attend to action prediction
+ elif self.attention_type == "v3": # v1= v3 sad
+ tgt_mask[0 : self.chunksize, self.chunksize :] = float(
+ "-inf"
+ ) # action prediction cannot attend to token prediction
+ if (
+ self.predict_frame < self.chunksize
+ ): # don't need attent action after target frame
+ tgt_mask[self.chunksize :, self.predict_frame : self.chunksize] = (
+ float("-inf")
+ )
+ elif self.attention_type == "causal":
+ predict_frame = noisy_tokens.shape[1]
+ num_pred_token_per_frame = tgt_token.shape[0] // predict_frame
+ chunk_size = query_embed.shape[0]
+ temporal_compression_rate = chunk_size // predict_frame
+ tgt_mask = torch.full((seq_len, seq_len), float("-inf")).to(
+ tgt.device
+ ) # seq_len = chunksize + num_pred_token_per_frame*predict_frame
+ # tgt_mask[:chunk_size, :chunk_size] = torch.triu(torch.full((chunk_size, seq_len), float('-inf')), diagonal=1).to(tgt.device)
+ tgt_mask[:chunk_size, :chunk_size] = torch.zeros(
+ (chunk_size, chunk_size)
+ ).to(tgt.device)
+ for t in range(predict_frame):
+ tgt_mask[
+ chunk_size
+ + t * num_pred_token_per_frame : chunk_size
+ + (t + 1) * num_pred_token_per_frame,
+ 0 : t * temporal_compression_rate,
+ ] = 0
+ tgt_mask[
+ chunk_size
+ + t * num_pred_token_per_frame : chunk_size
+ + (t + 1) * num_pred_token_per_frame,
+ chunk_size : chunk_size + (t + 1) * num_pred_token_per_frame,
+ ] = 0
+ # print(tgt.shape, memory.shape, tgt_key_padding_mask.shape, pos_embed.shape, query_embed_all.shape)
+ hs = self.decoder(
+ tgt,
+ memory,
+ tgt_mask,
+ memory_key_padding_mask=mask,
+ tgt_key_padding_mask=tgt_key_padding_mask,
+ pos=pos_embed,
+ query_pos=query_embed_all,
+ ).transpose(1, 2)[
+ 0
+ ] # TODO 1 H B D
+ hs_action = hs[:, : self.chunksize] # B H D
+ hs_token = hs[:, self.chunksize :] # B H D -> None
+ # simplify version
+ # tgt_mask = torch.full((seq_len, seq_len), float('-inf')).to(tgt.device)
+ # # Action-to-Action mask
+ # tgt_mask[:chunk_size, :chunk_size] = torch.triu(torch.full((chunk_size, chunk_size), float('-inf')), diagonal=1)
+
+ # # Frame-to-All mask
+ # frame_indices = torch.arange(chunk_size, seq_len).view(predict_frame, num_pred_token_per_frame).to(tgt.device)
+ # action_indices = torch.arange(chunk_size).to(tgt.device)
+
+ # for t in range(predict_frame):
+ # tgt_mask[frame_indices[t], action_indices[:t * temporal_compression_rate]] = 0
+ # tgt_mask[frame_indices[t], frame_indices[:t + 1].flatten()] = 0
+
+ elif self.share_decoder and noisy_tokens is None:
+ hs = self.decoder(
+ tgt_action,
+ memory,
+ memory_key_padding_mask=mask,
+ tgt_key_padding_mask=is_pad,
+ pos=pos_embed,
+ query_pos=query_embed,
+ ).transpose(1, 2)[0]
+ hs_action = hs
+ hs_token = None
+
+ else:
+ hs_action = self.decoder(
+ tgt_action,
+ memory,
+ memory_key_padding_mask=mask,
+ tgt_key_padding_mask=is_pad,
+ pos=pos_embed,
+ query_pos=query_embed,
+ ).transpose(1, 2)[0]
+ hs_token = self.decoder_token(
+ tgt_token,
+ memory,
+ memory_key_padding_mask=mask, # tgt_key_padding_mask = is_pad_token, # is_pad_token is nonetype ! fix
+ pos=pos_embed,
+ query_pos=query_embed_token,
+ ).transpose(1, 2)[0]
+ # hs_action B chunksize D
+ # hs_token B T'*N*H'*W' D
+ # print("tgt_token has NaN:", torch.isnan(tgt_token).any())
+ # print("memory has NaN:", torch.isnan(memory).any())
+ # print('tgt_key_padding_mask has NaN:', torch.isnan(is_pad_token).any())
+ # print("pos_embed has NaN:", torch.isnan(pos_embed).any())
+ # print("query_embed_token has NaN:", torch.isnan(query_embed_token).any())
+ # print("hs_token has NaN:", torch.isnan(hs_token).any()) # has nan
+ return hs_action, hs_token
+
+
+class Transformer_diffusion_prediction_with_dual_visual_token(nn.Module):
+ def __init__(
+ self,
+ d_model=512,
+ nhead=8,
+ num_encoder_layers=6,
+ num_decoder_layers=6,
+ dim_feedforward=2048,
+ dropout=0.1,
+ activation="relu",
+ normalize_before=False,
+ return_intermediate_dec=False,
+ num_queries=100,
+ share_decoder=True,
+ patch_size=5,
+ diffusion_timestep_type="cat",
+ attention_type="v1",
+ predict_frame=16,
+ predict_only_last=False,
+ token_dim=6,
+ ):
+ super().__init__()
+
+ encoder_layer = TransformerEncoderLayer(
+ d_model, nhead, dim_feedforward, dropout, activation, normalize_before
+ )
+ encoder_norm = nn.LayerNorm(d_model) if normalize_before else None
+ self.encoder = TransformerEncoder(
+ encoder_layer, num_encoder_layers, encoder_norm
+ )
+
+ decoder_layer = TransformerDecoderLayer(
+ d_model, nhead, dim_feedforward, dropout, activation, normalize_before
+ )
+ decoder_norm = nn.LayerNorm(d_model)
+ self.decoder = TransformerDecoder(
+ decoder_layer,
+ num_decoder_layers,
+ decoder_norm,
+ return_intermediate=return_intermediate_dec,
+ )
+ if share_decoder == False:
+ self.decoder_token = copy.deepcopy(self.decoder)
+
+ self.share_decoder = share_decoder
+ self.diffusion_timestep_type = diffusion_timestep_type
+ self.time_embed = nn.Sequential(
+ nn.Linear(d_model, d_model),
+ nn.SiLU(),
+ nn.Linear(d_model, d_model),
+ )
+ self._reset_parameters()
+ # TODO change action dim
+ token_dim = token_dim
+ self.action_embed = nn.Sequential(
+ nn.Linear(14, d_model),
+ nn.SiLU(),
+ nn.Linear(d_model, d_model),
+ )
+
+ self.token_embed = nn.Sequential(
+ nn.Linear(
+ token_dim * patch_size * patch_size, d_model
+ ), # Hardcode patch size * path size * patch dim
+ nn.SiLU(),
+ nn.Linear(d_model, d_model),
+ )
+
+ self.d_model = d_model
+ self.nhead = nhead
+ self.chunksize = num_queries
+ self.attention_type = attention_type
+ self.predict_frame = predict_frame
+ self.predict_only_last = predict_only_last
+
+ def _reset_parameters(self):
+ for p in self.parameters():
+ if p.dim() > 1:
+ nn.init.xavier_uniform_(p)
+
+ # hs_action, hs_token = self.transformer(src, None, self.query_embed.weight, self.query_embed_toekn.weight, pos, latent_input, proprio_input, self.additional_pos_embed.weight, noisy_actions, noise_tokens, denoise_steps, self.denoise_step_pos_embed.weight)
+ def forward(
+ self,
+ src,
+ mask,
+ query_embed,
+ query_embed_token,
+ pos_embed,
+ latent_input=None,
+ proprio_input=None,
+ additional_pos_embed=None,
+ noisy_actions=None,
+ noisy_tokens=None,
+ denoise_steps=None,
+ denoise_step_pos_embed=None,
+ is_pad=None,
+ is_pad_token=None,
+ addition_visual_token=None,
+ addition_visual_token_pos=None,
+ ):
+ # src: B D H W*num_view*T
+ # mask:None
+ # query_embed: chunksize D for action query
+ # query_embed_token: Num_tokens D for token query
+ # pos_embed: 1 D H W*num_view*T for current frame token
+ # latent_input: B D
+ # proprio_input: B T' D
+ # additional_pos_embed: T'+1 D, include proprio and latent
+ # noisy_actions: B chunksize D
+ # noisy_tokens: B T' N D H' W', T' = predict_frame / temporal_compression_rate
+ # denoise_steps: B
+ # denoise_step_pos_embed:1 D
+ # is_pad: B chunksize
+ # is_pad_token: B T' N H' W'
+ # addition_visual_token: B D H' W'*num_view
+ # addition_visual_token_pos: H'*W'*num_view D for visual token position embedding
+
+ if len(src.shape) == 4: # has H and W
+ bs, c, h, w = src.shape # B D H W*num_view*T
+ # For encoder
+ src = src.flatten(2).permute(
+ 2, 0, 1
+ ) # H*W*num_view*T, bs, c deal with visual features from resnet
+ addition_visual_token = addition_visual_token.flatten(2).permute(
+ 2, 0, 1
+ ) # H*W*num_view, bs, c visual token from tokenzier
+ latent_input = latent_input.unsqueeze(1) # B 1 D
+ addition_input = torch.cat([latent_input, proprio_input], axis=1).permute(
+ 1, 0, 2
+ ) # B T+1 D -> T+1 B D
+
+ pos_embed = (
+ pos_embed.flatten(2).permute(2, 0, 1).repeat(1, bs, 1)
+ ) # H*W*num_view*T, bs, c
+ addition_visual_token_pos = addition_visual_token_pos.unsqueeze(1).repeat(
+ 1, bs, 1
+ ) # H*W*num_view, bs, c
+ additional_pos_embed = additional_pos_embed.unsqueeze(1).repeat(
+ 1, bs, 1
+ ) # seq, bs, dim
+
+ # only consider the visual token from resnet for encoder, robot state + visual token
+ pos_embed = torch.cat([additional_pos_embed, pos_embed], axis=0)
+ src = torch.cat([addition_input, src], axis=0)
+
+ # For decoder
+ query_embed = query_embed.unsqueeze(1).repeat(1, bs, 1)
+ query_embed_token = query_embed_token.unsqueeze(1).repeat(
+ 1, bs, 1
+ ) # TODO temporal-spatial position embedding FOR predicted frame token
+ query_embed_all = torch.cat(
+ [query_embed, query_embed_token], axis=0
+ ) # chunksize + num_tokens, bs, c
+
+ denoise_step_pos_embed = denoise_step_pos_embed.unsqueeze(1).repeat(
+ 1, bs, 1
+ ) # 1 D -> 1 B D
+ else:
+ assert len(src.shape) == 3
+
+ # encoder TODO add visual token from tokenizer here?
+ memory = self.encoder(
+ src, src_key_padding_mask=mask, pos=pos_embed
+ ) # resnet visual feature + proprio + latent
+
+ # add denoise timestep embedding for decoder denoise step or current visual token
+ denoise_embed = self.time_embed(
+ get_timestep_embedding(denoise_steps, self.d_model)
+ ).unsqueeze(
+ 0
+ ) # B D -> 1 B D
+ if (
+ self.diffusion_timestep_type == "cat"
+ ): # only resnet visual token + proprio + latent + timestep embedding
+ memory = torch.cat([memory, denoise_embed], axis=0)
+ pos_embed = torch.cat([pos_embed, denoise_step_pos_embed], axis=0)
+ elif (
+ self.diffusion_timestep_type == "vis_cat"
+ ): # add visual token from tokenizer
+ memory = torch.cat(
+ [memory, addition_visual_token, denoise_embed], axis=0
+ ) # visual token map
+ pos_embed = torch.cat(
+ [pos_embed, addition_visual_token_pos, denoise_step_pos_embed], axis=0
+ ) # pos_embed_visual_token
+ elif self.diffusion_timestep_type == "add":
+ memory = memory + denoise_embed
+
+ # Noisy action and token
+ tgt_action = self.action_embed(noisy_actions).permute(1, 0, 2) # H B D
+ noisy_tokens = noisy_tokens.permute(
+ 0, 1, 2, 4, 5, 3
+ ) # B T' N D H' W' -> B T' N H' W' D
+ tgt_token = (
+ self.token_embed(noisy_tokens)
+ .reshape(bs, -1, self.d_model)
+ .permute(1, 0, 2)
+ ) # B T' N H' W' D -> T'*N*H'*W' B D
+
+ if self.share_decoder:
+ tgt = torch.cat([tgt_action, tgt_token], axis=0) # H1+ H2 B D
+ seq_len = tgt.shape[0]
+ bs = tgt.shape[1]
+ is_pad_token_zero = torch.zeros_like(is_pad_token).bool() # HARD CODE
+ if is_pad is None: # if action is pad
+ is_pad = torch.zeros(bs, self.chunksize).bool().to(src.device)
+ tgt_key_padding_mask = torch.cat(
+ [is_pad, is_pad_token_zero], axis=1
+ ) # HARD CODE avoid nan when all token is pad add causal mechanism
+ seq_len = tgt.shape[0]
+ tgt_mask = torch.zeros(seq_len, seq_len).to(tgt.device)
+ # TODO design mask
+ if self.attention_type == "v1":
+ tgt_mask[0 : self.chunksize, self.chunksize :] = float(
+ "-inf"
+ ) # action prediction cannot attend to token prediction
+ tgt_mask[self.chunksize :, self.predict_frame : self.chunksize] = float(
+ "-inf"
+ ) # token prediction cannot attend to action token after token prediction
+ elif self.attention_type == "v2":
+ tgt_mask[0 : self.chunksize, self.chunksize :] = float(
+ "-inf"
+ ) # action prediction cannot attend to token prediction
+ tgt_mask[self.chunksize :, 0 : self.chunksize] = float(
+ "-inf"
+ ) # token prediction cannot attend to action prediction
+ elif self.attention_type == "v3":
+ tgt_mask[0 : self.chunksize, self.chunksize :] = float(
+ "-inf"
+ ) # action prediction cannot attend to token prediction
+
+ hs = self.decoder(
+ tgt,
+ memory,
+ tgt_mask,
+ memory_key_padding_mask=mask,
+ tgt_key_padding_mask=tgt_key_padding_mask,
+ pos=pos_embed,
+ query_pos=query_embed_all,
+ ).transpose(1, 2)[
+ 0
+ ] # TODO 1 H B D
+ hs_action = hs[:, : self.chunksize] # B H D
+ hs_token = hs[:, self.chunksize :] # B H D
+ else:
+ hs_action = self.decoder(
+ tgt_action,
+ memory,
+ memory_key_padding_mask=mask,
+ tgt_key_padding_mask=is_pad,
+ pos=pos_embed,
+ query_pos=query_embed,
+ ).transpose(1, 2)[0]
+ hs_token = self.decoder_token(
+ tgt_token,
+ memory,
+ memory_key_padding_mask=mask, # tgt_key_padding_mask = is_pad_token, # is_pad_token is nonetype ! fix
+ pos=pos_embed,
+ query_pos=query_embed_token,
+ ).transpose(1, 2)[0]
+ return hs_action, hs_token
+
+
+class Transformer_diffusion_prediction_pixel(nn.Module):
+ def __init__(
+ self,
+ d_model=512,
+ nhead=8,
+ num_encoder_layers=6,
+ num_decoder_layers=6,
+ dim_feedforward=2048,
+ dropout=0.1,
+ activation="relu",
+ normalize_before=False,
+ return_intermediate_dec=False,
+ num_queries=100,
+ share_decoder=True,
+ patch_size=5,
+ diffusion_timestep_type="cat",
+ attention_type="v1",
+ predict_frame=16,
+ predict_only_last=False,
+ ):
+ super().__init__()
+
+ encoder_layer = TransformerEncoderLayer(
+ d_model, nhead, dim_feedforward, dropout, activation, normalize_before
+ )
+ encoder_norm = nn.LayerNorm(d_model) if normalize_before else None
+ self.encoder = TransformerEncoder(
+ encoder_layer, num_encoder_layers, encoder_norm
+ )
+
+ decoder_layer = TransformerDecoderLayer(
+ d_model, nhead, dim_feedforward, dropout, activation, normalize_before
+ )
+ decoder_norm = nn.LayerNorm(d_model)
+ self.decoder = TransformerDecoder(
+ decoder_layer,
+ num_decoder_layers,
+ decoder_norm,
+ return_intermediate=return_intermediate_dec,
+ )
+ if share_decoder == False:
+ self.decoder_token = copy.deepcopy(self.decoder)
+
+ self.share_decoder = share_decoder
+ self.diffusion_timestep_type = diffusion_timestep_type
+ self.time_embed = nn.Sequential(
+ nn.Linear(d_model, d_model),
+ nn.SiLU(),
+ nn.Linear(d_model, d_model),
+ )
+ self._reset_parameters()
+
+ self.action_embed = nn.Sequential(
+ nn.Linear(14, d_model),
+ nn.SiLU(),
+ nn.Linear(d_model, d_model),
+ )
+
+ self.token_embed = nn.Sequential(
+ nn.Linear(
+ 3 * patch_size * patch_size, d_model
+ ), # Hardcode patch size * path size * patch dim
+ nn.SiLU(),
+ nn.Linear(d_model, d_model),
+ )
+
+ self.d_model = d_model
+ self.nhead = nhead
+ self.chunksize = num_queries
+ self.attention_type = attention_type
+ self.predict_frame = predict_frame
+ self.predict_only_last = predict_only_last
+
+ def _reset_parameters(self):
+ for p in self.parameters():
+ if p.dim() > 1:
+ nn.init.xavier_uniform_(p)
+
+ # hs_action, hs_token = self.transformer(src, None, self.query_embed.weight, self.query_embed_toekn.weight, pos, latent_input, proprio_input, self.additional_pos_embed.weight, noisy_actions, noise_tokens, denoise_steps, self.denoise_step_pos_embed.weight)
+ def forward(
+ self,
+ src,
+ mask,
+ query_embed,
+ query_embed_token,
+ pos_embed,
+ latent_input=None,
+ proprio_input=None,
+ additional_pos_embed=None,
+ noisy_actions=None,
+ noisy_tokens=None,
+ denoise_steps=None,
+ denoise_step_pos_embed=None,
+ is_pad=None,
+ is_pad_token=None,
+ ):
+ # src: B D H W*num_view*T
+ # mask:None
+ # query_embed: chunksize D for action query
+ # query_embed_token: Num_tokens D for token query
+ # pos_embed: 1 D H W*num_view*T for current frame token
+ # latent_input: B D
+ # proprio_input: B T' D
+ # additional_pos_embed: B T'+1 D, include proprio and latent
+ # noisy_actions: B chunksize D
+ # noisy_tokens: B T' N D H' W', T' = predict_frame / temporal_compression_rate
+ # denoise_steps: B
+ # denoise_step_pos_embed:1 D
+ # is_pad: B chunksize
+ # is_pad_token: B T' N H' W'
+ if len(src.shape) == 4: # has H and W
+ bs, c, h, w = src.shape # B D H W*num_view*T
+ src = src.flatten(2).permute(
+ 2, 0, 1
+ ) # H*W*num_view*T, bs, c deal with visual features from resnet
+ # ?TODO check the shape of pos_embed
+ pos_embed = (
+ pos_embed.flatten(2).permute(2, 0, 1).repeat(1, bs, 1)
+ ) # H*W*num_view*T, bs, c
+
+ query_embed = query_embed.unsqueeze(1).repeat(1, bs, 1)
+ query_embed_token = query_embed_token.unsqueeze(1).repeat(
+ 1, bs, 1
+ ) # TODO temporal-spatial position embedding FOR predicted frame token
+ query_embed_all = torch.cat(
+ [query_embed, query_embed_token], axis=0
+ ) # chunksize + num_tokens, bs, c
+ additional_pos_embed = additional_pos_embed.unsqueeze(1).repeat(
+ 1, bs, 1
+ ) # seq, bs, dim
+ pos_embed = torch.cat([additional_pos_embed, pos_embed], axis=0)
+
+ latent_input = latent_input.unsqueeze(1) # B 1 D
+ addition_input = torch.cat([latent_input, proprio_input], axis=1).permute(
+ 1, 0, 2
+ ) # B T+1 D -> T+1 B D
+ src = torch.cat([addition_input, src], axis=0)
+
+ denoise_step_pos_embed = denoise_step_pos_embed.unsqueeze(1).repeat(
+ 1, bs, 1
+ ) # 1 D -> 1 B D
+ else:
+ assert len(src.shape) == 3
+
+ # encoder
+ memory = self.encoder(src, src_key_padding_mask=mask, pos=pos_embed)
+
+ # add denoise timestep embedding for decoder denoise step
+ denoise_embed = self.time_embed(
+ get_timestep_embedding(denoise_steps, self.d_model)
+ ).unsqueeze(
+ 0
+ ) # B D -> 1 B D
+ if (
+ self.diffusion_timestep_type == "cat"
+ ): # TODO add tokenizer visual token & timestep embedding
+ memory = torch.cat([memory, denoise_embed], axis=0)
+ pos_embed = torch.cat([pos_embed, denoise_step_pos_embed], axis=0)
+ # elif self.diffusion_timestep_type == 'vis_cat':
+ # memory = torch.cat([memory, visual_token, denoise_embed], axis=0) # visual token map
+ # pos_embed = torch.cat([pos_embed, pos_embed_visual_token, denoise_step_pos_embed], axis=0) # pos_embed_visual_token
+ else:
+ memory = memory + denoise_embed
+
+ tgt_action = self.action_embed(noisy_actions).permute(1, 0, 2) # H B D
+ noisy_tokens = noisy_tokens.permute(
+ 0, 1, 2, 4, 5, 3
+ ) # B T' N D H' W' -> B T' N H' W' D
+ tgt_token = (
+ self.token_embed(noisy_tokens)
+ .reshape(bs, -1, self.d_model)
+ .permute(1, 0, 2)
+ ) # B T' N H' W' D -> T'*N*H'*W' B D
+ # tgt_key_padding_mask = torch.zeros_like(tgt_token).sum(-1).bool() # T'*N*H'*W' B
+
+ if self.share_decoder:
+ tgt = torch.cat([tgt_action, tgt_token], axis=0) # H1+ H2 B D
+ seq_len = tgt.shape[0]
+ bs = tgt.shape[1]
+ # tgt_key_padding_mask = torch.cat([is_pad, is_pad_token], axis=1) # B N+M is
+ is_pad_token_zero = torch.zeros_like(is_pad_token).bool() # HARD CODE
+ if is_pad is None:
+ is_pad = torch.zeros(bs, self.chunksize).bool().to(src.device)
+ tgt_key_padding_mask = torch.cat(
+ [is_pad, is_pad_token_zero], axis=1
+ ) # HARD CODE avoid nan when all token is pad
+
+ seq_len = tgt.shape[0]
+ tgt_mask = torch.zeros(seq_len, seq_len).to(tgt.device)
+ # TODO design mask
+ if self.attention_type == "v1":
+ tgt_mask[0 : self.chunksize, self.chunksize :] = float(
+ "-inf"
+ ) # action prediction cannot attend to token prediction
+ tgt_mask[self.chunksize :, self.predict_frame : self.chunksize] = float(
+ "-inf"
+ ) # token prediction cannot attend to action token after token prediction
+ elif self.attention_type == "v2":
+ tgt_mask[0 : self.chunksize, self.chunksize :] = float(
+ "-inf"
+ ) # action prediction cannot attend to token prediction
+ tgt_mask[self.chunksize :, 0 : self.chunksize] = float(
+ "-inf"
+ ) # token prediction cannot attend to action prediction
+ elif self.attention_type == "v3":
+ tgt_mask[0 : self.chunksize, self.chunksize :] = float(
+ "-inf"
+ ) # action prediction cannot attend to token prediction
+
+ # print(tgt.shape, memory.shape, tgt_key_padding_mask.shape, pos_embed.shape, query_embed_all.shape)
+ hs = self.decoder(
+ tgt,
+ memory,
+ tgt_mask,
+ memory_key_padding_mask=mask,
+ tgt_key_padding_mask=tgt_key_padding_mask,
+ pos=pos_embed,
+ query_pos=query_embed_all,
+ ).transpose(1, 2)[
+ 0
+ ] # TODO 1 H B D
+ hs_action = hs[:, : self.chunksize] # B H D
+ hs_token = hs[:, self.chunksize :] # B H D
+ else:
+ hs_action = self.decoder(
+ tgt_action,
+ memory,
+ memory_key_padding_mask=mask,
+ tgt_key_padding_mask=is_pad,
+ pos=pos_embed,
+ query_pos=query_embed,
+ ).transpose(1, 2)[0]
+ hs_token = self.decoder_token(
+ tgt_token,
+ memory,
+ memory_key_padding_mask=mask, # tgt_key_padding_mask = is_pad_token, # is_pad_token is nonetype ! fix
+ pos=pos_embed,
+ query_pos=query_embed_token,
+ ).transpose(1, 2)[0]
+ # hs_action B chunksize D
+ # hs_token B T'*N*H'*W' D
+ # print("tgt_token has NaN:", torch.isnan(tgt_token).any())
+ # print("memory has NaN:", torch.isnan(memory).any())
+ # print('tgt_key_padding_mask has NaN:', torch.isnan(is_pad_token).any())
+ # print("pos_embed has NaN:", torch.isnan(pos_embed).any())
+ # print("query_embed_token has NaN:", torch.isnan(query_embed_token).any())
+ # print("hs_token has NaN:", torch.isnan(hs_token).any()) # has nan
+ return hs_action, hs_token
+
+
+class Transformer_diffusion(nn.Module):
+
+ def __init__(
+ self,
+ d_model=512,
+ nhead=8,
+ num_encoder_layers=6,
+ num_decoder_layers=6,
+ dim_feedforward=2048,
+ dropout=0.1,
+ activation="relu",
+ normalize_before=False,
+ return_intermediate_dec=False,
+ jpeg_dim=400,
+ num_jpeg=80,
+ ):
+ super().__init__()
+
+ encoder_layer = TransformerEncoderLayer(
+ d_model, nhead, dim_feedforward, dropout, activation, normalize_before
+ )
+ encoder_norm = nn.LayerNorm(d_model) if normalize_before else None
+ self.encoder = TransformerEncoder(
+ encoder_layer, num_encoder_layers, encoder_norm
+ )
+
+ decoder_layer = TransformerDecoderLayer(
+ d_model, nhead, dim_feedforward, dropout, activation, normalize_before
+ )
+ decoder_norm = nn.LayerNorm(d_model)
+ self.decoder = TransformerDecoder(
+ decoder_layer,
+ num_decoder_layers,
+ decoder_norm,
+ return_intermediate=return_intermediate_dec,
+ )
+ self.time_embed = nn.Sequential(
+ nn.Linear(d_model, d_model),
+ nn.SiLU(),
+ nn.Linear(d_model, d_model),
+ )
+ self.action_embed = nn.Sequential(
+ nn.Linear(14, d_model),
+ nn.SiLU(),
+ nn.Linear(d_model, d_model),
+ )
+ self.jpeg_embed = nn.Sequential(
+ nn.Linear(jpeg_dim, d_model),
+ nn.SiLU(),
+ nn.Linear(d_model, d_model),
+ )
+ self._reset_parameters()
+ self.d_model = d_model
+ self.nhead = nhead
+ self.num_jpeg = num_jpeg
+
+ def _reset_parameters(self):
+ for p in self.parameters():
+ if p.dim() > 1:
+ nn.init.xavier_uniform_(p)
+
+ def forward(
+ self,
+ src,
+ mask,
+ query_embed,
+ pos_embed,
+ latent_input=None,
+ proprio_input=None,
+ additional_pos_embed=None,
+ noisy_actions=None,
+ noisy_jpegs=None,
+ denoise_steps=None,
+ ):
+ # TODO flatten only when input has H and W
+ if len(src.shape) == 4: # has H and W
+ # flatten NxCxHxW to HWxNxC
+ bs, c, h, w = src.shape
+ src = src.flatten(2).permute(2, 0, 1)
+ pos_embed = pos_embed.flatten(2).permute(2, 0, 1).repeat(1, bs, 1)
+ query_embed = query_embed.unsqueeze(1).repeat(1, bs, 1)
+ # mask = mask.flatten(1)
+
+ additional_pos_embed = additional_pos_embed.unsqueeze(1).repeat(
+ 1, bs, 1
+ ) # seq, bs, dim
+ pos_embed = torch.cat([additional_pos_embed, pos_embed], axis=0)
+
+ addition_input = torch.stack([latent_input, proprio_input], axis=0)
+ src = torch.cat([addition_input, src], axis=0)
+ else:
+ assert len(src.shape) == 3
+ # flatten NxHWxC to HWxN
+ tgt_action = self.action_embed(noisy_actions).permute(1, 0, 2) # B H D
+ tgt_jpeg = self.jpeg_embed(noisy_jpegs).permute(1, 0, 2) # B H. D
+ tgt = torch.cat([tgt_action, tgt_jpeg], axis=0) # H1+ H2 B D
+
+ denoise_embed = get_timestep_embedding(denoise_steps, self.d_model) # B -> B D
+ denoise_embed = self.time_embed(denoise_embed)
+ memory = self.encoder(
+ src, src_key_padding_mask=mask, pos=pos_embed
+ ) # cross attentionc# TODO mak mask to focus on self-time-step
+ memory = memory + denoise_embed.unsqueeze(0)
+ seq_len = tgt.shape[0]
+ tgt_mask = torch.zeros(seq_len, seq_len).to(tgt.device)
+ tgt_mask[: -self.num_jpeg, -self.num_jpeg :] = float(
+ "-inf"
+ )
+ tgt_mask[-self.num_jpeg :, : -self.num_jpeg] = float("-inf") #
+ hs = self.decoder(
+ tgt,
+ memory,
+ tgt_mask,
+ memory_key_padding_mask=mask,
+ pos=pos_embed,
+ query_pos=query_embed,
+ ) # TODO 1 H B D
+ hs = hs.transpose(1, 2) # 1 B H D
+ return hs
+
+
+class Transformer_diffusion_seperate(nn.Module):
+
+ def __init__(
+ self,
+ d_model=512,
+ nhead=8,
+ num_encoder_layers=6,
+ num_decoder_layers=6,
+ dim_feedforward=2048,
+ dropout=0.1,
+ activation="relu",
+ normalize_before=False,
+ return_intermediate_dec=False,
+ share_decoder=False,
+ ):
+ super().__init__()
+
+ encoder_layer = TransformerEncoderLayer(
+ d_model, nhead, dim_feedforward, dropout, activation, normalize_before
+ )
+ encoder_norm = nn.LayerNorm(d_model) if normalize_before else None
+ self.encoder = TransformerEncoder(
+ encoder_layer, num_encoder_layers, encoder_norm
+ )
+
+ decoder_layer = TransformerDecoderLayer(
+ d_model, nhead, dim_feedforward, dropout, activation, normalize_before
+ )
+ decoder_norm = nn.LayerNorm(d_model)
+ self.decoder_action = TransformerDecoder(
+ decoder_layer,
+ num_decoder_layers,
+ decoder_norm,
+ return_intermediate=return_intermediate_dec,
+ )
+ # self.decoder_jpeg = TransformerDecoder(decoder_layer, num_decoder_layers, decoder_norm,
+ # return_intermediate=return_intermediate_dec)
+ if share_decoder:
+ self.decoder_jpeg = self.decoder_action
+ else:
+ self.decoder_jpeg = copy.deepcopy(self.decoder_action)
+
+ self.time_embed = nn.Sequential(
+ nn.Linear(d_model, d_model),
+ nn.SiLU(),
+ nn.Linear(d_model, d_model),
+ )
+ self.action_embed = nn.Sequential(
+ nn.Linear(14, d_model),
+ nn.SiLU(),
+ nn.Linear(d_model, d_model),
+ )
+
+ self._reset_parameters()
+ self.d_model = d_model
+ self.nhead = nhead
+
+ def _reset_parameters(self):
+ for p in self.parameters():
+ if p.dim() > 1:
+ nn.init.xavier_uniform_(p)
+
+ def forward(
+ self,
+ src,
+ mask,
+ query_embed_action=None,
+ query_embed_jpeg=None,
+ pos_embed=None,
+ latent_input=None,
+ proprio_input=None,
+ additional_pos_embed=None,
+ noisy_actions=None,
+ denoise_steps=None,
+ ):
+ # TODO flatten only when input has H and W
+ if len(src.shape) == 4: # has H and W
+ # flatten NxCxHxW to HWxNxC
+ bs, c, h, w = src.shape
+ src = src.flatten(2).permute(2, 0, 1)
+ pos_embed = pos_embed.flatten(2).permute(2, 0, 1).repeat(1, bs, 1)
+ query_embed_action = query_embed_action.unsqueeze(1).repeat(1, bs, 1)
+ query_embed_jpeg = query_embed_jpeg.unsqueeze(1).repeat(1, bs, 1)
+ # mask = mask.flatten(1)
+
+ additional_pos_embed = additional_pos_embed.unsqueeze(1).repeat(
+ 1, bs, 1
+ ) # seq, bs, dim
+ pos_embed = torch.cat([additional_pos_embed, pos_embed], axis=0)
+
+ addition_input = torch.stack([latent_input, proprio_input], axis=0)
+ src = torch.cat([addition_input, src], axis=0)
+ else:
+ assert len(src.shape) == 3
+ # flatten NxHWxC to HWxN
+
+ memory = self.encoder(
+ src, src_key_padding_mask=mask, pos=pos_embed
+ ) # cross attentionc# TODO mak mask to focus on self-time-step
+ denoise_embed = get_timestep_embedding(denoise_steps, self.d_model) # B -> B D
+ denoise_embed = self.time_embed(denoise_embed)
+ memory_action = memory + denoise_embed.unsqueeze(0)
+
+ tgt_action = self.action_embed(noisy_actions).permute(1, 0, 2) # B H D
+ # print('tgt_action', tgt_action.shape)
+ hs_action = self.decoder_action(
+ tgt_action,
+ memory_action,
+ memory_key_padding_mask=mask,
+ pos=pos_embed,
+ query_pos=query_embed_action,
+ )
+ # print('hs_action before', hs_action.shape)
+ hs_action = hs_action.transpose(1, 2) # 1 B H D
+ # print('hs_action after', hs_action.shape)
+ tgt_jpeg = torch.zeros_like(query_embed_jpeg).cuda()
+ # print('tgt_jpeg', tgt_jpeg.shape)
+ hs_jpeg = self.decoder_jpeg(
+ tgt_jpeg,
+ memory,
+ memory_key_padding_mask=mask,
+ pos=pos_embed,
+ query_pos=query_embed_jpeg,
+ )
+ hs_jpeg = hs_jpeg.transpose(1, 2) # 1 B H D
+
+ return hs_action[0], hs_jpeg[0]
+
+ def inference_only_actin(
+ self,
+ src,
+ mask,
+ query_embed_action=None,
+ pos_embed=None,
+ latent_input=None,
+ proprio_input=None,
+ additional_pos_embed=None,
+ noisy_actions=None,
+ denoise_steps=None,
+ ):
+ if len(src.shape) == 4: # has H and W
+ # flatten NxCxHxW to HWxNxC
+ bs, c, h, w = src.shape
+ src = src.flatten(2).permute(2, 0, 1)
+ pos_embed = pos_embed.flatten(2).permute(2, 0, 1).repeat(1, bs, 1)
+ query_embed_action = query_embed_action.unsqueeze(1).repeat(1, bs, 1)
+ additional_pos_embed = additional_pos_embed.unsqueeze(1).repeat(
+ 1, bs, 1
+ ) # seq, bs, dim
+ pos_embed = torch.cat([additional_pos_embed, pos_embed], axis=0)
+
+ addition_input = torch.stack([latent_input, proprio_input], axis=0)
+ src = torch.cat([addition_input, src], axis=0)
+ else:
+ assert len(src.shape) == 3
+
+ memory = self.encoder(
+ src, src_key_padding_mask=mask, pos=pos_embed
+ ) # cross attentionc# TODO mak mask to focus on self-time-step
+ denoise_embed = get_timestep_embedding(denoise_steps, self.d_model) # B -> B D
+ denoise_embed = self.time_embed(denoise_embed)
+ memory_action = memory + denoise_embed.unsqueeze(0)
+
+ tgt_action = self.action_embed(noisy_actions).permute(1, 0, 2) # B H D
+
+ hs_action = self.decoder_action(
+ tgt_action,
+ memory_action,
+ memory_key_padding_mask=mask,
+ pos=pos_embed,
+ query_pos=query_embed_action,
+ )
+ return hs_action, None
+
+
+class TransformerEncoder(nn.Module):
+
+ def __init__(self, encoder_layer, num_layers, norm=None):
+ super().__init__()
+ self.layers = _get_clones(encoder_layer, num_layers)
+ self.num_layers = num_layers
+ self.norm = norm
+
+ def forward(
+ self,
+ src,
+ mask: Optional[Tensor] = None,
+ src_key_padding_mask: Optional[Tensor] = None,
+ pos: Optional[Tensor] = None,
+ ):
+ output = src
+
+ for layer in self.layers:
+ output = layer(
+ output,
+ src_mask=mask,
+ src_key_padding_mask=src_key_padding_mask,
+ pos=pos,
+ )
+
+ if self.norm is not None:
+ output = self.norm(output)
+
+ return output
+
+class TransformerEncoder_AdLN(nn.Module):
+
+ def __init__(self, encoder_layer, num_layers, norm=None):
+ super().__init__()
+ self.layers = _get_clones(encoder_layer, num_layers)
+ self.num_layers = num_layers
+ self.norm = norm
+
+ def forward(
+ self,
+ src,
+ condition,
+ mask: Optional[Tensor] = None,
+ src_key_padding_mask: Optional[Tensor] = None,
+ pos: Optional[Tensor] = None,
+ ):
+ output = src
+
+ for layer in self.layers:
+ output = layer(
+ output,
+ condition,
+ src_mask=mask,
+ src_key_padding_mask=src_key_padding_mask,
+ pos=pos,
+ )
+
+ if self.norm is not None:
+ output = self.norm(output)
+
+ return output.unsqueeze(0) #1 T B D
+
+
+class TransformerDecoder(nn.Module):
+
+ def __init__(self, decoder_layer, num_layers, norm=None, return_intermediate=False):
+ super().__init__()
+ self.layers = _get_clones(decoder_layer, num_layers)
+ self.num_layers = num_layers
+ self.norm = norm
+ self.return_intermediate = return_intermediate
+
+ def forward(
+ self,
+ tgt,
+ memory,
+ tgt_mask: Optional[Tensor] = None,
+ memory_mask: Optional[Tensor] = None,
+ tgt_key_padding_mask: Optional[Tensor] = None,
+ memory_key_padding_mask: Optional[Tensor] = None,
+ pos: Optional[Tensor] = None,
+ query_pos: Optional[Tensor] = None,
+ ):
+ output = tgt
+
+ intermediate = []
+
+ for layer in self.layers:
+ output = layer(
+ output,
+ memory,
+ tgt_mask=tgt_mask,
+ memory_mask=memory_mask,
+ tgt_key_padding_mask=tgt_key_padding_mask,
+ memory_key_padding_mask=memory_key_padding_mask,
+ pos=pos,
+ query_pos=query_pos,
+ )
+ if self.return_intermediate:
+ intermediate.append(self.norm(output))
+
+ if self.norm is not None:
+ output = self.norm(output)
+ if self.return_intermediate:
+ intermediate.pop()
+ intermediate.append(output)
+
+ if self.return_intermediate:
+ return torch.stack(intermediate)
+
+ return output.unsqueeze(0)
+
+class TransformerDecoder_alter(nn.Module):
+
+ def __init__(self, decoder_layer, num_layers, norm=None, return_intermediate=False):
+ super().__init__()
+ self.layers = _get_clones(decoder_layer, num_layers)
+ self.num_layers = num_layers
+ self.norm = norm
+ self.return_intermediate = return_intermediate
+
+ def forward(
+ self,
+ tgt,
+ memory,
+ memory_alter,
+ tgt_mask: Optional[Tensor] = None,
+ memory_mask: Optional[Tensor] = None,
+ tgt_key_padding_mask: Optional[Tensor] = None,
+ memory_key_padding_mask: Optional[Tensor] = None,
+ pos: Optional[Tensor] = None,
+ pos_alter: Optional[Tensor] = None,
+ query_pos: Optional[Tensor] = None,
+ ):
+ output = tgt
+ memory_list = [memory, memory_alter]
+ pos_list = [pos, pos_alter]
+ intermediate = []
+ num_layer = 0
+ for layer in self.layers:
+ index = num_layer % 2
+ output = layer(
+ output,
+ memory_list[index],
+ tgt_mask=tgt_mask,
+ memory_mask=memory_mask,
+ tgt_key_padding_mask=tgt_key_padding_mask,
+ memory_key_padding_mask=memory_key_padding_mask,
+ pos=pos_list[index],
+ query_pos=query_pos,
+ )
+ if self.return_intermediate:
+ intermediate.append(self.norm(output))
+
+ if self.norm is not None:
+ output = self.norm(output)
+ if self.return_intermediate:
+ intermediate.pop()
+ intermediate.append(output)
+
+ if self.return_intermediate:
+ return torch.stack(intermediate)
+
+ return output.unsqueeze(0)
+
+
+class TransformerEncoderLayer(nn.Module):
+
+ def __init__(
+ self,
+ d_model,
+ nhead,
+ dim_feedforward=2048,
+ dropout=0.1,
+ activation="relu",
+ normalize_before=False,
+ ):
+ super().__init__()
+ self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
+ # Implementation of Feedforward model
+ self.linear1 = nn.Linear(d_model, dim_feedforward)
+ self.dropout = nn.Dropout(dropout)
+ self.linear2 = nn.Linear(dim_feedforward, d_model)
+
+ self.norm1 = nn.LayerNorm(d_model)
+ self.norm2 = nn.LayerNorm(d_model)
+ self.dropout1 = nn.Dropout(dropout)
+ self.dropout2 = nn.Dropout(dropout)
+
+ self.activation = _get_activation_fn(activation)
+ self.normalize_before = normalize_before
+
+ def with_pos_embed(self, tensor, pos: Optional[Tensor]):
+ return tensor if pos is None else tensor + pos
+
+ def forward_post(
+ self,
+ src,
+ src_mask: Optional[Tensor] = None,
+ src_key_padding_mask: Optional[Tensor] = None,
+ pos: Optional[Tensor] = None,
+ ):
+ q = k = self.with_pos_embed(src, pos)
+ src2 = self.self_attn(
+ q, k, value=src, attn_mask=src_mask, key_padding_mask=src_key_padding_mask
+ )[0]
+ src = src + self.dropout1(src2)
+ src = self.norm1(src)
+ src2 = self.linear2(self.dropout(self.activation(self.linear1(src))))
+ src = src + self.dropout2(src2)
+ src = self.norm2(src)
+ return src
+
+ def forward_pre(
+ self,
+ src,
+ src_mask: Optional[Tensor] = None,
+ src_key_padding_mask: Optional[Tensor] = None,
+ pos: Optional[Tensor] = None,
+ ):
+ src2 = self.norm1(src)
+ q = k = self.with_pos_embed(src2, pos)
+ src2 = self.self_attn(
+ q, k, value=src2, attn_mask=src_mask, key_padding_mask=src_key_padding_mask
+ )[0]
+ src = src + self.dropout1(src2)
+ src2 = self.norm2(src)
+ src2 = self.linear2(self.dropout(self.activation(self.linear1(src2))))
+ src = src + self.dropout2(src2)
+ return src
+
+ def forward(
+ self,
+ src,
+ src_mask: Optional[Tensor] = None,
+ src_key_padding_mask: Optional[Tensor] = None,
+ pos: Optional[Tensor] = None,
+ ):
+ if self.normalize_before:
+ return self.forward_pre(src, src_mask, src_key_padding_mask, pos)
+ return self.forward_post(src, src_mask, src_key_padding_mask, pos)
+
+
+def modulate(x, shift, scale):
+ return x * (1 + scale.unsqueeze(0)) + shift.unsqueeze(0)
+
+class TransformerEncoderLayer_AdLN(nn.Module):
+
+ def __init__(
+ self,
+ d_model,
+ nhead,
+ dim_feedforward=2048,
+ dropout=0.1,
+ activation="relu",
+ normalize_before=True,
+ ):
+ super().__init__()
+ self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
+ # Implementation of Feedforward model
+ self.linear1 = nn.Linear(d_model, dim_feedforward)
+ self.dropout = nn.Dropout(dropout)
+ self.linear2 = nn.Linear(dim_feedforward, d_model)
+
+ self.norm1 = nn.LayerNorm(d_model)
+ self.norm2 = nn.LayerNorm(d_model)
+ self.dropout1 = nn.Dropout(dropout)
+ self.dropout2 = nn.Dropout(dropout)
+
+ self.activation = _get_activation_fn(activation)
+ self.normalize_before = normalize_before
+
+ self.adaLN_modulation = nn.Sequential(
+ nn.SiLU(),
+ nn.Linear(d_model, 6 * d_model, bias=True)
+ ) # neccesary for adaLN
+
+ def with_pos_embed(self, tensor, pos: Optional[Tensor]):
+ return tensor if pos is None else tensor + pos
+
+ def forward_pre(
+ self,
+ src, # B T D
+ condition, # B D
+ src_mask: Optional[Tensor] = None,
+ src_key_padding_mask: Optional[Tensor] = None,
+ pos: Optional[Tensor] = None,
+ ):
+ shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(condition).chunk(6, dim=1) # B D
+
+ src2 = self.norm1(src)
+ src2 = modulate(src2, shift_msa, scale_msa)
+ q = k = self.with_pos_embed(src2, pos)
+ src2 = self.self_attn(
+ q, k, value=src2, attn_mask=src_mask, key_padding_mask=src_key_padding_mask
+ )[0]
+ src = src + self.dropout1(gate_msa.unsqueeze(0) *src2)
+
+ src2 = self.norm2(src)
+ src2 = modulate(src2, shift_mlp, scale_mlp)
+ src2 = self.linear2(self.dropout(self.activation(self.linear1(src2))))
+ src = src + self.dropout2(gate_mlp.unsqueeze(0) *src2)
+ return src
+
+ def forward(
+ self,
+ src,
+ condition,
+ src_mask: Optional[Tensor] = None,
+ src_key_padding_mask: Optional[Tensor] = None,
+ pos: Optional[Tensor] = None,
+ ):
+
+ return self.forward_pre(src, condition,src_mask, src_key_padding_mask, pos)
+
+
+
+class TransformerDecoderLayer(nn.Module):
+
+ def __init__(
+ self,
+ d_model,
+ nhead,
+ dim_feedforward=2048,
+ dropout=0.1,
+ activation="relu",
+ normalize_before=False,
+ ):
+ super().__init__()
+ self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
+ self.multihead_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
+ # Implementation of Feedforward model
+ self.linear1 = nn.Linear(d_model, dim_feedforward)
+ self.dropout = nn.Dropout(dropout)
+ self.linear2 = nn.Linear(dim_feedforward, d_model)
+
+ self.norm1 = nn.LayerNorm(d_model)
+ self.norm2 = nn.LayerNorm(d_model)
+ self.norm3 = nn.LayerNorm(d_model)
+ self.dropout1 = nn.Dropout(dropout)
+ self.dropout2 = nn.Dropout(dropout)
+ self.dropout3 = nn.Dropout(dropout)
+
+ self.activation = _get_activation_fn(activation)
+ self.normalize_before = normalize_before
+
+ def with_pos_embed(self, tensor, pos: Optional[Tensor]):
+ return tensor if pos is None else tensor + pos
+
+ def forward_post(
+ self,
+ tgt,
+ memory,
+ tgt_mask: Optional[Tensor] = None,
+ memory_mask: Optional[Tensor] = None,
+ tgt_key_padding_mask: Optional[Tensor] = None,
+ memory_key_padding_mask: Optional[Tensor] = None,
+ pos: Optional[Tensor] = None,
+ query_pos: Optional[Tensor] = None,
+ ):
+ q = k = self.with_pos_embed(tgt, query_pos)
+ tgt2 = self.self_attn(
+ q, k, value=tgt, attn_mask=tgt_mask, key_padding_mask=tgt_key_padding_mask
+ )[0]
+ tgt = tgt + self.dropout1(tgt2)
+ tgt = self.norm1(tgt)
+ tgt2 = self.multihead_attn(
+ query=self.with_pos_embed(tgt, query_pos),
+ key=self.with_pos_embed(memory, pos),
+ value=memory,
+ attn_mask=memory_mask,
+ key_padding_mask=memory_key_padding_mask,
+ )[0]
+ tgt = tgt + self.dropout2(tgt2)
+ tgt = self.norm2(tgt)
+ tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt))))
+ tgt = tgt + self.dropout3(tgt2)
+ tgt = self.norm3(tgt)
+ return tgt
+
+ def forward_pre(
+ self,
+ tgt,
+ memory,
+ tgt_mask: Optional[Tensor] = None,
+ memory_mask: Optional[Tensor] = None,
+ tgt_key_padding_mask: Optional[Tensor] = None,
+ memory_key_padding_mask: Optional[Tensor] = None,
+ pos: Optional[Tensor] = None,
+ query_pos: Optional[Tensor] = None,
+ ):
+ tgt2 = self.norm1(tgt)
+ q = k = self.with_pos_embed(tgt2, query_pos)
+ tgt2 = self.self_attn(
+ q, k, value=tgt2, attn_mask=tgt_mask, key_padding_mask=tgt_key_padding_mask
+ )[0]
+ tgt = tgt + self.dropout1(tgt2)
+ tgt2 = self.norm2(tgt)
+ tgt2 = self.multihead_attn(
+ query=self.with_pos_embed(tgt2, query_pos),
+ key=self.with_pos_embed(memory, pos),
+ value=memory,
+ attn_mask=memory_mask,
+ key_padding_mask=memory_key_padding_mask,
+ )[0]
+ tgt = tgt + self.dropout2(tgt2)
+ tgt2 = self.norm3(tgt)
+ tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt2))))
+ tgt = tgt + self.dropout3(tgt2)
+ return tgt
+
+ def forward(
+ self,
+ tgt,
+ memory,
+ tgt_mask: Optional[Tensor] = None,
+ memory_mask: Optional[Tensor] = None,
+ tgt_key_padding_mask: Optional[Tensor] = None,
+ memory_key_padding_mask: Optional[Tensor] = None,
+ pos: Optional[Tensor] = None,
+ query_pos: Optional[Tensor] = None,
+ ):
+ if self.normalize_before:
+ return self.forward_pre(
+ tgt,
+ memory,
+ tgt_mask,
+ memory_mask,
+ tgt_key_padding_mask,
+ memory_key_padding_mask,
+ pos,
+ query_pos,
+ )
+ return self.forward_post(
+ tgt,
+ memory,
+ tgt_mask,
+ memory_mask,
+ tgt_key_padding_mask,
+ memory_key_padding_mask,
+ pos,
+ query_pos,
+ )
+
+
+def _get_clones(module, N):
+ return nn.ModuleList([copy.deepcopy(module) for i in range(N)])
+
+
+def build_transformer(args):
+ return Transformer(
+ d_model=args.hidden_dim,
+ dropout=args.dropout,
+ nhead=args.nheads,
+ dim_feedforward=args.dim_feedforward,
+ num_encoder_layers=args.enc_layers,
+ num_decoder_layers=args.dec_layers,
+ normalize_before=args.pre_norm,
+ return_intermediate_dec=True,
+ )
+
+
+def build_transformer_denoise(args):
+ print(f"Using {args.condition_type} for condition")
+ print(f'enc_layers: {args.enc_layers}, dec_layers: {args.dec_layers}')
+ if args.condition_type == "adaLN":
+ return Transformer_Denoise_AdLN(
+ d_model=args.hidden_dim,
+ dropout=args.dropout,
+ nhead=args.nheads,
+ dim_feedforward=args.dim_feedforward,
+ num_encoder_layers=args.enc_layers,
+ num_decoder_layers=args.dec_layers,
+ normalize_before=args.pre_norm,
+ return_intermediate_dec=True,
+ )
+ return Transformer_Denoise(
+ d_model=args.hidden_dim,
+ dropout=args.dropout,
+ nhead=args.nheads,
+ dim_feedforward=args.dim_feedforward,
+ num_encoder_layers=args.enc_layers,
+ num_decoder_layers=args.dec_layers,
+ normalize_before=args.pre_norm,
+ return_intermediate_dec=True,
+ causal_mask=args.causal_mask,
+ )
+
+def build_transformer_denoise_tactile(args):
+ return Transformer_Denoise_Tactile(
+ d_model=args.hidden_dim,
+ dropout=args.dropout,
+ nhead=args.nheads,
+ dim_feedforward=args.dim_feedforward,
+ num_encoder_layers=args.enc_layers,
+ num_decoder_layers=args.dec_layers,
+ normalize_before=args.pre_norm,
+ return_intermediate_dec=True,
+ causal_mask=args.causal_mask,
+ )
+
+
+def build_transformer_diffusion_prediction(args):
+ return Transformer_diffusion_prediction(
+ d_model=args.hidden_dim,
+ dropout=args.dropout,
+ nhead=args.nheads,
+ dim_feedforward=args.dim_feedforward,
+ num_encoder_layers=args.enc_layers,
+ num_decoder_layers=args.dec_layers,
+ normalize_before=args.pre_norm,
+ return_intermediate_dec=True,
+ num_queries=args.num_queries,
+ share_decoder=args.share_decoder,
+ patch_size=args.patch_size,
+ diffusion_timestep_type=args.diffusion_timestep_type,
+ attention_type=args.attention_type,
+ predict_frame=args.predict_frame,
+ predict_only_last=args.predict_only_last,
+ token_dim=args.token_dim,
+ )
+
+
+def build_transformer_diffusion_prediction_with_dual_visual_token(args):
+ return Transformer_diffusion_prediction_with_dual_visual_token(
+ d_model=args.hidden_dim,
+ dropout=args.dropout,
+ nhead=args.nheads,
+ dim_feedforward=args.dim_feedforward,
+ num_encoder_layers=args.enc_layers,
+ num_decoder_layers=args.dec_layers,
+ normalize_before=args.pre_norm,
+ return_intermediate_dec=True,
+ num_queries=args.num_queries,
+ share_decoder=args.share_decoder,
+ patch_size=args.patch_size,
+ diffusion_timestep_type=args.diffusion_timestep_type,
+ attention_type=args.attention_type,
+ predict_frame=args.predict_frame,
+ predict_only_last=args.predict_only_last,
+ )
+
+
+def build_transformer_diffusion(args):
+ return Transformer_diffusion(
+ d_model=args.hidden_dim,
+ dropout=args.dropout,
+ nhead=args.nheads,
+ dim_feedforward=args.dim_feedforward,
+ num_encoder_layers=args.enc_layers,
+ num_decoder_layers=args.dec_layers,
+ normalize_before=args.pre_norm,
+ jpeg_dim=args.jpeg_dim, # Hard codes
+ num_jpeg=args.jpeg_token_num * args.predict_frame, # Hard codes
+ return_intermediate_dec=True,
+ )
+
+
+def build_transformer_diffusion_pixel_prediction(args):
+ return Transformer_diffusion_prediction_pixel(
+ d_model=args.hidden_dim,
+ dropout=args.dropout,
+ nhead=args.nheads,
+ dim_feedforward=args.dim_feedforward,
+ num_encoder_layers=args.enc_layers,
+ num_decoder_layers=args.dec_layers,
+ normalize_before=args.pre_norm,
+ return_intermediate_dec=True,
+ num_queries=args.num_queries,
+ share_decoder=args.share_decoder,
+ patch_size=args.patch_size,
+ diffusion_timestep_type=args.diffusion_timestep_type,
+ attention_type=args.attention_type,
+ predict_frame=args.predict_frame,
+ predict_only_last=args.predict_only_last,
+ )
+
+
+def build_transformer_diffusion_seperate(args):
+ return Transformer_diffusion_seperate(
+ d_model=args.hidden_dim,
+ dropout=args.dropout,
+ nhead=args.nheads,
+ dim_feedforward=args.dim_feedforward,
+ num_encoder_layers=args.enc_layers,
+ num_decoder_layers=args.dec_layers,
+ normalize_before=args.pre_norm,
+ share_decoder=args.share_decoder,
+ return_intermediate_dec=True,
+ )
+
+
+def _get_activation_fn(activation):
+ """Return an activation function given a string"""
+ if activation == "relu":
+ return F.relu
+ if activation == "gelu":
+ return F.gelu
+ if activation == "glu":
+ return F.glu
+ raise RuntimeError(f"activation should be relu/gelu, not {activation}.")
+
+
+def get_timestep_embedding(
+ timesteps: torch.Tensor,
+ embedding_dim: int,
+ flip_sin_to_cos=False,
+ downscale_freq_shift=1,
+ scale=1,
+ max_period=10000,
+):
+ """
+ This matches the implementation in Denoising Diffusion Probabilistic Models: Create sinusoidal timestep embeddings.
+
+ :param timesteps: a 1-D Tensor of N indices, one per batch element.
+ These may be fractional.
+ :param embedding_dim: the dimension of the output. :param max_period: controls the minimum frequency of the
+ embeddings. :return: an [N x dim] Tensor of positional embeddings.
+ """
+ assert len(timesteps.shape) == 1, "Timesteps should be a 1d-array"
+
+ half_dim = embedding_dim // 2
+ exponent = -math.log(max_period) * torch.arange(
+ start=0, end=half_dim, dtype=torch.float32, device=timesteps.device
+ )
+ exponent = exponent / (half_dim - downscale_freq_shift)
+
+ emb = torch.exp(exponent)
+ emb = timesteps[:, None].float() * emb[None, :]
+
+ # scale embeddings
+ emb = scale * emb
+
+ # concat sine and cosine embeddings
+ emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=-1)
+
+ # flip sine and cosine embeddings
+ if flip_sin_to_cos:
+ emb = torch.cat([emb[:, half_dim:], emb[:, :half_dim]], dim=-1)
+
+ # zero pad
+ if embedding_dim % 2 == 1:
+ emb = torch.nn.functional.pad(emb, (0, 1, 0, 0))
+
+ return emb
+
+class Tactile_ConvBlock(nn.Module):
+ def __init__(self, in_channels, out_channels):
+ super(Tactile_ConvBlock, self).__init__()
+ self.conv1 = nn.Conv2d(in_channels, out_channels//2, kernel_size=3, stride=3)
+ self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
+ self.conv2 = nn.Conv2d(out_channels//2, out_channels, kernel_size=5, stride=5)
+ self.pool2 = nn.AdaptiveAvgPool2d((4, 4))
+ self.activation = nn.SiLU()
+ self.bn1 = nn.BatchNorm2d(out_channels//2) # LN for conv1 output
+ self.bn2 = nn.BatchNorm2d(out_channels) # LN for conv2 output
+ def forward(self, x):
+ x = self.conv1(x)
+ x = self.bn1(x)
+ x = self.activation(x)
+ x = self.pool(x)
+ x = self.conv2(x)
+ x = self.bn2(x)
+ x = self.pool2(x)
+
+ return x
+
+class Tactile_Encoder(nn.Module):
+ def __init__(self,tactile_dim,dropout):
+ super().__init__()
+ self.conv_ll = Tactile_ConvBlock(3, tactile_dim)
+ self.conv_lr = Tactile_ConvBlock(3, tactile_dim)
+ self.conv_rl = Tactile_ConvBlock(3, tactile_dim)
+ self.conv_rr = Tactile_ConvBlock(3, tactile_dim)
+ self.feature_extractor = nn.ModuleList([
+ self.conv_ll,
+ self.conv_lr,
+ self.conv_rl,
+ self.conv_rr
+ ])
+ self.input_proj_tacile = nn.Conv2d(
+ tactile_dim,
+ tactile_dim,
+ kernel_size=1,
+ )
+
+ # self.query_tokens = nn.Parameter(torch.randn(16, tactile_dim))
+ # self.attn = nn.MultiheadAttention(embed_dim=tactile_dim, num_heads=8, dropout=dropout,batch_first=True)
+ self.tactile_dim = tactile_dim
+
+ def forward(self,tactile_data):
+ #B 4 C H W
+ tactile_data = tactile_data[:,0]
+ B = tactile_data.shape[0]
+ tactile_feature_list = []
+ for i in range(tactile_data.shape[1]):
+ tactile_feature = self.feature_extractor[i](tactile_data[:,i])
+ tactile_feature = self.input_proj_tacile(tactile_feature)
+ tactile_feature_list.append(tactile_feature)
+ tactile_features_raw = torch.stack(tactile_feature_list, dim=2) # B C 4 H W
+ # print('before attention tactile feature raw shape', tactile_features_raw.shape)
+ # tactile_features = tactile_features_raw.view(B, tactile_features_raw.size(1), -1) # B C 4*H*W
+ # print('before attention tactile feature shape', tactile_features.shape)
+ # tactile_features = tactile_features.permute(0, 2, 1) # B H*W*4 C
+ return tactile_features_raw # b d 4 h w
+ # print('before attention tactile feature shape', tactile_features.shape)
+ # query = self.query_tokens.unsqueeze(0).repeat(B, 1, 1)
+ # # print('query shape', query.shape)
+ # attn_output, _ = self.attn(query, tactile_features, tactile_features) # B 16 D
+ # return attn_output
+
+if __name__ == "__main__":
+ tactile_data = torch.randn(2, 1, 4, 3, 960, 960).cuda()
+ tactile_encoder = Tactile_Encoder(512, 0.1).cuda()
+ tactile_feature = tactile_encoder(tactile_data)
+ print(tactile_feature.shape) # B 16 D
+ total_params = sum(p.numel() for p in tactile_encoder.parameters() if p.requires_grad)
+ total_params_in_million = total_params / 1e6
+ print(f"Total trainable parameters: {total_params_in_million:.2f} M")
\ No newline at end of file
diff --git a/ACT_DP_multitask/detr/models/vision_transformer.py b/ACT_DP_multitask/detr/models/vision_transformer.py
new file mode 100644
index 0000000000000000000000000000000000000000..a1b718ea1bb1e616f29437db33d9140273c5f806
--- /dev/null
+++ b/ACT_DP_multitask/detr/models/vision_transformer.py
@@ -0,0 +1,397 @@
+# Copyright (c) Facebook, Inc. and its affiliates.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+Mostly copy-paste from timm library.
+https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py
+"""
+import math
+from functools import partial
+
+import torch
+import torch.nn as nn
+
+import numpy as np
+
+def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
+ """
+ embed_dim: output dimension for each position
+ pos: a list of positions to be encoded: size (M,)
+ out: (M, D)
+ """
+ assert embed_dim % 2 == 0
+ omega = np.arange(embed_dim // 2, dtype=np.float32)
+ omega /= embed_dim / 2.
+ omega = 1. / 10000**omega # (D/2,)
+
+ pos = pos.reshape(-1) # (M,)
+ out = np.einsum('m,d->md', pos, omega) # (M, D/2), outer product
+
+ emb_sin = np.sin(out) # (M, D/2)
+ emb_cos = np.cos(out) # (M, D/2)
+
+ emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D)
+ return emb
+
+
+def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
+ assert embed_dim % 2 == 0
+
+ # use half of dimensions to encode grid_h
+ emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2)
+ emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2)
+
+ emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D)
+ return emb
+
+
+def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False):
+ """
+ grid_size: int of the grid height and width
+ return:
+ pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
+ """
+ grid_h = np.arange(grid_size, dtype=np.float32)
+ grid_w = np.arange(grid_size, dtype=np.float32)
+ grid = np.meshgrid(grid_w, grid_h) # here w goes first
+ grid = np.stack(grid, axis=0)
+
+ grid = grid.reshape([2, 1, grid_size, grid_size])
+ pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
+ if cls_token:
+ pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0)
+ return pos_embed
+
+def get_2d_sincos_pos_embed_v2(embed_dim, grid_size, cls_token=False):
+ """
+ grid_size: int of the grid height and width
+ return:
+ pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
+ """
+ grid_h = np.arange(grid_size[0], dtype=np.float32)
+ grid_w = np.arange(grid_size[1], dtype=np.float32)
+ grid = np.meshgrid(grid_w, grid_h) # here w goes first
+ grid = np.stack(grid, axis=0)
+
+ grid = grid.reshape([2, 1, grid_size[0], grid_size[1]])
+ pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
+ if cls_token:
+ pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0)
+ return pos_embed
+
+# from utils import trunc_normal_
+
+def _no_grad_trunc_normal_(tensor, mean, std, a, b):
+ # Cut & paste from PyTorch official master until it's in a few official releases - RW
+ # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
+ def norm_cdf(x):
+ # Computes standard normal cumulative distribution function
+ return (1. + math.erf(x / math.sqrt(2.))) / 2.
+
+ if (mean < a - 2 * std) or (mean > b + 2 * std):
+ warnings.warn("mean is more than 2 std from [a, b] in nn.init.trunc_normal_. "
+ "The distribution of values may be incorrect.",
+ stacklevel=2)
+
+ with torch.no_grad():
+ # Values are generated by using a truncated uniform distribution and
+ # then using the inverse CDF for the normal distribution.
+ # Get upper and lower cdf values
+ l = norm_cdf((a - mean) / std)
+ u = norm_cdf((b - mean) / std)
+
+ # Uniformly fill tensor with values from [l, u], then translate to
+ # [2l-1, 2u-1].
+ tensor.uniform_(2 * l - 1, 2 * u - 1)
+
+ # Use inverse cdf transform for normal distribution to get truncated
+ # standard normal
+ tensor.erfinv_()
+
+ # Transform to proper mean, std
+ tensor.mul_(std * math.sqrt(2.))
+ tensor.add_(mean)
+
+ # Clamp to ensure it's in the proper range
+ tensor.clamp_(min=a, max=b)
+ return tensor
+
+def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.):
+ # type: (Tensor, float, float, float, float) -> Tensor
+ return _no_grad_trunc_normal_(tensor, mean, std, a, b)
+
+def drop_path(x, drop_prob: float = 0., training: bool = False):
+ if drop_prob == 0. or not training:
+ return x
+ keep_prob = 1 - drop_prob
+ shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
+ random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device)
+ random_tensor.floor_() # binarize
+ output = x.div(keep_prob) * random_tensor
+ return output
+
+
+class DropPath(nn.Module):
+ """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
+ """
+ def __init__(self, drop_prob=None):
+ super(DropPath, self).__init__()
+ self.drop_prob = drop_prob
+
+ def forward(self, x):
+ return drop_path(x, self.drop_prob, self.training)
+
+
+class Mlp(nn.Module):
+ def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
+ super().__init__()
+ out_features = out_features or in_features
+ hidden_features = hidden_features or in_features
+ self.fc1 = nn.Linear(in_features, hidden_features)
+ self.act = act_layer()
+ self.fc2 = nn.Linear(hidden_features, out_features)
+ self.drop = nn.Dropout(drop)
+
+ def forward(self, x):
+ x = self.fc1(x)
+ x = self.act(x)
+ x = self.drop(x)
+ x = self.fc2(x)
+ x = self.drop(x)
+ return x
+
+
+class Attention(nn.Module):
+ def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.):
+ super().__init__()
+ self.num_heads = num_heads
+ head_dim = dim // num_heads
+ self.scale = qk_scale or head_dim ** -0.5
+
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
+ self.attn_drop = nn.Dropout(attn_drop)
+ self.proj = nn.Linear(dim, dim)
+ self.proj_drop = nn.Dropout(proj_drop)
+
+ def forward(self, x):
+ B, N, C = x.shape
+ qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
+ q, k, v = qkv[0], qkv[1], qkv[2]
+
+ attn = (q @ k.transpose(-2, -1)) * self.scale
+ attn = attn.softmax(dim=-1)
+ attn = self.attn_drop(attn)
+
+ x = (attn @ v).transpose(1, 2).reshape(B, N, C)
+ x = self.proj(x)
+ x = self.proj_drop(x)
+ return x, attn
+
+
+class Block(nn.Module):
+ def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
+ drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm):
+ super().__init__()
+ self.norm1 = norm_layer(dim)
+ self.attn = Attention(
+ dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)
+ self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
+ self.norm2 = norm_layer(dim)
+ mlp_hidden_dim = int(dim * mlp_ratio)
+ self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
+
+ def forward(self, x, return_attention=False):
+ y, attn = self.attn(self.norm1(x))
+ if return_attention:
+ return attn
+ x = x + self.drop_path(y)
+ x = x + self.drop_path(self.mlp(self.norm2(x)))
+ return x
+
+
+class PatchEmbed(nn.Module):
+ """ Image to Patch Embedding
+ """
+ def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768):
+ super().__init__()
+ num_patches = (img_size // patch_size) * (img_size // patch_size)
+ self.img_size = img_size
+ self.patch_size = patch_size
+ self.num_patches = num_patches
+
+ self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
+
+ def forward(self, x):
+ B, C, H, W = x.shape
+ x = self.proj(x).flatten(2).transpose(1, 2)
+ return x
+
+
+class VisionTransformer(nn.Module):
+ """ Vision Transformer """
+ def __init__(self, img_size=[224], patch_size=16, in_chans=3, num_classes=0, embed_dim=768, depth=12,
+ num_heads=12, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop_rate=0., attn_drop_rate=0.,
+ drop_path_rate=0., norm_layer=nn.LayerNorm, **kwargs):
+ super().__init__()
+ self.num_features = self.embed_dim = embed_dim
+
+ self.patch_embed = PatchEmbed(
+ img_size=img_size[0], patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim)
+ num_patches = self.patch_embed.num_patches
+
+ self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
+ self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim))
+ self.pos_drop = nn.Dropout(p=drop_rate)
+
+ dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
+ self.blocks = nn.ModuleList([
+ Block(
+ dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
+ drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer)
+ for i in range(depth)])
+ self.norm = norm_layer(embed_dim)
+
+ # Classifier head
+ self.head = nn.Linear(embed_dim, num_classes) if num_classes > 0 else nn.Identity()
+
+ trunc_normal_(self.pos_embed, std=.02)
+ trunc_normal_(self.cls_token, std=.02)
+ self.apply(self._init_weights)
+
+ def _init_weights(self, m):
+ if isinstance(m, nn.Linear):
+ trunc_normal_(m.weight, std=.02)
+ if isinstance(m, nn.Linear) and m.bias is not None:
+ nn.init.constant_(m.bias, 0)
+ elif isinstance(m, nn.LayerNorm):
+ nn.init.constant_(m.bias, 0)
+ nn.init.constant_(m.weight, 1.0)
+
+ def interpolate_pos_encoding(self, x, w, h):
+ npatch = x.shape[1] - 1
+ N = self.pos_embed.shape[1] - 1
+ if npatch == N and w == h:
+ return self.pos_embed
+ class_pos_embed = self.pos_embed[:, 0]
+ patch_pos_embed = self.pos_embed[:, 1:]
+ dim = x.shape[-1]
+ w0 = w // self.patch_embed.patch_size
+ h0 = h // self.patch_embed.patch_size
+ # we add a small number to avoid floating point error in the interpolation
+ # see discussion at https://github.com/facebookresearch/dino/issues/8
+ w0, h0 = w0 + 0.1, h0 + 0.1
+ patch_pos_embed = nn.functional.interpolate(
+ patch_pos_embed.reshape(1, int(math.sqrt(N)), int(math.sqrt(N)), dim).permute(0, 3, 1, 2),
+ scale_factor=(w0 / math.sqrt(N), h0 / math.sqrt(N)),
+ mode='bicubic',
+ )
+ assert int(w0) == patch_pos_embed.shape[-2] and int(h0) == patch_pos_embed.shape[-1]
+ patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
+ return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1)
+
+ def prepare_tokens(self, x):
+ B, nc, w, h = x.shape
+ x = self.patch_embed(x) # patch linear embedding
+
+ # add the [CLS] token to the embed patch tokens
+ cls_tokens = self.cls_token.expand(B, -1, -1)
+ x = torch.cat((cls_tokens, x), dim=1)
+
+ # add positional encoding to each token
+ x = x + self.interpolate_pos_encoding(x, w, h)
+
+ return self.pos_drop(x)
+
+ def forward(self, x):
+ x = self.prepare_tokens(x)
+ for blk in self.blocks:
+ x = blk(x)
+ x = self.norm(x)
+ return x[:, 0], x[:, 1:]
+
+ def get_last_selfattention(self, x):
+ x = self.prepare_tokens(x)
+ for i, blk in enumerate(self.blocks):
+ if i < len(self.blocks) - 1:
+ x = blk(x)
+ else:
+ # return attention of the last block
+ return blk(x, return_attention=True)
+
+ def get_intermediate_layers(self, x, n=1):
+ x = self.prepare_tokens(x)
+ # we return the output tokens from the `n` last blocks
+ output = []
+ for i, blk in enumerate(self.blocks):
+ x = blk(x)
+ if len(self.blocks) - i <= n:
+ output.append(self.norm(x))
+ return output
+
+
+def vit_tiny(patch_size=16, **kwargs):
+ model = VisionTransformer(
+ patch_size=patch_size, embed_dim=192, depth=12, num_heads=3, mlp_ratio=4,
+ qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
+ return model
+
+
+def vit_small(patch_size=16, **kwargs):
+ model = VisionTransformer(
+ patch_size=patch_size, embed_dim=384, depth=12, num_heads=6, mlp_ratio=4,
+ qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
+ return model
+
+
+def vit_base(patch_size=16, **kwargs):
+ model = VisionTransformer(
+ patch_size=patch_size, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4,
+ qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
+ return model
+
+
+class DINOHead(nn.Module):
+ def __init__(self, in_dim, out_dim, use_bn=False, norm_last_layer=True, nlayers=3, hidden_dim=2048, bottleneck_dim=256):
+ super().__init__()
+ nlayers = max(nlayers, 1)
+ if nlayers == 1:
+ self.mlp = nn.Linear(in_dim, bottleneck_dim)
+ else:
+ layers = [nn.Linear(in_dim, hidden_dim)]
+ if use_bn:
+ layers.append(nn.BatchNorm1d(hidden_dim))
+ layers.append(nn.GELU())
+ for _ in range(nlayers - 2):
+ layers.append(nn.Linear(hidden_dim, hidden_dim))
+ if use_bn:
+ layers.append(nn.BatchNorm1d(hidden_dim))
+ layers.append(nn.GELU())
+ layers.append(nn.Linear(hidden_dim, bottleneck_dim))
+ self.mlp = nn.Sequential(*layers)
+ self.apply(self._init_weights)
+ self.last_layer = nn.utils.weight_norm(nn.Linear(bottleneck_dim, out_dim, bias=False))
+ self.last_layer.weight_g.data.fill_(1)
+ if norm_last_layer:
+ self.last_layer.weight_g.requires_grad = False
+
+ def _init_weights(self, m):
+ if isinstance(m, nn.Linear):
+ trunc_normal_(m.weight, std=.02)
+ if isinstance(m, nn.Linear) and m.bias is not None:
+ nn.init.constant_(m.bias, 0)
+
+ def forward(self, x):
+ x = self.mlp(x)
+ x = nn.functional.normalize(x, dim=-1, p=2)
+ x = self.last_layer(x)
+ return x
\ No newline at end of file
diff --git a/ACT_DP_multitask/detr/setup.py b/ACT_DP_multitask/detr/setup.py
new file mode 100644
index 0000000000000000000000000000000000000000..55d18c0db74e27a687b6d6e0fb236e1cbb801f20
--- /dev/null
+++ b/ACT_DP_multitask/detr/setup.py
@@ -0,0 +1,10 @@
+from distutils.core import setup
+from setuptools import find_packages
+
+setup(
+ name='detr',
+ version='0.0.0',
+ packages=find_packages(),
+ license='MIT License',
+ long_description=open('README.md').read(),
+)
\ No newline at end of file
diff --git a/ACT_DP_multitask/detr/util/__init__.py b/ACT_DP_multitask/detr/util/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..168f9979a4623806934b0ff1102ac166704e7dec
--- /dev/null
+++ b/ACT_DP_multitask/detr/util/__init__.py
@@ -0,0 +1 @@
+# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
diff --git a/ACT_DP_multitask/detr/util/__pycache__/__init__.cpython-310.pyc b/ACT_DP_multitask/detr/util/__pycache__/__init__.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..2986ac876cdf748c0006ae0c4fde5b5f75d5b06b
Binary files /dev/null and b/ACT_DP_multitask/detr/util/__pycache__/__init__.cpython-310.pyc differ
diff --git a/ACT_DP_multitask/detr/util/__pycache__/__init__.cpython-37.pyc b/ACT_DP_multitask/detr/util/__pycache__/__init__.cpython-37.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..26e2fbb6146ceee5b9f484f1cc99f7386957e270
Binary files /dev/null and b/ACT_DP_multitask/detr/util/__pycache__/__init__.cpython-37.pyc differ
diff --git a/ACT_DP_multitask/detr/util/__pycache__/misc.cpython-310.pyc b/ACT_DP_multitask/detr/util/__pycache__/misc.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..8cf42af5f3bf9b6fd1b97f8425bd254978913183
Binary files /dev/null and b/ACT_DP_multitask/detr/util/__pycache__/misc.cpython-310.pyc differ
diff --git a/ACT_DP_multitask/detr/util/__pycache__/misc.cpython-37.pyc b/ACT_DP_multitask/detr/util/__pycache__/misc.cpython-37.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..68dad05c8123660ccb4dee85cd343ba35ae4199a
Binary files /dev/null and b/ACT_DP_multitask/detr/util/__pycache__/misc.cpython-37.pyc differ
diff --git a/ACT_DP_multitask/detr/util/box_ops.py b/ACT_DP_multitask/detr/util/box_ops.py
new file mode 100644
index 0000000000000000000000000000000000000000..9c088e5bacc88ff7217fc971f5db889f5bb45b39
--- /dev/null
+++ b/ACT_DP_multitask/detr/util/box_ops.py
@@ -0,0 +1,88 @@
+# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
+"""
+Utilities for bounding box manipulation and GIoU.
+"""
+import torch
+from torchvision.ops.boxes import box_area
+
+
+def box_cxcywh_to_xyxy(x):
+ x_c, y_c, w, h = x.unbind(-1)
+ b = [(x_c - 0.5 * w), (y_c - 0.5 * h),
+ (x_c + 0.5 * w), (y_c + 0.5 * h)]
+ return torch.stack(b, dim=-1)
+
+
+def box_xyxy_to_cxcywh(x):
+ x0, y0, x1, y1 = x.unbind(-1)
+ b = [(x0 + x1) / 2, (y0 + y1) / 2,
+ (x1 - x0), (y1 - y0)]
+ return torch.stack(b, dim=-1)
+
+
+# modified from torchvision to also return the union
+def box_iou(boxes1, boxes2):
+ area1 = box_area(boxes1)
+ area2 = box_area(boxes2)
+
+ lt = torch.max(boxes1[:, None, :2], boxes2[:, :2]) # [N,M,2]
+ rb = torch.min(boxes1[:, None, 2:], boxes2[:, 2:]) # [N,M,2]
+
+ wh = (rb - lt).clamp(min=0) # [N,M,2]
+ inter = wh[:, :, 0] * wh[:, :, 1] # [N,M]
+
+ union = area1[:, None] + area2 - inter
+
+ iou = inter / union
+ return iou, union
+
+
+def generalized_box_iou(boxes1, boxes2):
+ """
+ Generalized IoU from https://giou.stanford.edu/
+
+ The boxes should be in [x0, y0, x1, y1] format
+
+ Returns a [N, M] pairwise matrix, where N = len(boxes1)
+ and M = len(boxes2)
+ """
+ # degenerate boxes gives inf / nan results
+ # so do an early check
+ assert (boxes1[:, 2:] >= boxes1[:, :2]).all()
+ assert (boxes2[:, 2:] >= boxes2[:, :2]).all()
+ iou, union = box_iou(boxes1, boxes2)
+
+ lt = torch.min(boxes1[:, None, :2], boxes2[:, :2])
+ rb = torch.max(boxes1[:, None, 2:], boxes2[:, 2:])
+
+ wh = (rb - lt).clamp(min=0) # [N,M,2]
+ area = wh[:, :, 0] * wh[:, :, 1]
+
+ return iou - (area - union) / area
+
+
+def masks_to_boxes(masks):
+ """Compute the bounding boxes around the provided masks
+
+ The masks should be in format [N, H, W] where N is the number of masks, (H, W) are the spatial dimensions.
+
+ Returns a [N, 4] tensors, with the boxes in xyxy format
+ """
+ if masks.numel() == 0:
+ return torch.zeros((0, 4), device=masks.device)
+
+ h, w = masks.shape[-2:]
+
+ y = torch.arange(0, h, dtype=torch.float)
+ x = torch.arange(0, w, dtype=torch.float)
+ y, x = torch.meshgrid(y, x)
+
+ x_mask = (masks * x.unsqueeze(0))
+ x_max = x_mask.flatten(1).max(-1)[0]
+ x_min = x_mask.masked_fill(~(masks.bool()), 1e8).flatten(1).min(-1)[0]
+
+ y_mask = (masks * y.unsqueeze(0))
+ y_max = y_mask.flatten(1).max(-1)[0]
+ y_min = y_mask.masked_fill(~(masks.bool()), 1e8).flatten(1).min(-1)[0]
+
+ return torch.stack([x_min, y_min, x_max, y_max], 1)
diff --git a/ACT_DP_multitask/detr/util/misc.py b/ACT_DP_multitask/detr/util/misc.py
new file mode 100644
index 0000000000000000000000000000000000000000..dfa9fb5b8f9e44c98e42aa9bb7275f6fa151472d
--- /dev/null
+++ b/ACT_DP_multitask/detr/util/misc.py
@@ -0,0 +1,468 @@
+# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
+"""
+Misc functions, including distributed helpers.
+
+Mostly copy-paste from torchvision references.
+"""
+import os
+import subprocess
+import time
+from collections import defaultdict, deque
+import datetime
+import pickle
+from packaging import version
+from typing import Optional, List
+
+import torch
+import torch.distributed as dist
+from torch import Tensor
+
+# needed due to empty tensor bug in pytorch and torchvision 0.5
+import torchvision
+if version.parse(torchvision.__version__) < version.parse('0.7'):
+ from torchvision.ops import _new_empty_tensor
+ from torchvision.ops.misc import _output_size
+
+
+class SmoothedValue(object):
+ """Track a series of values and provide access to smoothed values over a
+ window or the global series average.
+ """
+
+ def __init__(self, window_size=20, fmt=None):
+ if fmt is None:
+ fmt = "{median:.4f} ({global_avg:.4f})"
+ self.deque = deque(maxlen=window_size)
+ self.total = 0.0
+ self.count = 0
+ self.fmt = fmt
+
+ def update(self, value, n=1):
+ self.deque.append(value)
+ self.count += n
+ self.total += value * n
+
+ def synchronize_between_processes(self):
+ """
+ Warning: does not synchronize the deque!
+ """
+ if not is_dist_avail_and_initialized():
+ return
+ t = torch.tensor([self.count, self.total], dtype=torch.float64, device='cuda')
+ dist.barrier()
+ dist.all_reduce(t)
+ t = t.tolist()
+ self.count = int(t[0])
+ self.total = t[1]
+
+ @property
+ def median(self):
+ d = torch.tensor(list(self.deque))
+ return d.median().item()
+
+ @property
+ def avg(self):
+ d = torch.tensor(list(self.deque), dtype=torch.float32)
+ return d.mean().item()
+
+ @property
+ def global_avg(self):
+ return self.total / self.count
+
+ @property
+ def max(self):
+ return max(self.deque)
+
+ @property
+ def value(self):
+ return self.deque[-1]
+
+ def __str__(self):
+ return self.fmt.format(
+ median=self.median,
+ avg=self.avg,
+ global_avg=self.global_avg,
+ max=self.max,
+ value=self.value)
+
+
+def all_gather(data):
+ """
+ Run all_gather on arbitrary picklable data (not necessarily tensors)
+ Args:
+ data: any picklable object
+ Returns:
+ list[data]: list of data gathered from each rank
+ """
+ world_size = get_world_size()
+ if world_size == 1:
+ return [data]
+
+ # serialized to a Tensor
+ buffer = pickle.dumps(data)
+ storage = torch.ByteStorage.from_buffer(buffer)
+ tensor = torch.ByteTensor(storage).to("cuda")
+
+ # obtain Tensor size of each rank
+ local_size = torch.tensor([tensor.numel()], device="cuda")
+ size_list = [torch.tensor([0], device="cuda") for _ in range(world_size)]
+ dist.all_gather(size_list, local_size)
+ size_list = [int(size.item()) for size in size_list]
+ max_size = max(size_list)
+
+ # receiving Tensor from all ranks
+ # we pad the tensor because torch all_gather does not support
+ # gathering tensors of different shapes
+ tensor_list = []
+ for _ in size_list:
+ tensor_list.append(torch.empty((max_size,), dtype=torch.uint8, device="cuda"))
+ if local_size != max_size:
+ padding = torch.empty(size=(max_size - local_size,), dtype=torch.uint8, device="cuda")
+ tensor = torch.cat((tensor, padding), dim=0)
+ dist.all_gather(tensor_list, tensor)
+
+ data_list = []
+ for size, tensor in zip(size_list, tensor_list):
+ buffer = tensor.cpu().numpy().tobytes()[:size]
+ data_list.append(pickle.loads(buffer))
+
+ return data_list
+
+
+def reduce_dict(input_dict, average=True):
+ """
+ Args:
+ input_dict (dict): all the values will be reduced
+ average (bool): whether to do average or sum
+ Reduce the values in the dictionary from all processes so that all processes
+ have the averaged results. Returns a dict with the same fields as
+ input_dict, after reduction.
+ """
+ world_size = get_world_size()
+ if world_size < 2:
+ return input_dict
+ with torch.no_grad():
+ names = []
+ values = []
+ # sort the keys so that they are consistent across processes
+ for k in sorted(input_dict.keys()):
+ names.append(k)
+ values.append(input_dict[k])
+ values = torch.stack(values, dim=0)
+ dist.all_reduce(values)
+ if average:
+ values /= world_size
+ reduced_dict = {k: v for k, v in zip(names, values)}
+ return reduced_dict
+
+
+class MetricLogger(object):
+ def __init__(self, delimiter="\t"):
+ self.meters = defaultdict(SmoothedValue)
+ self.delimiter = delimiter
+
+ def update(self, **kwargs):
+ for k, v in kwargs.items():
+ if isinstance(v, torch.Tensor):
+ v = v.item()
+ assert isinstance(v, (float, int))
+ self.meters[k].update(v)
+
+ def __getattr__(self, attr):
+ if attr in self.meters:
+ return self.meters[attr]
+ if attr in self.__dict__:
+ return self.__dict__[attr]
+ raise AttributeError("'{}' object has no attribute '{}'".format(
+ type(self).__name__, attr))
+
+ def __str__(self):
+ loss_str = []
+ for name, meter in self.meters.items():
+ loss_str.append(
+ "{}: {}".format(name, str(meter))
+ )
+ return self.delimiter.join(loss_str)
+
+ def synchronize_between_processes(self):
+ for meter in self.meters.values():
+ meter.synchronize_between_processes()
+
+ def add_meter(self, name, meter):
+ self.meters[name] = meter
+
+ def log_every(self, iterable, print_freq, header=None):
+ i = 0
+ if not header:
+ header = ''
+ start_time = time.time()
+ end = time.time()
+ iter_time = SmoothedValue(fmt='{avg:.4f}')
+ data_time = SmoothedValue(fmt='{avg:.4f}')
+ space_fmt = ':' + str(len(str(len(iterable)))) + 'd'
+ if torch.cuda.is_available():
+ log_msg = self.delimiter.join([
+ header,
+ '[{0' + space_fmt + '}/{1}]',
+ 'eta: {eta}',
+ '{meters}',
+ 'time: {time}',
+ 'data: {data}',
+ 'max mem: {memory:.0f}'
+ ])
+ else:
+ log_msg = self.delimiter.join([
+ header,
+ '[{0' + space_fmt + '}/{1}]',
+ 'eta: {eta}',
+ '{meters}',
+ 'time: {time}',
+ 'data: {data}'
+ ])
+ MB = 1024.0 * 1024.0
+ for obj in iterable:
+ data_time.update(time.time() - end)
+ yield obj
+ iter_time.update(time.time() - end)
+ if i % print_freq == 0 or i == len(iterable) - 1:
+ eta_seconds = iter_time.global_avg * (len(iterable) - i)
+ eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))
+ if torch.cuda.is_available():
+ print(log_msg.format(
+ i, len(iterable), eta=eta_string,
+ meters=str(self),
+ time=str(iter_time), data=str(data_time),
+ memory=torch.cuda.max_memory_allocated() / MB))
+ else:
+ print(log_msg.format(
+ i, len(iterable), eta=eta_string,
+ meters=str(self),
+ time=str(iter_time), data=str(data_time)))
+ i += 1
+ end = time.time()
+ total_time = time.time() - start_time
+ total_time_str = str(datetime.timedelta(seconds=int(total_time)))
+ print('{} Total time: {} ({:.4f} s / it)'.format(
+ header, total_time_str, total_time / len(iterable)))
+
+
+def get_sha():
+ cwd = os.path.dirname(os.path.abspath(__file__))
+
+ def _run(command):
+ return subprocess.check_output(command, cwd=cwd).decode('ascii').strip()
+ sha = 'N/A'
+ diff = "clean"
+ branch = 'N/A'
+ try:
+ sha = _run(['git', 'rev-parse', 'HEAD'])
+ subprocess.check_output(['git', 'diff'], cwd=cwd)
+ diff = _run(['git', 'diff-index', 'HEAD'])
+ diff = "has uncommited changes" if diff else "clean"
+ branch = _run(['git', 'rev-parse', '--abbrev-ref', 'HEAD'])
+ except Exception:
+ pass
+ message = f"sha: {sha}, status: {diff}, branch: {branch}"
+ return message
+
+
+def collate_fn(batch):
+ batch = list(zip(*batch))
+ batch[0] = nested_tensor_from_tensor_list(batch[0])
+ return tuple(batch)
+
+
+def _max_by_axis(the_list):
+ # type: (List[List[int]]) -> List[int]
+ maxes = the_list[0]
+ for sublist in the_list[1:]:
+ for index, item in enumerate(sublist):
+ maxes[index] = max(maxes[index], item)
+ return maxes
+
+
+class NestedTensor(object):
+ def __init__(self, tensors, mask: Optional[Tensor]):
+ self.tensors = tensors
+ self.mask = mask
+
+ def to(self, device):
+ # type: (Device) -> NestedTensor # noqa
+ cast_tensor = self.tensors.to(device)
+ mask = self.mask
+ if mask is not None:
+ assert mask is not None
+ cast_mask = mask.to(device)
+ else:
+ cast_mask = None
+ return NestedTensor(cast_tensor, cast_mask)
+
+ def decompose(self):
+ return self.tensors, self.mask
+
+ def __repr__(self):
+ return str(self.tensors)
+
+
+def nested_tensor_from_tensor_list(tensor_list: List[Tensor]):
+ # TODO make this more general
+ if tensor_list[0].ndim == 3:
+ if torchvision._is_tracing():
+ # nested_tensor_from_tensor_list() does not export well to ONNX
+ # call _onnx_nested_tensor_from_tensor_list() instead
+ return _onnx_nested_tensor_from_tensor_list(tensor_list)
+
+ # TODO make it support different-sized images
+ max_size = _max_by_axis([list(img.shape) for img in tensor_list])
+ # min_size = tuple(min(s) for s in zip(*[img.shape for img in tensor_list]))
+ batch_shape = [len(tensor_list)] + max_size
+ b, c, h, w = batch_shape
+ dtype = tensor_list[0].dtype
+ device = tensor_list[0].device
+ tensor = torch.zeros(batch_shape, dtype=dtype, device=device)
+ mask = torch.ones((b, h, w), dtype=torch.bool, device=device)
+ for img, pad_img, m in zip(tensor_list, tensor, mask):
+ pad_img[: img.shape[0], : img.shape[1], : img.shape[2]].copy_(img)
+ m[: img.shape[1], :img.shape[2]] = False
+ else:
+ raise ValueError('not supported')
+ return NestedTensor(tensor, mask)
+
+
+# _onnx_nested_tensor_from_tensor_list() is an implementation of
+# nested_tensor_from_tensor_list() that is supported by ONNX tracing.
+@torch.jit.unused
+def _onnx_nested_tensor_from_tensor_list(tensor_list: List[Tensor]) -> NestedTensor:
+ max_size = []
+ for i in range(tensor_list[0].dim()):
+ max_size_i = torch.max(torch.stack([img.shape[i] for img in tensor_list]).to(torch.float32)).to(torch.int64)
+ max_size.append(max_size_i)
+ max_size = tuple(max_size)
+
+ # work around for
+ # pad_img[: img.shape[0], : img.shape[1], : img.shape[2]].copy_(img)
+ # m[: img.shape[1], :img.shape[2]] = False
+ # which is not yet supported in onnx
+ padded_imgs = []
+ padded_masks = []
+ for img in tensor_list:
+ padding = [(s1 - s2) for s1, s2 in zip(max_size, tuple(img.shape))]
+ padded_img = torch.nn.functional.pad(img, (0, padding[2], 0, padding[1], 0, padding[0]))
+ padded_imgs.append(padded_img)
+
+ m = torch.zeros_like(img[0], dtype=torch.int, device=img.device)
+ padded_mask = torch.nn.functional.pad(m, (0, padding[2], 0, padding[1]), "constant", 1)
+ padded_masks.append(padded_mask.to(torch.bool))
+
+ tensor = torch.stack(padded_imgs)
+ mask = torch.stack(padded_masks)
+
+ return NestedTensor(tensor, mask=mask)
+
+
+def setup_for_distributed(is_master):
+ """
+ This function disables printing when not in master process
+ """
+ import builtins as __builtin__
+ builtin_print = __builtin__.print
+
+ def print(*args, **kwargs):
+ force = kwargs.pop('force', False)
+ if is_master or force:
+ builtin_print(*args, **kwargs)
+
+ __builtin__.print = print
+
+
+def is_dist_avail_and_initialized():
+ if not dist.is_available():
+ return False
+ if not dist.is_initialized():
+ return False
+ return True
+
+
+def get_world_size():
+ if not is_dist_avail_and_initialized():
+ return 1
+ return dist.get_world_size()
+
+
+def get_rank():
+ if not is_dist_avail_and_initialized():
+ return 0
+ return dist.get_rank()
+
+
+def is_main_process():
+ return get_rank() == 0
+
+
+def save_on_master(*args, **kwargs):
+ if is_main_process():
+ torch.save(*args, **kwargs)
+
+
+def init_distributed_mode(args):
+ if 'RANK' in os.environ and 'WORLD_SIZE' in os.environ:
+ args.rank = int(os.environ["RANK"])
+ args.world_size = int(os.environ['WORLD_SIZE'])
+ args.gpu = int(os.environ['LOCAL_RANK'])
+ elif 'SLURM_PROCID' in os.environ:
+ args.rank = int(os.environ['SLURM_PROCID'])
+ args.gpu = args.rank % torch.cuda.device_count()
+ else:
+ print('Not using distributed mode')
+ args.distributed = False
+ return
+
+ args.distributed = True
+
+ torch.cuda.set_device(args.gpu)
+ args.dist_backend = 'nccl'
+ print('| distributed init (rank {}): {}'.format(
+ args.rank, args.dist_url), flush=True)
+ torch.distributed.init_process_group(backend=args.dist_backend, init_method=args.dist_url,
+ world_size=args.world_size, rank=args.rank)
+ torch.distributed.barrier()
+ setup_for_distributed(args.rank == 0)
+
+
+@torch.no_grad()
+def accuracy(output, target, topk=(1,)):
+ """Computes the precision@k for the specified values of k"""
+ if target.numel() == 0:
+ return [torch.zeros([], device=output.device)]
+ maxk = max(topk)
+ batch_size = target.size(0)
+
+ _, pred = output.topk(maxk, 1, True, True)
+ pred = pred.t()
+ correct = pred.eq(target.view(1, -1).expand_as(pred))
+
+ res = []
+ for k in topk:
+ correct_k = correct[:k].view(-1).float().sum(0)
+ res.append(correct_k.mul_(100.0 / batch_size))
+ return res
+
+
+def interpolate(input, size=None, scale_factor=None, mode="nearest", align_corners=None):
+ # type: (Tensor, Optional[List[int]], Optional[float], str, Optional[bool]) -> Tensor
+ """
+ Equivalent to nn.functional.interpolate, but with support for empty batch sizes.
+ This will eventually be supported natively by PyTorch, and this
+ class can go away.
+ """
+ if version.parse(torchvision.__version__) < version.parse('0.7'):
+ if input.numel() > 0:
+ return torch.nn.functional.interpolate(
+ input, size, scale_factor, mode, align_corners
+ )
+
+ output_shape = _output_size(2, input, size, scale_factor)
+ output_shape = list(input.shape[:-2]) + list(output_shape)
+ return _new_empty_tensor(input, output_shape)
+ else:
+ return torchvision.ops.misc.interpolate(input, size, scale_factor, mode, align_corners)
diff --git a/ACT_DP_multitask/detr/util/plot_utils.py b/ACT_DP_multitask/detr/util/plot_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..0f24bed0d3fe4624aeb231ddd02633f2e58e4bff
--- /dev/null
+++ b/ACT_DP_multitask/detr/util/plot_utils.py
@@ -0,0 +1,107 @@
+"""
+Plotting utilities to visualize training logs.
+"""
+import torch
+import pandas as pd
+import numpy as np
+import seaborn as sns
+import matplotlib.pyplot as plt
+
+from pathlib import Path, PurePath
+
+
+def plot_logs(logs, fields=('class_error', 'loss_bbox_unscaled', 'mAP'), ewm_col=0, log_name='log.txt'):
+ '''
+ Function to plot specific fields from training log(s). Plots both training and test results.
+
+ :: Inputs - logs = list containing Path objects, each pointing to individual dir with a log file
+ - fields = which results to plot from each log file - plots both training and test for each field.
+ - ewm_col = optional, which column to use as the exponential weighted smoothing of the plots
+ - log_name = optional, name of log file if different than default 'log.txt'.
+
+ :: Outputs - matplotlib plots of results in fields, color coded for each log file.
+ - solid lines are training results, dashed lines are test results.
+
+ '''
+ func_name = "plot_utils.py::plot_logs"
+
+ # verify logs is a list of Paths (list[Paths]) or single Pathlib object Path,
+ # convert single Path to list to avoid 'not iterable' error
+
+ if not isinstance(logs, list):
+ if isinstance(logs, PurePath):
+ logs = [logs]
+ print(f"{func_name} info: logs param expects a list argument, converted to list[Path].")
+ else:
+ raise ValueError(f"{func_name} - invalid argument for logs parameter.\n \
+ Expect list[Path] or single Path obj, received {type(logs)}")
+
+ # Quality checks - verify valid dir(s), that every item in list is Path object, and that log_name exists in each dir
+ for i, dir in enumerate(logs):
+ if not isinstance(dir, PurePath):
+ raise ValueError(f"{func_name} - non-Path object in logs argument of {type(dir)}: \n{dir}")
+ if not dir.exists():
+ raise ValueError(f"{func_name} - invalid directory in logs argument:\n{dir}")
+ # verify log_name exists
+ fn = Path(dir / log_name)
+ if not fn.exists():
+ print(f"-> missing {log_name}. Have you gotten to Epoch 1 in training?")
+ print(f"--> full path of missing log file: {fn}")
+ return
+
+ # load log file(s) and plot
+ dfs = [pd.read_json(Path(p) / log_name, lines=True) for p in logs]
+
+ fig, axs = plt.subplots(ncols=len(fields), figsize=(16, 5))
+
+ for df, color in zip(dfs, sns.color_palette(n_colors=len(logs))):
+ for j, field in enumerate(fields):
+ if field == 'mAP':
+ coco_eval = pd.DataFrame(
+ np.stack(df.test_coco_eval_bbox.dropna().values)[:, 1]
+ ).ewm(com=ewm_col).mean()
+ axs[j].plot(coco_eval, c=color)
+ else:
+ df.interpolate().ewm(com=ewm_col).mean().plot(
+ y=[f'train_{field}', f'test_{field}'],
+ ax=axs[j],
+ color=[color] * 2,
+ style=['-', '--']
+ )
+ for ax, field in zip(axs, fields):
+ ax.legend([Path(p).name for p in logs])
+ ax.set_title(field)
+
+
+def plot_precision_recall(files, naming_scheme='iter'):
+ if naming_scheme == 'exp_id':
+ # name becomes exp_id
+ names = [f.parts[-3] for f in files]
+ elif naming_scheme == 'iter':
+ names = [f.stem for f in files]
+ else:
+ raise ValueError(f'not supported {naming_scheme}')
+ fig, axs = plt.subplots(ncols=2, figsize=(16, 5))
+ for f, color, name in zip(files, sns.color_palette("Blues", n_colors=len(files)), names):
+ data = torch.load(f)
+ # precision is n_iou, n_points, n_cat, n_area, max_det
+ precision = data['precision']
+ recall = data['params'].recThrs
+ scores = data['scores']
+ # take precision for all classes, all areas and 100 detections
+ precision = precision[0, :, :, 0, -1].mean(1)
+ scores = scores[0, :, :, 0, -1].mean(1)
+ prec = precision.mean()
+ rec = data['recall'][0, :, 0, -1].mean()
+ print(f'{naming_scheme} {name}: mAP@50={prec * 100: 05.1f}, ' +
+ f'score={scores.mean():0.3f}, ' +
+ f'f1={2 * prec * rec / (prec + rec + 1e-8):0.3f}'
+ )
+ axs[0].plot(recall, precision, c=color)
+ axs[1].plot(recall, scores, c=color)
+
+ axs[0].set_title('Precision / Recall')
+ axs[0].legend(names)
+ axs[1].set_title('Scores / Recall')
+ axs[1].legend(names)
+ return fig, axs
diff --git a/ACT_DP_multitask/policy_anyrobot.py b/ACT_DP_multitask/policy_anyrobot.py
new file mode 100644
index 0000000000000000000000000000000000000000..464f4e13d2bd3317c2b7068e44e89941ed73f46d
--- /dev/null
+++ b/ACT_DP_multitask/policy_anyrobot.py
@@ -0,0 +1,1496 @@
+import math
+import torch
+import torch.nn as nn
+from torch.nn import functional as F
+import torchvision.transforms as transforms
+from diffusers import DDIMScheduler, DDPMScheduler
+from detr.main import *
+import numpy as np
+from collections import deque
+from utils import normalize_data, tensor2numpy, kl_divergence, RandomShiftsAug
+import os
+
+# from cosmos_tokenizer.video_lib import CausalVideoTokenizer
+# from cosmos_tokenizer.image_lib import ImageTokenizer
+
+
+def get_tokenizer(model_name):
+ print(f"Loading tokenizer {model_name}")
+ current_dir = os.path.dirname(__file__)
+ checkpoint_enc = (
+ f"{current_dir}/Cosmos-Tokenizer/pretrained_ckpts/{model_name}/encoder.jit"
+ )
+ checkpoint_dec = (
+ f"{current_dir}/Cosmos-Tokenizer/pretrained_ckpts/{model_name}/decoder.jit"
+ )
+ model_type = model_name[18] # I or V
+ if model_type == "I":
+ encoder = ImageTokenizer(checkpoint_enc=checkpoint_enc)
+ decoder = ImageTokenizer(checkpoint_dec=checkpoint_dec)
+ elif model_type == "V":
+ encoder = CausalVideoTokenizer(checkpoint_enc=checkpoint_enc)
+ decoder = CausalVideoTokenizer(checkpoint_dec=checkpoint_dec)
+
+ for param in encoder.parameters(): # frozen
+ param.requires_grad = False
+ for param in decoder.parameters():
+ param.requires_grad = False
+ return encoder, decoder
+
+
+class ACTPolicy(nn.Module):
+ def __init__(self, args_override):
+ super().__init__()
+ model, optimizer = build_ACT_model_and_optimizer(args_override)
+ self.model = model # CVAE decoder
+ self.optimizer = optimizer
+ self.kl_weight = args_override["kl_weight"]
+ print(f"KL Weight {self.kl_weight}")
+ self.args = args_override
+
+ def __call__(
+ self, qpos, image, actions=None, is_pad=None, is_training=False, instances=None
+ ):
+ env_state = None
+ normalize = transforms.Normalize(
+ mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
+ )
+
+ if actions is not None: # training time
+ # image = self.aug(image) if is_training else image # disbale aug
+ image = normalize(image)
+ actions = actions[:, : self.model.num_queries]
+ is_pad = is_pad[:, : self.model.num_queries]
+
+ if self.args["segmentation"]:
+ # ACT with segmentation
+ a_hat, is_pad_hat, (mu, logvar), (mask_classes, outputs_seg_masks) = (
+ self.model(qpos, image, env_state, actions, is_pad)
+ )
+ else:
+ # Vanilla ACT
+ a_hat, is_pad_hat, (mu, logvar) = self.model(
+ qpos, image, env_state, actions, is_pad
+ )
+
+ total_kld, dim_wise_kld, mean_kld = kl_divergence(mu, logvar)
+ loss_dict = dict()
+ all_l1 = F.l1_loss(actions, a_hat, reduction="none")
+
+ l1 = (all_l1 * ~is_pad.unsqueeze(-1)).mean()
+ loss_dict["l1"] = l1
+ loss_dict["kl"] = total_kld[0]
+
+ if self.args["segmentation"]:
+ targets = self.prepare_targets(instances, image)
+ losses_seg = self.criterion(
+ {"pred_logits": mask_classes, "pred_masks": outputs_seg_masks},
+ targets,
+ )
+
+ for k in list(losses_seg.keys()):
+ if k in self.criterion.weight_dict:
+ losses_seg[k] *= (
+ self.criterion.weight_dict[k] * self.segloss_weight
+ )
+ else:
+ # remove this loss if not specified in `weight_dict`
+ losses_seg.pop(k)
+ loss_dict.update(losses_seg)
+ loss_dict["loss"] = (
+ loss_dict["l1"]
+ + loss_dict["kl"] * self.kl_weight
+ + loss_dict["loss_ce"]
+ + loss_dict["loss_mask"]
+ + loss_dict["loss_dice"]
+ )
+ else:
+ loss_dict["loss"] = loss_dict["l1"] + loss_dict["kl"] * self.kl_weight
+ return loss_dict
+ else: # inference time
+ image = normalize(image)
+ if self.args["segmentation"]:
+ a_hat, is_pad_hat, (mu, logvar), (mask_classes, outputs_seg_masks) = (
+ self.model(qpos, image, env_state, actions, is_pad)
+ )
+ for mask_cls_result, mask_pred_result in zip(
+ mask_classes, outputs_seg_masks
+ ):
+ sem_seg = torch.zeros(
+ (mask_pred_result.shape[1], mask_pred_result.shape[2], 3),
+ device=mask_pred_result.device,
+ )
+ keep = mask_cls_result.softmax(-1)[:, 0] > 0.5
+ for ii, mask in enumerate(mask_pred_result):
+ if keep[ii]:
+ sem_seg[mask.sigmoid() > 0.5, :] = torch.tensor(
+ self.colors[ii],
+ device=mask_pred_result.device,
+ dtype=sem_seg.dtype,
+ )
+ self.seg = sem_seg.cpu().numpy()
+ # sem_seg = self.semantic_inference(mask_cls_result, mask_pred_result)
+ # import matplotlib.pyplot as plt
+ # plt.subplot(122)
+ # plt.imshow(sem_seg.cpu().numpy()/255)
+ # plt.savefig('seg.png')
+ # import pdb; pdb.set_trace()
+ else:
+ a_hat, _, (_, _) = self.model(
+ qpos, image, env_state
+ ) # no action, sample from prior
+
+ return a_hat
+
+ def configure_optimizers(self):
+ return self.optimizer
+
+ def prepare_targets(self, targets, image):
+ # h, w = images.tensor.shape[-2:]
+ new_targets = []
+ for targets_per_image in targets:
+ # pad gt
+ targets_per_image = targets_per_image.to(image.device)
+ gt_masks = targets_per_image.gt_masks
+ # padded_masks = torch.zeros((gt_masks.shape[0], h, w), dtype=gt_masks.dtype, device=gt_masks.device)
+ # padded_masks[:, : gt_masks.shape[1], : gt_masks.shape[2]] = gt_masks
+ new_targets.append(
+ {
+ "labels": targets_per_image.gt_classes,
+ "masks": gt_masks,
+ }
+ )
+ return new_targets
+
+ def semantic_inference(self, mask_cls, mask_pred):
+ mask_cls = F.softmax(mask_cls, dim=-1)
+ mask_pred = mask_pred.sigmoid()
+ semseg = torch.einsum("qc,qhw->chw", mask_cls, mask_pred)
+ return semseg
+
+ # for robotwin
+ def reset_obs(self, stats, norm_type):
+ self.stats = stats
+ self.norm_type = norm_type
+
+ def update_obs(self, obs):
+ self.obs_image = (
+ torch.from_numpy(obs["head_cam"]).unsqueeze(0).unsqueeze(0).float().cuda()
+ ) # 1 1 C H W 0~1
+ obs_qpos = torch.from_numpy(obs["agent_pos"]).unsqueeze(0).float().cuda()
+ self.obs_qpos = normalize_data(
+ obs_qpos, self.stats, "gaussian", data_type="qpos"
+ ) # qpos mean std
+
+ def get_action(self):
+ a_hat = self(self.obs_qpos, self.obs_image).detach().cpu().numpy() # B T K
+ # unnormalize
+ if self.norm_type == "minmax":
+ a_hat = (a_hat + 1) / 2 * (
+ self.stats["action_max"] - self.stats["action_min"]
+ ) + self.stats["action_min"]
+ elif self.norm_type == "gaussian":
+ a_hat = a_hat * self.stats["action_std"] + self.stats["action_mean"]
+ return a_hat[0] # chunksize 14
+
+
+class ACTDiffusionPolicy(nn.Module):
+ def __init__(self, args_override):
+ super().__init__()
+ model, optimizer = build_ACTDiffusion_model_and_optimizer(args_override)
+ self.model = model # CVAE decoder
+ self.optimizer = optimizer
+ self.kl_weight = args_override["kl_weight"]
+ self.aug = RandomShiftsAug(15, 20) # TODO acording to the task
+ # for robotwin env
+ self.history_steps = args_override["history_step"]
+ self.obs_image = deque(maxlen=self.history_steps + 1)
+ self.obs_qpos = deque(maxlen=self.history_steps + 1)
+ # diffusion setup
+ self.num_inference_steps = args_override["num_inference_steps"]
+ self.num_queries = args_override["num_queries"]
+ num_train_timesteps = args_override["num_train_timesteps"]
+ prediction_type = args_override["prediction_type"]
+ beta_schedule = args_override["beta_schedule"]
+ noise_scheduler = (
+ DDIMScheduler if args_override["schedule_type"] == "DDIM" else DDPMScheduler
+ )
+ noise_scheduler = noise_scheduler(
+ num_train_timesteps=num_train_timesteps,
+ beta_schedule=beta_schedule,
+ prediction_type=prediction_type,
+ )
+
+ self.noise_scheduler = noise_scheduler
+ self.loss_type = args_override["loss_type"]
+ print("num_train_timesteps", {args_override["num_train_timesteps"]})
+ print("schedule_type", {args_override["schedule_type"]})
+ print("beta_schedule", {args_override["beta_schedule"]})
+ print("prediction_type", {args_override["prediction_type"]})
+ print(f"Loss Type {self.loss_type}")
+
+ def train_model(self, qpos, image, actions, is_pad=None,is_training=True, task_emb=None):
+ """
+ qpos: B his+1 14
+ image: B his+1 N_view 3 H W
+ """
+ env_state = None
+ noise = torch.randn_like(actions).to(actions.device)
+ bsz = actions.shape[0]
+ timesteps = torch.randint(
+ 0,
+ self.noise_scheduler.config.num_train_timesteps,
+ (bsz,),
+ device=actions.device,
+ )
+ # print('action device', actions.device, 'noise device', noise.device, 'timesteps device', timesteps.device, )
+ noisy_actions = self.noise_scheduler.add_noise(actions, noise, timesteps)
+
+ pred, is_pad_hat, [mu, logvar] = self.model(
+ qpos, image, env_state, noisy_actions, is_pad, denoise_steps=timesteps, is_training=is_training, task_emb=task_emb
+ )
+
+ pred_type = self.noise_scheduler.config.prediction_type
+
+ if pred_type == "epsilon":
+ target = noise
+ elif pred_type == "sample":
+ target = actions
+ elif pred_type == "v_prediction":
+ # https://github.com/huggingface/diffusers/blob/main/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py
+ # https://github.com/huggingface/diffusers/blob/v0.11.1-patch/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py
+ # sigma = self.noise_scheduler.sigmas[timesteps]
+ # alpha_t, sigma_t = self.noise_scheduler._sigma_to_alpha_sigma_t(sigma)
+ self.noise_scheduler.alpha_t = self.noise_scheduler.alpha_t.to(self.device)
+ self.noise_scheduler.sigma_t = self.noise_scheduler.sigma_t.to(self.device)
+ alpha_t, sigma_t = (
+ self.noise_scheduler.alpha_t[timesteps],
+ self.noise_scheduler.sigma_t[timesteps],
+ )
+ alpha_t = alpha_t.unsqueeze(-1).unsqueeze(-1)
+ sigma_t = sigma_t.unsqueeze(-1).unsqueeze(-1)
+ v_t = alpha_t * noise - sigma_t * actions
+ target = v_t
+ else:
+ raise ValueError(f"Unsupported prediction type {pred_type}")
+
+ loss_dict = {}
+ if self.loss_type == "l2":
+ loss = F.mse_loss(pred, target, reduction="none")
+ elif self.loss_type == "l1":
+ loss = F.l1_loss(pred, target, reduction="none")
+ diffusion_loss = (loss * ~is_pad.unsqueeze(-1)).mean()
+ diffusion_loss_name = pred_type + "_diffusion_loss_" + self.loss_type
+ loss_dict[diffusion_loss_name] = diffusion_loss
+
+ if mu is not None and logvar is not None: # for CVAE module
+ total_kld, dim_wise_kld, mean_kld = kl_divergence(mu, logvar)
+ loss_dict["kl"] = total_kld[0]
+ loss_dict["loss"] = (
+ loss_dict[diffusion_loss_name] + loss_dict["kl"] * self.kl_weight
+ )
+ else:
+ loss_dict["loss"] = loss_dict[diffusion_loss_name]
+ return loss_dict
+
+ # ===================inferece ===============
+ def conditional_sample(self, qpos, image, task_emb=None):
+ """
+ diffusion process to generate actions
+ """
+ if len(image.shape) == 5: # B N C H W
+ qpos = qpos.unsqueeze(1)
+ image = image.unsqueeze(1)
+ env_state = None
+ model = self.model
+ scheduler = self.noise_scheduler
+ batch = image.shape[0]
+ action_shape = (batch, self.num_queries, 14)
+ actions = torch.randn(action_shape, device=qpos.device, dtype=qpos.dtype)
+ scheduler.set_timesteps(self.num_inference_steps)
+ for t in scheduler.timesteps:
+ # print('diffusion timestep', t)
+ timesteps = torch.full((batch,), t, device=qpos.device, dtype=torch.long)
+ model_output, is_pad_hat, [mu, logvar] = model(
+ qpos,
+ image,
+ env_state,
+ actions,
+ None,
+ denoise_steps=timesteps,
+ is_training=False,
+ task_emb=task_emb,
+ )
+ actions = scheduler.step(model_output, t, actions).prev_sample
+ return actions
+
+ def __call__(self, qpos, image, actions=None, is_pad=None, is_training=True, task_emb=None):
+ # qpos: B D
+ # image: B Num_view C H W
+ # actions: B T K
+ normalize = transforms.Normalize(
+ mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
+ )
+
+ if actions is not None: # training time
+ # image = self.aug(image) if is_training else image
+ image = normalize(image)
+ actions = actions[:, : self.model.num_queries]
+ is_pad = is_pad[:, : self.model.num_queries]
+ loss_dict = self.train_model(qpos, image, actions, is_pad, is_training,task_emb)
+ return loss_dict
+ else: # inference time
+ image = normalize(image)
+ a_hat = self.conditional_sample(qpos, image,task_emb)
+ return a_hat
+
+ def configure_optimizers(self):
+ return self.optimizer
+
+ def reset_obs(self, stats, norm_type):
+ self.stats = stats
+ self.norm_type = norm_type
+
+
+ def get_action(self, qpos, image, task_emb):
+ # normalize qpos
+ qpos = qpos.unsqueeze(0) # B D
+ image = image.unsqueeze(0) # B N C H W
+ task_emb = task_emb.unsqueeze(0) # B D ?
+ qpos = normalize_data(qpos, self.stats, "gaussian", data_type="qpos")
+ with torch.no_grad():
+ a_hat = self(qpos, image, is_pad=None, is_training=False, task_emb=task_emb).detach().cpu().numpy()
+ # unnormalize action
+ if self.norm_type == "minmax":
+ a_hat = (a_hat + 1) / 2 * (
+ self.stats["action_max"] - self.stats["action_min"]
+ ) + self.stats["action_min"]
+ elif self.norm_type == "gaussian":
+ a_hat = a_hat * self.stats["action_std"] + self.stats["action_mean"]
+ # detach and convert to numpy
+ a_hat = a_hat[0]
+ return a_hat # 50 14
+
+
+
+class ACT_Flow_Matching(nn.Module):
+ def __init__(self, args_override):
+ super().__init__()
+ model, optimizer = build_ACTDiffusion_model_and_optimizer(args_override)
+ self.model = model
+ self.optimizer = optimizer
+ if "sim" in args_override["task_name"]: # for aloha env
+ self.aug = RandomShiftsAug(15, 20) # TODO acording to the task
+ else:
+ self.aug = RandomShiftsAug(8, 10) # for robotwin env
+ self.history_steps = args_override["history_step"]
+ self.obs_image = deque(maxlen=self.history_steps + 1)
+ self.obs_qpos = deque(maxlen=self.history_steps + 1)
+ self.num_queries = args_override["num_queries"]
+ # flow matching steps
+ self.num_inference_steps = args_override["num_inference_steps"]
+ self.noise_scheduler = torch.distributions.Beta(1.5, 1)
+ self.loss_type = args_override["loss_type"]
+ def train_model(self, qpos, image, actions, is_pad=None):
+ """
+ qpos: B his+1 14
+ image: B his+1 N_view 3 H W
+ """
+ env_state = None
+ noise = torch.randn_like(actions).to(actions.device)
+ bsz = actions.shape[0]
+ # refer to openpi0
+ timesteps = self.noise_scheduler.sample((bsz,)).to(actions.device) * 0.999 + 0.001
+ timesteps_expand = timesteps[...,None,None] # B 1 1
+ noisy_actions = timesteps_expand * noise + (1 - timesteps_expand) * actions
+
+ pred, is_pad_hat, [mu, logvar] = self.model(
+ qpos, image, env_state, noisy_actions, is_pad, denoise_steps=timesteps
+ )
+
+ target = noise - actions # ut
+ loss_dict = {}
+ if self.loss_type == "l2":
+ loss = F.mse_loss(pred, target, reduction="none")
+ elif self.loss_type == "l1":
+ loss = F.l1_loss(pred, target, reduction="none")
+ diffusion_loss = (loss * ~is_pad.unsqueeze(-1)).mean()
+ pred_type = 'v_pred'
+ diffusion_loss_name = pred_type + "_flow_loss_" + self.loss_type
+ loss_dict[diffusion_loss_name] = diffusion_loss
+
+ if mu is not None and logvar is not None: # for CVAE module
+ total_kld, dim_wise_kld, mean_kld = kl_divergence(mu, logvar)
+ loss_dict["kl"] = total_kld[0]
+ loss_dict["loss"] = (
+ loss_dict[diffusion_loss_name] + loss_dict["kl"] * self.kl_weight
+ )
+ else:
+ loss_dict["loss"] = loss_dict[diffusion_loss_name]
+ return loss_dict
+
+ # ===================inferece ===============
+ def conditional_sample(self, qpos, image, is_pad):
+ """
+ diffusion process to generate actions
+ """
+ if len(image.shape) == 5: # B N C H W
+ qpos = qpos.unsqueeze(1)
+ image = image.unsqueeze(1)
+ env_state = None
+ model = self.model
+ batch = image.shape[0]
+ action_shape = (batch, self.num_queries, 14)
+ actions = torch.randn(action_shape, device=qpos.device, dtype=qpos.dtype)
+ dt = -1.0 / self.num_inference_steps # -0.1, constant speed
+ timesteps = torch.ones(batch, device=qpos.device)
+ while timesteps[0].item() >= -dt/2: # inferen_step
+ model_output, is_pad_hat, [mu, logvar] = model(
+ qpos,
+ image,
+ env_state,
+ actions,
+ is_pad,
+ denoise_steps=timesteps,
+ is_training=False,
+ )
+ actions += model_output * dt
+ timesteps += dt
+ return actions.clip(min=-1, max=1)
+
+ def __call__(self, qpos, image, actions=None, is_pad=None, is_training=True):
+ # qpos: B D
+ # image: B Num_view C H W
+ # actions: B T K
+ normalize = transforms.Normalize(
+ mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
+ )
+
+ if actions is not None: # training time
+ # image = self.aug(image) if is_training else image
+ image = normalize(image)
+ actions = actions[:, : self.model.num_queries]
+ is_pad = is_pad[:, : self.model.num_queries]
+ loss_dict = self.train_model(qpos, image, actions, is_pad)
+ return loss_dict
+ else: # inference time
+ image = normalize(image)
+ a_hat = self.conditional_sample(qpos, image, is_pad)
+ return a_hat
+
+ def configure_optimizers(self):
+ return self.optimizer
+
+
+class ACTDiffusionPolicy_Tactile(nn.Module):
+ def __init__(self, args_override):
+ super().__init__()
+ model, optimizer = build_ACTDiffusion_tactile_model_and_optimizer(args_override)
+ self.model = model # CVAE decoder
+ self.optimizer = optimizer
+ print(args_override.keys())
+ self.aug = RandomShiftsAug(8, 10) # for robotwin env
+ self.history_steps = 0
+ self.obs_image = deque(maxlen=self.history_steps + 1)
+ self.obs_qpos = deque(maxlen=self.history_steps + 1)
+ self.obs_tactile = deque(maxlen=self.history_steps + 1)
+ # diffusion setup
+ self.num_inference_steps = args_override["num_inference_steps"]
+ self.num_queries = args_override["num_queries"]
+ num_train_timesteps = args_override["num_train_timesteps"]
+ prediction_type = args_override["prediction_type"]
+ beta_schedule = args_override["beta_schedule"]
+ noise_scheduler = (
+ DDIMScheduler if args_override["schedule_type"] == "DDIM" else DDPMScheduler
+ )
+ noise_scheduler = noise_scheduler(
+ num_train_timesteps=num_train_timesteps,
+ beta_schedule=beta_schedule,
+ prediction_type=prediction_type,
+ )
+
+ self.noise_scheduler = noise_scheduler
+ self.loss_type = args_override["loss_type"]
+ print("num_train_timesteps", {args_override["num_train_timesteps"]})
+ print("schedule_type", {args_override["schedule_type"]})
+ print("beta_schedule", {args_override["beta_schedule"]})
+ print("prediction_type", {args_override["prediction_type"]})
+ print(f"Loss Type {self.loss_type}")
+
+ def train_model(self, qpos, image,tactile, actions, is_pad=None):
+ """
+ qpos: B his+1 14
+ image: B his+1 N_view 3 H W
+ """
+ env_state = None
+ noise = torch.randn_like(actions).to(actions.device)
+ bsz = actions.shape[0]
+ timesteps = torch.randint(
+ 0,
+ self.noise_scheduler.config.num_train_timesteps,
+ (bsz,),
+ device=actions.device,
+ )
+ noisy_actions = self.noise_scheduler.add_noise(actions, noise, timesteps)
+
+ pred, is_pad_hat = self.model(
+ qpos, image, tactile, env_state, noisy_actions, is_pad, denoise_steps=timesteps
+ )
+
+ pred_type = self.noise_scheduler.config.prediction_type
+
+ if pred_type == "epsilon":
+ target = noise
+ elif pred_type == "sample":
+ target = actions
+ elif pred_type == "v_prediction":
+ # https://github.com/huggingface/diffusers/blob/main/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py
+ # https://github.com/huggingface/diffusers/blob/v0.11.1-patch/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py
+ # sigma = self.noise_scheduler.sigmas[timesteps]
+ # alpha_t, sigma_t = self.noise_scheduler._sigma_to_alpha_sigma_t(sigma)
+ self.noise_scheduler.alpha_t = self.noise_scheduler.alpha_t.to(self.device)
+ self.noise_scheduler.sigma_t = self.noise_scheduler.sigma_t.to(self.device)
+ alpha_t, sigma_t = (
+ self.noise_scheduler.alpha_t[timesteps],
+ self.noise_scheduler.sigma_t[timesteps],
+ )
+ alpha_t = alpha_t.unsqueeze(-1).unsqueeze(-1)
+ sigma_t = sigma_t.unsqueeze(-1).unsqueeze(-1)
+ v_t = alpha_t * noise - sigma_t * actions
+ target = v_t # flow matching?
+ else:
+ raise ValueError(f"Unsupported prediction type {pred_type}")
+
+ loss_dict = {}
+ if self.loss_type == "l2":
+ loss = F.mse_loss(pred, target, reduction="none")
+ elif self.loss_type == "l1":
+ loss = F.l1_loss(pred, target, reduction="none")
+ diffusion_loss = (loss * ~is_pad.unsqueeze(-1)).mean()
+ diffusion_loss_name = pred_type + "_diffusion_loss_" + self.loss_type
+ loss_dict[diffusion_loss_name] = diffusion_loss
+
+ loss_dict["loss"] = loss_dict[diffusion_loss_name]
+ return loss_dict
+
+ # ===================inferece ===============
+ def conditional_sample(self, qpos, image,tactile, is_pad):
+ """
+ diffusion process to generate actions
+ """
+ if len(image.shape) == 5: # B N C H W
+ qpos = qpos.unsqueeze(1)
+ image = image.unsqueeze(1)
+ tactile = tactile.unsqueeze(1)
+ env_state = None
+ model = self.model
+ scheduler = self.noise_scheduler
+ batch = image.shape[0]
+ action_shape = (batch, self.num_queries, 14)
+ actions = torch.randn(action_shape, device=qpos.device, dtype=qpos.dtype)
+ scheduler.set_timesteps(self.num_inference_steps)
+ for t in scheduler.timesteps:
+ timesteps = torch.full((batch,), t, device=qpos.device, dtype=torch.long)
+ model_output, is_pad_hat, [mu, logvar] = model(
+ qpos,
+ image,
+ tactile,
+ env_state,
+ actions,
+ is_pad,
+ denoise_steps=timesteps,
+ is_training=False,
+ )
+ actions = scheduler.step(model_output, t, actions).prev_sample
+ return actions
+
+ def __call__(self, qpos, image, tactile, actions=None, is_pad=None, is_training=True):
+ # qpos: B D
+ # image: B Num_view C H W
+ # actions: B T K
+ normalize = transforms.Normalize(
+ mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
+ )
+
+ if actions is not None: # training time
+ image = self.aug(image) if is_training else image
+ image = normalize(image)
+ actions = actions[:, : self.model.num_queries]
+ is_pad = is_pad[:, : self.model.num_queries]
+ loss_dict = self.train_model(qpos, image,tactile, actions, is_pad)
+ return loss_dict
+ else: # inference time
+ image = normalize(image)
+ a_hat = self.conditional_sample(qpos, image,tactile,is_pad)
+ return a_hat
+
+ def configure_optimizers(self):
+ return self.optimizer
+
+
+## use visual tokenization
+"""
+Input:
+ A unified dataset for all the datasets
+ image_data [0~1]: history_steps+1 Num_view C H W
+ qpos_data [normalized]: history_steps+1 D
+ action_data [raw]: chunk_size D
+ is_pad: chunk_size
+ future_imgs_data [0~1]: predict_frame Num_view C H W
+ is_pad_img : predict_frame
+"""
+
+
+class ACTPolicyDiffusion_with_Token_Prediction(nn.Module):
+ def __init__(self, args_override):
+ super().__init__()
+ model = build_diffusion_tp_model_and_optimizer(args_override)
+ self.model = model # decoder
+ self.camera_num = len(args_override["camera_names"])
+ self.kl_weight = args_override["kl_weight"]
+ print(f"KL Weight {self.kl_weight}")
+ # memory buffer
+ self.history_steps = args_override["history_step"]
+ self.obs_image = deque(maxlen=self.history_steps + 1)
+ self.obs_qpos = deque(maxlen=self.history_steps + 1)
+ # visual tokenization
+ if "sim" in args_override["task_name"]: # for aloha env
+ self.aug = RandomShiftsAug(15, 20)
+ else:
+ self.aug = RandomShiftsAug(8, 10) # for robotwin env
+ # tokenizer model and shape
+ self.tokenizer_model_type = args_override["tokenizer_model_name"][17:19] # VI
+ self.tokenizer_enc, self.tokenizer_dec = get_tokenizer(
+ args_override["tokenizer_model_name"]
+ )
+ self.token_dim = args_override["token_dim"]
+ self.num_temporal_token = self.model.num_temporal_token
+ self.token_h = (
+ args_override["image_height"]
+ // args_override["image_downsample_rate"]
+ // args_override["tokenizer_model_spatial_rate"]
+ // args_override["resize_rate"]
+ )
+ self.token_w = (
+ args_override["image_width"]
+ // args_override["image_downsample_rate"]
+ // args_override["tokenizer_model_spatial_rate"]
+ // args_override["resize_rate"]
+ )
+ print(
+ "token shape",
+ "token_h",
+ self.token_h,
+ "token_w",
+ self.token_w,
+ "token_dim",
+ self.token_dim,
+ )
+ # video prediction hyperparameters
+ self.temporal_compression = args_override[
+ "tokenizer_model_temporal_rate"
+ ] # temporal compression
+ self.predict_only_last = args_override["predict_only_last"]
+ self.prediction_weight = args_override["prediction_weight"]
+ self.imitate_weight = args_override["imitate_weight"]
+ self.predict_frame = args_override["predict_frame"]
+ self.temporal_downsample_rate = args_override[
+ "temporal_downsample_rate"
+ ] # uniformly sample
+ self.resize_rate = args_override["resize_rate"]
+ print("tokenizer_model_type", self.tokenizer_model_type)
+ print("predict_frame", self.predict_frame)
+ print("prediction_weight", self.prediction_weight)
+ print("imitate_weight", self.imitate_weight)
+
+ # diffusion hyperparameters
+ self.num_inference_steps = args_override["num_inference_steps"]
+ self.num_queries = args_override["num_queries"]
+ num_train_timesteps = args_override["num_train_timesteps"]
+ prediction_type = args_override["prediction_type"]
+ beta_schedule = args_override["beta_schedule"]
+ noise_scheduler = (
+ DDIMScheduler if args_override["schedule_type"] == "DDIM" else DDPMScheduler
+ )
+ noise_scheduler = noise_scheduler(
+ num_train_timesteps=num_train_timesteps,
+ beta_schedule=beta_schedule,
+ prediction_type=prediction_type,
+ )
+ self.noise_scheduler = noise_scheduler
+ pred_type = self.noise_scheduler.config.prediction_type
+ self.loss_type = args_override["loss_type"]
+ self.diffusion_loss_name = pred_type + "_diffusion_loss_" + self.loss_type
+
+ print("num_train_timesteps", {args_override["num_train_timesteps"]})
+ print("schedule_type", {args_override["schedule_type"]})
+ print("beta_schedule", {args_override["beta_schedule"]})
+ print("prediction_type", {args_override["prediction_type"]})
+ print(f"Loss Type {self.loss_type}")
+
+ self.normalize = transforms.Normalize(
+ mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
+ )
+ self.vis_idx = 0
+
+ def train_model(self, qpos, image, env_state, actions, is_pad, is_image_pad=None):
+ # qpos: B T' D; T' = history_steps+1
+ # image: B T+1 N C H W ,T = history_steps+1+predict_frame 0~1
+ # actions: B H D, H = chunk_size
+ # is_pad: B H, is_valid
+ # is_image_pad: B predict_frame
+ env_state = None
+ bsz = actions.shape[0]
+ is_tokens_pad = torch.ones(
+ bsz, self.num_temporal_token, device=actions.device, dtype=torch.bool
+ ) # B T/t length after temporal compression
+ if self.predict_only_last:
+ is_tokens_pad = is_image_pad
+ else:
+ valid_is_tokens_pad = is_image_pad[
+ :, :: self.temporal_compression
+ ] # avoid meaningless token
+ is_tokens_pad[:, : valid_is_tokens_pad.shape[-1]] = valid_is_tokens_pad
+
+ timesteps = torch.randint(
+ 0,
+ self.noise_scheduler.config.num_train_timesteps,
+ (bsz,),
+ device=actions.device,
+ )
+
+ # image tokenization ; another image normalization for resnet
+ current_image_norm = self.normalize(
+ image[:, 0 : self.history_steps + 1] # image[:, 0:1]
+ ) # B 1 N C H W TODO image[:, self.history_steps,self.history_steps+1]
+ # TOOD only compress current and future frames image[:, self.history_steps:,:]
+ image_tokens = self.get_visual_token(
+ image[
+ :, self.history_steps :, 0:1, :, :: self.resize_rate, :: self.resize_rate
+ ] # resize future frame
+ ) # B T+1 N C H W
+ # image_tokens = self.get_visual_token(
+ # image[..., :: self.resize_rate, :: self.resize_rate] # resize future frame
+ # ) # B T/t + 1 N D H' W' including history_steps+1+predict_frame
+ current_image_tokens = image_tokens[:, 0:1] # B 1 N D H' W' useless
+ future_image_tokens = image_tokens[:, 1:] # B T N D H' W'
+ self.image_tokens_shape = future_image_tokens.shape # B T N D H' W'
+ # TODO can use differnent noise_scheduler for image token & actions
+ actions_noise = torch.randn_like(actions).to(actions.device)
+ token_noise = torch.randn_like(future_image_tokens).to(
+ future_image_tokens.device
+ )
+ noisy_actions = self.noise_scheduler.add_noise(
+ actions, actions_noise, timesteps
+ )
+ noise_tokens = self.noise_scheduler.add_noise(
+ future_image_tokens, token_noise, timesteps
+ ) # future image token
+ # use detr-diffusion model to predict actions & image tokens
+ a_hat, is_pad_hat, pred_token, (mu, logvar) = self.model(
+ qpos, # B his+1 D
+ (current_image_norm, current_image_tokens), # B 1 N C H W
+ env_state,
+ actions, # B T D
+ is_pad,
+ noisy_actions, # B T D
+ noise_tokens, # B T' N_view D H' W'
+ is_tokens_pad,
+ denoise_steps=timesteps,
+ )
+
+ # prediction type
+ pred_type = self.noise_scheduler.config.prediction_type
+
+ if pred_type == "epsilon":
+ target_action = actions_noise
+ target_token = token_noise
+ elif pred_type == "sample":
+ target_action = actions
+ target_token = future_image_tokens
+ else:
+ raise ValueError(f"Unsupported prediction type {pred_type}")
+
+ # calculate diffusion loss
+ loss_dict = {}
+ if self.loss_type == "l2":
+ loss = F.mse_loss(a_hat, target_action, reduction="none")
+ elif self.loss_type == "l1":
+ loss = F.l1_loss(a_hat, target_action, reduction="none")
+ diffusion_loss = (loss * ~is_pad.unsqueeze(-1)).mean()
+ loss_dict[self.diffusion_loss_name] = diffusion_loss
+ # just vis diffusion l2 loss
+ diffusion_l2 = F.mse_loss(a_hat, target_action, reduction="none")
+ diffusion_l2 = (diffusion_l2 * ~is_pad.unsqueeze(-1)).mean().detach()
+ loss_dict["diffusion_l2"] = diffusion_l2
+
+ tokens_loss = F.mse_loss(
+ pred_token, target_token, reduction="none"
+ ) # B T N D H' W'
+ is_tokens_pad = (
+ is_tokens_pad.unsqueeze(-1).unsqueeze(-1).unsqueeze(-1).unsqueeze(-1)
+ ) # B T N D H' W'
+ tokens_loss = (tokens_loss * ~is_tokens_pad).mean()
+ loss_dict["loss_prediction"] = tokens_loss
+
+ if mu is not None and logvar is not None:
+ total_kld, dim_wise_kld, mean_kld = kl_divergence(mu, logvar)
+ loss_dict["kl"] = total_kld[0]
+ loss_dict["loss"] = (
+ loss_dict[self.diffusion_loss_name] * self.imitate_weight
+ + loss_dict["kl"] * self.kl_weight
+ + loss_dict["loss_prediction"] * self.prediction_weight
+ )
+ else:
+ loss_dict["loss"] = (
+ loss_dict[self.diffusion_loss_name] * self.imitate_weight
+ + loss_dict["loss_prediction"] * self.prediction_weight
+ )
+
+ return loss_dict, (current_image_tokens, target_token, pred_token)
+
+ def conditional_sample(self, qpos, image):
+ # qpos: B 1 D or B D
+ # image: B 1 N C H W or B N C H W
+ if len(image.shape) == 5: # B N C H W single frame
+ qpos = qpos.unsqueeze(1)
+ image = image.unsqueeze(1)
+ env_state = None
+ model = self.model
+ scheduler = self.noise_scheduler
+ scheduler.set_timesteps(self.num_inference_steps)
+ # process image observation
+ current_image_norm = self.normalize(image[:, 0 : self.history_steps + 1]) # B 1 N C H W
+ # initial noise action & token
+ batch = image.shape[0]
+ action_shape = (batch, self.num_queries, 14)
+ actions = torch.randn(action_shape, device=qpos.device, dtype=qpos.dtype)
+ tokens = None # TODO discard token prediction while evaluation
+ for t in scheduler.timesteps:
+ timesteps = torch.full((batch,), t, device=qpos.device, dtype=torch.long)
+ model_action_output, is_pad_hat, model_token_output, (mu, logvar) = model(
+ qpos,
+ (current_image_norm, None),
+ env_state,
+ None,
+ None,
+ actions,
+ tokens,
+ None,
+ denoise_steps=timesteps,
+ )
+ actions = scheduler.step(model_action_output, t, actions).prev_sample
+ return actions, tokens, mu, logvar
+
+ def __call__(
+ self,
+ qpos,
+ image,
+ actions=None,
+ is_pad=None,
+ future_imgs=None,
+ is_pad_img=None,
+ is_training=True,
+ save_rec=False,
+ ):
+ env_state = None
+ if actions is not None: # training time
+ # print(image.shape, future_imgs.shape) B 1 N C H W & B H N C H W
+ all_image = torch.cat(
+ [image, future_imgs], dim=1
+ ) # B H+1+T' N C H W same resize maybe just use image_sample_size
+ all_image = self.aug(all_image) if is_training else all_image
+ loss_dict, (current_image_tokens, target_token, pred_token) = (
+ self.train_model(
+ qpos, all_image, env_state, actions, is_pad, is_pad_img
+ )
+ ) # B H
+
+ if save_rec == True: # show the visulization result
+ # a_hat, pred_token, _, _ = self.conditional_sample(qpos, image) # generate bug?
+ # tokens_loss = F.mse_loss(pred_token, target_token)
+ # print('conditional tokens_loss', tokens_loss) # not predict from noise
+ print("is_image_pad rate ", is_pad_img[0].sum() / 20)
+ raw_videos = (
+ all_image[:, :, 0].permute(0, 2, 1, 3, 4) * 2 - 1
+ ) # B T C H W -> B C T H W -1,1
+ rec_videos = self.generate_video_by_codes(
+ current_image_tokens, pred_token
+ )
+ rec_gt_videos = self.generate_video_by_codes(
+ current_image_tokens, target_token
+ )
+ error_videos = (rec_gt_videos - rec_videos).clip(-1, 1)
+
+ raw_videos = tensor2numpy(raw_videos)[0] # C T H W
+ rec_videos = tensor2numpy(rec_videos)[0] # T H W C
+ rec_gt_videos = tensor2numpy(rec_gt_videos)[0] # T H W C
+ error_videos = tensor2numpy(error_videos)[0] # T H W C
+ vis_video = np.concatenate(
+ [raw_videos, rec_gt_videos, rec_videos, error_videos], axis=1
+ ) # T H*N W C
+ return loss_dict, vis_video
+ else:
+ return loss_dict
+
+ else: # inference time
+ qpos = qpos # B 1 D or B D
+ image = image # B 1 N C H W or B N C H W
+ # print(image.shape, image.max(), image.min())
+ a_hat, pred_token, _, _ = self.conditional_sample(qpos, image)
+ # print('prediction action', a_hat.shape)
+ return a_hat # B H D
+
+ # visual tokenization generate tokens & reconstruct video
+ def get_visual_token(self, input_tensor):
+ # input_tensor: B T N C H W range [0,1] -> [-1,1]
+ input_tensor = input_tensor * 2 - 1 # B T N C H W
+ input_tensor = input_tensor.permute(0, 2, 3, 1, 4, 5) # B N C T H W
+ horizon = input_tensor.shape[3]
+ C, H, W = input_tensor.shape[2], input_tensor.shape[4], input_tensor.shape[5]
+ self.num_view = input_tensor.shape[1]
+ codes_list = []
+ # refer to Cosmos-Tokenizer/cosmos_tokenizer/video_lib.py Line 103
+ for view_idx in range(
+ self.num_view
+ ): # deal with each view for video tokenization B C T H W
+ if self.tokenizer_model_type == "DV": # encoder for video tokenization
+ (indices, codes) = self.tokenizer_enc._enc_model(
+ input_tensor[:, view_idx] # B C T H W
+ )[
+ :-1
+ ] # B D T' H' W'
+ elif self.tokenizer_model_type == "DI": # encoder for image tokenization
+ input_tensor_image = (
+ input_tensor[:, view_idx].permute(0, 2, 1, 3, 4).view(-1, C, H, W)
+ ) # B C T H W -> B*T C H W
+ (indices, codes) = self.tokenizer_enc._enc_model(input_tensor_image)[
+ :-1
+ ] # B*T D H' W'
+ codes = codes.view(
+ -1, horizon, codes.shape[1], codes.shape[2], codes.shape[3]
+ ).permute(
+ 0, 2, 1, 3, 4
+ ) # B T+1 D H' W' -> B D T+1 H' W'
+ elif (
+ self.tokenizer_model_type == "CV"
+ ): # encoder for video tokenization TODO check
+ (codes,) = self.tokenizer_enc._enc_model(input_tensor[:, view_idx])[
+ :-1
+ ] # B 16 T' H' W'
+ codes = codes / 15.0 # nomalize to [-1,1] HardCode
+ # TODO codes should normlize to [-1,1]
+ elif self.tokenizer_model_type == "CI":
+ input_tensor_image = (
+ input_tensor[:, view_idx].permute(0, 2, 1, 3, 4).view(-1, C, H, W)
+ ) # B C T H W -> B*T C H W
+ (codes,) = self.tokenizer_enc._enc_model(input_tensor_image)[
+ :-1
+ ] # B*T D H' W'
+ codes = codes.view(
+ -1, horizon, codes.shape[1], codes.shape[2], codes.shape[3]
+ ).permute(
+ 0, 2, 1, 3, 4
+ ) # B T+1 D H' W' -> B D T+1 H' W'
+ codes = codes / 15.0 # nomalize to [-1,1] HardCode
+ codes_list.append(codes.detach()) # Important
+ codes = torch.stack(codes_list, dim=1) # B N D T' H' W' -> # B T+1 N D H' W'
+ codes = codes.permute(0, 3, 1, 2, 4, 5).float() # B T+1 N D H' W'
+ return codes
+
+ def generate_video_by_codes(self, current_codes, pred_codes):
+ # Input: B T'+1 D H' W' simgle view video tokens
+ # Output: B C T+1 H W range [-1,1]
+ codes = torch.cat([current_codes, pred_codes], dim=1)[
+ :, :, 0
+ ] # B T+1 N D H' W' -> B T'+1 D H' W'# HARD CODE
+ codes = codes.permute(0, 2, 1, 3, 4).to(dtype=torch.bfloat16) # B D T'+1 H' W'
+ # suitable for DV only
+ if self.tokenizer_model_type == "DV":
+ h = self.tokenizer_dec._dec_model.post_quant_conv(
+ codes
+ ) # problem shoudl fixed TODO
+ elif self.tokenizer_model_type == "CV":
+ h = codes * 15 # unnormalize to original
+ reconstructed_videos = self.tokenizer_dec._dec_model.decoder(
+ h
+ ).detach() # B C T+1 H W -1,1
+
+ return reconstructed_videos
+
+ # For ROBOTWIN
+ def reset_obs(self, stats=None, norm_type="minmax"):
+ self.obs_image.clear()
+ self.obs_qpos.clear()
+ self.stats = stats
+ self.norm_type = norm_type
+
+ def update_obs(self, obs):
+ image_data = (
+ torch.from_numpy(obs["head_cam"]).unsqueeze(0).unsqueeze(0).float().cuda()
+ ) # B 1 C H W 0~1
+ obs_qpos = torch.from_numpy(obs["agent_pos"]).unsqueeze(0).float().cuda() # B D
+ qpos_data = normalize_data(
+ obs_qpos, self.stats, "gaussian", data_type="qpos"
+ ) # qpos mean std
+
+ if len(self.obs_image) == 0:
+ for _ in range(self.history_steps + 1):
+ self.obs_image.append(image_data) # B T N C H W
+ self.obs_qpos.append(qpos_data)
+ else:
+ self.obs_image.append(image_data)
+ self.obs_qpos.append(qpos_data)
+
+ def get_action(self):
+ stacked_obs_image = torch.stack(
+ list(self.obs_image), dim=1
+ ) # 1 n+1 1 3 H W raw
+ stacked_obs_qpos = torch.stack(list(self.obs_qpos), dim=1) # 1 n+1 14
+ a_hat = (
+ self(stacked_obs_qpos, stacked_obs_image).detach().cpu().numpy()
+ ) # 1 chunksize 14
+ if self.norm_type == "minmax":
+ a_hat = (a_hat + 1) / 2 * (
+ self.stats["action_max"] - self.stats["action_min"]
+ ) + self.stats["action_min"]
+ elif self.norm_type == "gaussian":
+ a_hat = a_hat * self.stats["action_std"] + self.stats["action_mean"]
+ return a_hat[0] # chunksize 14
+
+
+class ACTPolicyDiffusion_with_Pixel_Prediction(nn.Module):
+ def __init__(self, args_override):
+ super().__init__()
+ model = build_diffusion_pp_model_and_optimizer(args_override) # TODO
+ self.model = model # CVAE decoder
+ self.kl_weight = args_override["kl_weight"]
+ print(f"KL Weight {self.kl_weight}")
+
+ # memory buffer
+ self.history_steps = args_override["history_step"]
+ self.obs_image = deque(maxlen=self.history_steps + 1)
+ self.obs_qpos = deque(maxlen=self.history_steps + 1)
+ # self.obs_depth = deque(maxlen=self.history_steps+1)
+ # visual tokenization
+ if "sim" in args_override["task_name"]:
+ self.aug = RandomShiftsAug(15, 20) #
+ self.aug = RandomShiftsAug(8, 10)
+ self.predict_only_last = args_override["predict_only_last"]
+ self.prediction_weight = args_override["prediction_weight"]
+ self.imitate_weight = args_override["imitate_weight"]
+ self.predict_frame = args_override["predict_frame"]
+ self.temporal_downsample_rate = args_override["temporal_downsample_rate"]
+ self.resize_rate = args_override["resize_rate"]
+ self.image_height = args_override["image_height"]
+ self.image_width = args_override["image_width"]
+ # T N C H W
+ self.future_images_shape = (
+ args_override["predict_frame"] // args_override["temporal_downsample_rate"],
+ len(args_override["camera_names"]),
+ 3,
+ self.image_height // self.resize_rate,
+ self.image_width // self.resize_rate,
+ )
+ print("predict_frame", self.predict_frame)
+ print("prediction_weight", self.prediction_weight)
+ print("imitate_weight", self.imitate_weight)
+ # diffusion step
+ self.num_inference_steps = args_override["num_inference_steps"]
+ self.num_queries = args_override["num_queries"]
+ num_train_timesteps = args_override["num_train_timesteps"]
+ prediction_type = args_override["prediction_type"]
+ beta_schedule = args_override["beta_schedule"]
+ noise_scheduler = (
+ DDIMScheduler if args_override["schedule_type"] == "DDIM" else DDPMScheduler
+ )
+ noise_scheduler = noise_scheduler(
+ num_train_timesteps=num_train_timesteps,
+ beta_schedule=beta_schedule,
+ prediction_type=prediction_type,
+ )
+ self.noise_scheduler = noise_scheduler
+ pred_type = self.noise_scheduler.config.prediction_type
+ self.loss_type = args_override["loss_type"]
+ self.diffusion_loss_name = pred_type + "_diffusion_loss_" + self.loss_type
+
+ print("num_train_timesteps", {args_override["num_train_timesteps"]})
+ print("schedule_type", {args_override["schedule_type"]})
+ print("beta_schedule", {args_override["beta_schedule"]})
+ print("prediction_type", {args_override["prediction_type"]})
+ print(f"Loss Type {self.loss_type}")
+ # discard resnet
+ self.normalize = transforms.Normalize(
+ mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
+ )
+ self.vis_idx = 0
+
+ def train_model(self, qpos, image, env_state, actions, is_pad, is_image_pad=None):
+ # qpos: B T' D; T' = history_steps+1
+ # image: B T+1 N C H W ,T = history_steps+1+predict_frame 0~1s
+ # actions: B H D H = chunk_size
+ # is_pad: B H
+ # is_image_pad: B predict_frame
+ env_state = None
+ bsz = actions.shape[0]
+
+ timesteps = torch.randint(
+ 0,
+ self.noise_scheduler.config.num_train_timesteps,
+ (bsz,),
+ device=actions.device,
+ )
+ current_image_norm = self.normalize(image[:, 0:1])
+ future_images = (
+ image[:, 1:, :, :, :: self.resize_rate, :: self.resize_rate] * 2 - 1
+ ) # B T N C H W scale to [-1,1]
+ # diffusion process
+ actions_noise = torch.randn_like(actions).to(actions.device)
+ pixel_noise = torch.randn_like(future_images).to(future_images.device)
+ noisy_actions = self.noise_scheduler.add_noise(
+ actions, actions_noise, timesteps
+ )
+ noise_pixel = self.noise_scheduler.add_noise(
+ future_images, pixel_noise, timesteps
+ ) # future image token
+ # predict clean data
+ a_hat, is_pad_hat, pred_images, (mu, logvar) = self.model(
+ qpos,
+ (current_image_norm, None),
+ env_state,
+ actions,
+ is_pad,
+ noisy_actions,
+ noise_pixel,
+ is_image_pad,
+ denoise_steps=timesteps,
+ )
+
+ # prediction type
+ pred_type = self.noise_scheduler.config.prediction_type
+
+ if pred_type == "epsilon":
+ target_action = actions_noise
+ target_images = pixel_noise
+ elif pred_type == "sample":
+ target_action = actions
+ target_images = future_images
+ else:
+ raise ValueError(f"Unsupported prediction type {pred_type}")
+
+ # calculate diffusion loss
+ loss_dict = {}
+ if self.loss_type == "l2":
+ loss = F.mse_loss(a_hat, target_action, reduction="none")
+ elif self.loss_type == "l1":
+ loss = F.l1_loss(a_hat, target_action, reduction="none")
+ diffusion_loss = (loss * ~is_pad.unsqueeze(-1)).mean()
+ loss_dict[self.diffusion_loss_name] = diffusion_loss
+ diffusion_l2 = F.mse_loss(a_hat, target_action, reduction="none")
+ diffusion_l2 = (diffusion_l2 * ~is_pad.unsqueeze(-1)).mean().detach()
+ loss_dict["diffusion_l2"] = diffusion_l2
+
+ pixel_loss = F.mse_loss(
+ pred_images, target_images, reduction="none"
+ ) # B T N C H W
+ is_image_pad = (
+ is_image_pad.unsqueeze(-1).unsqueeze(-1).unsqueeze(-1).unsqueeze(-1)
+ ) # B T N C H W
+ pixel_loss = (pixel_loss * ~is_image_pad).mean()
+ loss_dict["loss_prediction_pixel"] = pixel_loss
+
+ if mu is not None and logvar is not None:
+ total_kld, dim_wise_kld, mean_kld = kl_divergence(mu, logvar)
+ loss_dict["kl"] = total_kld[0]
+ loss_dict["loss"] = (
+ loss_dict[self.diffusion_loss_name] * self.imitate_weight
+ + loss_dict["kl"] * self.kl_weight
+ + loss_dict["loss_prediction_pixel"] * self.prediction_weight
+ )
+ else:
+ loss_dict["loss"] = (
+ loss_dict[self.diffusion_loss_name] * self.imitate_weight
+ + loss_dict["loss_prediction_pixel"] * self.prediction_weight
+ )
+
+ return loss_dict, (image[:, 0:1], target_images, pred_images)
+
+ def conditional_sample(self, qpos, image):
+ # qpos: B 1 D
+ # image: B 1 N C H W
+ if len(image.shape) == 5:
+ qpos = qpos.unsqueeze(1)
+ image = image.unsqueeze(1)
+ env_state = None
+ model = self.model
+ scheduler = self.noise_scheduler
+ scheduler.set_timesteps(self.num_inference_steps)
+ # process image observation
+ current_image_norm = self.normalize(image[:, 0:1]) # B 1 N C H W
+ batch = image.shape[0]
+ action_shape = (batch, self.num_queries, 14)
+ future_images_shape = (batch, *self.future_images_shape) # B T N 3 H' W'
+ actions = torch.randn(action_shape, device=qpos.device, dtype=qpos.dtype)
+ pixels = torch.randn(future_images_shape, device=qpos.device, dtype=qpos.dtype)
+ # denoise
+ for t in scheduler.timesteps:
+ timesteps = torch.full((batch,), t, device=qpos.device, dtype=torch.long)
+ model_action_output, is_pad_hat, model_pixel_output, (mu, logvar) = model(
+ qpos,
+ (current_image_norm, None),
+ env_state,
+ None,
+ None,
+ actions,
+ pixels,
+ None,
+ denoise_steps=timesteps,
+ )
+ actions = scheduler.step(model_action_output, t, actions).prev_sample
+ pixels = scheduler.step(model_pixel_output, t, pixels).prev_sample
+ return actions, pixels, mu, logvar
+
+ def __call__(
+ self,
+ qpos,
+ image,
+ actions=None,
+ is_pad=None,
+ future_imgs=None,
+ is_pad_img=None,
+ is_training=True,
+ ):
+ env_state = None
+ if actions is not None: # training time
+ all_image = torch.cat(
+ [image, future_imgs], dim=1
+ ) # B T N C H W same resize maybe just use image_sample_size
+ all_image = self.aug(all_image) if is_training else all_image
+ loss_dict, (current_image, target_images, pred_images) = self.train_model(
+ qpos, all_image, env_state, actions, is_pad, is_pad_img
+ ) # B H
+ return loss_dict
+ else:
+ qpos = qpos # B 1 D
+ image = image # B 1 N C H W 0~1
+ a_hat, pred_images, _, _ = self.conditional_sample(qpos, image)
+ return a_hat # B H D
+
+
+# discard
+class ACTPolicy_NextFrame(nn.Module):
+ def __init__(self, args_override):
+ super().__init__()
+ model, optimizer = build_ACT_NF_model_and_optimizer(args_override)
+ self.model = model # CVAE decoder
+ self.optimizer = optimizer
+ self.kl_weight = args_override["kl_weight"]
+ self.nextframe_weight = 1
+ print(f"KL Weight {self.kl_weight}")
+
+ def __call__(self, qpos, image, actions=None, is_pad=None):
+ env_state = None
+ normalize = transforms.Normalize(
+ mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
+ )
+ if actions is not None: # training time
+ curr_image = image[:, 0:1]
+ next_image = image[:, 1:]
+ curr_image = normalize(curr_image)
+ image = torch.cat(
+ [curr_image, next_image], dim=1
+ ) # B T C H W normalize currernt image not funture
+
+ actions = actions[:, : self.model.num_queries]
+ is_pad = is_pad[:, : self.model.num_queries]
+
+ a_hat, is_pad_hat, (mu, logvar), (obs_preds, obs_targets) = self.model(
+ qpos, image, env_state, actions, is_pad
+ )
+ total_kld, dim_wise_kld, mean_kld = kl_divergence(mu, logvar)
+ loss_dict = dict()
+ all_l1 = F.l1_loss(actions, a_hat, reduction="none")
+ l1 = (all_l1 * ~is_pad.unsqueeze(-1)).mean()
+ obs_loss = ((obs_preds.sigmoid() - obs_targets) ** 2).mean()
+ loss_dict["l1"] = l1
+ loss_dict["kl"] = total_kld[0]
+ loss_dict["next_frame"] = obs_loss
+ loss_dict["loss"] = (
+ loss_dict["l1"]
+ + loss_dict["kl"] * self.kl_weight
+ + loss_dict["next_frame"] * self.nextframe_weight
+ )
+ return loss_dict
+ else: # inference time
+ image = normalize(image)
+ a_hat, _, (_, _), (obs_preds, obs_targets) = self.model(
+ qpos, image, env_state
+ ) # no action, sample from prior
+
+ # next frame prediction
+ bs = a_hat.shape[0]
+ patch_size = 16
+ # image_size = 224
+ img_h, img_w = self.model.img_h, self.model.img_w
+ ph, pw = img_h // patch_size, img_w // patch_size
+ nf_pred = obs_preds.sigmoid().reshape(
+ shape=(bs, ph, pw, patch_size, patch_size, 3)
+ )
+ nf_pred = nf_pred.permute(0, 5, 1, 3, 2, 4)
+ nf_pred = nf_pred.reshape(shape=(bs, 3, img_h, img_w))
+
+ # import matplotlib.pyplot as plt
+ # plt.imshow(nf_pred[0].cpu().numpy().transpose(1,2,0))
+ # plt.savefig('tmp.png')
+ # plt.clf()
+ # plt.close()
+ # import pdb; pdb.set_trace()
+ # self.next_frames.append({'next_frame':nf_pred[0].cpu().numpy().transpose(1,2,0)})
+ self.next_frame = nf_pred[0].cpu().numpy().transpose(1, 2, 0)
+ return a_hat
+
+ def configure_optimizers(self):
+ return self.optimizer
+
+
+class CNNMLPPolicy(nn.Module):
+ def __init__(self, args_override):
+ super().__init__()
+ model, optimizer = build_CNNMLP_model_and_optimizer(args_override)
+ self.model = model # decoder
+ self.optimizer = optimizer
+
+ def __call__(self, qpos, image, actions=None, is_pad=None):
+ env_state = None # TODO
+ normalize = transforms.Normalize(
+ mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
+ )
+ image = normalize(image)
+ if actions is not None: # training time
+ actions = actions[:, 0]
+ a_hat = self.model(qpos, image, env_state, actions)
+ mse = F.mse_loss(actions, a_hat)
+ loss_dict = dict()
+ loss_dict["mse"] = mse
+ loss_dict["loss"] = loss_dict["mse"]
+ return loss_dict
+ else: # inference time
+ a_hat = self.model(qpos, image, env_state) # no action, sample from prior
+ return a_hat
+
+ def configure_optimizers(self):
+ return self.optimizer
+
+
+if __name__ == "__main__":
+
+
+ import json
+ config_path = '/attached/remote-home2/xhl/8_kaust_pj/RoboTwin/policy/ACT-DP-TP-Policy/checkpoints/put_bottles_dustbin/single_20_2_300_600/act_dp/policy_config.json'
+ config = json.load(open(config_path, 'r'))
+
+ # policy = ACT_Flow_Matching(config)
+
+ # qpos = torch.randn(1,3, 14).to('cuda')
+ # image = torch.rand(1, 3,1, 3, 480, 640).to('cuda')
+ # actions = torch.randn(1, 20, 14).to('cuda')
+ # is_pad = torch.zeros(1, 20).bool().to('cuda')
+ # loss_dict = policy(qpos, image)
+ # print(loss_dict)
+
+ config['condition_type'] = "adaLN"
+ policy = ACTDiffusionPolicy(config)
+
+ qpos = torch.randn(1,3, 14).to('cuda')
+ image = torch.rand(1, 3,1, 3, 480, 640).to('cuda')
+ actions = torch.randn(1, 20, 14).to('cuda')
+ is_pad = torch.zeros(1, 20).bool().to('cuda')
+ loss_dict = policy(qpos, image)
+ print(loss_dict)
+
+
+ # from cosmos_tokenizer.networks import TokenizerConfigs
+ # from cosmos_tokenizer.utils import (
+ # get_filepaths,
+ # get_output_filepath,
+ # read_video,
+ # resize_video,
+ # write_video,
+ # )
+ # # from cosmos_tokenizer.video_lib import CausalVideoTokenizer
+ # from cosmos_tokenizer.image_lib import ImageTokenizer
+
+ # tokenizer_type = 'DV'
+ # spatial_compression = 16
+ # temporal_compression = 8
+ # mode = 'torch'
+ # temporal_window = 17
+ # dtype = 'bfloat16'
+ # device = 'cuda'
+ # model_name = "Cosmos-Tokenizer-DV8x16x16"
+ # checkpoint_enc = f'Cosmos-Tokenizer/pretrained_ckpts/{model_name}/encoder.jit'
+ # checkpoint_dec = f'Cosmos-Tokenizer/pretrained_ckpts/{model_name}/decoder.jit'
+ # # load model
+ # tokenizer_config = TokenizerConfigs[tokenizer_type].value
+ # tokenizer_config.update(dict(spatial_compression=spatial_compression))
+ # tokenizer_config.update(dict(temporal_compression=temporal_compression))
+ # autoencoder = CausalVideoTokenizer(
+ # checkpoint= None,
+ # checkpoint_enc=checkpoint_enc,
+ # checkpoint_dec=checkpoint_dec,
+ # tokenizer_config=tokenizer_config,
+ # device=device,
+ # dtype=dtype,
+ # )
+ # # T C HW
+ # input_tensor = read_video('Cosmos-Tokenizer/test_robot_data/sim_transfer_cube_scripted/episode_0.mp4')
+ # print(input_tensor.dtype) # uint8 255
+
+ # batch_video = np.array(input_tensor[49:98:3,::2,::2])[np.newaxis, ...] # B T H W C
+ # output_video = autoencoder(batch_video, temporal_window=temporal_window)[0] # T H W C np.unit8 255
+
+ # autoencoder.get_latent_codes(batch_video)# B C T' H' W' # B 6 T' H' W'
+
+ # encoder = CausalVideoTokenizer(checkpoint_enc=checkpoint_enc)
+ # decoder = CausalVideoTokenizer(checkpoint_dec=checkpoint_dec)
+ # batch_video_tensor = torch.from_numpy(batch_video).cuda().permute(0,4,1,2,3) / 127.5 - 1 # B C T H W -1,1
+ # print(batch_video_tensor.shape)
+ # output_index, output_code = encoder._enc_model(batch_video_tensor)[:-1] # Input B C T H W -1,1
+ # print(output_index.shape, output_code.shape)
+ # h = decoder._dec_model.post_quant_conv(output_code) # input shape B 6 T' H' W', output shape B 16 T' H' W'
+ # reconstructed_tensor = decoder._dec_model.decoder(h).detach() # B C T H W -1,1
+ # rec_video = tensor2numpy(reconstructed_tensor)[0] # T H W C
+
+ # print('diff', (rec_video - output_video).mean())
+ # print('rec diff', (rec_video - batch_video).mean())
+ # print('out diff', (output_video - batch_video).mean())
+
+ # vis_path = 'vis_0.mp4'
+ # vis_video = np.concatenate([batch_video[0], output_video, rec_video], axis=2) # T H W*2 C
+ # media.write_video(vis_path, vis_video, fps=3)
+
+ # model_name = "Cosmos-Tokenizer-DI16x16"
+ # encoder = ImageTokenizer(checkpoint_enc=f'Cosmos-Tokenizer/pretrained_ckpts/{model_name}/encoder.jit')
+ # decoder = ImageTokenizer(checkpoint_dec=f'Cosmos-Tokenizer/pretrained_ckpts/{model_name}/decoder.jit')
+ # input_tensor = torch.randn(32, 3, 480, 640).to('cuda').to(torch.bfloat16) # [B, C, T, H, W]
+ # (indices, codes) = encoder.encode(input_tensor)
+ # print(codes.shape)
+ # print(codes.max(), codes.min())
+ # reconstructed_tensor = decoder.decode(indices) # input index
+ # print(reconstructed_tensor.shape)
+ # model_name = "Cosmos-Tokenizer-CV4x8x8"
+ # encoder = CausalVideoTokenizer(
+ # checkpoint_enc=f"Cosmos-Tokenizer/pretrained_ckpts/{model_name}/encoder.jit"
+ # )
+ # decoder = CausalVideoTokenizer(
+ # checkpoint_dec=f"Cosmos-Tokenizer/pretrained_ckpts/{model_name}/decoder.jit"
+ # )
+ # # B N C T H W
+ # input_tensor = (
+ # -torch.ones(16, 5, 3, 21, 480, 640).to("cuda").to(torch.bfloat16)
+ # ) # [B, C, T, H, W]
+ # for view_idx in range(5):
+ # (codes,) = encoder._enc_model(input_tensor[:, view_idx])[:-1] # B 16 T' H' W'
+ # codes = codes.detach()
+ # print(codes.shape)
+ # print(codes.max(), codes.min())
diff --git a/ACT_DP_multitask/requirements.txt b/ACT_DP_multitask/requirements.txt
new file mode 100644
index 0000000000000000000000000000000000000000..bd53ea17411770a66ab8affd00ece502cde36d3c
--- /dev/null
+++ b/ACT_DP_multitask/requirements.txt
@@ -0,0 +1,160 @@
+absl-py==2.3.0
+ackermann-msgs==2.0.2
+action-msgs==1.2.1
+action-tutorials-interfaces==0.20.5
+actionlib-msgs==4.8.0
+aiohappyeyeballs==2.6.1
+ament-cmake-test==1.3.11
+ament-package==0.14.0
+angles==1.15.0
+annotated-types==0.7.0
+argcomplete==3.6.2
+asttokens==3.0.0
+async-timeout==5.0.1
+attrs==25.3.0
+boto==2.49.0
+builtin-interfaces==1.2.1
+cachetools==5.5.2
+catkin-pkg
+certifi==2022.12.7
+charset-normalizer==2.1.1
+click==8.2.1
+composition-interfaces==1.2.1
+control-msgs==4.8.0
+controller-manager==2.50.0
+controller-manager-msgs==2.50.0
+crcmod==1.7
+cv-bridge==3.2.1
+cycler==0.12.1
+decorator==5.2.1
+diagnostic-msgs==4.8.0
+docstring_parser==0.16
+einops==0.8.1
+etils==1.12.2
+example-interfaces==0.9.3
+executing==2.2.0
+fasteners==0.19
+filelock==3.13.1
+flatbuffers==25.2.10
+fonttools==4.58.1
+frozenlist==1.6.0
+fsspec==2024.6.1
+gast==0.6.0
+geometry-msgs==4.8.0
+grpcio==1.72.1
+hf-xet==1.1.2
+hjson==3.1.0
+idna==3.4
+image-geometry==3.2.1
+immutabledict==4.2.1
+importlib_resources==6.5.2
+interactive-markers==2.3.2
+keras==2.15.0
+kiwisolver==1.4.8
+laser-geometry==2.4.0
+libclang==18.1.1
+lifecycle-msgs==1.2.1
+logging-demo==0.20.5
+map-msgs==2.1.0
+Markdown==3.8
+MarkupSafe==3.0.2
+message-filters==4.3.7
+monotonic==1.6
+mpmath==1.3.0
+msgpack==1.1.0
+nav-msgs==4.8.0
+networkx==3.3
+ninja==1.11.1.4
+numpy==1.26.4
+nvidia-ml-py==12.575.51
+oauthlib==3.2.2
+opt_einsum==3.4.0
+osrf-pycommon==2.1.6
+packaging==25.0
+parso==0.8.4
+pcl-msgs==1.0.0
+pendulum-msgs==0.20.5
+pillow==10.2.0
+piper_msgs==0.0.0
+platformdirs==4.3.8
+propcache==0.3.1
+protobuf==4.21.12
+psutil==7.0.0
+ptyprocess==0.7.0
+pure_eval==0.2.3
+py-cpuinfo==9.0.0
+pyarrow==20.0.0
+pyasn1==0.6.1
+pycparser==2.22
+Pygments==2.19.1
+pyparsing==3.2.3
+pyrealsense2==2.55.1.6486
+python-qt-binding==1.1.2
+pytz==2025.2
+PyYAML==6.0.1
+qt-dotgraph==2.2.4
+qt-gui==2.2.4
+qt-gui-cpp==2.2.4
+qt-gui-py-common==2.2.4
+rcl-interfaces==1.2.1
+rclpy==3.3.16
+rcutils==5.1.6
+regex==2024.11.6
+resource-retriever==3.1.3
+retry-decorator==1.1.1
+rmw-dds-common==1.6.0
+ros2cli==0.18.12
+rosbag2-interfaces==0.15.14
+rosbag2-py==0.15.14
+rosgraph-msgs==1.2.1
+rosidl-adapter==3.1.6
+rosidl-cli==3.1.6
+rosidl-cmake==3.1.6
+rosidl-generator-c==3.1.6
+rosidl-generator-cpp==3.1.6
+rosidl-generator-py==0.14.4
+rosidl-parser==3.1.6
+rosidl-runtime-py==0.9.3
+rosidl-typesupport-c==2.0.2
+rosidl-typesupport-cpp==2.0.2
+rosidl-typesupport-fastrtps-c==2.2.2
+rosidl-typesupport-fastrtps-cpp==2.2.2
+rosidl-typesupport-introspection-c==3.1.6
+rosidl-typesupport-introspection-cpp==3.1.6
+rpyutils==0.2.1
+rqt-py-common==1.1.7
+safetensors==0.5.3
+sensor-msgs==4.8.0
+sentencepiece==0.2.0
+setproctitle==1.3.6
+setuptools==78.1.1
+shape-msgs==4.8.0
+six==1.17.0
+smmap==5.0.2
+statistics-msgs==1.2.1
+std-msgs==4.8.0
+std-srvs==4.8.0
+stereo-msgs==4.8.0
+tensorboard-data-server==0.7.2
+tensorflow-estimator==2.15.0
+tensorflow-io-gcs-filesystem==0.37.1
+termcolor==3.1.0
+tf2-geometry-msgs==0.25.12
+tf2-kdl==0.25.12
+tf2-msgs==0.25.12
+tf2-py==0.25.12
+toml==0.10.2
+tqdm==4.67.1
+traitlets==5.14.3
+trajectory-msgs==4.8.0
+turtlesim==1.4.2
+typeguard==2.13.3
+typing_extensions==4.13.2
+unique-identifier-msgs==2.2.1
+urllib3==1.26.13
+visualization-msgs==4.8.0
+wcwidth==0.2.13
+wheel==0.45.1
+wrapt==1.14.1
+xacro==2.0.13
+zipp==3.22.0
diff --git a/ACT_DP_multitask/t5_encoder.py b/ACT_DP_multitask/t5_encoder.py
new file mode 100644
index 0000000000000000000000000000000000000000..f931abeff5edd00b2a15ee0c1c29e13ebdcf2e7c
--- /dev/null
+++ b/ACT_DP_multitask/t5_encoder.py
@@ -0,0 +1,110 @@
+import torch
+from transformers import AutoTokenizer, T5EncoderModel
+
+class T5Embedder:
+ # available_models = ["google/t5-v1_1-xxl"]
+
+ def __init__(
+ self,
+ device,
+ from_pretrained=None,
+ *,
+ cache_dir=None,
+ hf_token=None,
+ use_text_preprocessing=True,
+ t5_model_kwargs=None,
+ torch_dtype=None,
+ use_offload_folder=None,
+ model_max_length=120,
+ local_files_only=False,
+ ):
+ # from_pretrained="google/t5-v1_1-xxl" # zijian
+ self.device = torch.device(device)
+ self.torch_dtype = torch_dtype or torch.bfloat16
+ self.cache_dir = cache_dir
+
+ if t5_model_kwargs is None:
+ t5_model_kwargs = {
+ "low_cpu_mem_usage": True,
+ "torch_dtype": self.torch_dtype,
+ }
+
+ if use_offload_folder is not None:
+ t5_model_kwargs["offload_folder"] = use_offload_folder
+ t5_model_kwargs["device_map"] = {
+ "shared": self.device,
+ "encoder.embed_tokens": self.device,
+ "encoder.block.0": self.device,
+ "encoder.block.1": self.device,
+ "encoder.block.2": self.device,
+ "encoder.block.3": self.device,
+ "encoder.block.4": self.device,
+ "encoder.block.5": self.device,
+ "encoder.block.6": self.device,
+ "encoder.block.7": self.device,
+ "encoder.block.8": self.device,
+ "encoder.block.9": self.device,
+ "encoder.block.10": self.device,
+ "encoder.block.11": self.device,
+ "encoder.block.12": "disk",
+ "encoder.block.13": "disk",
+ "encoder.block.14": "disk",
+ "encoder.block.15": "disk",
+ "encoder.block.16": "disk",
+ "encoder.block.17": "disk",
+ "encoder.block.18": "disk",
+ "encoder.block.19": "disk",
+ "encoder.block.20": "disk",
+ "encoder.block.21": "disk",
+ "encoder.block.22": "disk",
+ "encoder.block.23": "disk",
+ "encoder.final_layer_norm": "disk",
+ "encoder.dropout": "disk",
+ }
+ else:
+ t5_model_kwargs["device_map"] = {
+ "shared": self.device,
+ "encoder": self.device,
+ }
+
+ self.use_text_preprocessing = use_text_preprocessing
+ self.hf_token = hf_token
+
+ # assert from_pretrained in self.available_models
+ self.tokenizer = AutoTokenizer.from_pretrained(
+ from_pretrained,
+ model_max_length=model_max_length,
+ cache_dir=cache_dir,
+ local_files_only=local_files_only,
+ )
+ self.model = T5EncoderModel.from_pretrained(
+ from_pretrained,
+ cache_dir=cache_dir,
+ local_files_only=local_files_only,
+ **t5_model_kwargs,
+ ).eval()
+ self.model_max_length = model_max_length
+
+ def get_text_embeddings(self, texts):
+ text_tokens_and_mask = self.tokenizer(
+ texts,
+ max_length=self.model_max_length,
+ padding="longest",
+ truncation=True,
+ return_attention_mask=True,
+ add_special_tokens=True,
+ return_tensors="pt",
+ )
+
+ input_ids = text_tokens_and_mask["input_ids"].to(self.device)
+ attention_mask = text_tokens_and_mask["attention_mask"].to(self.device)
+ with torch.no_grad():
+ text_encoder_embs = self.model(
+ input_ids=input_ids,
+ attention_mask=attention_mask,
+ )["last_hidden_state"].detach()
+ return text_encoder_embs, attention_mask
+
+
+if __name__ == "__main__":
+ T5Embedder(from_pretrained="google/t5-v1_1-xxl", device='cuda:7')
\ No newline at end of file
diff --git a/ACT_DP_multitask/utils.py b/ACT_DP_multitask/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..a3b40ca589c4da1ad9f0c00d5371957cec5ca9b7
--- /dev/null
+++ b/ACT_DP_multitask/utils.py
@@ -0,0 +1,410 @@
+import torch
+from torch.utils.data import Dataset, DataLoader
+import numpy as np
+import os
+import cv2
+import matplotlib.pyplot as plt
+import math
+from torch.nn.modules.batchnorm import _BatchNorm
+from collections import OrderedDict
+from torch.optim.lr_scheduler import LambdaLR
+import torch.nn as nn
+from torch.nn import functional as F
+import h5py
+import fnmatch
+from torchvision import transforms
+import pickle
+from tqdm import tqdm
+_UINT8_MAX_F = float(torch.iinfo(torch.uint8).max)
+
+def plot_history(train_history, validation_history, num_epochs, ckpt_dir, seed):
+ # save training curves
+ for key in train_history[0]:
+ plot_path = os.path.join(ckpt_dir, f"train_val_{key}_seed_{seed}.png")
+ plt.figure()
+ train_values = [summary[key].item() for summary in train_history]
+ val_values = [summary[key].item() for summary in validation_history]
+ plt.plot(
+ np.linspace(0, num_epochs - 1, len(train_history)),
+ train_values,
+ label="train",
+ )
+ plt.plot(
+ np.linspace(0, num_epochs - 1, len(validation_history)),
+ val_values,
+ label="validation",
+ )
+ plt.tight_layout()
+ plt.legend()
+ plt.title(key)
+ plt.savefig(plot_path)
+ print(f"Saved plots to {ckpt_dir}")
+
+
+def tensor2numpy(input_tensor: torch.Tensor, range_min: int = -1) -> np.ndarray:
+ """Converts tensor in [-1,1] to image(dtype=np.uint8) in range [0..255].
+
+ Args:
+ input_tensor: Input image tensor of Bx3xHxW layout, range [-1..1].
+ Returns:
+ A numpy image of layout BxHxWx3, range [0..255], uint8 dtype.
+ """
+ if range_min == -1:
+ input_tensor = (input_tensor.float() + 1.0) / 2.0
+ ndim = input_tensor.ndim
+ output_image = input_tensor.clamp(0, 1).cpu().numpy()
+ output_image = output_image.transpose((0,) + tuple(range(2, ndim)) + (1,))
+ return (output_image * _UINT8_MAX_F + 0.5).astype(np.uint8)
+
+
+def kl_divergence(mu, logvar):
+ batch_size = mu.size(0)
+ assert batch_size != 0
+ if mu.data.ndimension() == 4:
+ mu = mu.view(mu.size(0), mu.size(1))
+ if logvar.data.ndimension() == 4:
+ logvar = logvar.view(logvar.size(0), logvar.size(1))
+
+ klds = -0.5 * (1 + logvar - mu.pow(2) - logvar.exp())
+ total_kld = klds.sum(1).mean(0, True)
+ dimension_wise_kld = klds.mean(0)
+ mean_kld = klds.mean(1).mean(0, True)
+
+ return total_kld, dimension_wise_kld, mean_kld
+
+
+class RandomShiftsAug(nn.Module):
+ def __init__(self, pad_h, pad_w):
+ super().__init__()
+ self.pad_h = pad_h
+ self.pad_w = pad_w
+ print(f"RandomShiftsAug: pad_h {pad_h}, pad_w {pad_w}")
+
+ def forward(self, x):
+ orignal_shape = x.shape
+ n, h, w = x.shape[0], x.shape[-2], x.shape[-1] # n,T,M,C,H,W
+ x = x.view(n, -1, h, w) # n,T*M*C,H,W
+ padding = (
+ self.pad_w,
+ self.pad_w,
+ self.pad_h,
+ self.pad_h,
+ ) # left, right, top, bottom padding
+ x = F.pad(x, padding, mode="replicate")
+
+ h_pad, w_pad = h + 2 * self.pad_h, w + 2 * self.pad_w
+ eps_h = 1.0 / h_pad
+ eps_w = 1.0 / w_pad
+
+ arange_h = torch.linspace(
+ -1.0 + eps_h, 1.0 - eps_h, h_pad, device=x.device, dtype=x.dtype
+ )[:h]
+ arange_w = torch.linspace(
+ -1.0 + eps_w, 1.0 - eps_w, w_pad, device=x.device, dtype=x.dtype
+ )[:w]
+
+ arange_h = arange_h.unsqueeze(1).repeat(1, w).unsqueeze(2) # h w 1
+ arange_w = arange_w.unsqueeze(1).repeat(1, h).unsqueeze(2) # w h 1
+
+ # print(arange_h.shape, arange_w.shape)
+ base_grid = torch.cat([arange_w.transpose(1, 0), arange_h], dim=2) # [H, W, 2]
+ base_grid = base_grid.unsqueeze(0).repeat(
+ n, 1, 1, 1
+ ) # Repeat for batch [B, H, W, 2]
+
+ shift_h = torch.randint(
+ 0, 2 * self.pad_h + 1, size=(n, 1, 1, 1), device=x.device, dtype=x.dtype
+ ).float()
+ shift_w = torch.randint(
+ 0, 2 * self.pad_w + 1, size=(n, 1, 1, 1), device=x.device, dtype=x.dtype
+ ).float()
+ shift_h *= 2.0 / h_pad
+ shift_w *= 2.0 / w_pad
+
+ grid = base_grid + torch.cat([shift_w, shift_h], dim=3)
+ x = F.grid_sample(x, grid, padding_mode="zeros", align_corners=False)
+ return x.view(orignal_shape)
+
+
+def get_norm_stats(state, action):
+ all_qpos_data = torch.from_numpy(np.array(state))
+ all_action_data = torch.from_numpy(np.array(action))
+ # normalize action data
+ action_mean = all_action_data.mean(dim=[0], keepdim=True)
+ action_std = all_action_data.std(dim=[0], keepdim=True)
+ action_std = torch.clip(action_std, 1e-2, np.inf) # clipping
+ action_max = torch.amax(all_action_data, dim=[0], keepdim=True)
+ action_min = torch.amin(all_action_data, dim=[0], keepdim=True)
+
+ # normalize qpos data
+ qpos_mean = all_qpos_data.mean(dim=[0], keepdim=True)
+ qpos_std = all_qpos_data.std(dim=[0], keepdim=True)
+ qpos_std = torch.clip(qpos_std, 1e-2, np.inf) # clipping
+
+ stats = {
+ "action_mean": action_mean.numpy().squeeze(),
+ "action_std": action_std.numpy().squeeze(),
+ "action_max": action_max.numpy().squeeze(),
+ "action_min": action_min.numpy().squeeze(),
+ "qpos_mean": qpos_mean.numpy().squeeze(),
+ "qpos_std": qpos_std.numpy().squeeze(),
+ }
+
+ return stats
+
+class EpisodicDataset_Unified_Multiview(Dataset):
+ def __init__(self, data_path_list, camera_names, chunk_size,stats, img_aug=False):
+ super(EpisodicDataset_Unified_Multiview).__init__()
+ self.data_path_list = data_path_list
+ self.camera_names = camera_names
+ self.chunk_size = chunk_size
+ self.norm_stats = stats
+ self.img_aug = img_aug
+ self.ColorJitter = transforms.ColorJitter(
+ brightness=0.2,contrast=0.2,saturation=0.2,hue=0.01)
+ def __len__(self):
+ return len(self.data_path_list) * 16
+ def __getitem__(self, path_index):
+ # qpos = np.concatenate((root['/arm/jointStatePosition/masterLeft'][()], root['/arm/jointStatePosition/masterRight'][()]), axis=-1)
+ # actions = np.concatenate((root['/arm/jointStatePosition/puppetLeft']
+ path_index = path_index % len(self.data_path_list) # ensure index is within bounds
+ example_path = self.data_path_list[path_index]
+ with h5py.File(example_path, 'r') as f:
+ action = f['observations']['qpos'][()] # jointStatePosition/master
+ qpos = f['action'][()] # jointStatePosition/puppet
+
+ parent_path = os.path.dirname(example_path)
+ Instruction_path = os.path.join(parent_path, 'instructions')
+ # randomly sample instruction file
+ instruction_files = [f for f in os.listdir(Instruction_path) if fnmatch.fnmatch(f, '*.pt')]
+ instruction_file = os.path.join(Instruction_path, np.random.choice(instruction_files))
+ instruction = torch.load(instruction_file, weights_only=False) # num_token * 4096 tensor
+ # randomly sample an episode inex
+ episode_len = action.shape[0]
+ index = np.random.randint(0, episode_len)
+ obs_qpos = qpos[index:index + 1]
+
+ # stack images
+ with h5py.File(example_path, 'r') as f:
+ camera_list = []
+ for camera_name in self.camera_names:
+ cam_jpeg_code = f['observations']['images'][camera_name][index]
+ cam_image = cv2.imdecode(np.frombuffer(cam_jpeg_code, np.uint8), cv2.IMREAD_COLOR) # rgb
+ camera_list.append(cam_image)
+ obs_img = np.stack(camera_list, axis=0) # shape: (N_views, H, W, C)
+ original_action_shape = (self.chunk_size, *action.shape[1:])
+ gt_action = np.zeros(original_action_shape)
+ action_len = min(self.chunk_size, episode_len - index)
+ gt_action[:action_len] = action[
+ index : index + action_len
+ ]
+ is_pad = np.zeros(self.chunk_size)
+ is_pad[action_len:] = 1
+
+ # construct observations tensor type
+ image_data = torch.from_numpy(obs_img).unsqueeze(0).float() # (history_steps+1, 1, H, W, 3) add num_view
+ image_data = image_data.permute(0, 1, 4, 2, 3) # (1, N_views, 3, H, W)
+ qpos_data = torch.from_numpy(obs_qpos).float()# .unsqueeze(0) # (1, 14)
+ action_data = torch.from_numpy(gt_action).float() # (chunk_size, 14)
+ is_pad = torch.from_numpy(is_pad).bool() # (chunk_size, )
+ instruction_data = instruction.mean(0).float() # (4096)
+ # normalize image and qpos
+ image_data = image_data / 255.0 # Normalize to [0, 1] T N C H W
+ qpos_data = (qpos_data - self.norm_stats["qpos_mean"]) / self.norm_stats[
+ "qpos_std"
+ ]
+ if self.img_aug and random.random() < 0.25:
+ for t in range(image_data.shape[0]):
+ for i in range(image_data.shape[1]):
+ image_data[t, i] =self.ColorJitter(image_data[t, i])
+ return image_data, qpos_data.float(), action_data, is_pad, instruction_data
+
+def load_data_unified(
+ data_dir='/home/algo/anyrobot/Anyrobot_RoboTwin_Challenge/policy/RDT/training_data/rdt_real_multitask',
+ camera_names=['cam_high', 'cam_left_wrist', 'cam_right_wrist'],
+ batch_size_train=32,
+ chunk_size=100,
+ img_aug=False,
+ fintune=False,
+):
+
+ HDF5_file_path = []
+ for root, _, files in os.walk(data_dir, followlinks=True):
+ for filename in files:
+ if filename.endswith('.hdf5'):
+ HDF5_file_path.append(os.path.join(root, filename))
+ print(f"Loading data from {data_dir} with {len(HDF5_file_path)} episodes and batch size {batch_size_train}")
+
+ state_list = []
+ action_list = []
+ # qpos = np.concatenate((root['/arm/jointStatePosition/masterLeft'][()], root['/arm/jointStatePosition/masterRight'][()]), axis=-1)
+ # actions = np.concatenate((root['/arm/jointStatePosition/puppetLeft']
+ for p in tqdm(HDF5_file_path, desc="Data statics collection"):
+ with h5py.File(p, 'r') as f:
+ action = f['observations']['qpos'][()]
+ qpos = f['action'][()]
+ state_list.append(qpos)
+ action_list.append(action)
+ states = np.concatenate(state_list, axis=0)
+ actions = np.concatenate(action_list, axis=0)
+
+ if fintune:
+ # load stats from pretrain path 1590 episodes
+ pretrain_stats_path = '/home/algo/anyrobot/Anyrobot_RoboTwin_Challenge/policy/ACT_DP_multitask/checkpoints/real_pretrain_50_2000/act_dp/dataset_stats.pkl'
+ with open(pretrain_stats_path, 'rb') as f:
+ stats = pickle.load(f)
+ print(f"Loaded stats from {pretrain_stats_path}")
+ else:
+ stats = get_norm_stats(states, actions)
+
+ for key, value in stats.items():
+ print(f"{key}: {value}")
+
+ train_dataset = EpisodicDataset_Unified_Multiview(
+ data_path_list=HDF5_file_path,
+ camera_names=camera_names,
+ chunk_size=chunk_size,
+ stats=stats,
+ img_aug=img_aug,
+ )
+
+ traind_data_loader = DataLoader(
+ train_dataset,
+ batch_size=batch_size_train,
+ shuffle=True,
+ num_workers=8,
+ pin_memory=True,
+ )
+
+ return traind_data_loader,None,None, stats
+
+def compute_dict_mean(epoch_dicts):
+ result = {k: None for k in epoch_dicts[0]}
+ num_items = len(epoch_dicts)
+ for k in result:
+ value_sum = 0
+ for epoch_dict in epoch_dicts:
+ value_sum += epoch_dict[k]
+ result[k] = value_sum / num_items
+ return result
+
+
+def detach_dict(d):
+ new_d = dict()
+ for k, v in d.items():
+ new_d[k] = v.detach()
+ return new_d
+
+
+# def set_seed(seed):
+# torch.manual_seed(seed)
+# np.random.seed(seed)
+import random
+
+
+def set_seed(seed):
+ random.seed(seed) #
+ np.random.seed(seed) #
+ torch.manual_seed(seed) #
+ torch.cuda.manual_seed(seed) #
+ torch.cuda.manual_seed_all(seed) #
+
+def get_cosine_schedule_with_warmup(
+ optimizer, num_warmup_steps, num_training_steps, num_cycles=0.5, last_epoch=-1
+):
+ """
+ Create a schedule with a learning rate that decreases following the values of the cosine function between the
+ initial lr set in the optimizer to 0, after a warmup period during which it increases linearly between 0 and the
+ initial lr set in the optimizer.
+
+ Args:
+ optimizer ([~torch.optim.Optimizer]):
+ The optimizer for which to schedule the learning rate.
+ num_warmup_steps (int):
+ The number of steps for the warmup phase.
+ num_training_steps (int):
+ The total number of training steps.
+ num_cycles (float, *optional*, defaults to 0.5):
+ The number of waves in the cosine schedule (the defaults is to just decrease from the max value to 0
+ following a half-cosine).
+ last_epoch (int, *optional*, defaults to -1):
+ The index of the last epoch when resuming training.
+
+ Return:
+ torch.optim.lr_scheduler.LambdaLR with the appropriate schedule.
+ """
+
+ def lr_lambda(current_step):
+ if current_step < num_warmup_steps:
+ return float(current_step) / float(max(1, num_warmup_steps))
+ progress = float(current_step - num_warmup_steps) / float(
+ max(1, num_training_steps - num_warmup_steps)
+ )
+ return max(
+ 0.0, 0.5 * (1.0 + math.cos(math.pi * float(num_cycles) * 2.0 * progress))
+ )
+
+ return LambdaLR(optimizer, lr_lambda, last_epoch)
+
+
+def get_constant_schedule(optimizer, last_epoch: int = -1) -> LambdaLR:
+ """
+ Create a schedule with a constant learning rate, using the learning rate set in optimizer.
+
+ Args:
+ optimizer ([`~torch.optim.Optimizer`]):
+ The optimizer for which to schedule the learning rate.
+ last_epoch (`int`, *optional*, defaults to -1):
+ The index of the last epoch when resuming training.
+
+ Return:
+ `torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule.
+ """
+ return LambdaLR(optimizer, lambda _: 1, last_epoch=last_epoch)
+
+
+def normalize_data(action_data, stats, norm_type, data_type="action"):
+
+ if norm_type == "minmax":
+ action_max = torch.from_numpy(stats[data_type + "_max"]).float().to(action_data.device)
+ action_min = torch.from_numpy(stats[data_type + "_min"]).float().to(action_data.device)
+ action_data = (action_data - action_min) / (action_max - action_min) * 2 - 1
+ elif norm_type == "gaussian":
+ action_mean = torch.from_numpy(stats[data_type + "_mean"]).float().to(action_data.device)
+ action_std = torch.from_numpy(stats[data_type + "_std"]).float().to(action_data.device)
+ action_data = (action_data - action_mean) / action_std
+ return action_data
+
+
+def convert_weight(obj):
+ newmodel = OrderedDict()
+ for k, v in obj.items():
+ if k.startswith("module."):
+ newmodel[k[7:]] = v
+ else:
+ newmodel[k] = v
+ return newmodel
+
+
+if __name__ == "__main__":
+ train_dataloader,_,_,stats = load_data_unified(
+ data_dir='/home/algo/anyrobot/Anyrobot_RoboTwin_Challenge/policy/RDT/training_data/rdt_real_multitask',
+ camera_names=['cam_high', 'cam_left_wrist', 'cam_right_wrist'],
+ batch_size_train=32,
+ chunk_size=100,
+ img_aug=True,
+ )
+
+ for i, (image_data, qpos_data, action_data, is_pad, instruction_data) in enumerate(
+ tqdm(train_dataloader, desc="Data loading")
+ ):
+ if i == 0:
+ print(f"Batch {i}:")
+ print(f"Image data shape: {image_data.shape} {image_data.max()}")
+ print(f"Qpos data shape: {qpos_data.shape} {qpos_data.max()}" )
+ print(f"Action data shape: {action_data.shape} {action_data.max()}")
+ print(f"Is pad shape: {is_pad.shape}")
+ print(f"Instruction data shape: {instruction_data.shape}")
+
+ continue
+
\ No newline at end of file