diff --git a/.gitattributes b/.gitattributes index a6344aac8c09253b3b630fb776ae94478aa0275b..a1756f5b220600b99b6b2ea7fa92a4a8acba46a4 100644 --- a/.gitattributes +++ b/.gitattributes @@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text *.zip filter=lfs diff=lfs merge=lfs -text *.zst filter=lfs diff=lfs merge=lfs -text *tfevents* filter=lfs diff=lfs merge=lfs -text +ACT_DP_multitask/detr/models/mr_mg/media/model.gif filter=lfs diff=lfs merge=lfs -text diff --git a/ACT_DP_multitask/README.md b/ACT_DP_multitask/README.md new file mode 100644 index 0000000000000000000000000000000000000000..4e37f4582871f4b531e82eb67f6dddac6b93749e --- /dev/null +++ b/ACT_DP_multitask/README.md @@ -0,0 +1,16 @@ +### Install +``` +cd policy/ACT-DP-TP +cd detr +pip install -e . +cd .. +cd Cosmos-Tokenizer +pip install -e . +#upload policy/ACT-DP-TP/Cosmos-Tokenizer/pretrained_ckpts +``` +### Command +``` +#data_dir: policy/ACT-DP-TP/data_zarr +cd policy/ACT-DP-TP +bash scripts/act_dp_tp/train.sh bottle_adjust 300 20 20 0 +``` diff --git a/ACT_DP_multitask/base.yaml b/ACT_DP_multitask/base.yaml new file mode 100644 index 0000000000000000000000000000000000000000..01b47f3ae774ad8fd58008b60fe3134be004dd05 --- /dev/null +++ b/ACT_DP_multitask/base.yaml @@ -0,0 +1,71 @@ +common: + # The number of historical images + img_history_size: 2 + # The number of future actions to predict + action_chunk_size: 64 + # The number of cameras to be used in the model + num_cameras: 3 + # Dimension for state/action, we use the same space for both state and action + # This MUST be equal to configs/state_vec.py + state_dim: 128 + + +dataset: + # We will extract the data from raw dataset + # and store them in the disk buffer by producer + # When training, we will read the data + # randomly from the buffer by consumer + # The producer will replace the data which has been + # read by the consumer with new data + + # The path to the buffer (at least 400GB) + buf_path: /path/to/buffer + # The number of chunks in the buffer + buf_num_chunks: 512 + # The number of samples (step rather than episode) in each chunk + buf_chunk_size: 512 + + # We will filter the episodes with length less than `epsd_len_thresh_low` + epsd_len_thresh_low: 32 + # For those more than `epsd_len_thresh_high`, + # we will randomly sample `epsd_len_thresh_high` steps each time we load the episode + # to better balance the training datasets + epsd_len_thresh_high: 2048 + # How to fit the image size + image_aspect_ratio: pad + # Maximum number of language tokens + tokenizer_max_length: 1024 + +model: + # Config for condition adpators + lang_adaptor: mlp2x_gelu + img_adaptor: mlp2x_gelu + state_adaptor: mlp3x_gelu + lang_token_dim: 4096 + img_token_dim: 1152 + # Dim of action or proprioception vector + # A `state` refers to an action or a proprioception vector + state_token_dim: 128 + # Config for RDT structure + rdt: + # 1B: num_head 32 hidden_size 2048 + hidden_size: 2048 + depth: 28 + num_heads: 32 + cond_pos_embed_type: multimodal + # For noise scheduler + noise_scheduler: + type: ddpm + num_train_timesteps: 1000 + num_inference_timesteps: 5 + beta_schedule: squaredcos_cap_v2 # Critical choice + prediction_type: sample + clip_sample: False + # For EMA (params averaging) + # We do not use EMA currently + ema: + update_after_step: 0 + inv_gamma: 1.0 + power: 0.75 + min_value: 0.0 + max_value: 0.9999 diff --git a/ACT_DP_multitask/detr/LICENSE b/ACT_DP_multitask/detr/LICENSE new file mode 100644 index 0000000000000000000000000000000000000000..b1395e94b016dd1b95b4c7e3ed493e1d0b342917 --- /dev/null +++ b/ACT_DP_multitask/detr/LICENSE @@ -0,0 +1,201 @@ + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright 2020 - present, Facebook, Inc + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. diff --git a/ACT_DP_multitask/detr/README.md b/ACT_DP_multitask/detr/README.md new file mode 100644 index 0000000000000000000000000000000000000000..500b1b8d01108f8ff99b2c505a58cdd43a546fee --- /dev/null +++ b/ACT_DP_multitask/detr/README.md @@ -0,0 +1,9 @@ +This part of the codebase is modified from DETR https://github.com/facebookresearch/detr under APACHE 2.0. + + @article{Carion2020EndtoEndOD, + title={End-to-End Object Detection with Transformers}, + author={Nicolas Carion and Francisco Massa and Gabriel Synnaeve and Nicolas Usunier and Alexander Kirillov and Sergey Zagoruyko}, + journal={ArXiv}, + year={2020}, + volume={abs/2005.12872} + } \ No newline at end of file diff --git a/ACT_DP_multitask/detr/__pycache__/main.cpython-310.pyc b/ACT_DP_multitask/detr/__pycache__/main.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3c35be3fc0b1fd6b108e52648f6223818f7959dc Binary files /dev/null and b/ACT_DP_multitask/detr/__pycache__/main.cpython-310.pyc differ diff --git a/ACT_DP_multitask/detr/__pycache__/main.cpython-37.pyc b/ACT_DP_multitask/detr/__pycache__/main.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..598253e19babf5a1f2807b842afe1fc88f3332bc Binary files /dev/null and b/ACT_DP_multitask/detr/__pycache__/main.cpython-37.pyc differ diff --git a/ACT_DP_multitask/detr/detr.egg-info/PKG-INFO b/ACT_DP_multitask/detr/detr.egg-info/PKG-INFO new file mode 100644 index 0000000000000000000000000000000000000000..595dd87f3aa7838e4a3dc6787e6b130830496b7c --- /dev/null +++ b/ACT_DP_multitask/detr/detr.egg-info/PKG-INFO @@ -0,0 +1,17 @@ +Metadata-Version: 2.2 +Name: detr +Version: 0.0.0 +License: MIT License +License-File: LICENSE +Dynamic: description +Dynamic: license + +This part of the codebase is modified from DETR https://github.com/facebookresearch/detr under APACHE 2.0. + + @article{Carion2020EndtoEndOD, + title={End-to-End Object Detection with Transformers}, + author={Nicolas Carion and Francisco Massa and Gabriel Synnaeve and Nicolas Usunier and Alexander Kirillov and Sergey Zagoruyko}, + journal={ArXiv}, + year={2020}, + volume={abs/2005.12872} + } diff --git a/ACT_DP_multitask/detr/detr.egg-info/SOURCES.txt b/ACT_DP_multitask/detr/detr.egg-info/SOURCES.txt new file mode 100644 index 0000000000000000000000000000000000000000..effae1d8c053e3533179a4afa121adc37953335b --- /dev/null +++ b/ACT_DP_multitask/detr/detr.egg-info/SOURCES.txt @@ -0,0 +1,37 @@ +LICENSE +README.md +setup.py +detr.egg-info/PKG-INFO +detr.egg-info/SOURCES.txt +detr.egg-info/dependency_links.txt +detr.egg-info/top_level.txt +models/__init__.py +models/backbone.py +models/detr_vae.py +models/detr_vae_nfp.py +models/position_encoding.py +models/transformer.py +models/vision_transformer.py +models/mask_former/__init__.py +models/mask_former/config.py +models/mask_former/mask_former_model.py +models/mask_former/test_time_augmentation.py +models/mask_former/modeling/__init__.py +models/mask_former/modeling/criterion.py +models/mask_former/modeling/matcher.py +models/mask_former/modeling/backbone/__init__.py +models/mask_former/modeling/backbone/swin.py +models/mask_former/modeling/heads/__init__.py +models/mask_former/modeling/heads/mask_former_head.py +models/mask_former/modeling/heads/per_pixel_baseline.py +models/mask_former/modeling/heads/pixel_decoder.py +models/mask_former/modeling/transformer/__init__.py +models/mask_former/modeling/transformer/position_encoding.py +models/mask_former/modeling/transformer/transformer.py +models/mask_former/modeling/transformer/transformer_predictor.py +models/mask_former/utils/__init__.py +models/mask_former/utils/misc.py +util/__init__.py +util/box_ops.py +util/misc.py +util/plot_utils.py \ No newline at end of file diff --git a/ACT_DP_multitask/detr/detr.egg-info/dependency_links.txt b/ACT_DP_multitask/detr/detr.egg-info/dependency_links.txt new file mode 100644 index 0000000000000000000000000000000000000000..8b137891791fe96927ad78e64b0aad7bded08bdc --- /dev/null +++ b/ACT_DP_multitask/detr/detr.egg-info/dependency_links.txt @@ -0,0 +1 @@ + diff --git a/ACT_DP_multitask/detr/detr.egg-info/top_level.txt b/ACT_DP_multitask/detr/detr.egg-info/top_level.txt new file mode 100644 index 0000000000000000000000000000000000000000..c406055464420fa96c35835aa35153fb064acf68 --- /dev/null +++ b/ACT_DP_multitask/detr/detr.egg-info/top_level.txt @@ -0,0 +1,2 @@ +models +util diff --git a/ACT_DP_multitask/detr/main.py b/ACT_DP_multitask/detr/main.py new file mode 100644 index 0000000000000000000000000000000000000000..c34ae360374046d5bb0958bf6f87c259a5c7ce51 --- /dev/null +++ b/ACT_DP_multitask/detr/main.py @@ -0,0 +1,763 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved +import argparse +from pathlib import Path +import os +import numpy as np +import torch +from .models import * + +import IPython + +e = IPython.embed + + +def get_args_parser(): + parser = argparse.ArgumentParser("Set transformer detector", add_help=False) + parser.add_argument("--ckpt_path", type=str, default='policy/ACT_DP_multitask/checkpoints/real_fintune_50_2000/act_dp') + parser.add_argument("--eval_ckpts", default=0, type=int, help="eval_ckpts") + parser.add_argument("--eval_video_log", action="store_true") + parser.add_argument("--action_interval", default=1, type=int) + parser.add_argument("--lr", default=1e-4, type=float) # will be overridden + parser.add_argument("--lr_backbone", default=1e-5, type=float) # will be overridden + parser.add_argument( + "--lr_schedule_type", default="constant", type=str, help="lr_schedule_type" + ) + parser.add_argument( + "--num_episodes", type=int, help="num_epochs", default=0, required=False + ) + parser.add_argument("--batch_size", default=2, type=int) # not used + parser.add_argument( + "--samples_per_epoch", + default=1, + type=int, + help="samples_per_epoch", + required=False, + ) + parser.add_argument("--weight_decay", default=1e-4, type=float) + parser.add_argument("--epochs", default=300, type=int) # not used + parser.add_argument("--lr_drop", default=200, type=int) # not used + parser.add_argument( + "--clip_max_norm", + default=0.1, + type=float, # not used + help="gradient clipping max norm", + ) + parser.add_argument("--norm_type", default="meanstd", type=str, help="norm_type") + parser.add_argument( + "--num_train_steps", default=50, type=int, help="num_train_steps" + ) + parser.add_argument( + "--num_inference_steps", default=10, type=int, help="num_inference_steps" + ) + parser.add_argument( + "--schedule_type", default="DDIM", type=str, help="scheduler_type" + ) + parser.add_argument( + "--imitate_weight", default=1, type=int, help="imitate Weight", required=False + ) + parser.add_argument( + "--prediction_type", default="sample", type=str, help="prediction_type" + ) + parser.add_argument( + "--beta_schedule", default="squaredcos_cap_v2", type=str, help="prediction_type" + ) + parser.add_argument( + "--diffusion_timestep_type", + default="cat", + type=str, + help="diffusion_timestep_type, cat or add, how to combine timestep", + ) + parser.add_argument( + "--condition_type", + default="cross_attention", + type=str, + help="diffusion_condition_type, cross_attention or adaLN, how to combine observation condition", + ) + parser.add_argument("--attention_type", default="v0", help="decoder attention type") + parser.add_argument( + "--causal_mask", action="store_true", help="use causal mask for diffusion" + ) + parser.add_argument("--loss_type", default="l2", type=str, help="loss_type") + parser.add_argument( + "--disable_vae_latent", + action="store_true", + help="Use VAE latent space by default", + ) + parser.add_argument( + "--disable_resnet", + action="store_true", + help="Use resnet to encode obs image by default", + ) + parser.add_argument( + "--disable_scale", + action="store_true", + help="scale model up", + ) + parser.add_argument( + "--inference_num_queries", + default=0, + type=int, + help="inference_num_queries", + required=False, + ) # predict_frame + parser.add_argument( + "--disable_resize", action="store_true", help="if resize jpeg image" + ) + parser.add_argument( + "--share_decoder", action="store_true", help="jpeg and action share decoder" + ) + parser.add_argument( + "--resize_rate", + default=1, + type=int, + help="resize rate for pixel prediction", + required=False, + ) + parser.add_argument( + "--image_downsample_rate", + default=1, + type=int, + help="image_downsample_rate", + required=False, + ) + parser.add_argument( + "--temporal_downsample_rate", + default=1, + type=int, + help="temporal_downsample_rate", + required=False, + ) + # Model parameters external + parser.add_argument("--test_num", default=50, type=int, help="test_num") + parser.add_argument("--save_episode", action="store_true") + parser.add_argument( + "--depth_mode", + default="None", + type=str, + help="use depth/depth+coordinate/None. ALL/Single/None", + ) + parser.add_argument( + "--pc_mode", default="pc_camera", type=str, help="pc_world/pc_camera" + ) + parser.add_argument( + "--disable_multi_view", action="store_true", help="Use multi-view rgb images" + ) + # * Backbone + parser.add_argument( + "--backbone", + default="resnet18", + type=str, # will be overridden + help="Name of the convolutional backbone to use", + ) + parser.add_argument( + "--dilation", + action="store_true", + help="If true, we replace stride with dilation in the last convolutional block (DC5)", + ) + parser.add_argument( + "--position_embedding", + default="sine", + type=str, + choices=("sine", "learned"), + help="Type of positional embedding to use on top of the image features", + ) + parser.add_argument( + "--camera_names", + default=[], + type=list, # will be overridden + help="A list of camera names", + ) + + # * Transformer + parser.add_argument( + "--enc_layers", + default=4, + type=int, # will be overridden + help="Number of encoding layers in the transformer", + ) + parser.add_argument( + "--dec_layers", + default=6, + type=int, # will be overridden + help="Number of decoding layers in the transformer", + ) + parser.add_argument( + "--dim_feedforward", + default=2048, + type=int, # will be overridden + help="Intermediate size of the feedforward layers in the transformer blocks", + ) + parser.add_argument( + "--hidden_dim", + default=256, + type=int, # will be overridden + help="Size of the embeddings (dimension of the transformer)", + ) + parser.add_argument( + "--dropout", default=0.1, type=float, help="Dropout applied in the transformer" + ) + parser.add_argument( + "--nheads", + default=8, + type=int, # will be overridden + help="Number of attention heads inside the transformer's attentions", + ) + parser.add_argument( + "--num_queries", + default=400, + type=int, # will be overridden + help="Number of query slots", + ) + parser.add_argument("--pre_norm", action="store_true") + + # # * Segmentation + parser.add_argument( + "--masks", + action="store_true", + help="Train segmentation head if the flag is provided", + ) + + # repeat args in imitate_episodes just to avoid error. Will not be used + parser.add_argument("--eval", action="store_true") + parser.add_argument("--onscreen_render", action="store_true") + parser.add_argument( + "--ckpt_dir", action="store", type=str, help="ckpt_dir", required=False + ) + parser.add_argument( + "--policy_class", + action="store", + type=str, + help="policy_class, capitalize", + required=False, + ) + parser.add_argument( + "--task_name", action="store", type=str, help="task_name", required=False + ) + parser.add_argument("--seed", action="store", type=int, help="seed", required=False) + parser.add_argument( + "--num_epochs", action="store", type=int, help="num_epochs", required=False + ) + parser.add_argument( + "--kl_weight", action="store", type=int, help="KL Weight", required=False + ) + parser.add_argument( + "--save_epoch", + action="store", + type=int, + help="save_epoch", + default=500, + required=False, + ) + parser.add_argument( + "--chunk_size", action="store", type=int, help="chunk_size", required=False + ) + parser.add_argument( + "--history_step", default=0, type=int, help="history_step", required=False + ) + parser.add_argument( + "--predict_frame", default=0, type=int, help="predict_frame", required=False + ) + # add image_width and image_height + parser.add_argument( + "--image_width", default=320, type=int, help="image_width", required=False + ) + parser.add_argument( + "--image_height", default=240, type=int, help="image_height", required=False + ) + parser.add_argument( + "--predict_only_last", action="store_true" + ) # only predict the last #predict_frame frame + parser.add_argument("--temporal_agg", action="store_true") + # visual tokenizer + parser.add_argument( + "--tokenizer_model_type", + default="DV", + type=str, + help="tokenizer_model_type, DV,CV,DI,CI", + ) + parser.add_argument( + "--tokenizer_model_temporal_rate", + default=8, + type=int, + help="tokenizer_model_temporal_rate, 4,8", + ) + parser.add_argument( + "--tokenizer_model_spatial_rate", + default=16, + type=int, + help="tokenizer_model_spatial_rate, 8,16", + ) + parser.add_argument( + "--tokenizer_model_name", + default="Cosmos-Tokenizer-DV4x8x8", + type=str, + help="tokenizer_model_name", + ) + parser.add_argument( + "--prediction_weight", + default=1, + type=float, + help="pred token Weight", + required=False, + ) + parser.add_argument( + "--token_dim", default=6, type=int, help="token_dim", required=False + ) # token_pe_type + parser.add_argument( + "--patch_size", default=5, type=int, help="patch_size", required=False + ) # token_pe_type + parser.add_argument( + "--token_pe_type", + default="learned", + type=str, + help="token_pe_type", + required=False, + ) + parser.add_argument("--nf", action="store_true") + parser.add_argument("--pretrain", action="store_true", required=False) + parser.add_argument("--is_wandb", action="store_true") + parser.add_argument("--mae", action="store_true") + # parser.add_argument('--seg', action='store_true') + # parser.add_argument('--seg_next', action='store_true') + + # parameters for distributed training + parser.add_argument( + "--resume", + default="", + type=str, + metavar="PATH", + help="path to latest checkpoint (default: none)", + ) + parser.add_argument( + "--world-size", + default=-1, + type=int, + help="number of nodes for distributed training", + ) + parser.add_argument( + "--rank", default=-1, type=int, help="node rank for distributed training" + ) + parser.add_argument( + "--dist-url", + default="tcp://224.66.41.62:23456", + type=str, + help="url used to set up distributed training", + ) + parser.add_argument( + "--dist-backend", default="nccl", type=str, help="distributed backend" + ) + # parser.add_argument( + # "--seed", default=None, type=int, help="seed for initializing training. " + # ) + parser.add_argument("--gpu", default=None, type=int, help="GPU id to use.") + parser.add_argument( + "--multiprocessing-distributed", + action="store_true", + help="Use multi-processing distributed training to launch " + "N processes per node, which has N GPUs. This is the " + "fastest way to use PyTorch for either single node or " + "multi node data parallel training", + ) + parser.add_argument( + "-j", + "--workers", + default=32, + type=int, + metavar="N", + help="number of data loading workers (default: 32)", + ) + + return parser + + +def build_ACT_model_and_optimizer(args_override): + parser = argparse.ArgumentParser( + "DETR training and evaluation script", parents=[get_args_parser()] + ) + args = parser.parse_args() + + for k, v in args_override.items(): + setattr(args, k, v) + + if args_override["segmentation"]: + model = build_ACT_Seg_model(args) + else: + model = build_ACT_model(args) + model.cuda() + + param_dicts = [ + { + "params": [ + p + for n, p in model.named_parameters() + if "backbone" not in n and p.requires_grad + ] + }, + { + "params": [ + p + for n, p in model.named_parameters() + if "backbone" in n and p.requires_grad + ], + "lr": args.lr_backbone, + }, + ] + optimizer = torch.optim.AdamW( + param_dicts, lr=args.lr, weight_decay=args.weight_decay + ) + + return model, optimizer + + +def build_ACTDiffusion_model_and_optimizer(args_override): + parser = argparse.ArgumentParser( + "DETR training and evaluation script", parents=[get_args_parser()] + ) + args = parser.parse_args() + for k, v in args_override.items(): + setattr(args, k, v) + # print('args',args) # get + model = build_ACTDiffusion_model(args) + model.cuda() + + param_dicts = [ + { + "params": [ + p + for n, p in model.named_parameters() + if "backbone" not in n and p.requires_grad + ] + }, + { + "params": [ + p + for n, p in model.named_parameters() + if "backbone" in n and p.requires_grad + ], + "lr": args.lr_backbone, + }, + ] + optimizer = torch.optim.AdamW( + param_dicts, lr=args.lr, weight_decay=args.weight_decay + ) + + return model, optimizer + + +def build_ACTDiffusion_tactile_model_and_optimizer(args_override): + parser = argparse.ArgumentParser( + "DETR training and evaluation script", parents=[get_args_parser()] + ) + args = parser.parse_args() + for k, v in args_override.items(): + setattr(args, k, v) + # print('args',args) # get + model = build_ACTDiffusion_tactile_model(args) + model.cuda() + + param_dicts = [ + { + "params": [ + p + for n, p in model.named_parameters() + if "backbone" not in n and p.requires_grad + ] + }, + { + "params": [ + p + for n, p in model.named_parameters() + if "backbone" in n and p.requires_grad + ], + "lr": args.lr_backbone, + }, + ] + optimizer = torch.optim.AdamW( + param_dicts, lr=args.lr, weight_decay=args.weight_decay + ) + + return model, optimizer + + +def build_diffusion_tp_model_and_optimizer(args_override): + parser = argparse.ArgumentParser( + "DETR training and evaluation script", parents=[get_args_parser()] + ) + args = parser.parse_args() + for k, v in args_override.items(): + setattr(args, k, v) + # print('args',args) # get + model = build_ACTDiffusion_tp_model(args) + model.cuda() + + return model # , optimizer + + +def build_diffusion_pp_model_and_optimizer(args_override): + parser = argparse.ArgumentParser( + "DETR training and evaluation script", parents=[get_args_parser()] + ) + args = parser.parse_args() + for k, v in args_override.items(): + setattr(args, k, v) + # print('args',args) # get + model = build_ACTDiffusion_pp_model(args) + model.cuda() + + return model + + +# discard + + +def build_ACT_NF_model_and_optimizer(args_override): + parser = argparse.ArgumentParser( + "DETR training and evaluation script", parents=[get_args_parser()] + ) + args = parser.parse_args() + + for k, v in args_override.items(): + setattr(args, k, v) + + model = build_ACT_NF_model(args) + model.cuda() + + param_dicts = [ + { + "params": [ + p + for n, p in model.named_parameters() + if "backbone" not in n and p.requires_grad + ] + }, + { + "params": [ + p + for n, p in model.named_parameters() + if "backbone" in n and p.requires_grad + ], + "lr": args.lr_backbone, + }, + ] + optimizer = torch.optim.AdamW( + param_dicts, lr=args.lr, weight_decay=args.weight_decay + ) + + return model, optimizer + + +def build_ACT_Dino_model_and_optimizer(args_override): + parser = argparse.ArgumentParser( + "DETR training and evaluation script", parents=[get_args_parser()] + ) + args = parser.parse_args() + + for k, v in args_override.items(): + setattr(args, k, v) + + model = build_ACT_dino_model(args) + model.cuda() + + param_dicts = [ + { + "params": [ + p + for n, p in model.named_parameters() + if "backbone" not in n and p.requires_grad + ] + }, + { + "params": [ + p + for n, p in model.named_parameters() + if "backbone" in n and p.requires_grad + ], + "lr": args.lr_backbone, + }, + ] + optimizer = torch.optim.AdamW( + param_dicts, lr=args.lr, weight_decay=args.weight_decay + ) + + return model, optimizer + + +def build_ACT_jpeg_model_and_optimizer(args_override): + parser = argparse.ArgumentParser( + "DETR training and evaluation script", parents=[get_args_parser()] + ) + args = parser.parse_args() + + for k, v in args_override.items(): + setattr(args, k, v) + + model = build_ACT_jpeg_model(args) + model.cuda() + + param_dicts = [ + { + "params": [ + p + for n, p in model.named_parameters() + if "backbone" not in n and p.requires_grad + ] + }, + { + "params": [ + p + for n, p in model.named_parameters() + if "backbone" in n and p.requires_grad + ], + "lr": args.lr_backbone, + }, + ] + optimizer = torch.optim.AdamW( + param_dicts, lr=args.lr, weight_decay=args.weight_decay + ) + + return model, optimizer + + +def build_ACT_jpeg_diffusion_model_and_optimizer(args_override): + parser = argparse.ArgumentParser( + "DETR training and evaluation script", parents=[get_args_parser()] + ) + args = parser.parse_args() + + for k, v in args_override.items(): + setattr(args, k, v) + + model = build_ACT_jpeg_diffusion_model(args) + model.cuda() + + param_dicts = [ + { + "params": [ + p + for n, p in model.named_parameters() + if "backbone" not in n and p.requires_grad + ] + }, + { + "params": [ + p + for n, p in model.named_parameters() + if "backbone" in n and p.requires_grad + ], + "lr": args.lr_backbone, + }, + ] + optimizer = torch.optim.AdamW( + param_dicts, lr=args.lr, weight_decay=args.weight_decay + ) + + return model, optimizer + + +def build_ACT_jpeg_diffusion_seperate_model_and_optimizer(args_override): + parser = argparse.ArgumentParser( + "DETR training and evaluation script", parents=[get_args_parser()] + ) + args = parser.parse_args() + + for k, v in args_override.items(): + setattr(args, k, v) + + model = build_ACT_jpeg_diffusion_seperate_model(args) + model.cuda() + + param_dicts = [ + { + "params": [ + p + for n, p in model.named_parameters() + if "backbone" not in n and p.requires_grad + ] + }, + { + "params": [ + p + for n, p in model.named_parameters() + if "backbone" in n and p.requires_grad + ], + "lr": args.lr_backbone, + }, + ] + optimizer = torch.optim.AdamW( + param_dicts, lr=args.lr, weight_decay=args.weight_decay + ) + + return model, optimizer + + +def build_nf_diffusion_seperate_model_and_optimizer(args_override): + parser = argparse.ArgumentParser( + "DETR training and evaluation script", parents=[get_args_parser()] + ) + args = parser.parse_args() + + for k, v in args_override.items(): + setattr(args, k, v) + + model = build_nf_diffusion_seperate_model(args) + model.cuda() + + param_dicts = [ + { + "params": [ + p + for n, p in model.named_parameters() + if "backbone" not in n and p.requires_grad + ] + }, + { + "params": [ + p + for n, p in model.named_parameters() + if "backbone" in n and p.requires_grad + ], + "lr": args.lr_backbone, + }, + ] + optimizer = torch.optim.AdamW( + param_dicts, lr=args.lr, weight_decay=args.weight_decay + ) + + return model, optimizer + + +def build_CNNMLP_model_and_optimizer(args_override): + parser = argparse.ArgumentParser( + "DETR training and evaluation script", parents=[get_args_parser()] + ) + args = parser.parse_args() + + for k, v in args_override.items(): + setattr(args, k, v) + + model = build_CNNMLP_model(args) + model.cuda() + + param_dicts = [ + { + "params": [ + p + for n, p in model.named_parameters() + if "backbone" not in n and p.requires_grad + ] + }, + { + "params": [ + p + for n, p in model.named_parameters() + if "backbone" in n and p.requires_grad + ], + "lr": args.lr_backbone, + }, + ] + optimizer = torch.optim.AdamW( + param_dicts, lr=args.lr, weight_decay=args.weight_decay + ) + + return model, optimizer + diff --git a/ACT_DP_multitask/detr/models/__init__.py b/ACT_DP_multitask/detr/models/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..ecb30ad65567274fba7157b568ac0ff042d6a2cf --- /dev/null +++ b/ACT_DP_multitask/detr/models/__init__.py @@ -0,0 +1,60 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved +from .detr_vae import build as build_vae +from .detr_vae import build_seg as build_vae_seg +from .detr_vae_nfp import build as build_vae_nfp +from .detr_vae import build_cnnmlp as build_cnnmlp +from .detr_vae import build_dino as build_dino +from .detr_vae import build_jpeg as build_jpeg +from .detr_vae import build_jpeg_diffusion as build_jpeg_diffusion +from .detr_vae import build_jpeg_diffusion_seperate as build_jpeg_diffusion_seperate +from .detr_vae import build_nf_diffusion_seperate as build_nf_diffusion_seperate +from .detr_vae import build_diffusion as build_diffusion +from .detr_vae import build_diffusion_tp as build_diffusion_tp +from .detr_vae import build_diffusion_tp_with_dual_visual_token as build_diffusion_tp_with_dual_visual_token +from .detr_vae import build_diffusion_pp as build_diffusion_pp +from .detr_vae import build_diffusion_tactile as build_diffusion_tactile + +def build_ACT_model(args): + return build_vae(args) + +def build_CNNMLP_model(args): + return build_cnnmlp(args) + +def build_ACTDiffusion_model(args): + return build_diffusion(args) + +def build_ACTDiffusion_tactile_model(args): + return build_diffusion_tactile(args) + +def build_ACTDiffusion_tp_model(args): + if args.diffusion_timestep_type == 'vis_cat': # HARDCODE whether use tokenizer feature for decoder & action prediction + print('Using dual visual token for decoder and action prediction') + return build_diffusion_tp_with_dual_visual_token(args) + else: + return build_diffusion_tp(args) + +def build_ACTDiffusion_pp_model(args): + return build_diffusion_pp(args) + +# discard +def build_ACT_NF_model(args): + return build_vae_nfp(args) + +def build_ACT_Seg_model(args): + return build_vae_seg(args) + +def build_ACT_dino_model(args): + return build_dino(args) + +def build_ACT_jpeg_model(args): + return build_jpeg(args) + +def build_ACT_jpeg_diffusion_model(args): + return build_jpeg_diffusion(args) + +def build_ACT_jpeg_diffusion_seperate_model(args): + return build_jpeg_diffusion_seperate(args) + +def build_nf_diffusion_seperate_model(args): + return build_nf_diffusion_seperate(args) + diff --git a/ACT_DP_multitask/detr/models/__pycache__/__init__.cpython-310.pyc b/ACT_DP_multitask/detr/models/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8ff241e6845d924770d91e296912b754e9bb1037 Binary files /dev/null and b/ACT_DP_multitask/detr/models/__pycache__/__init__.cpython-310.pyc differ diff --git a/ACT_DP_multitask/detr/models/__pycache__/__init__.cpython-37.pyc b/ACT_DP_multitask/detr/models/__pycache__/__init__.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f15f86d73e1cfd2c04218d8ab24b5388bac2890f Binary files /dev/null and b/ACT_DP_multitask/detr/models/__pycache__/__init__.cpython-37.pyc differ diff --git a/ACT_DP_multitask/detr/models/__pycache__/__init__.cpython-38.pyc b/ACT_DP_multitask/detr/models/__pycache__/__init__.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..75f1fbcf5fb261eb909e31debcfb0500be3f0857 Binary files /dev/null and b/ACT_DP_multitask/detr/models/__pycache__/__init__.cpython-38.pyc differ diff --git a/ACT_DP_multitask/detr/models/__pycache__/backbone.cpython-310.pyc b/ACT_DP_multitask/detr/models/__pycache__/backbone.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3057fddd7d741d12c184edcd2ee025c9b5f420f2 Binary files /dev/null and b/ACT_DP_multitask/detr/models/__pycache__/backbone.cpython-310.pyc differ diff --git a/ACT_DP_multitask/detr/models/__pycache__/backbone.cpython-37.pyc b/ACT_DP_multitask/detr/models/__pycache__/backbone.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..def5d907456559aed4e23290080c1afaace64fb3 Binary files /dev/null and b/ACT_DP_multitask/detr/models/__pycache__/backbone.cpython-37.pyc differ diff --git a/ACT_DP_multitask/detr/models/__pycache__/backbone.cpython-38.pyc b/ACT_DP_multitask/detr/models/__pycache__/backbone.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d9fd912bbdd116f42baa238c0ce16f455a093efc Binary files /dev/null and b/ACT_DP_multitask/detr/models/__pycache__/backbone.cpython-38.pyc differ diff --git a/ACT_DP_multitask/detr/models/__pycache__/detr_vae.cpython-310.pyc b/ACT_DP_multitask/detr/models/__pycache__/detr_vae.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3d163d23c7998a92d59b27ad4ff8dd5106b232f8 Binary files /dev/null and b/ACT_DP_multitask/detr/models/__pycache__/detr_vae.cpython-310.pyc differ diff --git a/ACT_DP_multitask/detr/models/__pycache__/detr_vae.cpython-37.pyc b/ACT_DP_multitask/detr/models/__pycache__/detr_vae.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e8788a652254702f82e6b06e34ac510ddee44bc7 Binary files /dev/null and b/ACT_DP_multitask/detr/models/__pycache__/detr_vae.cpython-37.pyc differ diff --git a/ACT_DP_multitask/detr/models/__pycache__/detr_vae.cpython-38.pyc b/ACT_DP_multitask/detr/models/__pycache__/detr_vae.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1d30aa9eb4e5fa98a78e531e94b1dc39b08073ef Binary files /dev/null and b/ACT_DP_multitask/detr/models/__pycache__/detr_vae.cpython-38.pyc differ diff --git a/ACT_DP_multitask/detr/models/__pycache__/detr_vae_nfp.cpython-310.pyc b/ACT_DP_multitask/detr/models/__pycache__/detr_vae_nfp.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..144a41ad3f0275e9817bc17cf6f00e5fc58706d9 Binary files /dev/null and b/ACT_DP_multitask/detr/models/__pycache__/detr_vae_nfp.cpython-310.pyc differ diff --git a/ACT_DP_multitask/detr/models/__pycache__/detr_vae_nfp.cpython-37.pyc b/ACT_DP_multitask/detr/models/__pycache__/detr_vae_nfp.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f1d9982b9c3a5358b5933b69137a0fcffc65d192 Binary files /dev/null and b/ACT_DP_multitask/detr/models/__pycache__/detr_vae_nfp.cpython-37.pyc differ diff --git a/ACT_DP_multitask/detr/models/__pycache__/detr_vae_nfp.cpython-38.pyc b/ACT_DP_multitask/detr/models/__pycache__/detr_vae_nfp.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..aad89d2ddf1fcaea00dbb7db2fd026a0685cfb45 Binary files /dev/null and b/ACT_DP_multitask/detr/models/__pycache__/detr_vae_nfp.cpython-38.pyc differ diff --git a/ACT_DP_multitask/detr/models/__pycache__/position_encoding.cpython-310.pyc b/ACT_DP_multitask/detr/models/__pycache__/position_encoding.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..31228f52a00e0f77bf929b0d0edaa3c87da75d12 Binary files /dev/null and b/ACT_DP_multitask/detr/models/__pycache__/position_encoding.cpython-310.pyc differ diff --git a/ACT_DP_multitask/detr/models/__pycache__/position_encoding.cpython-37.pyc b/ACT_DP_multitask/detr/models/__pycache__/position_encoding.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b0a4a6e6597403eadb101c3d570b8f24f4eebcef Binary files /dev/null and b/ACT_DP_multitask/detr/models/__pycache__/position_encoding.cpython-37.pyc differ diff --git a/ACT_DP_multitask/detr/models/__pycache__/position_encoding.cpython-38.pyc b/ACT_DP_multitask/detr/models/__pycache__/position_encoding.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4ed830ec82cfc73f6bd2061da7689dbc780f9d06 Binary files /dev/null and b/ACT_DP_multitask/detr/models/__pycache__/position_encoding.cpython-38.pyc differ diff --git a/ACT_DP_multitask/detr/models/__pycache__/resnet_film.cpython-310.pyc b/ACT_DP_multitask/detr/models/__pycache__/resnet_film.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ebb3e72a11b2f267f60b4307ff087167dd1f2ad4 Binary files /dev/null and b/ACT_DP_multitask/detr/models/__pycache__/resnet_film.cpython-310.pyc differ diff --git a/ACT_DP_multitask/detr/models/__pycache__/transformer.cpython-310.pyc b/ACT_DP_multitask/detr/models/__pycache__/transformer.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..444f11624b6fe0783733bec98d810f1627f996aa Binary files /dev/null and b/ACT_DP_multitask/detr/models/__pycache__/transformer.cpython-310.pyc differ diff --git a/ACT_DP_multitask/detr/models/__pycache__/transformer.cpython-37.pyc b/ACT_DP_multitask/detr/models/__pycache__/transformer.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b3b7fd65409dbb49db23166019d58475f20f9518 Binary files /dev/null and b/ACT_DP_multitask/detr/models/__pycache__/transformer.cpython-37.pyc differ diff --git a/ACT_DP_multitask/detr/models/__pycache__/transformer.cpython-38.pyc b/ACT_DP_multitask/detr/models/__pycache__/transformer.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..515275a7f9e090bf6407253ae2f7183b6831e1c0 Binary files /dev/null and b/ACT_DP_multitask/detr/models/__pycache__/transformer.cpython-38.pyc differ diff --git a/ACT_DP_multitask/detr/models/__pycache__/vision_transformer.cpython-310.pyc b/ACT_DP_multitask/detr/models/__pycache__/vision_transformer.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..761b345f1415a83be930b2a1c14c974d1016921f Binary files /dev/null and b/ACT_DP_multitask/detr/models/__pycache__/vision_transformer.cpython-310.pyc differ diff --git a/ACT_DP_multitask/detr/models/__pycache__/vision_transformer.cpython-37.pyc b/ACT_DP_multitask/detr/models/__pycache__/vision_transformer.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2eead987d8c11019e247f723f3e176b98340b352 Binary files /dev/null and b/ACT_DP_multitask/detr/models/__pycache__/vision_transformer.cpython-37.pyc differ diff --git a/ACT_DP_multitask/detr/models/__pycache__/vision_transformer.cpython-38.pyc b/ACT_DP_multitask/detr/models/__pycache__/vision_transformer.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0342901b46494e03dde2ec0a1425f321ca02ac42 Binary files /dev/null and b/ACT_DP_multitask/detr/models/__pycache__/vision_transformer.cpython-38.pyc differ diff --git a/ACT_DP_multitask/detr/models/backbone.py b/ACT_DP_multitask/detr/models/backbone.py new file mode 100644 index 0000000000000000000000000000000000000000..82fb532b5db2e03b93c83b7bea0b608a5f4d57eb --- /dev/null +++ b/ACT_DP_multitask/detr/models/backbone.py @@ -0,0 +1,209 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved +""" +Backbone modules. +""" +from collections import OrderedDict + +import torch +import torch.nn.functional as F +import torchvision +from torch import nn +from torchvision.models._utils import IntermediateLayerGetter +from typing import Dict, List +from typing import Any, Dict, List, Mapping, Optional +from ..util.misc import NestedTensor, is_main_process + +from .position_encoding import build_position_encoding +from .resnet_film import resnet18 as resnet18_film +from .resnet_film import resnet34 as resnet34_film +import IPython +e = IPython.embed + +class FrozenBatchNorm2d(torch.nn.Module): + """ + BatchNorm2d where the batch statistics and the affine parameters are fixed. + + Copy-paste from torchvision.misc.ops with added eps before rqsrt, + without which any other policy_models than torchvision.policy_models.resnet[18,34,50,101] + produce nans. + """ + + def __init__(self, n): + super(FrozenBatchNorm2d, self).__init__() + self.register_buffer("weight", torch.ones(n)) + self.register_buffer("bias", torch.zeros(n)) + self.register_buffer("running_mean", torch.zeros(n)) + self.register_buffer("running_var", torch.ones(n)) + + def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, + missing_keys, unexpected_keys, error_msgs): + num_batches_tracked_key = prefix + 'num_batches_tracked' + if num_batches_tracked_key in state_dict: + del state_dict[num_batches_tracked_key] + + super(FrozenBatchNorm2d, self)._load_from_state_dict( + state_dict, prefix, local_metadata, strict, + missing_keys, unexpected_keys, error_msgs) + + def forward(self, x): + # move reshapes to the beginning + # to make it fuser-friendly + w = self.weight.reshape(1, -1, 1, 1) + b = self.bias.reshape(1, -1, 1, 1) + rv = self.running_var.reshape(1, -1, 1, 1) + rm = self.running_mean.reshape(1, -1, 1, 1) + eps = 1e-5 + scale = w * (rv + eps).rsqrt() + bias = b - rm * scale + return x * scale + bias + + +class BackboneBase(nn.Module): + + def __init__(self, backbone: nn.Module, train_backbone: bool, num_channels: int, return_interm_layers: bool): + super().__init__() + # for name, parameter in backbone.named_parameters(): # only train later layers # TODO do we want this? + # if not train_backbone or 'layer2' not in name and 'layer3' not in name and 'layer4' not in name: + # parameter.requires_grad_(False) + if return_interm_layers: + return_layers = {"layer1": "0", "layer2": "1", "layer3": "2", "layer4": "3"} + else: + return_layers = {'layer4': "0"} + self.body = IntermediateLayerGetter(backbone, return_layers=return_layers) + self.num_channels = num_channels + + def forward(self, tensor): + xs = self.body(tensor) + return xs + # out: Dict[str, NestedTensor] = {} + # for name, x in xs.items(): + # m = tensor_list.mask + # assert m is not None + # mask = F.interpolate(m[None].float(), size=x.shape[-2:]).to(torch.bool)[0] + # out[name] = NestedTensor(x, mask) + # return out + + +class Backbone(BackboneBase): + """ResNet backbone with frozen BatchNorm.""" + def __init__(self, name: str, + train_backbone: bool, + return_interm_layers: bool, + dilation: bool): + backbone = getattr(torchvision.models, name)( + replace_stride_with_dilation=[False, False, dilation], + pretrained=is_main_process(), norm_layer=FrozenBatchNorm2d) # pretrained # TODO do we want frozen batch_norm?? + num_channels = 512 if name in ('resnet18', 'resnet34') else 2048 + super().__init__(backbone, train_backbone, num_channels, return_interm_layers) + +# ==== ResNet Backbone ==== +class ResNetFilmBackbone(nn.Module): + def __init__(self, embedding_name: str, pretrained: bool = False, + film_config: Optional[Mapping[str, Any]] = None): + super().__init__() + self._pretrained = pretrained + weights = 'IMAGENET1K_V1' if pretrained else None + if embedding_name in ('resnet34_film', 'resnet34'): + backbone = resnet34_film(weights=weights, film_config=film_config, pretrained=pretrained) + embedding_dim = 512 + elif embedding_name in ('resnet18_film', 'resnet18'): + backbone = resnet18_film(weights=weights, film_config=film_config, pretrained=pretrained) + embedding_dim = 512 + else: + raise NotImplementedError + + self.resnet_film_model = backbone + self._embedding_dim = embedding_dim + self.resnet_film_model.fc = nn.Identity() + self.resnet_film_model.avgpool = nn.Identity() + + self.num_channels = self._embedding_dim + + # FiLM config + self.film_config = film_config + if film_config is not None and film_config['use']: + film_models = [] + for layer_idx, num_blocks in enumerate(self.resnet_film_model.layers): + if layer_idx in film_config['use_in_layers']: + num_planes = self.resnet_film_model.film_planes[layer_idx] + film_model_layer = nn.Linear( + film_config['task_embedding_dim'], num_blocks * 2 * num_planes) + else: + film_model_layer = None + film_models.append(film_model_layer) + + self.film_models = nn.ModuleList(film_models) + + def forward(self, x, texts: Optional[List[str]] = None, task_emb: Optional[torch.Tensor] = None, **kwargs): + film_outputs = None + if self.film_config is not None and self.film_config['use']: + film_outputs = [] + for layer_idx, num_blocks in enumerate(self.resnet_film_model.layers): + if self.film_config['use'] and self.film_models[layer_idx] is not None: + film_features = self.film_models[layer_idx](task_emb) + else: + film_features = None + film_outputs.append(film_features) + return self.resnet_film_model(x, film_features=film_outputs, flatten=False) + + @property + def embed_dim(self): + return self._embedding_dim + + +# class Joiner(nn.Sequential): +# def __init__(self, backbone, position_embedding): +# super().__init__(backbone, position_embedding) + +# def forward(self, tensor_list: NestedTensor, task_emb:NestedTensor): +# xs = self[0](tensor_list) +# out: List[NestedTensor] = [] +# pos = [] +# for name, x in xs.items(): +# out.append(x) +# # position encoding +# pos.append(self[1](x).to(x.dtype)) + +# return out, pos + +class Joiner(nn.Sequential): + def __init__(self, backbone, position_embedding): + super().__init__(backbone, position_embedding) + + def forward(self, tensor_list: NestedTensor, task_emb: Optional[Any] = None): + if task_emb is not None: + xs = self[0](tensor_list, task_emb=task_emb) + # Make a dictionary out of the last layer outputs since we don't have IntermediateLayerGetter + xs = {'0': xs} + else: + xs = self[0](tensor_list) + out: List[NestedTensor] = [] + pos = [] + for name, x in xs.items(): + out.append(x) + # position encoding + pos.append(self[1](x).to(x.dtype)) + + return out, pos + +def build_backbone(args): + position_embedding = build_position_encoding(args) + train_backbone = args.lr_backbone > 0 + return_interm_layers = args.masks + backbone = Backbone(args.backbone, train_backbone, return_interm_layers, args.dilation) + model = Joiner(backbone, position_embedding) + model.num_channels = backbone.num_channels + return model + +def build_film_backbone(args): + position_embedding = build_position_encoding(args) + film_config = { + 'use': True, + 'use_in_layers': [1, 2, 3], + 'task_embedding_dim': 512, + 'film_planes': [64, 128, 256, 512], + } + backbone = ResNetFilmBackbone(args.backbone, film_config=film_config) + model = Joiner(backbone, position_embedding) + model.num_channels = backbone.num_channels + return model \ No newline at end of file diff --git a/ACT_DP_multitask/detr/models/detr_vae.py b/ACT_DP_multitask/detr/models/detr_vae.py new file mode 100644 index 0000000000000000000000000000000000000000..78fbbe3913bbed57b9a29f9df9b3490a578476c2 --- /dev/null +++ b/ACT_DP_multitask/detr/models/detr_vae.py @@ -0,0 +1,3193 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved +""" +DETR model and criterion classes. +""" +import torch +from torch import nn +from torch.autograd import Variable +from .backbone import build_backbone, build_film_backbone +from .transformer import * +from .vision_transformer import ( + Block, + get_2d_sincos_pos_embed, + get_2d_sincos_pos_embed_v2, +) +from einops import rearrange +import numpy as np + +import IPython + +e = IPython.embed + + +class FourierFeatureMapping(nn.Module): + def __init__(self, input_dim, mapping_size, scale=10.0): + """ + Args: + input_dim (int): input dimension. + mapping_size (int): Fourier Features output dimension. + scale (float): scale factor for frequencies. + """ + super(FourierFeatureMapping, self).__init__() + self.B = torch.randn((mapping_size, input_dim)) * scale + + def forward(self, x): + """ + Args: + x (Tensor): [batch_size, input_dim] + Returns: + Tensor: Fourier Features [batch_size, mapping_size * 2] + """ + x_proj = 2 * torch.pi * x @ self.B.T # [batch_size, mapping_size] + return torch.cat( + [torch.sin(x_proj), torch.cos(x_proj)], dim=-1 + ) # Concatenate sin and cos + + +class MLPWithFourierFeatures(nn.Module): + def __init__(self, input_dim, mapping_size, hidden_dim, output_dim): + """ + Args: + input_dim (int): input dimension. + mapping_size (int): Fourier Features output dimension. + hidden_dim (int): MLP layer layer dimension. + output_dim (int): MLP output dimension. + """ + super(MLPWithFourierFeatures, self).__init__() + self.fourier_mapping = FourierFeatureMapping(input_dim, mapping_size) + self.mlp = nn.Sequential( + nn.Linear(mapping_size * 2, hidden_dim), + nn.ReLU(), + nn.Linear(hidden_dim, hidden_dim), + nn.ReLU(), + nn.Linear(hidden_dim, output_dim), + ) + + def forward(self, x): + """ + Args: + x (Tensor): [batch_size, input_dim] + Returns: + Tensor: MLP + """ + x_mapped = self.fourier_mapping(x) # Fourier Features Mapping + output = self.mlp(x_mapped) # MLP + return output + + +def get_spatial_temporal_positional_encoding(height, width, seq_len, dim): + """ + Generate spatial-temporal positional encoding for video data. + + Args: + height (int): Height of the video frames. + width (int): Width of the video frames. + seq_len (int): Number of frames in the video. + dim (int): Embedding dimension. + device (str): Device to store the tensor. + + Returns: + torch.Tensor: Positional encoding with shape (seq_len, height, width, dim). + """ + assert dim % 2 == 0, "Embedding dimension (dim) must be even." + + spatial_dim = dim // 2 + temporal_dim = dim - spatial_dim + + # Temporal encoding + temporal_encoding = get_sinusoid_encoding_table(seq_len, temporal_dim)[ + 0 + ] # Shape: (seq_len, temporal_dim) + + # Spatial encoding + position_h = torch.arange(height, dtype=torch.float32).unsqueeze(1) # (height, 1) + position_w = torch.arange(width, dtype=torch.float32).unsqueeze(1) # (1, width) + div_term_h = torch.exp( + -torch.arange(0, spatial_dim, 2, dtype=torch.float32) + * (math.log(10000.0) / spatial_dim) + ).unsqueeze(0) + div_term_w = torch.exp( + -torch.arange(0, spatial_dim, 2, dtype=torch.float32) + * (math.log(10000.0) / spatial_dim) + ).unsqueeze(0) + + spatial_encoding_h = torch.sin( + position_h * div_term_h + ) # (height, spatial_dim // 2) + spatial_encoding_w = torch.cos(position_w * div_term_w) # (width, spatial_dim // 2) + + # print(spatial_encoding_h.shape, spatial_encoding_w.shape) + # Combine H and W spatial encodings + spatial_encoding_h = spatial_encoding_h.unsqueeze(1).expand( + -1, width, -1 + ) # (height, width£¬ spatial_dim // 2) + spatial_encoding_w = spatial_encoding_w.unsqueeze(0).expand( + height, -1, -1 + ) # (height, width , spatial_dim // 2) + spatial_encoding = torch.cat( + [spatial_encoding_h, spatial_encoding_w], dim=-1 + ) # (height, width , spatial_dim) + spatial_encoding = spatial_encoding.unsqueeze(0).repeat( + seq_len, 1, 1, 1 + ) # (seq_len, height, width , spatial_dim) + + # Combine spatial and temporal + temporal_encoding = ( + temporal_encoding.unsqueeze(1).unsqueeze(1).repeat(1, height, width, 1) + ) # (seq_len, height, width, temporal_dim) + # print(spatial_encoding.shape, temporal_encoding.shape) + pos_encoding = torch.cat( + [spatial_encoding, temporal_encoding], dim=-1 + ) # Combine spatial and temporal + + return pos_encoding # (seq_len, height, width, dim) + + +def get_sinusoid_encoding_table(n_position, d_hid): + def get_position_angle_vec(position): + return [ + position / np.power(10000, 2 * (hid_j // 2) / d_hid) + for hid_j in range(d_hid) + ] + + sinusoid_table = np.array( + [get_position_angle_vec(pos_i) for pos_i in range(n_position)] + ) + # print(sinusoid_table.shape) + sinusoid_table[:, 0::2] = np.sin(sinusoid_table[:, 0::2]) # dim 2i + sinusoid_table[:, 1::2] = np.cos(sinusoid_table[:, 1::2]) # dim 2i+1 + + return torch.FloatTensor(sinusoid_table).unsqueeze(0) + + +def reparametrize(mu, logvar): + std = logvar.div(2).exp() + eps = Variable(std.data.new(std.size()).normal_()) + return mu + std * eps + + +def get_1d_sincos_pos_embed_from_grid(embed_dim, pos): + """ + embed_dim: output dimension for each position + pos: a list of positions to be encoded: size (M,) + out: (M, D) + """ + assert embed_dim % 2 == 0 + omega = np.arange(embed_dim // 2, dtype=np.float64) + omega /= embed_dim / 2.0 + omega = 1.0 / 10000**omega # (D/2,) + + if not isinstance(pos, np.ndarray): + pos = np.array(pos, dtype=np.float64) + pos = pos.reshape(-1) # (M,) + out = np.einsum("m,d->md", pos, omega) # (M, D/2), outer product + + emb_sin = np.sin(out) # (M, D/2) + emb_cos = np.cos(out) # (M, D/2) + + emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D) + return emb + + +def get_nd_sincos_pos_embed_from_grid(embed_dim, grid_sizes): # useful + """ + embed_dim: output dimension for each position + grid_sizes: the grids sizes in each dimension (K,). for example [T, C, P] n_his, n_view,n_patch + out: (grid_sizes[0], ..., grid_sizes[K-1], D) for example (T, C, P, D) + """ + num_sizes = len(grid_sizes) + # For grid size of 1, we do not need to add any positional embedding + num_valid_sizes = len([x for x in grid_sizes if x > 1]) + emb = np.zeros(grid_sizes + (embed_dim,)) + # Uniformly divide the embedding dimension for each grid size + dim_for_each_grid = embed_dim // num_valid_sizes + # To make it even + if dim_for_each_grid % 2 != 0: + dim_for_each_grid -= 1 + valid_size_idx = 0 + for size_idx in range(num_sizes): + grid_size = grid_sizes[size_idx] + if grid_size <= 1: + continue + pos = np.arange(grid_size) + posemb_shape = [1] * len(grid_sizes) + [dim_for_each_grid] + posemb_shape[size_idx] = -1 + emb[ + ..., + valid_size_idx + * dim_for_each_grid : (valid_size_idx + 1) + * dim_for_each_grid, + ] += get_1d_sincos_pos_embed_from_grid(dim_for_each_grid, pos).reshape( + posemb_shape + ) + valid_size_idx += 1 + return emb + + +class DETRVAE(nn.Module): + """This is the DETR module that performs object detection""" + + def __init__( + self, backbones, transformer, encoder, state_dim, num_queries, camera_names + ): + """Initializes the model. + Parameters: + backbones: torch module of the backbone to be used. See backbone.py + transformer: torch module of the transformer architecture. See transformer.py + state_dim: robot state dimension of the environment + num_queries: number of object queries, ie detection slot. This is the maximal number of objects + DETR can detect in a single image. For COCO, we recommend 100 queries. + aux_loss: True if auxiliary decoding losses (loss at each decoder layer) are to be used. + """ + super().__init__() + self.num_queries = num_queries + self.camera_names = camera_names + self.transformer = transformer + self.encoder = encoder + hidden_dim = transformer.d_model + self.action_head = nn.Linear(hidden_dim, state_dim) + self.is_pad_head = nn.Linear(hidden_dim, 1) + self.query_embed = nn.Embedding(num_queries, hidden_dim) + if backbones is not None: + self.input_proj = nn.Conv2d( + backbones[0].num_channels, hidden_dim, kernel_size=1 + ) + self.backbones = nn.ModuleList(backbones) + self.input_proj_robot_state = nn.Linear(14, hidden_dim) + else: + # input_dim = 14 + 7 # robot_state + env_state + self.input_proj_robot_state = nn.Linear(14, hidden_dim) + self.input_proj_env_state = nn.Linear(7, hidden_dim) + self.pos = torch.nn.Embedding(2, hidden_dim) + self.backbones = None + + # encoder extra parameters + self.latent_dim = 32 # final size of latent z # TODO tune + self.cls_embed = nn.Embedding(1, hidden_dim) # extra cls token embedding + self.encoder_action_proj = nn.Linear( + 14, hidden_dim + ) # project action to embedding + self.encoder_joint_proj = nn.Linear(14, hidden_dim) # project qpos to embedding + self.latent_proj = nn.Linear( + hidden_dim, self.latent_dim * 2 + ) # project hidden state to latent std, var + self.register_buffer( + "pos_table", get_sinusoid_encoding_table(1 + 1 + num_queries, hidden_dim) + ) # [CLS], qpos, a_seq + + # decoder extra parameters + self.latent_out_proj = nn.Linear( + self.latent_dim, hidden_dim + ) # project latent sample to embedding + self.additional_pos_embed = nn.Embedding( + 2, hidden_dim + ) # learned position embedding for proprio and latent + + def forward(self, qpos, image, env_state, actions=None, is_pad=None): + """ + qpos: batch, qpos_dim + image: batch, num_cam, channel, height, width + env_state: None + actions: batch, seq, action_dim + """ + is_training = actions is not None # train or val + bs, _ = qpos.shape + ### Obtain latent z from action sequence + if is_training: + # project action sequence to embedding dim, and concat with a CLS token + action_embed = self.encoder_action_proj(actions) # (bs, seq, hidden_dim) + qpos_embed = self.encoder_joint_proj(qpos) # (bs, hidden_dim) + qpos_embed = torch.unsqueeze(qpos_embed, axis=1) # (bs, 1, hidden_dim) + cls_embed = self.cls_embed.weight # (1, hidden_dim) + cls_embed = torch.unsqueeze(cls_embed, axis=0).repeat( + bs, 1, 1 + ) # (bs, 1, hidden_dim) + encoder_input = torch.cat( + [cls_embed, qpos_embed, action_embed], axis=1 + ) # (bs, seq+1, hidden_dim) + encoder_input = encoder_input.permute(1, 0, 2) # (seq+1, bs, hidden_dim) + # do not mask cls token + cls_joint_is_pad = torch.full((bs, 2), False).to( + qpos.device + ) # False: not a padding + is_pad = torch.cat([cls_joint_is_pad, is_pad], axis=1) # (bs, seq+1) + # obtain position embedding + pos_embed = self.pos_table.clone().detach() + pos_embed = pos_embed.permute(1, 0, 2) # (seq+1, 1, hidden_dim) + # query model + encoder_output = self.encoder( + encoder_input, pos=pos_embed, src_key_padding_mask=is_pad + ) + encoder_output = encoder_output[0] # take cls output only + latent_info = self.latent_proj(encoder_output) + mu = latent_info[:, : self.latent_dim] + logvar = latent_info[:, self.latent_dim :] + latent_sample = reparametrize(mu, logvar) + latent_input = self.latent_out_proj(latent_sample) + else: + mu = logvar = None + latent_sample = torch.zeros([bs, self.latent_dim], dtype=torch.float32).to( + qpos.device + ) + latent_input = self.latent_out_proj(latent_sample) + + if self.backbones is not None: + # Image observation features and position embeddings + all_cam_features = [] + all_cam_pos = [] + for cam_id, cam_name in enumerate(self.camera_names): + features, pos = self.backbones[0](image[:, cam_id]) # HARDCODED + features = features[0] # take the last layer feature + pos = pos[0] + all_cam_features.append(self.input_proj(features)) + all_cam_pos.append(pos) + # proprioception features + proprio_input = self.input_proj_robot_state(qpos) + # fold camera dimension into width dimension + src = torch.cat(all_cam_features, axis=3) + pos = torch.cat(all_cam_pos, axis=3) + hs = self.transformer( + src, + None, + self.query_embed.weight, + pos, + latent_input, + proprio_input, + self.additional_pos_embed.weight, + )[0] + else: + qpos = self.input_proj_robot_state(qpos) + env_state = self.input_proj_env_state(env_state) + transformer_input = torch.cat([qpos, env_state], axis=1) # seq length = 2 + hs = self.transformer( + transformer_input, None, self.query_embed.weight, self.pos.weight + )[0] + a_hat = self.action_head(hs) + is_pad_hat = self.is_pad_head(hs) + return a_hat, is_pad_hat, [mu, logvar] + + +class DETRVAE_Denoise(nn.Module): + """This is the DETR module that performs object detection""" + + def __init__( + self, + backbones, + transformer, + encoder, + state_dim, + num_queries, + camera_names, + history_step, + disable_vae_latent, + is_multi_task=False, + use_film=False + ): + """Initializes the model. + Parameters: + backbones: torch module of the backbone to be used. See backbone.py + transformer: torch module of the transformer architecture. See transformer.py + state_dim: robot state dimension of the environment + num_queries: number of object queries, ie detection slot. This is the maximal number of objects + DETR can detect in a single image. For COCO, we recommend 100 queries. + aux_loss: True if auxiliary decoding losses (loss at each decoder layer) are to be used. + """ + super().__init__() + self.num_queries = num_queries + self.camera_names = camera_names + self.transformer = transformer + self.encoder = encoder + self.disable_vae_latent = disable_vae_latent + hidden_dim = transformer.d_model # 512 + self.hidden_dim = hidden_dim + # self.action_head = nn.Linear(hidden_dim, state_dim) + self.action_head = nn.Sequential( + nn.Linear(hidden_dim, hidden_dim), + nn.SiLU(), + nn.Linear(hidden_dim, state_dim), # TODO add layers + ) + self.is_pad_head = nn.Linear(hidden_dim, 1) + self.query_embed = nn.Embedding(num_queries, hidden_dim) + + self.multi_task = is_multi_task + self.use_film = use_film + + print('use_film', use_film) + print('is_multi_task', is_multi_task) + + if backbones is not None: + self.input_proj = nn.Conv2d( + backbones[0].num_channels * (history_step + 1) , hidden_dim, kernel_size=1 + ) + self.backbones = nn.ModuleList(backbones) + self.input_proj_robot_state = nn.Linear(14, hidden_dim) # proprioception + else: + # input_dim = 14 + 7 # robot_state + env_state + self.input_proj_robot_state = nn.Linear(14, hidden_dim) + self.input_proj_env_state = nn.Linear(7, hidden_dim) + self.pos = torch.nn.Embedding(2, hidden_dim) + self.backbones = None + + # encoder extra parameters + self.latent_dim = 32 # final size of latent z # TODO tune + self.cls_embed = nn.Embedding(1, hidden_dim) # extra cls token embedding + self.encoder_action_proj = nn.Linear( + 14, hidden_dim + ) # project action to embedding + self.encoder_joint_proj = nn.Linear(14, hidden_dim) # project qpos to embedding + self.latent_proj = nn.Linear( + hidden_dim, self.latent_dim * 2 + ) # project hidden state to latent std, var + self.register_buffer( + "pos_table", get_sinusoid_encoding_table(1 + 1 + history_step + num_queries, hidden_dim) + ) # [CLS], qpos, a_seq + self.history_step = history_step + # decoder extra parameters vae latent to decoder token space + self.latent_out_proj = nn.Linear( + self.latent_dim, hidden_dim + ) # project latent sample to embedding + if self.multi_task: # include proprio and latent and language token + self.additional_pos_embed = nn.Embedding( + 2 + history_step + 1, hidden_dim + ) + else: + self.additional_pos_embed = nn.Embedding( + 2 + history_step, hidden_dim + ) # learned position embedding for proprio and latent + + self.proj_text_emb = nn.Linear(4096, hidden_dim) # TODO hard code text embedding dim + # src, mask, query_embed, pos_embed, latent_input=None, proprio_input=None, additional_pos_embed=None, noisy_actions = None, denoise_steps=None + def forward( + self, + qpos, + image, + env_state, + actions=None, + is_pad=None, + denoise_steps=None, + is_training=True, + task_emb=None + ): + """ + qpos: batch, 1+history, qpos_dim + image: batch, 1+history, num_cam, channel, height, width + env_state: None + actions: batch, seq, action_dim noisy action + denoise_step: int, the step of denoise + """ + # is_training = actions is not None # train or val + bs = qpos.shape[0] + ### Obtain latent z from action sequence add a paramers = use_latent = True + if is_training and self.disable_vae_latent == False: + # project action sequence to embedding dim, and concat with a CLS token + action_embed = self.encoder_action_proj(actions) # (bs, seq, hidden_dim) + qpos_embed = self.encoder_joint_proj(qpos) # (bs, hidden_dim) + qpos_embed = torch.unsqueeze(qpos_embed, axis=1) # (bs, 1, hidden_dim) + cls_embed = self.cls_embed.weight # (1, hidden_dim) + cls_embed = torch.unsqueeze(cls_embed, axis=0).repeat( + bs, 1, 1 + ) # (bs, 1, hidden_dim) + # vae encoder + # print(cls_embed.shape, qpos_embed.shape, action_embed.shape) + encoder_input = torch.cat( + [cls_embed, qpos_embed[:,0], action_embed], axis=1 + ) # (bs, seq+1, hidden_dim) + encoder_input = encoder_input.permute(1, 0, 2) # (seq+1, bs, hidden_dim) + # do not mask cls token + cls_joint_is_pad = torch.full((bs, 2), False).to( + qpos.device + ) # False: not a padding + is_pad = torch.cat([cls_joint_is_pad, is_pad], axis=1) # (bs, seq+1) + # obtain position embedding + pos_embed = self.pos_table.clone().detach() + pos_embed = pos_embed.permute(1, 0, 2) # (seq+1, 1, hidden_dim) + # query model + encoder_output = self.encoder( + encoder_input, pos=pos_embed, src_key_padding_mask=is_pad + ) + encoder_output = encoder_output[0] # take cls output only + latent_info = self.latent_proj(encoder_output) + mu = latent_info[:, : self.latent_dim] + logvar = latent_info[:, self.latent_dim :] + latent_sample = reparametrize(mu, logvar) + latent_input = self.latent_out_proj( + latent_sample + ) # get latent input for decoder TODO do we need this? + task_emb = self.proj_text_emb(task_emb) if task_emb is not None else None + else: # if dismiss latent ,just set to zero when training or val + mu = logvar = None + latent_sample = torch.zeros([bs, self.latent_dim], dtype=torch.float32).to( + qpos.device + ) + # latent_sample = torch.randn([bs, self.latent_dim], dtype=torch.float32).to(qpos.device) + latent_input = self.latent_out_proj(latent_sample) + task_emb = self.proj_text_emb(task_emb) if task_emb is not None else None + + if self.backbones is not None: + # proprioception features + proprio_input = self.input_proj_robot_state(qpos) + # Image observation features and position embeddings B T N_view C H W + image = image.view( -1, *image.shape[2:]) # frame stack shape: batch*T', num_cam, 3, height, width + all_cam_features = [] # no multiview + for cam_id, cam_name in enumerate(self.camera_names): + if self.use_film: + features, pos = self.backbones[0](image[:, cam_id], task_emb=task_emb) # HARDCODED + else: + features, pos = self.backbones[cam_id]( + image[:, cam_id] + ) # shape: batch*T', C, H, W + features = features[0] # take the last layer feature + features = features.view( + bs, -1, *features.shape[-2:] + ) # shape: batch, T'*C, H, W + all_cam_features.append(self.input_proj(features)) + src = torch.stack(all_cam_features, axis=-3) # shape: batch,D, N_view, H, W + pos = get_nd_sincos_pos_embed_from_grid( + self.hidden_dim, src.shape[2:] + ) # N_view, H, W, D numpy for hw no view + pos = torch.from_numpy(pos).to(src.device).unsqueeze(0).float() + # 1 N_view, H, W, D + src = rearrange(src,"b d n_view h w -> b d h (w n_view)",) # + pos = rearrange(pos,"b n_view h w d -> b d h (w n_view)",) # will b d h*w*n_view in the following transformer + + # print('proprio_input shape', proprio_input.shape) + hs = self.transformer( + src, + None, + self.query_embed.weight, + pos, + latent_input, + proprio_input, + self.additional_pos_embed.weight, + actions, + denoise_steps, + task_emb=task_emb, + )[0] + else: + qpos = self.input_proj_robot_state(qpos) + env_state = self.input_proj_env_state(env_state) + transformer_input = torch.cat([qpos, env_state], axis=1) # seq length = 2 + hs = self.transformer( + transformer_input, + None, + self.query_embed.weight, + self.pos.weight, + denoise_steps, + )[0] + a_hat = self.action_head(hs) # predict action or noise output of module + is_pad_hat = self.is_pad_head(hs) + return a_hat, is_pad_hat, [mu, logvar] + + +class DETRVAE_Denoise_Token_Prediction(nn.Module): + """This is the DETR module that performs object detection""" + + def __init__( + self, + backbones, + transformer, + encoder, + state_dim, + num_queries, + camera_names, + history_step, + predict_frame, + image_downsample_rate, + temporal_downsample_rate, + disable_vae_latent, + disable_resnet, + patch_size=5, + token_pe_type="learned", + image_height=480, + image_width=640, + tokenizer_model_temporal_rate=8, + tokenizer_model_spatial_rate=16, + resize_rate=1, + ): + """Initializes the model. + Parameters: + backbones: torch module of the backbone to be used. See backbone.py + transformer: torch module of the transformer architecture. See transformer.py + state_dim: robot state dimension of the environment + num_queries: number of object queries, ie detection slot. This is the maximal number of objects + DETR can detect in a single image. For COCO, we recommend 100 queries. + aux_loss: True if auxiliary decoding losses (loss at each decoder layer) are to be used. + """ + super().__init__() + self.num_queries = num_queries + self.camera_names = camera_names + self.future_camera_names = [camera_names[0]] # TODO attention + self.transformer = transformer + self.transformer_share_decoder = transformer.share_decoder + self.predict_only_last = transformer.predict_only_last + self.encoder = encoder + self.disable_vae_latent = disable_vae_latent + self.disable_resnet = disable_resnet + hidden_dim = transformer.d_model # 512 + self.hidden_dim = hidden_dim + token_dim = transformer.token_dim + # Action head state_dim = action_dim + # self.action_head = nn.Linear(hidden_dim, state_dim) # replace MLP? + self.action_head = nn.Sequential( + nn.Linear(hidden_dim, hidden_dim), + nn.SiLU(), + nn.Linear(hidden_dim, state_dim), + ) + + self.token_head = nn.Sequential( + nn.Linear( + hidden_dim, hidden_dim + ), # Hardcode patch size * path size * patch dim + nn.SiLU(), + nn.Linear(hidden_dim, token_dim * patch_size * patch_size), + ) + + self.is_pad_head = nn.Linear(hidden_dim, 1) + self.query_embed = nn.Embedding(num_queries, hidden_dim) # for decoder as PE + self.diffusion_timestep_type = transformer.diffusion_timestep_type + if backbones is not None and disable_resnet == False: + self.input_proj = nn.Conv2d( + backbones[0].num_channels * (history_step + 1), + hidden_dim, + kernel_size=1, + ) # = MLP c h w -> c' h' w' frame stack + self.backbones = nn.ModuleList(backbones) # N encoders for N view + self.input_proj_robot_state = nn.Linear(14, hidden_dim) # proprioception + elif backbones is not None and disable_resnet == True: + self.pos_encoder = backbones[0][1] # for 2D PositionEmbedding + self.backbones = None # HardCDcode + # Hardcode divide the latent feature representation into non-overlapping patches + self.input_proj_token = nn.Conv2d( + token_dim, hidden_dim, kernel_size=5, stride=5, bias=False + ) # Hardcode deal with image token pa + self.input_proj_robot_state = nn.Linear(14, hidden_dim) # proprioception + else: + # input_dim = 14 + 7 # robot_state + env_state + self.input_proj_robot_state = nn.Linear(14, hidden_dim) + self.input_proj_env_state = nn.Linear(7, hidden_dim) + self.pos = torch.nn.Embedding(2, hidden_dim) + self.backbones = None + + # encoder extra parameters + self.latent_dim = 32 # final size of latent z # TODO tune + self.cls_embed = nn.Embedding(1, hidden_dim) # extra cls token embedding + self.encoder_action_proj = nn.Linear( + 14, hidden_dim + ) # project action to embedding TODO action dim = 14/16 + self.encoder_joint_proj = nn.Linear(14, hidden_dim) # project qpos to embedding + self.latent_proj = nn.Linear( + hidden_dim, self.latent_dim * 2 + ) # project hidden state to latent std, var + self.register_buffer( + "pos_table", + get_sinusoid_encoding_table(1 + 1 + history_step + num_queries, hidden_dim), + ) # [CLS], qpos, a_seq + + # decoder extra parameters vae latent to decoder token space + self.history_step = history_step + self.latent_out_proj = nn.Linear( + self.latent_dim, hidden_dim + ) # project latent sample to embedding + self.additional_pos_embed = nn.Embedding( + 2 + history_step, hidden_dim + ) # learned position embedding for proprio and latent and denpoe + self.denoise_step_pos_embed = nn.Embedding( + 1, hidden_dim + ) # learned position embedding for denoise step + # setting for token prediction + if self.predict_only_last: + self.num_temporal_token = 1 + else: + # print('predict frame', predict_frame, 'temporal_downsample_rate', temporal_downsample_rate, 'tokenizer_model_temporal_rate', tokenizer_model_temporal_rate) + self.num_temporal_token = math.ceil( + predict_frame + // temporal_downsample_rate + / tokenizer_model_temporal_rate + ) + self.patch_size = patch_size + self.image_h = ( + image_height + // image_downsample_rate + // tokenizer_model_spatial_rate + // resize_rate + // self.patch_size + ) # + self.image_w = ( + image_width + // image_downsample_rate + // tokenizer_model_spatial_rate + // resize_rate + // self.patch_size + ) # + self.num_pred_token_per_timestep = ( + self.image_h * self.image_w * len(self.future_camera_names) + ) # + self.token_shape = ( + self.num_temporal_token, + self.image_h, + self.image_w, + self.patch_size, + ) + if token_pe_type == "learned": + self.query_embed_token = nn.Embedding( + self.num_temporal_token * self.num_pred_token_per_timestep, hidden_dim + ) # for decoder as PE TODO replace with temporal spatial PE + print( + "predict token shape", # B T' N_view D*P*P H' W' + ( + self.num_pred_token_per_timestep, # H' W' N_view + self.num_temporal_token, # T' + self.image_h, + self.image_w, + self.patch_size, + ), + ) + query_embed_token_fixed = get_nd_sincos_pos_embed_from_grid( + hidden_dim, + (self.num_temporal_token, len(self.future_camera_names), self.image_h, self.image_w), + ) + self.query_embed_token_fixed = ( + torch.from_numpy(query_embed_token_fixed).view(-1, hidden_dim).float() + ) + self.token_pe_type = token_pe_type + + def forward( + self, + qpos, + current_image, + env_state, + actions, + is_pad, + noisy_actions, + noise_tokens, + is_tokens_pad, + denoise_steps, + ): + # qpos: batch, 1+ history, qpos_dim + # current_image: 1. batch, T', num_cam, 3, height, width - image_norm T' = 1+ history + # 2. batch, T', num_cam*6/16, height', width' - image_token + # env_state: None + # actions: batch, seq, action_dim clean action for vae encoder + # is_pad: batch, seq, for vae encoder + # noisy_actions: batch, seq, action_dim, noisy action for denoise + # noise_tokens: batch, seq, num_cam, 6/16, height', width', noise token for denoise + # is_tokens_pad: batch, seq + # denoise_steps: int, the step of denoise batch + is_training = actions is not None # train or val + bs = qpos.shape[0] + is_actions_pad = is_pad + + ### Obtain latent z from action sequence + mu = logvar = None + latent_sample = torch.zeros([bs, self.latent_dim], dtype=torch.float32).to( + qpos.device + ) + latent_input = self.latent_out_proj( + latent_sample + ) # TODO maybe add tacile input + + # proprioception features + proprio_input = self.input_proj_robot_state(qpos) + # Image observation features and position embeddings + image = current_image[ + 0 + ] # shape: batch, T', num_cam, 3, height, width image_norm + image = image.view( + -1, *image.shape[2:] + ) # frame stack shape: batch*T', num_cam, 3, height, width + all_cam_features = [] + for cam_id, cam_name in enumerate(self.camera_names): + features, pos = self.backbones[cam_id]( + image[:, cam_id] + ) # shape: batch*T', C, H, W + features = features[0] # take the last layer feature + features = features.view( + bs, -1, *features.shape[-2:] + ) # shape: batch, T'*C, H, W + all_cam_features.append(self.input_proj(features)) + src = torch.stack(all_cam_features, axis=-3) # shape: batch,D,N_view, H, W + pos = get_nd_sincos_pos_embed_from_grid( + self.hidden_dim, src.shape[2:] + ) # N_view, H, W, D numpy + pos = ( + torch.from_numpy(pos).to(src.device).unsqueeze(0).float() + ) # 1 N_view, H, W, D + src = rearrange( + src, + "b d n_view h w -> b d h (w n_view)", + ) + pos = rearrange( + pos, + "b n_view h w d -> b d h (w n_view)", + ) # will b d h*w*n_view in the following transformer + + # deal with token patchfy + noise_tokens = ( + rearrange( + noise_tokens, + "b s n d (ph p1) (pw p2) -> b s n (d p1 p2) ph pw", + p1=self.patch_size, + p2=self.patch_size, + ) + if noise_tokens is not None + else None + ) + + if is_tokens_pad is not None: # B T -> B T' N_view*H'*W' -> B T'*N_view*H'*W' + is_tokens_pad = ( + is_tokens_pad.unsqueeze(2) + .repeat(1, 1, self.num_pred_token_per_timestep) + .reshape(bs, -1) + ) + else: + is_tokens_pad = torch.zeros( + bs, + self.num_pred_token_per_timestep * self.num_temporal_token, + dtype=torch.bool, + ).to(qpos.device) + is_pad = torch.zeros(bs, self.num_queries, dtype=torch.bool).to(qpos.device) + + if self.token_pe_type == "learned": + query_embed_token = self.query_embed_token.weight + else: + query_embed_token = self.query_embed_token_fixed.to(qpos.device) + # print('detr query_embed_token', query_embed_token.shape) + hs_action, hs_token = self.transformer( + src, # obeserved image token + None, + self.query_embed.weight, # for action token pe + query_embed_token, # for future token pe + pos, # obeserved image token pe + latent_input, + proprio_input, + self.additional_pos_embed.weight, # latent & proprio token pe + noisy_actions, + noise_tokens, + denoise_steps, + self.denoise_step_pos_embed.weight, # denoise step token pe + is_actions_pad, + is_tokens_pad, + ) + + a_hat = self.action_head(hs_action) + is_pad_a_hat = self.is_pad_head(hs_action) + pred_token = ( + self.token_head(hs_token) if hs_token is not None else None + ) # B T'*N*H'*W' 6*self.patch_size*self.patch_size + is_pad_token_hat = self.is_pad_head(hs_token) if hs_token is not None else None + is_pad_hat = ( + torch.cat([is_pad_a_hat, is_pad_token_hat], axis=1) + if is_pad_token_hat is not None + else is_pad_a_hat + ) + + pred_token = ( + rearrange( + pred_token, + "b (t n hp wp) (c ph pw) -> b t n c (hp ph) (wp pw)", + t=self.num_temporal_token, + n=len(self.future_camera_names), + hp=self.image_h, + wp=self.image_w, + ph=self.patch_size, + pw=self.patch_size, + ) + if pred_token is not None + else None + ) + + return a_hat, is_pad_hat, pred_token, [mu, logvar] + + +class DETRVAE_Denoise_Tactile(nn.Module): + def __init__(self, + backbones, + transformer, + tactile_encoder, + state_dim, + num_queries, + camera_names,): + super().__init__() + self.num_queries = num_queries + self.camera_names = camera_names + hidden_dim = transformer.d_model # 512 + self.hidden_dim = hidden_dim + self.tactile_dim = tactile_encoder.tactile_dim + + # tokenize input obeservation + self.input_proj_robot_state = nn.Linear(14, hidden_dim) # proprioception + self.input_proj = nn.Conv2d( + backbones[0].num_channels, + hidden_dim, + kernel_size=1, + ) # = MLP c h w -> c' h' w' frame stack + self.backbones = nn.ModuleList(backbones) # N encoders for N view + self.tactile_encoder = tactile_encoder + self.transformer = transformer + + # ouput head + self.action_head = nn.Sequential( + nn.Linear(hidden_dim, hidden_dim), + nn.SiLU(), + nn.Linear(hidden_dim, state_dim), + ) + self.is_pad_head = nn.Linear(hidden_dim, 1) + self.query_embed = nn.Embedding(num_queries, hidden_dim) # for decoder as PE + self.additional_pos_embed = nn.Embedding( + 1, hidden_dim + ) # learned position embedding for proprio + + def forward( + self, + qpos, + image, + tactile, + env_state, + actions=None, + is_pad=None, + denoise_steps=None, + is_training=True, + ): + """ + qpos: batch, 1, qpos_dim + image: batch, 1, num_cam, channel, height, width + tactile: batch, 1, 4, 3, 960, 960 + env_state: None + actions: batch, seq, action_dim noisy action + denoise_step: int, the step of denoise + """ + bs = qpos.shape[0] + + # proprioception features + proprio_input = self.input_proj_robot_state(qpos) # B 1 D + # Image observation features and position embeddings + image = image.view( -1, *image.shape[2:]) # frame stack shape: batch*T', num_cam, 3, height, width + all_cam_features = [] + for cam_id, cam_name in enumerate(self.camera_names): + features, pos = self.backbones[cam_id]( + image[:, cam_id] + ) # shape: batch*T', C, H, W + features = features[0] # take the last layer feature + features = features.view( + bs, -1, *features.shape[-2:] + ) # shape: batch, T'*C, H, W + all_cam_features.append(self.input_proj(features)) + src = torch.stack(all_cam_features, axis=-3) # shape: batch,D,N_view, H, W + pos = get_nd_sincos_pos_embed_from_grid( + self.hidden_dim, src.shape[2:] + ) + pos = torch.from_numpy(pos).to(src.device).unsqueeze(0).float() + src = rearrange(src,"b d n_view h w -> b d h (w n_view)",) + pos = rearrange(pos,"b n_view h w d -> b d h (w n_view)",) # will b d h*w*n_view in the following transformer + # tactile features + tactile_input = self.tactile_encoder(tactile) # batch, d, 4 h w + # print('detr tactile_input', tactile_input.shape) + tactile_pos = get_nd_sincos_pos_embed_from_grid( + self.hidden_dim, tactile_input.shape[2:] + ) # 4,H,W,d -> 1,4,h,w,d + tactile_pos = torch.from_numpy(tactile_pos).to(src.device).unsqueeze(0).float() + tactile_input = rearrange(tactile_input,"b d n_view h w -> b d h (w n_view)",) + tactile_pos = rearrange(tactile_pos,"b n_view h w d -> b d h (w n_view)",) # will b d h*w*n_view in the following transformer + # tactile_pos = tactile_pos.permute(0, 3, 1, 2) # batch, d, 4, h, w + hs = self.transformer( + src, # B D h w + None, + self.query_embed.weight, # H D + pos, # B D h w + tactile_input, # B D h 4*w + tactile_pos, # 1 D h 4*w + proprio_input, # B 1 D + self.additional_pos_embed.weight, # D + actions, # B H D + denoise_steps, # B + )[0] + a_hat = self.action_head(hs) # predict action or noise output of module + is_pad_hat = self.is_pad_head(hs) + return a_hat, is_pad_hat + +class DETRVAE_Denoise_Token_Prediction_Dual_Visual_Token(nn.Module): + """This is the DETR module that performs object detection""" + + def __init__( + self, + backbones, + transformer, + encoder, + state_dim, + num_queries, + camera_names, + history_step, + predict_frame, + image_downsample_rate, + temporal_downsample_rate, + disable_vae_latent, + disable_resnet, + patch_size=5, + token_pe_type="learned", + image_height=480, + image_width=640, + tokenizer_model_temporal_rate=8, + tokenizer_model_spatial_rate=16, + ): + """Initializes the model. + both resnet feature and token feature are used for token prediction & action prediction + Parameters: + backbones: torch module of the backbone to be used. See backbone.py + transformer: torch module of the transformer architecture. See transformer.py + state_dim: robot state dimension of the environment + num_queries: number of object queries, ie detection slot. This is the maximal number of objects + DETR can detect in a single image. For COCO, we recommend 100 queries. + aux_loss: True if auxiliary decoding losses (loss at each decoder layer) are to be used. + """ + super().__init__() + self.num_queries = num_queries + self.camera_names = camera_names + self.transformer = transformer + self.transformer_share_decoder = transformer.share_decoder + self.predict_only_last = transformer.predict_only_last + self.encoder = encoder + self.disable_vae_latent = disable_vae_latent + self.disable_resnet = disable_resnet + hidden_dim = transformer.d_model # 512 + self.hidden_dim = hidden_dim + self.action_head = nn.Linear(hidden_dim, state_dim) # TODO replace MLP + if self.transformer_share_decoder == False: + self.token_head = nn.Linear( + hidden_dim, 6 * patch_size * patch_size + ) # HardCode TODO replace MLP + else: # default share decoder + self.token_head = nn.Sequential( + nn.Linear( + hidden_dim, hidden_dim + ), # Hardcode patch size * path size * patch dim + nn.SiLU(), + nn.Linear(hidden_dim, 6 * patch_size * patch_size), + ) + + self.is_pad_head = nn.Linear(hidden_dim, 1) + self.query_embed = nn.Embedding(num_queries, hidden_dim) # for decoder as PE + self.diffusion_timestep_type = transformer.diffusion_timestep_type + if backbones is not None: # resnet encoder to extract visual feature + self.backbones = nn.ModuleList(backbones) + self.input_proj = nn.Conv2d( + backbones[0].num_channels, hidden_dim, kernel_size=1 + ) + self.input_proj_token = nn.Conv2d( + 6, hidden_dim, kernel_size=patch_size, stride=patch_size, bias=False + ) # Hardcode deal with image token patch size = 5 + self.input_proj_robot_state = nn.Linear(14, hidden_dim) # proprioception + else: + # input_dim = 14 + 7 # robot_state + env_state + self.input_proj_robot_state = nn.Linear(14, hidden_dim) + self.input_proj_env_state = nn.Linear(7, hidden_dim) + self.pos = torch.nn.Embedding(2, hidden_dim) + self.backbones = None + + # encoder extra parameters + self.latent_dim = 32 # final size of latent z # TODO tune + self.cls_embed = nn.Embedding(1, hidden_dim) # extra cls token embedding + self.encoder_action_proj = nn.Linear( + 14, hidden_dim + ) # project action to embedding + self.encoder_joint_proj = nn.Linear(14, hidden_dim) # project qpos to embedding + self.latent_proj = nn.Linear( + hidden_dim, self.latent_dim * 2 + ) # project hidden state to latent std, var + self.register_buffer( + "pos_table", + get_sinusoid_encoding_table(1 + 1 + history_step + num_queries, hidden_dim), + ) # [CLS], qpos, a_seq + + # decoder extra parameters vae latent to decoder token space + self.history_step = history_step + self.latent_out_proj = nn.Linear( + self.latent_dim, hidden_dim + ) # project latent sample to embedding + self.additional_pos_embed = nn.Embedding( + 2 + history_step, hidden_dim + ) # learned position embedding for proprio and latent and denpoe + self.denoise_step_pos_embed = nn.Embedding( + 1, hidden_dim + ) # learned position embedding for denoise step + # setting for token prediction hyper + if self.predict_only_last: + self.num_temporal_token = 1 + else: + self.num_temporal_token = math.ceil( + predict_frame + // temporal_downsample_rate + / tokenizer_model_temporal_rate + ) # align with tokenizer output + self.patch_size = patch_size + self.image_h = ( + image_height + // image_downsample_rate + // tokenizer_model_spatial_rate + // self.patch_size + ) # + self.image_w = ( + image_width + // image_downsample_rate + // tokenizer_model_spatial_rate + // self.patch_size + ) # + self.num_pred_token_per_timestep = ( + self.image_h * self.image_w * len(camera_names) + ) + # setting for token prediction position embedding + if token_pe_type == "learned": + self.current_embed_token = nn.Embedding( + 1 * self.num_pred_token_per_timestep, hidden_dim + ) # for current image token encoder + self.query_embed_token = nn.Embedding( + self.num_temporal_token * self.num_pred_token_per_timestep, hidden_dim + ) # for decoder as PE TODO replace with temporal spatial PE + + spatial_temporal_pe = get_spatial_temporal_positional_encoding( + self.image_h, self.image_w, self.num_temporal_token + 1, hidden_dim + ) + self.current_embed_token_fixed = spatial_temporal_pe[0].view( + -1, hidden_dim + ) # for current image token encoder + self.query_embed_token_fixed = spatial_temporal_pe[1:].view( + -1, hidden_dim + ) # for decoder as PE # (seq_len, height, width, dim) -> (seq_len*height*width, dim) + + self.token_pe_type = token_pe_type + print( + "predict token shape", + ( + self.num_pred_token_per_timestep, + self.num_temporal_token, + self.image_h, + self.image_w, + self.patch_size, + ), + ) + + def forward( + self, + qpos, + current_image, + env_state, + actions, + is_pad, + noisy_actions, + noise_tokens, + is_tokens_pad, + denoise_steps, + ): + # qpos: batch, 1+ history, qpos_dim + # current_image: 1. batch, T', num_cam, 3, height, width - image_norm T' = 1+ history defualt = 1 + # 2. batch, T', num_cam, 6, height', width' - image_token (current_image_norm, current_image_tokens) + # env_state: None + # actions: batch, seq, action_dim clean action for vae encoder + # is_pad: batch, seq, for vae encoder + # noisy_actions: batch, seq, action_dim, noisy action for denoise + # noise_tokens: batch, seq, num_cam, 6, height', width', noise token for denoise + # is_tokens_pad: batch, seq + # denoise_steps: int, the step of denoise + is_training = actions is not None # train or val + bs = qpos.shape[0] + is_actions_pad = is_pad + + ### Obtain latent z from action sequence + if is_training and not self.disable_vae_latent: + # project action sequence to embedding dim, and concat with a CLS token + action_embed = self.encoder_action_proj(actions) # (bs, seq, hidden_dim) + qpos_embed = self.encoder_joint_proj(qpos) # (bs, hidden_dim) + qpos_embed = ( + torch.unsqueeze(qpos_embed, axis=1) + if len(qpos_embed.shape) == 2 + else qpos_embed + ) # (bs, 1, hidden_dim) + cls_embed = self.cls_embed.weight # (1, hidden_dim) + cls_embed = torch.unsqueeze(cls_embed, axis=0).repeat( + bs, 1, 1 + ) # (bs, 1, hidden_dim) + encoder_input = torch.cat( + [cls_embed, qpos_embed, action_embed], axis=1 + ) # (bs, seq+1, hidden_dim) + encoder_input = encoder_input.permute(1, 0, 2) # (seq+1, bs, hidden_dim) + # do not mask cls token + cls_joint_is_pad = torch.full((bs, 2 + self.history_step), False).to( + qpos.device + ) # False: not a padding + + is_pad = torch.cat([cls_joint_is_pad, is_pad], axis=1) # (bs, seq+1) + # obtain position embedding + pos_embed = self.pos_table.clone().detach() + pos_embed = pos_embed.permute(1, 0, 2) # (seq+1, 1, hidden_dim) + # query model + encoder_output = self.encoder( + encoder_input, pos=pos_embed, src_key_padding_mask=is_pad + ) + encoder_output = encoder_output[0] # take cls output only + latent_info = self.latent_proj(encoder_output) # predict mu and logvar + mu = latent_info[:, : self.latent_dim] + logvar = latent_info[:, self.latent_dim :] + latent_sample = reparametrize(mu, logvar) + latent_input = self.latent_out_proj(latent_sample) + else: + mu = logvar = None + latent_sample = torch.zeros([bs, self.latent_dim], dtype=torch.float32).to( + qpos.device + ) + latent_input = self.latent_out_proj(latent_sample) + + # Image observation features and position embeddings for resnet visual encoder + all_cam_features = [] + all_cam_pos = [] + if self.backbones is not None: + image = current_image[0] # shape: batch, T', num_cam, 3, height, width + for t in range(self.history_step + 1): # only current frame + for cam_id, cam_name in enumerate(self.camera_names): + features, pos = self.backbones[0]( + image[:, t, cam_id] + ) # resnet feature and pos + features = features[0] + pos = pos[0] + all_cam_features.append(self.input_proj(features)) + all_cam_pos.append(pos) # resnet feature and pos + src = torch.cat( + all_cam_features, axis=3 + ) # B C H W*num_view*T # fold camera dimension into width dimension + pos = torch.cat(all_cam_pos, axis=3) # B C H W*num_view*T + + # Image observation features and position embeddings for visual tokenizer encoder, only 1 frame + all_token_features = [] + all_token_pos = [] + current_visual_token = current_image[ + 1 + ] # shape: batch, T', num_cam, 6, height', width' Depth - image_token + for cam_id, cam_name in enumerate(self.camera_names): + token_features = current_visual_token[:, -1, cam_id] # B C H W + token_features = self.input_proj_token(token_features) # B C H W + token_pos = self.current_embed_token_fixed.to( + qpos.device + ) # height*width, C + all_token_features.append(token_features) + all_token_pos.append(token_pos) + addition_visual_token = torch.cat( + all_token_features, axis=3 + ) # B C H W*num_view + addition_visual_token_pos = torch.cat(all_token_pos, axis=0) # H*W*num_view C + + proprio_input = self.input_proj_robot_state(qpos) + + # deal with token patchfy + noise_tokens = rearrange( + noise_tokens, + "b s n d (ph p1) (pw p2) -> b s n (d p1 p2) ph pw", + p1=self.patch_size, + p2=self.patch_size, + ) # b seq_len num_cam 6*patch_size*atch_size height width + + if ( + is_tokens_pad is not None + ): # B seq_len -> B seq_len*self.num_pred_token_per_timestep useless for decoder + is_tokens_pad = ( + is_tokens_pad.unsqueeze(2) + .repeat(1, 1, self.num_pred_token_per_timestep) + .reshape(bs, -1) + ) + else: + is_tokens_pad = torch.zeros( + bs, + self.num_pred_token_per_timestep * self.num_temporal_token, + dtype=torch.bool, + ).to(qpos.device) + + if self.token_pe_type == "learned": + query_embed_token = self.query_embed_token.weight + else: + query_embed_token = self.query_embed_token_fixed.to( + qpos.device + ) # consider the visual token from current frame + hs_action, hs_token = self.transformer( + src, + None, + self.query_embed.weight, + query_embed_token, + pos, + latent_input, + proprio_input, + self.additional_pos_embed.weight, + noisy_actions, + noise_tokens, + denoise_steps, + self.denoise_step_pos_embed.weight, + is_actions_pad, + is_tokens_pad, + addition_visual_token, + addition_visual_token_pos, + ) + + a_hat = self.action_head(hs_action) + is_pad_a_hat = self.is_pad_head(hs_action) + pred_token = self.token_head( + hs_token + ) # B T'*N*H'*W' 6*self.patch_size*self.patch_size + is_pad_token_hat = self.is_pad_head(hs_token) + is_pad_hat = torch.cat([is_pad_a_hat, is_pad_token_hat], axis=1) + + pred_token = rearrange( + pred_token, + "b (t n hp wp) (c ph pw) -> b t n c (hp ph) (wp pw)", + t=self.num_temporal_token, + n=len(self.camera_names), + hp=self.image_h, + wp=self.image_w, + ph=self.patch_size, + pw=self.patch_size, + ) + + ## if no visal observation, + # qpos = self.input_proj_robot_state(qpos) + # env_state = self.input_proj_env_state(env_state) + # transformer_input = torch.cat([qpos, env_state], axis=1) # seq length = 2 + # hs = self.transformer(transformer_input, None, self.query_embed.weight, self.pos.weight)[0] + return a_hat, is_pad_hat, pred_token, [mu, logvar] + + +class DETRVAE_Denoise_Pixel_Prediction(nn.Module): + def __init__( + self, + backbones, + transformer, + encoder, + state_dim, + num_queries, + camera_names, + history_step, + predict_frame, + image_downsample_rate, + temporal_downsample_rate, + disable_vae_latent, + disable_resnet, + patch_size=5, + token_pe_type="learned", + image_height=480, + image_width=640, + resize_rate=8, + ): + """Initializes the model. + Parameters: + backbones: torch module of the backbone to be used. See backbone.py + transformer: torch module of the transformer architecture. See transformer.py + state_dim: robot state dimension of the environment + num_queries: number of object queries, ie detection slot. This is the maximal number of objects + DETR can detect in a single image. For COCO, we recommend 100 queries. + aux_loss: True if auxiliary decoding losses (loss at each decoder layer) are to be used. + """ + super().__init__() + self.num_queries = num_queries + self.camera_names = camera_names + self.transformer = transformer + self.transformer_share_decoder = transformer.share_decoder + self.predict_only_last = transformer.predict_only_last + self.encoder = encoder + self.disable_vae_latent = disable_vae_latent + self.disable_resnet = disable_resnet + hidden_dim = transformer.d_model # 512 + self.hidden_dim = hidden_dim + self.action_head = nn.Linear(hidden_dim, state_dim) # TODO replace MLP + if self.transformer_share_decoder == False: + self.pixel_head = nn.Linear( + hidden_dim, 3 * patch_size * patch_size + ) # HardCode TODO replace MLP + else: + self.pixel_head = nn.Sequential( + nn.Linear( + hidden_dim, hidden_dim + ), # Hardcode patch size * path size * patch dim + nn.SiLU(), + nn.Linear(hidden_dim, 3 * patch_size * patch_size), + ) + + self.is_pad_head = nn.Linear(hidden_dim, 1) + self.query_embed = nn.Embedding(num_queries, hidden_dim) # for decoder as PE + self.diffusion_timestep_type = transformer.diffusion_timestep_type + if backbones is not None and disable_resnet == False: + self.input_proj = nn.Conv2d( + backbones[0].num_channels, hidden_dim, kernel_size=1 + ) + self.backbones = nn.ModuleList(backbones) + self.input_proj_robot_state = nn.Linear(14, hidden_dim) # proprioception + elif backbones is not None and disable_resnet == True: + self.pos_encoder = backbones[0][1] # for 2D PositionEmbedding + self.backbones = None # HardCDcode + # Hardcode divide the latent feature representation into non-overlapping patches + self.input_proj_token = nn.Conv2d( + 6, hidden_dim, kernel_size=5, stride=5, bias=False + ) # Hardcode deal with image token pa + self.input_proj_robot_state = nn.Linear(14, hidden_dim) # proprioception + else: + # input_dim = 14 + 7 # robot_state + env_state + self.input_proj_robot_state = nn.Linear(14, hidden_dim) + self.input_proj_env_state = nn.Linear(7, hidden_dim) + self.pos = torch.nn.Embedding(2, hidden_dim) + self.backbones = None + + # encoder extra parameters + self.latent_dim = 32 # final size of latent z # TODO tune + self.cls_embed = nn.Embedding(1, hidden_dim) # extra cls token embedding + self.encoder_action_proj = nn.Linear( + 14, hidden_dim + ) # project action to embedding + self.encoder_joint_proj = nn.Linear(14, hidden_dim) # project qpos to embedding + self.latent_proj = nn.Linear( + hidden_dim, self.latent_dim * 2 + ) # project hidden state to latent std, var + self.register_buffer( + "pos_table", + get_sinusoid_encoding_table(1 + 1 + history_step + num_queries, hidden_dim), + ) # [CLS], qpos, a_seq + + # decoder extra parameters vae latent to decoder token space + self.history_step = history_step + self.latent_out_proj = nn.Linear( + self.latent_dim, hidden_dim + ) # project latent sample to embedding + self.additional_pos_embed = nn.Embedding( + 2 + history_step, hidden_dim + ) # learned position embedding for proprio and latent and denpoe + self.denoise_step_pos_embed = nn.Embedding( + 1, hidden_dim + ) # learned position embedding for denoise step + # setting for token prediction + # self.num_temporal_token = predict_frame // temporal_downsample_rate // tokenizer_model_temporal_rate + if self.predict_only_last: + self.num_frames = 1 + else: + self.num_frames = math.ceil(predict_frame // temporal_downsample_rate) + self.patch_size = patch_size + self.image_h = ( + image_height // image_downsample_rate // resize_rate // self.patch_size + ) # + self.image_w = ( + image_width // image_downsample_rate // resize_rate // self.patch_size + ) # + self.num_pred_pixel_per_timestep = ( + self.image_h * self.image_w * len(camera_names) + ) + + if token_pe_type == "learned": + self.query_embed_token = nn.Embedding( + self.num_frames * self.num_pred_pixel_per_timestep, hidden_dim + ) # for decoder as PE TODO replace with temporal spatial PE + self.query_embed_token_fixed = get_spatial_temporal_positional_encoding( + self.image_h, self.image_w, self.num_frames, hidden_dim + ).view( + -1, hidden_dim + ) # for decoder as PE # (seq_len, height, width, dim) -> (seq_len*height*width, dim) + self.token_pe_type = token_pe_type + print( + "predict pixel shape", + ( + self.num_frames, + self.num_pred_pixel_per_timestep, + self.image_h, + self.image_w, + self.patch_size, + ), + ) + + def forward( + self, + qpos, + current_image, + env_state, + actions, + is_pad, + noisy_actions, + noise_tokens, + is_tokens_pad, + denoise_steps, + ): + # qpos: batch, 1+ history, qpos_dim + # current_image: 1. batch, T', num_cam, 3, height, width - image_norm T' = 1+ history + # 2. batch, T', num_cam*6, height', width' - image_token + # env_state: None + # actions: batch, seq, action_dim clean action for vae encoder + # is_pad: batch, seq, for vae encoder + # noisy_actions: batch, seq, action_dim, noisy action for denoise + # noise_tokens: batch, seq, num_cam, 6, height', width', noise token for denoise + # is_tokens_pad: batch, seq + # denoise_steps: int, the step of denoise + + # print('qpos',qpos.shape) + # print('current_image',current_image[0].shape) + # print('actions',actions.shape) + # print('is_pad',is_pad.shape) + # print('noisy_actions',noisy_actions.shape) + # print('noise_tokens',noise_tokens.shape) + # print('is_tokens_pad',is_tokens_pad.shape) + # print('denoise_steps',denoise_steps) + + is_training = actions is not None # train or val + bs = qpos.shape[0] + is_actions_pad = is_pad + + ### Obtain latent z from action sequence + if is_training and not self.disable_vae_latent: + # project action sequence to embedding dim, and concat with a CLS token + action_embed = self.encoder_action_proj(actions) # (bs, seq, hidden_dim) + qpos_embed = self.encoder_joint_proj(qpos) # (bs, hidden_dim) + qpos_embed = ( + torch.unsqueeze(qpos_embed, axis=1) + if len(qpos_embed.shape) == 2 + else qpos_embed + ) # (bs, 1, hidden_dim) + cls_embed = self.cls_embed.weight # (1, hidden_dim) + cls_embed = torch.unsqueeze(cls_embed, axis=0).repeat( + bs, 1, 1 + ) # (bs, 1, hidden_dim) + encoder_input = torch.cat( + [cls_embed, qpos_embed, action_embed], axis=1 + ) # (bs, seq+1, hidden_dim) + encoder_input = encoder_input.permute(1, 0, 2) # (seq+1, bs, hidden_dim) + # do not mask cls token + cls_joint_is_pad = torch.full((bs, 2 + self.history_step), False).to( + qpos.device + ) # False: not a padding + + is_pad = torch.cat([cls_joint_is_pad, is_pad], axis=1) # (bs, seq+1) + # obtain position embedding + pos_embed = self.pos_table.clone().detach() + pos_embed = pos_embed.permute(1, 0, 2) # (seq+1, 1, hidden_dim) + # query model + encoder_output = self.encoder( + encoder_input, pos=pos_embed, src_key_padding_mask=is_pad + ) + encoder_output = encoder_output[0] # take cls output only + latent_info = self.latent_proj(encoder_output) # predict mu and logvar + mu = latent_info[:, : self.latent_dim] + logvar = latent_info[:, self.latent_dim :] + latent_sample = reparametrize(mu, logvar) + latent_input = self.latent_out_proj(latent_sample) + else: + mu = logvar = None + latent_sample = torch.zeros([bs, self.latent_dim], dtype=torch.float32).to( + qpos.device + ) + latent_input = self.latent_out_proj(latent_sample) + + # Image observation features and position embeddings + all_cam_features = [] + all_cam_pos = [] + if self.backbones is not None and self.disable_resnet == False: + image = current_image[0] # shape: batch, T', num_cam, 3, height, width + for t in range(self.history_step + 1): + for cam_id, cam_name in enumerate(self.camera_names): + features, pos = self.backbones[0](image[:, t, cam_id]) + features = features[0] + pos = pos[0] + all_cam_features.append(self.input_proj(features)) + all_cam_pos.append(pos) + else: + image = current_image[ + 1 + ] # [:,:self.num_temporal_token] # shape: batch, T', num_cam, 6, height', width' hardcode keep consistent with token prediction + for t in range(self.history_step + 1): + for cam_id, cam_name in enumerate(self.camera_names): + features = image[:, t, cam_id] # B C H W + features = self.input_proj(features) + pos = self.pos_encoder(features).to(features.dtype) + all_cam_features.append(features) + all_cam_pos.append(pos) + # proprioception features + proprio_input = self.input_proj_robot_state(qpos) + # fold camera dimension into width dimension + src = torch.cat(all_cam_features, axis=3) # B C H W*num_view*T + pos = torch.cat(all_cam_pos, axis=3) # B C H W*num_view*T + # deal with token patchfy + noise_tokens = rearrange( + noise_tokens, + "b s n d (ph p1) (pw p2) -> b s n (d p1 p2) ph pw", + p1=self.patch_size, + p2=self.patch_size, + ) + # print('noise_tokens after rearrange',noise_tokens.shape) + if is_tokens_pad is not None: + is_tokens_pad = ( + is_tokens_pad.unsqueeze(2) + .repeat(1, 1, self.num_pred_pixel_per_timestep) + .reshape(bs, -1) + ) + else: + is_tokens_pad = torch.zeros( + bs, self.num_frames * self.num_pred_pixel_per_timestep, dtype=torch.bool + ).to(qpos.device) + is_pad = torch.zeros(bs, self.num_queries, dtype=torch.bool).to(qpos.device) + + if self.token_pe_type == "learned": + query_embed_token = self.query_embed_token.weight + else: + query_embed_token = self.query_embed_token_fixed.to(qpos.device) + + hs_action, hs_token = self.transformer( + src, + None, + self.query_embed.weight, + query_embed_token, + pos, + latent_input, + proprio_input, + self.additional_pos_embed.weight, + noisy_actions, + noise_tokens, + denoise_steps, + self.denoise_step_pos_embed.weight, + is_actions_pad, + is_tokens_pad, + ) + + a_hat = self.action_head(hs_action) + is_pad_a_hat = self.is_pad_head(hs_action) + pred_images = self.pixel_head( + hs_token + ) # B T'*N*H'*W' 6*self.patch_size*self.patch_size + is_pad_image_hat = self.is_pad_head(hs_token) + is_pad_hat = torch.cat([is_pad_a_hat, is_pad_image_hat], axis=1) + + pred_images = rearrange( + pred_images, + "b (t n hp wp) (c ph pw) -> b t n c (hp ph) (wp pw)", + t=self.num_frames, + n=len(self.camera_names), + hp=self.image_h, + wp=self.image_w, + ph=self.patch_size, + pw=self.patch_size, + ) + + return a_hat, is_pad_hat, pred_images, [mu, logvar] + + +class DETRVAEDino(nn.Module): + """This is the DETR module that performs object detection""" + + def __init__( + self, backbones, transformer, encoder, state_dim, num_queries, camera_names + ): + """Initializes the model. + Parameters: + backbones: torch module of the backbone to be used. See backbone.py + transformer: torch module of the transformer architecture. See transformer.py + state_dim: robot state dimension of the environment + num_queries: number of object queries, ie detection slot. This is the maximal number of objects + DETR can detect in a single image. For COCO, we recommend 100 queries. + aux_loss: True if auxiliary decoding losses (loss at each decoder layer) are to be used. + """ + super().__init__() + self.num_queries = num_queries + self.camera_names = camera_names + self.transformer = transformer + self.encoder = encoder + hidden_dim = transformer.d_model + self.action_head = nn.Linear(hidden_dim, state_dim) + self.is_pad_head = nn.Linear(hidden_dim, 1) + # self.cls_token_head = nn.Linear(hidden_dim, 384) + self.num_cls_tokens = 50 + self.query_embed = nn.Embedding(num_queries + self.num_cls_tokens, hidden_dim) + if backbones is not None: + self.input_proj = nn.Conv2d( + backbones[0].num_channels, hidden_dim, kernel_size=1 + ) + self.backbones = nn.ModuleList(backbones) + self.input_proj_robot_state = nn.Linear(14, hidden_dim) + else: + # input_dim = 14 + 7 # robot_state + env_state + self.input_proj_robot_state = nn.Linear(14, hidden_dim) + self.input_proj_env_state = nn.Linear(7, hidden_dim) + self.pos = torch.nn.Embedding(2, hidden_dim) + self.backbones = None + + # encoder extra parameters + self.latent_dim = 32 # final size of latent z # TODO tune + self.cls_embed = nn.Embedding(1, hidden_dim) # extra cls token embedding + self.encoder_action_proj = nn.Linear( + 14, hidden_dim + ) # project action to embedding + self.encoder_joint_proj = nn.Linear(14, hidden_dim) # project qpos to embedding + self.latent_proj = nn.Linear( + hidden_dim, self.latent_dim * 2 + ) # project hidden state to latent std, var + self.register_buffer( + "pos_table", get_sinusoid_encoding_table(1 + 1 + num_queries, hidden_dim) + ) # [CLS], qpos, a_seq + + # decoder extra parameters + self.latent_out_proj = nn.Linear( + self.latent_dim, hidden_dim + ) # project latent sample to embedding + self.additional_pos_embed = nn.Embedding( + 2, hidden_dim + ) # learned position embedding for proprio and latent + + # settings for cls token prediction + # self.mask_token = nn.Parameter(torch.zeros(1, 1, hidden_dim)) + # self.n_patch = (self.image_size//self.patch_size)**2 + # self.k = 1 # number of next frames + # self.n_patch = (self.img_h//self.patch_size)*(self.img_w//self.patch_size)*(self.k) + # self.decoder_pos_embed = nn.Parameter(torch.zeros(1, self.n_patch, hidden_dim), requires_grad=False) # (1, n_patch, h) + # self.patch_embed = nn.Embedding(self.n_patch, hidden_dim) + # self.decoder_embed = nn.Linear(hidden_dim, hidden_dim, bias=True) + + decoder_depth = 2 # hardcode + self.decoder_blocks = nn.ModuleList( + [ + Block( + hidden_dim, + 16, + 4, + qkv_bias=True, + qk_scale=None, + norm_layer=nn.LayerNorm, + ) + for i in range(decoder_depth) + ] + ) + + self.decoder_norm = nn.LayerNorm(hidden_dim) + self.decoder_pred = nn.Linear(hidden_dim, 384, bias=True) # decoder to patch + + # decoder_pos_embed = get_2d_sincos_pos_embed(self.decoder_pos_embed.shape[-1], (self.image_size//self.patch_size), cls_token=False) + # decoder_pos_embed = get_2d_sincos_pos_embed_v2(self.decoder_pos_embed.shape[-1], (self.img_h//self.patch_size, self.img_w//self.patch_size)) + # self.decoder_pos_embed.data.copy_(torch.from_numpy(decoder_pos_embed).float().unsqueeze(0).repeat(1,self.k,1)) + + def forward(self, qpos, image, env_state, actions=None, is_pad=None): + """ + qpos: batch, qpos_dim + image: batch, num_cam, channel, height, width + env_state: None + actions: batch, seq, action_dim + """ + is_training = actions is not None # train or val + bs, _ = qpos.shape + ### Obtain latent z from action sequence + if is_training: + # project action sequence to embedding dim, and concat with a CLS token + action_embed = self.encoder_action_proj(actions) # (bs, seq, hidden_dim) + qpos_embed = self.encoder_joint_proj(qpos) # (bs, hidden_dim) + qpos_embed = torch.unsqueeze(qpos_embed, axis=1) # (bs, 1, hidden_dim) + cls_embed = self.cls_embed.weight # (1, hidden_dim) + cls_embed = torch.unsqueeze(cls_embed, axis=0).repeat( + bs, 1, 1 + ) # (bs, 1, hidden_dim) + encoder_input = torch.cat( + [cls_embed, qpos_embed, action_embed], axis=1 + ) # (bs, seq+1, hidden_dim) + encoder_input = encoder_input.permute(1, 0, 2) # (seq+1, bs, hidden_dim) + # do not mask cls token + cls_joint_is_pad = torch.full((bs, 2), False).to( + qpos.device + ) # False: not a padding + is_pad = torch.cat([cls_joint_is_pad, is_pad], axis=1) # (bs, seq+1) + # obtain position embedding + pos_embed = self.pos_table.clone().detach() + pos_embed = pos_embed.permute(1, 0, 2) # (seq+1, 1, hidden_dim) + # query model + encoder_output = self.encoder( + encoder_input, pos=pos_embed, src_key_padding_mask=is_pad + ) + encoder_output = encoder_output[0] # take cls output only + latent_info = self.latent_proj(encoder_output) + mu = latent_info[:, : self.latent_dim] + logvar = latent_info[:, self.latent_dim :] + latent_sample = reparametrize(mu, logvar) + latent_input = self.latent_out_proj(latent_sample) + else: + mu = logvar = None + latent_sample = torch.zeros([bs, self.latent_dim], dtype=torch.float32).to( + qpos.device + ) + latent_input = self.latent_out_proj(latent_sample) + + if self.backbones is not None: + # Image observation features and position embeddings + all_cam_features = [] + all_cam_pos = [] + for cam_id, cam_name in enumerate(self.camera_names): + features, pos = self.backbones[0](image[:, cam_id]) # HARDCODED + features = features[0] # take the last layer feature + pos = pos[0] + all_cam_features.append(self.input_proj(features)) + all_cam_pos.append(pos) + # proprioception features + proprio_input = self.input_proj_robot_state(qpos) + # fold camera dimension into width dimension + src = torch.cat(all_cam_features, axis=3) + pos = torch.cat(all_cam_pos, axis=3) + hs = self.transformer( + src, + None, + self.query_embed.weight, + pos, + latent_input, + proprio_input, + self.additional_pos_embed.weight, + )[0] + else: + qpos = self.input_proj_robot_state(qpos) + env_state = self.input_proj_env_state(env_state) + transformer_input = torch.cat([qpos, env_state], axis=1) # seq length = 2 + hs = self.transformer( + transformer_input, None, self.query_embed.weight, self.pos.weight + )[0] + a_hat = self.action_head(hs[:, : self.num_queries]) + is_pad_hat = self.is_pad_head(hs[:, : self.num_queries]) + + for blk in self.decoder_blocks: + cls_token = blk(hs[:, self.num_queries :]) + cls_token = self.decoder_norm(cls_token) + cls_token_hat = self.decoder_pred(cls_token) + return a_hat, is_pad_hat, [mu, logvar], cls_token_hat + + +class DETRVAEjpeg(nn.Module): + """This is the DETR module that performs object detection""" + + def __init__( + self, backbones, transformer, encoder, state_dim, num_queries, camera_names + ): + """Initializes the model. + Parameters: + backbones: torch module of the backbone to be used. See backbone.py + transformer: torch module of the transformer architecture. See transformer.py + state_dim: robot state dimension of the environment + num_queries: number of object queries, ie detection slot. This is the maximal number of objects + DETR can detect in a single image. For COCO, we recommend 100 queries. + aux_loss: True if auxiliary decoding losses (loss at each decoder layer) are to be used. + """ + super().__init__() + self.num_queries = num_queries + self.camera_names = camera_names + self.transformer = transformer + self.encoder = encoder + hidden_dim = transformer.d_model + self.action_head = nn.Linear(hidden_dim, state_dim) + self.is_pad_head = nn.Linear(hidden_dim, 1) + self.num_jpeg = ( + 20 * 4 + ) # 80 tokens for the 20 next frame jpegs, 4 token will represent 1 jpeg + self.query_embed = nn.Embedding(num_queries + self.num_jpeg, hidden_dim) + if backbones is not None: + self.input_proj = nn.Conv2d( + backbones[0].num_channels, hidden_dim, kernel_size=1 + ) + self.backbones = nn.ModuleList(backbones) + self.input_proj_robot_state = nn.Linear(14, hidden_dim) + else: + # input_dim = 14 + 7 # robot_state + env_state + self.input_proj_robot_state = nn.Linear(14, hidden_dim) + self.input_proj_env_state = nn.Linear(7, hidden_dim) + self.pos = torch.nn.Embedding(2, hidden_dim) + self.backbones = None + + # encoder extra parameters + self.latent_dim = 32 # final size of latent z # TODO tune + self.cls_embed = nn.Embedding(1, hidden_dim) # extra cls token embedding + self.encoder_action_proj = nn.Linear( + 14, hidden_dim + ) # project action to embedding + self.encoder_joint_proj = nn.Linear(14, hidden_dim) # project qpos to embedding + self.latent_proj = nn.Linear( + hidden_dim, self.latent_dim * 2 + ) # project hidden state to latent std, var + self.register_buffer( + "pos_table", get_sinusoid_encoding_table(1 + 1 + num_queries, hidden_dim) + ) # [CLS], qpos, a_seq + + # decoder extra parameters + self.latent_out_proj = nn.Linear( + self.latent_dim, hidden_dim + ) # project latent sample to embedding + self.additional_pos_embed = nn.Embedding( + 2, hidden_dim + ) # learned position embedding for proprio and latent + + # settings for cls token prediction + # self.mask_token = nn.Parameter(torch.zeros(1, 1, hidden_dim)) + # self.n_patch = (self.image_size//self.patch_size)**2 + # self.k = 1 # number of next frames + # self.n_patch = (self.img_h//self.patch_size)*(self.img_w//self.patch_size)*(self.k) + # self.decoder_pos_embed = nn.Parameter(torch.zeros(1, self.n_patch, hidden_dim), requires_grad=False) # (1, n_patch, h) + # self.patch_embed = nn.Embedding(self.n_patch, hidden_dim) + # self.decoder_embed = nn.Linear(hidden_dim, hidden_dim, bias=True) + + self.jpeg_head = nn.Sequential( + nn.Linear(hidden_dim, hidden_dim), nn.ReLU(), nn.Linear(hidden_dim, 400) + ) + + # self.jpeg_head = nn.Linear(hidden_dim, 400) + + # decoder_depth = 2 # hardcode + # self.decoder_blocks = nn.ModuleList([ + # Block(hidden_dim, 16, 4, qkv_bias=True, qk_scale=None, norm_layer=nn.LayerNorm) + # for i in range(decoder_depth)]) + + # self.decoder_norm = nn.LayerNorm(hidden_dim) + # self.decoder_pred = nn.Linear(hidden_dim, 400, bias=True) # decoder to patch + + # decoder_pos_embed = get_2d_sincos_pos_embed(self.decoder_pos_embed.shape[-1], (self.image_size//self.patch_size), cls_token=False) + # decoder_pos_embed = get_2d_sincos_pos_embed_v2(self.decoder_pos_embed.shape[-1], (self.img_h//self.patch_size, self.img_w//self.patch_size)) + # self.decoder_pos_embed.data.copy_(torch.from_numpy(decoder_pos_embed).float().unsqueeze(0).repeat(1,self.k,1)) + + def forward(self, qpos, image, env_state, actions=None, is_pad=None): + """ + qpos: batch, qpos_dim + image: batch, num_cam, channel, height, width + env_state: None + actions: batch, seq, action_dim + """ + is_training = actions is not None # train or val + bs, _ = qpos.shape + ### Obtain latent z from action sequence + if is_training: + # project action sequence to embedding dim, and concat with a CLS token + action_embed = self.encoder_action_proj(actions) # (bs, seq, hidden_dim) + qpos_embed = self.encoder_joint_proj(qpos) # (bs, hidden_dim) + qpos_embed = torch.unsqueeze(qpos_embed, axis=1) # (bs, 1, hidden_dim) + cls_embed = self.cls_embed.weight # (1, hidden_dim) + cls_embed = torch.unsqueeze(cls_embed, axis=0).repeat( + bs, 1, 1 + ) # (bs, 1, hidden_dim) + encoder_input = torch.cat( + [cls_embed, qpos_embed, action_embed], axis=1 + ) # (bs, seq+1, hidden_dim) + encoder_input = encoder_input.permute(1, 0, 2) # (seq+1, bs, hidden_dim) + # do not mask cls token + cls_joint_is_pad = torch.full((bs, 2), False).to( + qpos.device + ) # False: not a padding + is_pad = torch.cat([cls_joint_is_pad, is_pad], axis=1) # (bs, seq+1) + # obtain position embedding + pos_embed = self.pos_table.clone().detach() + pos_embed = pos_embed.permute(1, 0, 2) # (seq+1, 1, hidden_dim) + # query model + encoder_output = self.encoder( + encoder_input, pos=pos_embed, src_key_padding_mask=is_pad + ) + encoder_output = encoder_output[0] # take cls output only + latent_info = self.latent_proj(encoder_output) + mu = latent_info[:, : self.latent_dim] + logvar = latent_info[:, self.latent_dim :] + latent_sample = reparametrize(mu, logvar) + latent_input = self.latent_out_proj(latent_sample) + else: + mu = logvar = None + latent_sample = torch.zeros([bs, self.latent_dim], dtype=torch.float32).to( + qpos.device + ) + latent_input = self.latent_out_proj(latent_sample) + + if self.backbones is not None: + # Image observation features and position embeddings + all_cam_features = [] + all_cam_pos = [] + for cam_id, cam_name in enumerate(self.camera_names): + features, pos = self.backbones[0](image[:, cam_id]) # HARDCODED + features = features[0] # take the last layer feature + pos = pos[0] + all_cam_features.append(self.input_proj(features)) + all_cam_pos.append(pos) + # proprioception features + proprio_input = self.input_proj_robot_state(qpos) + # fold camera dimension into width dimension + src = torch.cat(all_cam_features, axis=3) + pos = torch.cat(all_cam_pos, axis=3) + hs = self.transformer( + src, + None, + self.query_embed.weight, + pos, + latent_input, + proprio_input, + self.additional_pos_embed.weight, + )[0] + else: + qpos = self.input_proj_robot_state(qpos) + env_state = self.input_proj_env_state(env_state) + transformer_input = torch.cat([qpos, env_state], axis=1) # seq length = 2 + hs = self.transformer( + transformer_input, None, self.query_embed.weight, self.pos.weight + )[0] + a_hat = self.action_head(hs[:, : self.num_queries]) + is_pad_hat = self.is_pad_head(hs[:, : self.num_queries]) + + # for blk in self.decoder_blocks: + # jpeg_token = blk(hs[:,self.num_queries:]) + # jpeg_token = self.decoder_norm(jpeg_token) + # jpeg_token_hat = self.decoder_pred(jpeg_token) + jpeg_token_hat = self.jpeg_head(hs[:, self.num_queries :]) + + return a_hat, is_pad_hat, [mu, logvar], jpeg_token_hat + + +class DETRVAEjpeg_diffusion(nn.Module): + # add timestep + """This is the DETR module that performs object detection""" + + def __init__( + self, + backbones, + transformer, + encoder, + state_dim, + num_queries, + camera_names, + disable_vae_latent=False, + predict_frame=20, + jpeg_token_num=4, + jpeg_dim=400, + ): + """Initializes the model. + Parameters: + backbones: torch module of the backbone to be used. See backbone.py + transformer: torch module of the transformer architecture. See transformer.py + state_dim: robot state dimension of the environment + num_queries: number of object queries, ie detection slot. This is the maximal number of objects + DETR can detect in a single image. For COCO, we recommend 100 queries. + aux_loss: True if auxiliary decoding losses (loss at each decoder layer) are to be used. + """ + super().__init__() + self.num_queries = num_queries + self.camera_names = camera_names + self.transformer = transformer + self.encoder = encoder + hidden_dim = transformer.d_model + self.disable_vae_latent = disable_vae_latent + self.action_head = nn.Linear(hidden_dim, state_dim) + self.is_pad_head = nn.Linear(hidden_dim, 1) + self.predict_frame = predict_frame + self.jpeg_token_num = jpeg_token_num + self.jpeg_dim = jpeg_dim + self.num_jpeg = ( + predict_frame * jpeg_token_num + ) # 80 tokens for the 20 next frame jpegs, 4 token will represent 1 jpeg TODO tune + self.query_embed = nn.Embedding(num_queries + self.num_jpeg, hidden_dim) + if backbones is not None: + self.input_proj = nn.Conv2d( + backbones[0].num_channels, hidden_dim, kernel_size=1 + ) + self.backbones = nn.ModuleList(backbones) + self.input_proj_robot_state = nn.Linear(14, hidden_dim) + else: + # input_dim = 14 + 7 # robot_state + env_state + self.input_proj_robot_state = nn.Linear(14, hidden_dim) + self.input_proj_env_state = nn.Linear(7, hidden_dim) + self.pos = torch.nn.Embedding(2, hidden_dim) + self.backbones = None + + # encoder extra parameters + self.latent_dim = 32 # final size of latent z # TODO tune + self.cls_embed = nn.Embedding(1, hidden_dim) # extra cls token embedding + self.encoder_action_proj = nn.Linear( + 14, hidden_dim + ) # project action to embedding + self.encoder_joint_proj = nn.Linear(14, hidden_dim) # project qpos to embedding + self.latent_proj = nn.Linear( + hidden_dim, self.latent_dim * 2 + ) # project hidden state to latent std, var + self.register_buffer( + "pos_table", get_sinusoid_encoding_table(1 + 1 + num_queries, hidden_dim) + ) # [CLS], qpos, a_seq + + # decoder extra parameters + self.latent_out_proj = nn.Linear( + self.latent_dim, hidden_dim + ) # project latent sample to embedding + self.additional_pos_embed = nn.Embedding( + 2, hidden_dim + ) # learned position embedding for proprio and latent + + self.jpeg_head = nn.Sequential( + nn.Linear(hidden_dim, hidden_dim), + nn.ReLU(), + nn.Linear(hidden_dim, jpeg_dim), + ) + + def forward( + self, + qpos, + image, + env_state, + actions=None, + is_pad=None, + jpegs=None, + is_jpeg_pad=None, + denoise_steps=None, + is_training=True, + ): + """ + qpos: batch, qpos_dim + image: batch, num_cam, channel, height, width + env_state: None + actions: batch, seq, action_dim + """ + # is_training = actions is not None # train or val + bs, _ = qpos.shape + ### Obtain latent z from action sequence + if is_training and self.disable_vae_latent == False: + # project action sequence to embedding dim, and concat with a CLS token + action_embed = self.encoder_action_proj(actions) # (bs, seq, hidden_dim) + qpos_embed = self.encoder_joint_proj(qpos) # (bs, hidden_dim) + qpos_embed = torch.unsqueeze(qpos_embed, axis=1) # (bs, 1, hidden_dim) + cls_embed = self.cls_embed.weight # (1, hidden_dim) + cls_embed = torch.unsqueeze(cls_embed, axis=0).repeat( + bs, 1, 1 + ) # (bs, 1, hidden_dim) + encoder_input = torch.cat( + [cls_embed, qpos_embed, action_embed], axis=1 + ) # (bs, seq+1, hidden_dim) + encoder_input = encoder_input.permute(1, 0, 2) # (seq+1, bs, hidden_dim) + # do not mask cls token + cls_joint_is_pad = torch.full((bs, 2), False).to( + qpos.device + ) # False: not a padding + is_pad = torch.cat([cls_joint_is_pad, is_pad], axis=1) # (bs, seq+1) + # obtain position embedding + pos_embed = self.pos_table.clone().detach() + pos_embed = pos_embed.permute(1, 0, 2) # (seq+1, 1, hidden_dim) + # query model + encoder_output = self.encoder( + encoder_input, pos=pos_embed, src_key_padding_mask=is_pad + ) + encoder_output = encoder_output[0] # take cls output only + latent_info = self.latent_proj(encoder_output) + mu = latent_info[:, : self.latent_dim] + logvar = latent_info[:, self.latent_dim :] + latent_sample = reparametrize(mu, logvar) + latent_input = self.latent_out_proj(latent_sample) + else: + mu = logvar = None + latent_sample = torch.zeros([bs, self.latent_dim], dtype=torch.float32).to( + qpos.device + ) + latent_input = self.latent_out_proj(latent_sample) + + if self.backbones is not None: + # Image observation features and position embeddings + all_cam_features = [] + all_cam_pos = [] + for cam_id, cam_name in enumerate(self.camera_names): + features, pos = self.backbones[0](image[:, cam_id]) # HARDCODED + features = features[0] # take the last layer feature + pos = pos[0] + all_cam_features.append(self.input_proj(features)) + all_cam_pos.append(pos) + # proprioception features + proprio_input = self.input_proj_robot_state(qpos) + # fold camera dimension into width dimension + src = torch.cat(all_cam_features, axis=3) + # need timestep? + pos = torch.cat(all_cam_pos, axis=3) + hs = self.transformer( + src, + None, + self.query_embed.weight, + pos, + latent_input, + proprio_input, + self.additional_pos_embed.weight, + actions, + jpegs, + denoise_steps, + )[0] + else: + qpos = self.input_proj_robot_state(qpos) + env_state = self.input_proj_env_state(env_state) + transformer_input = torch.cat([qpos, env_state], axis=1) # seq length = 2 + hs = self.transformer( + transformer_input, None, self.query_embed.weight, self.pos.weight + )[0] + a_hat = self.action_head(hs[:, : self.num_queries]) + # is_pad_hat = self.is_pad_head(hs[:, :self.num_queries]) + is_pad_hat = self.is_pad_head(hs) + # for blk in self.decoder_blocks: + # jpeg_token = blk(hs[:,self.num_queries:]) + # jpeg_token = self.decoder_norm(jpeg_token) + # jpeg_token_hat = self.decoder_pred(jpeg_token) + jpeg_token_hat = self.jpeg_head(hs[:, self.num_queries :]) + + return a_hat, jpeg_token_hat, is_pad_hat, [mu, logvar] + + +class DETRVAEjpeg_diffusion_seperate(nn.Module): + # add timestep + """This is the DETR module that performs object detection""" + + def __init__( + self, + backbones, + transformer, + encoder, + state_dim, + num_queries, + camera_names, + disable_vae_latent=False, + predict_frame=20, + jpeg_token_num=4, + jpeg_dim=400, + ): + """Initializes the model. + Parameters: + backbones: torch module of the backbone to be used. See backbone.py + transformer: torch module of the transformer architecture. See transformer.py + state_dim: robot state dimension of the environment + num_queries: number of object queries, ie detection slot. This is the maximal number of objects + DETR can detect in a single image. For COCO, we recommend 100 queries. + aux_loss: True if auxiliary decoding losses (loss at each decoder layer) are to be used. + """ + super().__init__() + self.num_queries = num_queries + self.camera_names = camera_names + self.transformer = transformer + self.encoder = encoder + hidden_dim = transformer.d_model + self.disable_vae_latent = disable_vae_latent + self.action_head = nn.Linear(hidden_dim, state_dim) + self.is_pad_head = nn.Linear(hidden_dim, 1) + self.predict_frame = predict_frame + self.jpeg_token_num = jpeg_token_num + self.jpeg_dim = jpeg_dim + self.num_jpeg = ( + predict_frame * jpeg_token_num + ) # 80 tokens for the 20 next frame jpegs, 4 token will represent 1 jpeg TODO tune + self.query_embed_action = nn.Embedding(num_queries, hidden_dim) + self.query_embed_jpeg = nn.Embedding(self.num_jpeg, hidden_dim) + # self.query_embed = nn.Embedding(num_queries+self.num_jpeg, hidden_dim) + if backbones is not None: + self.input_proj = nn.Conv2d( + backbones[0].num_channels, hidden_dim, kernel_size=1 + ) + self.backbones = nn.ModuleList(backbones) + self.input_proj_robot_state = nn.Linear(14, hidden_dim) + else: + # input_dim = 14 + 7 # robot_state + env_state + self.input_proj_robot_state = nn.Linear(14, hidden_dim) + self.input_proj_env_state = nn.Linear(7, hidden_dim) + self.pos = torch.nn.Embedding(2, hidden_dim) + self.backbones = None + + # encoder extra parameters + self.latent_dim = 32 # final size of latent z # TODO tune + self.cls_embed = nn.Embedding(1, hidden_dim) # extra cls token embedding + self.encoder_action_proj = nn.Linear( + 14, hidden_dim + ) # project action to embedding + self.encoder_joint_proj = nn.Linear(14, hidden_dim) # project qpos to embedding + self.latent_proj = nn.Linear( + hidden_dim, self.latent_dim * 2 + ) # project hidden state to latent std, var + self.register_buffer( + "pos_table", get_sinusoid_encoding_table(1 + 1 + num_queries, hidden_dim) + ) # [CLS], qpos, a_seq + + # decoder extra parameters + self.latent_out_proj = nn.Linear( + self.latent_dim, hidden_dim + ) # project latent sample to embedding + self.additional_pos_embed = nn.Embedding( + 2, hidden_dim + ) # learned position embedding for proprio and latent + + self.jpeg_head = nn.Sequential( + nn.Linear(hidden_dim, hidden_dim), + nn.ReLU(), + nn.Linear(hidden_dim, jpeg_dim), # todo + ) + + def forward( + self, + qpos, + image, + env_state, + actions=None, + is_pad=None, + denoise_steps=None, + is_training=True, + ): + """ + qpos: batch, qpos_dim + image: batch, num_cam, channel, height, width + env_state: None + actions: batch, seq, action_dim + """ + # is_training = actions is not None # train or val + bs, _ = qpos.shape + ### Obtain latent z from action sequence + if is_training and self.disable_vae_latent == False: + # project action sequence to embedding dim, and concat with a CLS token + action_embed = self.encoder_action_proj(actions) # (bs, seq, hidden_dim) + qpos_embed = self.encoder_joint_proj(qpos) # (bs, hidden_dim) + qpos_embed = torch.unsqueeze(qpos_embed, axis=1) # (bs, 1, hidden_dim) + cls_embed = self.cls_embed.weight # (1, hidden_dim) + cls_embed = torch.unsqueeze(cls_embed, axis=0).repeat( + bs, 1, 1 + ) # (bs, 1, hidden_dim) + encoder_input = torch.cat( + [cls_embed, qpos_embed, action_embed], axis=1 + ) # (bs, seq+1, hidden_dim) + encoder_input = encoder_input.permute(1, 0, 2) # (seq+1, bs, hidden_dim) + # do not mask cls token + cls_joint_is_pad = torch.full((bs, 2), False).to( + qpos.device + ) # False: not a padding + is_pad = torch.cat([cls_joint_is_pad, is_pad], axis=1) # (bs, seq+1) + # obtain position embedding + pos_embed = self.pos_table.clone().detach() + pos_embed = pos_embed.permute(1, 0, 2) # (seq+1, 1, hidden_dim) + # query model + encoder_output = self.encoder( + encoder_input, pos=pos_embed, src_key_padding_mask=is_pad + ) + encoder_output = encoder_output[0] # take cls output only + latent_info = self.latent_proj(encoder_output) + mu = latent_info[:, : self.latent_dim] + logvar = latent_info[:, self.latent_dim :] + latent_sample = reparametrize(mu, logvar) + latent_input = self.latent_out_proj(latent_sample) + else: + mu = logvar = None + latent_sample = torch.zeros([bs, self.latent_dim], dtype=torch.float32).to( + qpos.device + ) + latent_input = self.latent_out_proj(latent_sample) + + if self.backbones is not None: + # Image observation features and position embeddings + all_cam_features = [] + all_cam_pos = [] + for cam_id, cam_name in enumerate(self.camera_names): + features, pos = self.backbones[0](image[:, cam_id]) # HARDCODED + features = features[0] # take the last layer feature + pos = pos[0] + all_cam_features.append(self.input_proj(features)) + all_cam_pos.append(pos) + # proprioception features + proprio_input = self.input_proj_robot_state(qpos) + # fold camera dimension into width dimension + src = torch.cat(all_cam_features, axis=3) + # need timestep? + pos = torch.cat(all_cam_pos, axis=3) + # hs = self.transformer(src, None, self.query_embed.weight, pos, latent_input, proprio_input, self.additional_pos_embed.weight, actions, jpegs, denoise_steps,self.predict_frame, self.jpeg_token_num)[0] + hs_action, hs_jpeg = self.transformer( + src, + None, + self.query_embed_action.weight, + self.query_embed_jpeg.weight, + pos, + latent_input, + proprio_input, + self.additional_pos_embed.weight, + actions, + denoise_steps, + ) + else: + qpos = self.input_proj_robot_state(qpos) + env_state = self.input_proj_env_state(env_state) + transformer_input = torch.cat([qpos, env_state], axis=1) # seq length = 2 + # print(hs_action.shape,hs_jpeg.shape) + a_hat = self.action_head(hs_action) + is_pad_action_hat = self.is_pad_head(hs_action) + is_pad_jpeg_hat = self.is_pad_head(hs_jpeg) + # print(a_hat.shape) + # print(is_pad_action_hat.shape,is_pad_jpeg_hat.shape) + is_pad_hat = torch.cat([is_pad_action_hat, is_pad_jpeg_hat], axis=1) + jpeg_token_hat = self.jpeg_head(hs_jpeg) + + return a_hat, jpeg_token_hat, is_pad_hat, [mu, logvar] + + +class DETRVAE_nf_diffusion(nn.Module): + """This is the DETR module that performs object detection""" + + def __init__( + self, + backbones, + transformer, + encoder, + state_dim, + num_queries, + camera_names, + predict_frame, + disable_vae_latent=False, + ): + """Initializes the model. + Parameters: + backbones: torch module of the backbone to be used. See backbone.py + transformer: torch module of the transformer architecture. See transformer.py + state_dim: robot state dimension of the environment + num_queries: number of object queries, ie detection slot. This is the maximal number of objects + DETR can detect in a single image. For COCO, we recommend 100 queries. + aux_loss: True if auxiliary decoding losses (loss at each decoder layer) are to be used. + """ + super().__init__() + self.num_queries = num_queries + self.camera_names = camera_names + self.transformer = transformer + self.encoder = encoder + hidden_dim = transformer.d_model + self.action_head = nn.Linear(hidden_dim, state_dim) + self.is_pad_head = nn.Linear(hidden_dim, 1) + self.query_embed = nn.Embedding(num_queries, hidden_dim) + if backbones is not None: + self.input_proj = nn.Conv2d( + backbones[0].num_channels, hidden_dim, kernel_size=1 + ) + self.backbones = nn.ModuleList(backbones) + self.input_proj_robot_state = nn.Linear(14, hidden_dim) + else: + # input_dim = 14 + 7 # robot_state + env_state + self.input_proj_robot_state = nn.Linear(14, hidden_dim) + self.input_proj_env_state = nn.Linear(7, hidden_dim) + self.pos = torch.nn.Embedding(2, hidden_dim) + self.backbones = None + + # encoder extra parameters + self.latent_dim = 32 # final size of latent z # TODO tune + self.cls_embed = nn.Embedding(1, hidden_dim) # extra cls token embedding + self.encoder_action_proj = nn.Linear( + 14, hidden_dim + ) # project action to embedding + self.encoder_joint_proj = nn.Linear(14, hidden_dim) # project qpos to embedding + self.latent_proj = nn.Linear( + hidden_dim, self.latent_dim * 2 + ) # project hidden state to latent std, var + self.register_buffer( + "pos_table", get_sinusoid_encoding_table(1 + 1 + num_queries, hidden_dim) + ) # [CLS], qpos, a_seq + + # decoder extra parameters + self.latent_out_proj = nn.Linear( + self.latent_dim, hidden_dim + ) # project latent sample to embedding + self.additional_pos_embed = nn.Embedding( + 2, hidden_dim + ) # learned position embedding for proprio and latent + + # settings for next frame prediction + self.patch_size = 16 + self.img_h, self.img_w = 224, 224 + self.mask_token = nn.Parameter(torch.zeros(1, 1, hidden_dim)) + # self.n_patch = (self.image_size//self.patch_size)**2 + self.k = predict_frame # number of next frames + self.n_patch = ( + (self.img_h // self.patch_size) * (self.img_w // self.patch_size) * (self.k) + ) + self.decoder_pos_embed = nn.Parameter( + torch.zeros(1, self.n_patch, hidden_dim), requires_grad=False + ) # (1, n_patch, h) + self.patch_embed = nn.Embedding(self.n_patch, hidden_dim) + self.decoder_embed = nn.Linear(hidden_dim, hidden_dim, bias=True) + + decoder_depth = 2 # hardcode + self.decoder_blocks = nn.ModuleList( + [ + Block( + hidden_dim, + 16, + 4, + qkv_bias=True, + qk_scale=None, + norm_layer=nn.LayerNorm, + ) + for i in range(decoder_depth) + ] + ) + + self.decoder_norm = nn.LayerNorm(hidden_dim) + self.decoder_pred = nn.Linear( + hidden_dim, self.patch_size**2 * 3, bias=True + ) # decoder to patch + + # decoder_pos_embed = get_2d_sincos_pos_embed(self.decoder_pos_embed.shape[-1], (self.image_size//self.patch_size), cls_token=False) + decoder_pos_embed = get_2d_sincos_pos_embed_v2( + self.decoder_pos_embed.shape[-1], + (self.img_h // self.patch_size, self.img_w // self.patch_size), + ) + self.decoder_pos_embed.data.copy_( + torch.from_numpy(decoder_pos_embed) + .float() + .unsqueeze(0) + .repeat(1, self.k, 1) + ) + self.disable_vae_latent = disable_vae_latent + # fwd_params = sum(p.numel() for p in self.decoder_blocks.parameters() if p.requires_grad) + + def forward( + self, + qpos, + image, + env_state, + actions=None, + noisy_actions=None, + is_pad=None, + denoise_steps=None, + ): + """ + qpos: batch, qpos_dim + image: batch, num_cam, channel, height, width + env_state: None + actions: batch, seq, action_dim + """ + is_training = actions is not None # train or val + # print('is_training',is_training) + bs, _ = qpos.shape + # ### Obtain latent z from action sequence + # print('detr image shape',image.shape) + if is_training and not self.disable_vae_latent: + # project action sequence to embedding dim, and concat with a CLS token + action_embed = self.encoder_action_proj(actions) # (bs, seq, hidden_dim) + qpos_embed = self.encoder_joint_proj(qpos) # (bs, hidden_dim) + qpos_embed = torch.unsqueeze(qpos_embed, axis=1) # (bs, 1, hidden_dim) + cls_embed = self.cls_embed.weight # (1, hidden_dim) + cls_embed = torch.unsqueeze(cls_embed, axis=0).repeat( + bs, 1, 1 + ) # (bs, 1, hidden_dim) + encoder_input = torch.cat( + [cls_embed, qpos_embed, action_embed], axis=1 + ) # (bs, seq+1, hidden_dim) + encoder_input = encoder_input.permute(1, 0, 2) # (seq+1, bs, hidden_dim) + # do not mask cls token + cls_joint_is_pad = torch.full((bs, 2), False).to( + qpos.device + ) # False: not a padding + is_pad = torch.cat([cls_joint_is_pad, is_pad], axis=1) # (bs, seq+1) + # obtain position embedding + pos_embed = self.pos_table.clone().detach() + pos_embed = pos_embed.permute(1, 0, 2) # (seq+1, 1, hidden_dim) + # query model + encoder_output = self.encoder( + encoder_input, pos=pos_embed, src_key_padding_mask=is_pad + ) + encoder_output = encoder_output[0] # take cls output only + latent_info = self.latent_proj(encoder_output) + mu = latent_info[:, : self.latent_dim] + logvar = latent_info[:, self.latent_dim :] + latent_sample = reparametrize(mu, logvar) + latent_input = self.latent_out_proj(latent_sample) + else: + mu = logvar = None + latent_sample = torch.zeros([bs, self.latent_dim], dtype=torch.float32).to( + qpos.device + ) + latent_input = self.latent_out_proj(latent_sample) + + if self.backbones is not None: + # Image observation features and position embeddings + all_cam_features = [] + all_cam_pos = [] + if is_training: + next_frame_images = image[:, 1:] # should resize? + image = image[:, :1] + for cam_id, cam_name in enumerate(self.camera_names): + features, pos = self.backbones[0](image[:, cam_id]) # HARDCODED? + features = features[0] # take the last layer feature + pos = pos[0] + all_cam_features.append(self.input_proj(features)) + all_cam_pos.append(pos) + # proprioception features + proprio_input = self.input_proj_robot_state(qpos) + # fold camera dimension into width dimension + src = torch.cat(all_cam_features, axis=3) + pos = torch.cat(all_cam_pos, axis=3) + # query_embed = torch.cat([self.query_embed.weight, self.patch_embed.weight], axis=0) + # should change + # print('src',src.shape) + hs_action, hs_patch = self.transformer( + src, + None, + self.query_embed.weight, + self.patch_embed.weight, + pos, + latent_input, + proprio_input, + self.additional_pos_embed.weight, + noisy_actions, + denoise_steps, + ) + # hs = self.transformer(src, None, self.query_embed.weight, pos, latent_input, proprio_input, self.additional_pos_embed.weight)[0] + else: + qpos = self.input_proj_robot_state(qpos) + env_state = self.input_proj_env_state(env_state) + transformer_input = torch.cat([qpos, env_state], axis=1) # seq length = 2 + hs = self.transformer( + transformer_input, None, self.query_embed.weight, self.pos.weight + )[0] + + a_hat = self.action_head(hs_action) + is_pad_hat = self.is_pad_head(hs_action) + + # print('hs',hs_action.shape) + # print('a_hat',a_hat.shape) + # next frame prediction + mask_token = self.mask_token + mask_tokens = mask_token.repeat(bs, self.n_patch, 1) + mask_tokens = mask_tokens + self.decoder_pos_embed.repeat(bs, 1, 1) + + obs_pred = self.decoder_embed(hs_patch) + obs_pred_ = torch.cat([obs_pred, mask_tokens], dim=1) + # print('obs_pred_',obs_pred_.shape) + for blk in self.decoder_blocks: + obs_pred_ = blk(obs_pred_) + obs_pred_ = self.decoder_norm(obs_pred_) + obs_preds = self.decoder_pred(obs_pred_) + # print('obs_preds',obs_preds.shape) + # print(self.n_patch) + obs_preds = obs_preds[:, self.n_patch :] + # print('obs_preds',obs_preds.shape) + if is_training: + # next_frame_images = image[:,1:] + # print('next_frame_images',next_frame_images.shape) + next_frame_images = nn.functional.interpolate( + next_frame_images.reshape(bs, -1, *next_frame_images.shape[-2:]), + size=(self.img_h, self.img_w), + ) + # print('next_frame_images',next_frame_images.shape) + p = self.patch_size + h_p = self.img_h // p + w_p = self.img_w // p + obs_targets = next_frame_images.reshape( + shape=(bs, self.k, 3, h_p, p, w_p, p) + ) + obs_targets = obs_targets.permute(0, 1, 3, 5, 4, 6, 2) + obs_targets = obs_targets.reshape( + shape=(bs, h_p * w_p * self.k, (p**2) * 3) + ) + else: + obs_targets = torch.zeros_like(obs_preds) + + return a_hat, is_pad_hat, [mu, logvar], [obs_preds, obs_targets] + + +class CNNMLP(nn.Module): + def __init__(self, backbones, state_dim, camera_names): + """Initializes the model. + Parameters: + backbones: torch module of the backbone to be used. See backbone.py + transformer: torch module of the transformer architecture. See transformer.py + state_dim: robot state dimension of the environment + num_queries: number of object queries, ie detection slot. This is the maximal number of objects + DETR can detect in a single image. For COCO, we recommend 100 queries. + aux_loss: True if auxiliary decoding losses (loss at each decoder layer) are to be used. + """ + super().__init__() + self.camera_names = camera_names + self.action_head = nn.Linear(1000, state_dim) # TODO add more + if backbones is not None: + self.backbones = nn.ModuleList(backbones) + backbone_down_projs = [] + for backbone in backbones: + down_proj = nn.Sequential( + nn.Conv2d(backbone.num_channels, 128, kernel_size=5), + nn.Conv2d(128, 64, kernel_size=5), + nn.Conv2d(64, 32, kernel_size=5), + ) + backbone_down_projs.append(down_proj) + self.backbone_down_projs = nn.ModuleList(backbone_down_projs) + + mlp_in_dim = 768 * len(backbones) + 14 + self.mlp = mlp( + input_dim=mlp_in_dim, hidden_dim=1024, output_dim=14, hidden_depth=2 + ) + else: + raise NotImplementedError + + def forward(self, qpos, image, env_state, actions=None): + """ + qpos: batch, qpos_dim + image: batch, num_cam, channel, height, width + env_state: None + actions: batch, seq, action_dim + """ + is_training = actions is not None # train or val + bs, _ = qpos.shape + # Image observation features and position embeddings + all_cam_features = [] + for cam_id, cam_name in enumerate(self.camera_names): + features, pos = self.backbones[cam_id](image[:, cam_id]) + features = features[0] # take the last layer feature + pos = pos[0] # not used + all_cam_features.append(self.backbone_down_projs[cam_id](features)) + # flatten everything + flattened_features = [] + for cam_feature in all_cam_features: + flattened_features.append(cam_feature.reshape([bs, -1])) + flattened_features = torch.cat(flattened_features, axis=1) # 768 each + features = torch.cat([flattened_features, qpos], axis=1) # qpos: 14 + a_hat = self.mlp(features) + return a_hat + + +def mlp(input_dim, hidden_dim, output_dim, hidden_depth): + if hidden_depth == 0: + mods = [nn.Linear(input_dim, output_dim)] + else: + mods = [nn.Linear(input_dim, hidden_dim), nn.ReLU(inplace=True)] + for i in range(hidden_depth - 1): + mods += [nn.Linear(hidden_dim, hidden_dim), nn.ReLU(inplace=True)] + mods.append(nn.Linear(hidden_dim, output_dim)) + trunk = nn.Sequential(*mods) + return trunk + + +def build_encoder(args): + d_model = args.hidden_dim # 256 + dropout = args.dropout # 0.1 + nhead = args.nheads # 8 + dim_feedforward = args.dim_feedforward # 2048 + num_encoder_layers = args.enc_layers # 4 # TODO shared with VAE decoder + normalize_before = args.pre_norm # False + activation = "relu" + + encoder_layer = TransformerEncoderLayer( + d_model, nhead, dim_feedforward, dropout, activation, normalize_before + ) + encoder_norm = nn.LayerNorm(d_model) if normalize_before else None + encoder = TransformerEncoder(encoder_layer, num_encoder_layers, encoder_norm) + + return encoder + + +def build_tactile_encoder(args): + tactile_dim = args.hidden_dim + dropout = args.dropout + tactile_encoder = Tactile_Encoder( + tactile_dim,dropout ) + return tactile_encoder + +def build(args): + state_dim = 14 # TODO hardcode + + # From state + # backbone = None # from state for now, no need for conv nets + # From image + backbones = [] + backbone = build_backbone(args) + backbones.append(backbone) + + transformer = build_transformer(args) + + encoder = build_encoder(args) + + model = DETRVAE( + backbones, + transformer, + encoder, + state_dim=state_dim, + num_queries=args.num_queries, + camera_names=args.camera_names, + ) + + n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad) + print("number of parameters: %.2fM" % (n_parameters / 1e6,)) + + return model + + +def build_diffusion(args): + state_dim = 14 # TODO hardcode + + # From state + # backbone = None # from state for now, no need for conv nets + # From image + backbones = [] + for camera_id in args.camera_names: + if args.use_film: + backbone = build_film_backbone(args) + else: + backbone = build_backbone(args) + # backbone = build_backbone(args) + backbones.append(backbone) + transformer = build_transformer_denoise( + args + ) # decoder input noisy input & PE & time_embedding + encoder = build_encoder(args) + model = DETRVAE_Denoise( + backbones, + transformer, + encoder, + state_dim=state_dim, + num_queries=args.num_queries, + camera_names=args.camera_names, # add additional denoise step + history_step=args.history_step, + disable_vae_latent=args.disable_vae_latent, + is_multi_task=args.multitask, + use_film=args.use_film, + ) + + n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad) + print("number of parameters: %.2fM" % (n_parameters / 1e6,)) + + return model + + +def build_diffusion_tactile(args): + state_dim = 14 # TODO hardcode + + # From state + # backbone = None # from state for now, no need for conv nets + # From image + backbones = [] + for camera_id in args.camera_names: + backbone = build_backbone(args) + backbones.append(backbone) + transformer = build_transformer_denoise_tactile( + args + ) # decoder input noisy input & PE & time_embedding + tactile_encoder = build_tactile_encoder(args) + model = DETRVAE_Denoise_Tactile( + backbones, + transformer, + tactile_encoder, + state_dim=state_dim, + num_queries=args.num_queries, + camera_names=args.camera_names, # add additional denoise step + ) + + n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad) + print("number of parameters: %.2fM" % (n_parameters / 1e6,)) + + return model + +def build_diffusion_tp(args): # token prediction + state_dim = 14 # TODO hardcode + + # From state + # backbone = None # from state for now, no need for conv nets + # From image + backbones = [] + for camera_id in args.camera_names: + backbone = build_backbone(args) + backbones.append(backbone) + transformer = build_transformer_diffusion_prediction(args) + encoder = build_encoder(args) + + model = DETRVAE_Denoise_Token_Prediction( # TODO design for token pred + backbones, + transformer, + encoder, + state_dim=state_dim, + num_queries=args.num_queries, + camera_names=args.camera_names, # add additional denoise step + history_step=args.history_step, + predict_frame=args.predict_frame, + image_downsample_rate=args.image_downsample_rate, + temporal_downsample_rate=args.temporal_downsample_rate, + disable_vae_latent=args.disable_vae_latent, + disable_resnet=args.disable_resnet, + patch_size=args.patch_size, + token_pe_type=args.token_pe_type, + image_width=args.image_width, + image_height=args.image_height, + tokenizer_model_spatial_rate=args.tokenizer_model_spatial_rate, + tokenizer_model_temporal_rate=args.tokenizer_model_temporal_rate, + resize_rate=args.resize_rate, + ) + + n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad) + print("number of parameters: %.2fM" % (n_parameters / 1e6,)) + + return model + + +def build_diffusion_tp_with_dual_visual_token(args): # token prediction + state_dim = 14 # TODO hardcode + + # From state + # backbone = None # from state for now, no need for conv nets + # From image + backbones = [] + for i in range(len(args.camera_names)): + backbone = build_backbone(args) + backbones.append(backbone) # fixed + + transformer = build_transformer_diffusion_prediction_with_dual_visual_token( + args + ) # TODO design for token pred decoder input noisy input & PE & time_embedding + encoder = build_encoder(args) + + model = DETRVAE_Denoise_Token_Prediction_Dual_Visual_Token( # TODO design for token pred + backbones, + transformer, + encoder, + state_dim=state_dim, + num_queries=args.num_queries, + camera_names=args.camera_names, # add additional denoise step + history_step=args.history_step, + predict_frame=args.predict_frame, + image_downsample_rate=args.image_downsample_rate, + temporal_downsample_rate=args.temporal_downsample_rate, + disable_vae_latent=args.disable_vae_latent, + disable_resnet=args.disable_resnet, + patch_size=args.patch_size, + token_pe_type=args.token_pe_type, + image_width=args.image_width, + image_height=args.image_height, + tokenizer_model_spatial_rate=args.tokenizer_model_spatial_rate, + tokenizer_model_temporal_rate=args.tokenizer_model_temporal_rate, + ) + + n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad) + print("number of parameters: %.2fM" % (n_parameters / 1e6,)) + + return model + + +def build_diffusion_pp(args): # pixel prediction + state_dim = 14 # TODO hardcode + + # From state + # backbone = None # from state for now, no need for conv nets + # From image + backbones = [] + backbone = build_backbone(args) + if args.disable_resnet: # HARDCODE + for param in backbone.parameters(): + param.requires_grad = False + backbones.append(backbone) # fixed + # refer to Transformer_diffusion build_transformer_diffusion + # TODO unified prediction + transformer = build_transformer_diffusion_pixel_prediction( + args + ) # TODO design for token pred decoder input noisy input & PE & time_embedding + encoder = build_encoder(args) + # TODO fix + model = DETRVAE_Denoise_Pixel_Prediction( # TODO design for token pred + backbones, + transformer, + encoder, + state_dim=state_dim, + num_queries=args.num_queries, + camera_names=args.camera_names, # add additional denoise step + history_step=args.history_step, + predict_frame=args.predict_frame, + image_downsample_rate=args.image_downsample_rate, + temporal_downsample_rate=args.temporal_downsample_rate, + disable_vae_latent=args.disable_vae_latent, + disable_resnet=args.disable_resnet, + patch_size=args.patch_size, + token_pe_type=args.token_pe_type, + image_width=args.image_width, + image_height=args.image_height, + resize_rate=args.resize_rate, + ) + + n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad) + print("number of parameters: %.2fM" % (n_parameters / 1e6,)) + + return model + + +# discard +def build_seg(args): + state_dim = 14 # TODO hardcode + + # From state + # backbone = None # from state for now, no need for conv nets + # From image + # backbones = [] + # backbone = build_backbone(args) + # backbones.append(backbone) + + # transformer = build_transformer(args) + + # encoder = build_encoder(args) + + # model = DETRVAESeg( + # backbones, + # transformer, + # encoder, + # state_dim=state_dim, + # num_queries=args.num_queries, + # camera_names=args.camera_names, + # ) + + # n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad) + # print("number of parameters: %.2fM" % (n_parameters/1e6,)) + + # return model + + +def build_dino(args): + state_dim = 14 # TODO hardcode + + # From state + # backbone = None # from state for now, no need for conv nets + # From image + backbones = [] + backbone = build_backbone(args) + backbones.append(backbone) + + transformer = build_transformer(args) + + encoder = build_encoder(args) + + model = DETRVAEDino( + backbones, + transformer, + encoder, + state_dim=state_dim, + num_queries=args.num_queries, + camera_names=args.camera_names, + ) + + n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad) + print("number of parameters: %.2fM" % (n_parameters / 1e6,)) + + return model + + +def build_jpeg(args): + state_dim = 14 # TODO hardcode + + # From state + # backbone = None # from state for now, no need for conv nets + # From image + backbones = [] + backbone = build_backbone(args) + backbones.append(backbone) + + transformer = build_transformer(args) + + encoder = build_encoder(args) + + model = DETRVAEjpeg( + backbones, + transformer, + encoder, + state_dim=state_dim, + num_queries=args.num_queries, + camera_names=args.camera_names, + ) + + n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad) + print("number of parameters: %.2fM" % (n_parameters / 1e6,)) + + return model + + +def build_jpeg_diffusion(args): + state_dim = 14 # TODO hardcode + + # From state + # backbone = None # from state for now, no need for conv nets + # From image + backbones = [] + backbone = build_backbone(args) + backbones.append(backbone) + + transformer = build_transformer_diffusion(args) + + encoder = build_encoder(args) + + model = DETRVAEjpeg_diffusion( + backbones, + transformer, + encoder, + state_dim=state_dim, + num_queries=args.num_queries, + camera_names=args.camera_names, + disable_vae_latent=args.disable_vae_latent, + predict_frame=args.predict_frame, + ) + + n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad) + print("number of parameters: %.2fM" % (n_parameters / 1e6,)) + + return model + + +def build_jpeg_diffusion_seperate(args): + state_dim = 14 # TODO hardcode + + # From state + # backbone = None # from state for now, no need for conv nets + # From image + print("predict_frame", args.predict_frame) + print("jpeg_token_num", args.jpeg_token_num) + print("jpeg_dim", args.jpeg_dim) + backbones = [] + backbone = build_backbone(args) + backbones.append(backbone) + + transformer = build_transformer_diffusion_seperate(args) + + encoder = build_encoder(args) + + model = DETRVAEjpeg_diffusion_seperate( + backbones, + transformer, + encoder, + state_dim=state_dim, + num_queries=args.num_queries, + camera_names=args.camera_names, + disable_vae_latent=args.disable_vae_latent, + predict_frame=args.predict_frame, + jpeg_token_num=args.jpeg_token_num, + jpeg_dim=args.jpeg_dim, + ) + + n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad) + print("number of parameters: %.2fM" % (n_parameters / 1e6,)) + + return model + + +def build_nf_diffusion_seperate(args): + state_dim = 14 # TODO hardcode + + # From state + # backbone = None # from state for now, no need for conv nets + # From image + print("predict_frame", args.predict_frame) + backbones = [] + backbone = build_backbone(args) + backbones.append(backbone) + + transformer = build_transformer_diffusion_seperate(args) + + encoder = build_encoder(args) + + model = DETRVAE_nf_diffusion( + backbones, + transformer, + encoder, + state_dim=state_dim, + num_queries=args.num_queries, + camera_names=args.camera_names, + disable_vae_latent=args.disable_vae_latent, + predict_frame=args.predict_frame, + ) + + n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad) + print("number of parameters: %.2fM" % (n_parameters / 1e6,)) + + return model + + +def build_cnnmlp(args): + state_dim = 14 # TODO hardcode + + # From state + # backbone = None # from state for now, no need for conv nets + # From image + backbones = [] + for _ in args.camera_names: + backbone = build_backbone(args) + backbones.append(backbone) + + model = CNNMLP( + backbones, + state_dim=state_dim, + camera_names=args.camera_names, + ) + + n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad) + print("number of parameters: %.2fM" % (n_parameters / 1e6,)) + + return model diff --git a/ACT_DP_multitask/detr/models/detr_vae_nfp.py b/ACT_DP_multitask/detr/models/detr_vae_nfp.py new file mode 100644 index 0000000000000000000000000000000000000000..4557cf2da298f11eb57b736d7a816aeeb5902b64 --- /dev/null +++ b/ACT_DP_multitask/detr/models/detr_vae_nfp.py @@ -0,0 +1,523 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved +""" +DETR model and criterion classes. +""" +import torch +from torch import nn +from torch.autograd import Variable +from .backbone import build_backbone +from .transformer import build_transformer, TransformerEncoder, TransformerEncoderLayer +from .vision_transformer import Block, get_2d_sincos_pos_embed, get_2d_sincos_pos_embed_v2 +from .mr_mg.policy.model.vision_transformer import vit_base + +import numpy as np + +import IPython +e = IPython.embed + + +def reparametrize(mu, logvar): + std = logvar.div(2).exp() + eps = Variable(std.data.new(std.size()).normal_()) + return mu + std * eps + + +def get_sinusoid_encoding_table(n_position, d_hid): + def get_position_angle_vec(position): + return [position / np.power(10000, 2 * (hid_j // 2) / d_hid) for hid_j in range(d_hid)] + + sinusoid_table = np.array([get_position_angle_vec(pos_i) for pos_i in range(n_position)]) + sinusoid_table[:, 0::2] = np.sin(sinusoid_table[:, 0::2]) # dim 2i + sinusoid_table[:, 1::2] = np.cos(sinusoid_table[:, 1::2]) # dim 2i+1 + + return torch.FloatTensor(sinusoid_table).unsqueeze(0) + + +class DETRVAE(nn.Module): + """ This is the DETR module that performs object detection """ + def __init__(self, backbones, transformer, encoder, state_dim, num_queries, camera_names): + """ Initializes the model. + Parameters: + backbones: torch module of the backbone to be used. See backbone.py + transformer: torch module of the transformer architecture. See transformer.py + state_dim: robot state dimension of the environment + num_queries: number of object queries, ie detection slot. This is the maximal number of objects + DETR can detect in a single image. For COCO, we recommend 100 queries. + aux_loss: True if auxiliary decoding losses (loss at each decoder layer) are to be used. + """ + super().__init__() + self.num_queries = num_queries + self.camera_names = camera_names + self.transformer = transformer + self.encoder = encoder + hidden_dim = transformer.d_model + self.action_head = nn.Linear(hidden_dim, state_dim) + self.is_pad_head = nn.Linear(hidden_dim, 1) + self.query_embed = nn.Embedding(num_queries, hidden_dim) + if backbones is not None: + self.input_proj = nn.Conv2d(backbones[0].num_channels, hidden_dim, kernel_size=1) + self.backbones = nn.ModuleList(backbones) + self.input_proj_robot_state = nn.Linear(14, hidden_dim) + else: + # input_dim = 14 + 7 # robot_state + env_state + self.input_proj_robot_state = nn.Linear(14, hidden_dim) + self.input_proj_env_state = nn.Linear(7, hidden_dim) + self.pos = torch.nn.Embedding(2, hidden_dim) + self.backbones = None + + # encoder extra parameters + self.latent_dim = 32 # final size of latent z # TODO tune + self.cls_embed = nn.Embedding(1, hidden_dim) # extra cls token embedding + self.encoder_action_proj = nn.Linear(14, hidden_dim) # project action to embedding + self.encoder_joint_proj = nn.Linear(14, hidden_dim) # project qpos to embedding + self.latent_proj = nn.Linear(hidden_dim, self.latent_dim*2) # project hidden state to latent std, var + self.register_buffer('pos_table', get_sinusoid_encoding_table(1+1+num_queries, hidden_dim)) # [CLS], qpos, a_seq + + # decoder extra parameters + self.latent_out_proj = nn.Linear(self.latent_dim, hidden_dim) # project latent sample to embedding + self.additional_pos_embed = nn.Embedding(2, hidden_dim) # learned position embedding for proprio and latent + + # settings for next frame prediction + self.patch_size = 16 + # self.image_size = 224 + # self.img_h, self.img_w = 128, 160 + self.img_h, self.img_w = 224, 224 + self.mask_token = nn.Parameter(torch.zeros(1, 1, hidden_dim)) + # self.n_patch = (self.image_size//self.patch_size)**2 + self.k = 1 # number of next frames + self.n_patch = (self.img_h//self.patch_size)*(self.img_w//self.patch_size)*(self.k) + self.decoder_pos_embed = nn.Parameter(torch.zeros(1, self.n_patch, hidden_dim), requires_grad=False) # (1, n_patch, h) + self.patch_embed = nn.Embedding(self.n_patch, hidden_dim) + self.decoder_embed = nn.Linear(hidden_dim, hidden_dim, bias=True) + + decoder_depth = 2 # hardcode + self.decoder_blocks = nn.ModuleList([ + Block(hidden_dim, 16, 4, qkv_bias=True, qk_scale=None, norm_layer=nn.LayerNorm) + for i in range(decoder_depth)]) + + self.decoder_norm = nn.LayerNorm(hidden_dim) + self.decoder_pred = nn.Linear(hidden_dim, self.patch_size**2 * 3, bias=True) # decoder to patch + + # decoder_pos_embed = get_2d_sincos_pos_embed(self.decoder_pos_embed.shape[-1], (self.image_size//self.patch_size), cls_token=False) + decoder_pos_embed = get_2d_sincos_pos_embed_v2(self.decoder_pos_embed.shape[-1], (self.img_h//self.patch_size, self.img_w//self.patch_size)) + self.decoder_pos_embed.data.copy_(torch.from_numpy(decoder_pos_embed).float().unsqueeze(0).repeat(1,self.k,1)) + + # fwd_params = sum(p.numel() for p in self.decoder_blocks.parameters() if p.requires_grad) + + + def forward(self, qpos, image, env_state, actions=None, is_pad=None): + """ + qpos: batch, qpos_dim + image: batch, num_cam, channel, height, width + env_state: None + actions: batch, seq, action_dim + """ + is_training = actions is not None # train or val + bs, _ = qpos.shape + ### Obtain latent z from action sequence + if is_training: + # project action sequence to embedding dim, and concat with a CLS token + action_embed = self.encoder_action_proj(actions) # (bs, seq, hidden_dim) + qpos_embed = self.encoder_joint_proj(qpos) # (bs, hidden_dim) + qpos_embed = torch.unsqueeze(qpos_embed, axis=1) # (bs, 1, hidden_dim) + cls_embed = self.cls_embed.weight # (1, hidden_dim) + cls_embed = torch.unsqueeze(cls_embed, axis=0).repeat(bs, 1, 1) # (bs, 1, hidden_dim) + encoder_input = torch.cat([cls_embed, qpos_embed, action_embed], axis=1) # (bs, seq+1, hidden_dim) + encoder_input = encoder_input.permute(1, 0, 2) # (seq+1, bs, hidden_dim) + # do not mask cls token + cls_joint_is_pad = torch.full((bs, 2), False).to(qpos.device) # False: not a padding + is_pad = torch.cat([cls_joint_is_pad, is_pad], axis=1) # (bs, seq+1) + # obtain position embedding + pos_embed = self.pos_table.clone().detach() + pos_embed = pos_embed.permute(1, 0, 2) # (seq+1, 1, hidden_dim) + # query model + encoder_output = self.encoder(encoder_input, pos=pos_embed, src_key_padding_mask=is_pad) + encoder_output = encoder_output[0] # take cls output only + latent_info = self.latent_proj(encoder_output) + mu = latent_info[:, :self.latent_dim] + logvar = latent_info[:, self.latent_dim:] + latent_sample = reparametrize(mu, logvar) + latent_input = self.latent_out_proj(latent_sample) + else: + mu = logvar = None + latent_sample = torch.zeros([bs, self.latent_dim], dtype=torch.float32).to(qpos.device) + latent_input = self.latent_out_proj(latent_sample) + + if self.backbones is not None: + # Image observation features and position embeddings + all_cam_features = [] + all_cam_pos = [] + if is_training: + next_frame_images = image[:,1:] + image = image[:,:1] + for cam_id, cam_name in enumerate(self.camera_names): + features, pos = self.backbones[0](image[:, cam_id]) # HARDCODED? + features = features[0] # take the last layer feature + pos = pos[0] + all_cam_features.append(self.input_proj(features)) + all_cam_pos.append(pos) + # proprioception features + proprio_input = self.input_proj_robot_state(qpos) + # fold camera dimension into width dimension + src = torch.cat(all_cam_features, axis=3) + pos = torch.cat(all_cam_pos, axis=3) + query_embed = torch.cat([self.query_embed.weight, self.patch_embed.weight], axis=0) + hs = self.transformer(src, None, query_embed, pos, latent_input, proprio_input, self.additional_pos_embed.weight)[0] + # hs = self.transformer(src, None, self.query_embed.weight, pos, latent_input, proprio_input, self.additional_pos_embed.weight)[0] + else: + qpos = self.input_proj_robot_state(qpos) + env_state = self.input_proj_env_state(env_state) + transformer_input = torch.cat([qpos, env_state], axis=1) # seq length = 2 + hs = self.transformer(transformer_input, None, self.query_embed.weight, self.pos.weight)[0] + a_hat = self.action_head(hs[:,:self.num_queries]) + is_pad_hat = self.is_pad_head(hs[:,:self.num_queries]) + + # next frame prediction + mask_token = self.mask_token + mask_tokens = mask_token.repeat(bs, self.n_patch, 1) + mask_tokens = mask_tokens + self.decoder_pos_embed.repeat(bs, 1, 1) + + obs_pred = self.decoder_embed(hs[:,self.num_queries:]) + obs_pred_ = torch.cat([obs_pred, mask_tokens], dim=1) + for blk in self.decoder_blocks: + obs_pred_ = blk(obs_pred_) + obs_pred_ = self.decoder_norm(obs_pred_) + obs_preds = self.decoder_pred(obs_pred_) + obs_preds = obs_preds[:,self.n_patch:] + + if is_training: + # next_frame_images = image[:,1:] + next_frame_images = nn.functional.interpolate(next_frame_images.reshape(bs, self.k*3, 224, 224), size=(self.img_h, self.img_w)) + p = self.patch_size + h_p = self.img_h // p + w_p = self.img_w // p + obs_targets = next_frame_images.reshape(shape=(bs, self.k, 3, h_p, p, w_p, p)) + obs_targets = obs_targets.permute(0,1,3,5,4,6,2) + obs_targets = obs_targets.reshape(shape=(bs, h_p*w_p*self.k, (p**2)*3)) + else: + obs_targets = torch.zeros_like(obs_preds) + + return a_hat, is_pad_hat, [mu, logvar], [obs_preds, obs_targets] + + +class DETRVAE_MAE(nn.Module): + """ This is the DETR module that performs object detection """ + def __init__(self, backbones, transformer, encoder, state_dim, num_queries, camera_names): + """ Initializes the model. + Parameters: + backbones: torch module of the backbone to be used. See backbone.py + transformer: torch module of the transformer architecture. See transformer.py + state_dim: robot state dimension of the environment + num_queries: number of object queries, ie detection slot. This is the maximal number of objects + DETR can detect in a single image. For COCO, we recommend 100 queries. + aux_loss: True if auxiliary decoding losses (loss at each decoder layer) are to be used. + """ + super().__init__() + self.num_queries = num_queries + self.camera_names = camera_names + self.transformer = transformer + self.encoder = encoder + hidden_dim = transformer.d_model + self.action_head = nn.Linear(hidden_dim, state_dim) + self.is_pad_head = nn.Linear(hidden_dim, 1) + self.query_embed = nn.Embedding(num_queries, hidden_dim) + + # self.model_mae = vits.__dict__['vit_base'](patch_size=16, num_classes=0) + self.model_mae = vit_base(patch_size=16, num_classes=0) + mae_ckpt = 'checkpoints/pretrained/mae_pretrain_vit_base.pth' + checkpoint = torch.load(mae_ckpt, map_location='cpu') + self.model_mae.load_state_dict(checkpoint['model'], strict=True) + print('Load MAE pretrained model') + # for name, p in self.model_mae.named_parameters(): + # p.requires_grad = False + + if backbones is not None: + self.input_proj = nn.Conv2d(backbones[0].num_channels, hidden_dim, kernel_size=1) + self.backbones = nn.ModuleList(backbones) + self.input_proj_robot_state = nn.Linear(14, hidden_dim) + else: + # input_dim = 14 + 7 # robot_state + env_state + self.input_proj_robot_state = nn.Linear(14, hidden_dim) + self.input_proj_env_state = nn.Linear(7, hidden_dim) + self.pos = torch.nn.Embedding(2, hidden_dim) + self.backbones = None + + # encoder extra parameters + self.latent_dim = 32 # final size of latent z # TODO tune + self.cls_embed = nn.Embedding(1, hidden_dim) # extra cls token embedding + self.encoder_action_proj = nn.Linear(14, hidden_dim) # project action to embedding + self.encoder_joint_proj = nn.Linear(14, hidden_dim) # project qpos to embedding + self.latent_proj = nn.Linear(hidden_dim, self.latent_dim*2) # project hidden state to latent std, var + self.register_buffer('pos_table', get_sinusoid_encoding_table(1+1+num_queries, hidden_dim)) # [CLS], qpos, a_seq + + # decoder extra parameters + self.latent_out_proj = nn.Linear(self.latent_dim, hidden_dim) # project latent sample to embedding + self.additional_pos_embed = nn.Embedding(2, hidden_dim) # learned position embedding for proprio and latent + + # settings for next frame prediction + self.patch_size = 16 + self.img_h, self.img_w = 224, 224 + self.n_patch = (self.img_h//self.patch_size)*(self.img_w//self.patch_size) + self.mask_token = nn.Parameter(torch.zeros(1, 1, hidden_dim)) + self.decoder_pos_embed = nn.Parameter(torch.zeros(1, self.n_patch, hidden_dim), requires_grad=False) # (1, n_patch, h) + self.patch_embed = nn.Embedding(self.n_patch, hidden_dim) + self.decoder_embed = nn.Linear(hidden_dim, hidden_dim, bias=True) + + decoder_depth = 2 # hardcode + self.decoder_blocks = nn.ModuleList([ + Block(hidden_dim, 16, 4, qkv_bias=True, qk_scale=None, norm_layer=nn.LayerNorm) + for i in range(decoder_depth)]) + + self.decoder_norm = nn.LayerNorm(hidden_dim) + self.decoder_pred = nn.Linear(hidden_dim, self.patch_size**2 * 3, bias=True) # decoder to patch + + decoder_pos_embed = get_2d_sincos_pos_embed_v2(self.decoder_pos_embed.shape[-1], (self.img_h//self.patch_size, self.img_w//self.patch_size)) + self.decoder_pos_embed.data.copy_(torch.from_numpy(decoder_pos_embed).float().unsqueeze(0)) + + + def forward(self, qpos, image, env_state, actions=None, is_pad=None): + """ + qpos: batch, qpos_dim + image: batch, num_cam, channel, height, width + env_state: None + actions: batch, seq, action_dim + """ + is_training = actions is not None # train or val + bs, _ = qpos.shape + ### Obtain latent z from action sequence + if is_training: + # project action sequence to embedding dim, and concat with a CLS token + action_embed = self.encoder_action_proj(actions) # (bs, seq, hidden_dim) + qpos_embed = self.encoder_joint_proj(qpos) # (bs, hidden_dim) + qpos_embed = torch.unsqueeze(qpos_embed, axis=1) # (bs, 1, hidden_dim) + cls_embed = self.cls_embed.weight # (1, hidden_dim) + cls_embed = torch.unsqueeze(cls_embed, axis=0).repeat(bs, 1, 1) # (bs, 1, hidden_dim) + encoder_input = torch.cat([cls_embed, qpos_embed, action_embed], axis=1) # (bs, seq+1, hidden_dim) + encoder_input = encoder_input.permute(1, 0, 2) # (seq+1, bs, hidden_dim) + # do not mask cls token + cls_joint_is_pad = torch.full((bs, 2), False).to(qpos.device) # False: not a padding + is_pad = torch.cat([cls_joint_is_pad, is_pad], axis=1) # (bs, seq+1) + # obtain position embedding + pos_embed = self.pos_table.clone().detach() + pos_embed = pos_embed.permute(1, 0, 2) # (seq+1, 1, hidden_dim) + # query model + encoder_output = self.encoder(encoder_input, pos=pos_embed, src_key_padding_mask=is_pad) + encoder_output = encoder_output[0] # take cls output only + latent_info = self.latent_proj(encoder_output) + mu = latent_info[:, :self.latent_dim] + logvar = latent_info[:, self.latent_dim:] + latent_sample = reparametrize(mu, logvar) + latent_input = self.latent_out_proj(latent_sample) + else: + mu = logvar = None + latent_sample = torch.zeros([bs, self.latent_dim], dtype=torch.float32).to(qpos.device) + latent_input = self.latent_out_proj(latent_sample) + + if self.backbones is not None: + # Image observation features and position embeddings + all_cam_features = [] + all_cam_pos = [] + if is_training: + next_frame_images = image[:,1:] + image = image[:,:1] + for cam_id, cam_name in enumerate(self.camera_names): + # features, pos = self.backbones[0](image[:, cam_id]) # HARDCODED + # features = features[0] # take the last layer feature + # pos = pos[0] + # all_cam_features.append(self.input_proj(features)) + # all_cam_pos.append(pos) + + obs_embedings, patch_embedings, pos_mae = self.model_mae(image[:,cam_id]) + + # proprioception features + proprio_input = self.input_proj_robot_state(qpos) + # fold camera dimension into width dimension + # src = torch.cat(all_cam_features, axis=3) + # pos = torch.cat(all_cam_pos, axis=3) + query_embed = torch.cat([self.query_embed.weight, self.patch_embed.weight], axis=0) + hs = self.transformer(patch_embedings, None, query_embed, pos_mae[0,1:], latent_input, proprio_input, self.additional_pos_embed.weight)[0] + # hs = self.transformer(src, None, self.query_embed.weight, pos, latent_input, proprio_input, self.additional_pos_embed.weight)[0] + else: + qpos = self.input_proj_robot_state(qpos) + env_state = self.input_proj_env_state(env_state) + transformer_input = torch.cat([qpos, env_state], axis=1) # seq length = 2 + hs = self.transformer(transformer_input, None, self.query_embed.weight, self.pos.weight)[0] + a_hat = self.action_head(hs[:,:self.num_queries]) + is_pad_hat = self.is_pad_head(hs[:,:self.num_queries]) + + # next frame prediction + mask_token = self.mask_token + mask_tokens = mask_token.repeat(bs, self.n_patch, 1) + mask_tokens = mask_tokens + self.decoder_pos_embed.repeat(bs, 1, 1) + + obs_pred = self.decoder_embed(hs[:,self.num_queries:]) + obs_pred_ = torch.cat([obs_pred, mask_tokens], dim=1) + for blk in self.decoder_blocks: + obs_pred_ = blk(obs_pred_) + obs_pred_ = self.decoder_norm(obs_pred_) + obs_preds = self.decoder_pred(obs_pred_) + obs_preds = obs_preds[:,self.n_patch:] + + if is_training: + # next_frame_images = image[:,1:] + # next_frame_images = nn.functional.interpolate(next_frame_images[:,0], size=(self.img_h, self.img_w)) + next_frame_images = next_frame_images[:,0] + p = self.patch_size + h_p = self.img_h // p + w_p = self.img_w // p + obs_targets = next_frame_images.reshape(shape=(bs, 3, h_p, p, w_p, p)) + obs_targets = obs_targets.permute(0,2,4,3,5,1) + obs_targets = obs_targets.reshape(shape=(bs, h_p*w_p, (p**2)*3)) + else: + obs_targets = torch.zeros_like(obs_preds) + + + return a_hat, is_pad_hat, [mu, logvar], [obs_preds, obs_targets] + + +class CNNMLP(nn.Module): + def __init__(self, backbones, state_dim, camera_names): + """ Initializes the model. + Parameters: + backbones: torch module of the backbone to be used. See backbone.py + transformer: torch module of the transformer architecture. See transformer.py + state_dim: robot state dimension of the environment + num_queries: number of object queries, ie detection slot. This is the maximal number of objects + DETR can detect in a single image. For COCO, we recommend 100 queries. + aux_loss: True if auxiliary decoding losses (loss at each decoder layer) are to be used. + """ + super().__init__() + self.camera_names = camera_names + self.action_head = nn.Linear(1000, state_dim) # TODO add more + if backbones is not None: + self.backbones = nn.ModuleList(backbones) + backbone_down_projs = [] + for backbone in backbones: + down_proj = nn.Sequential( + nn.Conv2d(backbone.num_channels, 128, kernel_size=5), + nn.Conv2d(128, 64, kernel_size=5), + nn.Conv2d(64, 32, kernel_size=5) + ) + backbone_down_projs.append(down_proj) + self.backbone_down_projs = nn.ModuleList(backbone_down_projs) + + mlp_in_dim = 768 * len(backbones) + 14 + self.mlp = mlp(input_dim=mlp_in_dim, hidden_dim=1024, output_dim=14, hidden_depth=2) + else: + raise NotImplementedError + + def forward(self, qpos, image, env_state, actions=None): + """ + qpos: batch, qpos_dim + image: batch, num_cam, channel, height, width + env_state: None + actions: batch, seq, action_dim + """ + is_training = actions is not None # train or val + bs, _ = qpos.shape + # Image observation features and position embeddings + all_cam_features = [] + for cam_id, cam_name in enumerate(self.camera_names): + features, pos = self.backbones[cam_id](image[:, cam_id]) + features = features[0] # take the last layer feature + pos = pos[0] # not used + all_cam_features.append(self.backbone_down_projs[cam_id](features)) + # flatten everything + flattened_features = [] + for cam_feature in all_cam_features: + flattened_features.append(cam_feature.reshape([bs, -1])) + flattened_features = torch.cat(flattened_features, axis=1) # 768 each + features = torch.cat([flattened_features, qpos], axis=1) # qpos: 14 + a_hat = self.mlp(features) + return a_hat + + +def mlp(input_dim, hidden_dim, output_dim, hidden_depth): + if hidden_depth == 0: + mods = [nn.Linear(input_dim, output_dim)] + else: + mods = [nn.Linear(input_dim, hidden_dim), nn.ReLU(inplace=True)] + for i in range(hidden_depth - 1): + mods += [nn.Linear(hidden_dim, hidden_dim), nn.ReLU(inplace=True)] + mods.append(nn.Linear(hidden_dim, output_dim)) + trunk = nn.Sequential(*mods) + return trunk + + +def build_encoder(args): + d_model = args.hidden_dim # 256 + dropout = args.dropout # 0.1 + nhead = args.nheads # 8 + dim_feedforward = args.dim_feedforward # 2048 + num_encoder_layers = args.enc_layers # 4 # TODO shared with VAE decoder + normalize_before = args.pre_norm # False + activation = "relu" + + encoder_layer = TransformerEncoderLayer(d_model, nhead, dim_feedforward, + dropout, activation, normalize_before) + encoder_norm = nn.LayerNorm(d_model) if normalize_before else None + encoder = TransformerEncoder(encoder_layer, num_encoder_layers, encoder_norm) + + return encoder + + +def build(args): + state_dim = 14 # TODO hardcode + + # From state + # backbone = None # from state for now, no need for conv nets + # From image + backbones = [] + backbone = build_backbone(args) + backbones.append(backbone) + + transformer = build_transformer(args) + + encoder = build_encoder(args) + + if not args.mae: + model = DETRVAE( + backbones, + transformer, + encoder, + state_dim=state_dim, + num_queries=args.num_queries, + camera_names=args.camera_names, + ) + else: + model = DETRVAE_MAE( + backbones, + transformer, + encoder, + state_dim=state_dim, + num_queries=args.num_queries, + camera_names=args.camera_names, + ) + + n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad) + print("number of parameters: %.2fM" % (n_parameters/1e6,)) + + return model + +def build_cnnmlp(args): + state_dim = 14 # TODO hardcode + + # From state + # backbone = None # from state for now, no need for conv nets + # From image + backbones = [] + for _ in args.camera_names: + backbone = build_backbone(args) + backbones.append(backbone) + + model = CNNMLP( + backbones, + state_dim=state_dim, + camera_names=args.camera_names, + ) + + n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad) + print("number of parameters: %.2fM" % (n_parameters/1e6,)) + + return model + diff --git a/ACT_DP_multitask/detr/models/mask_former/__init__.py b/ACT_DP_multitask/detr/models/mask_former/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..0ef4d245cfcf49cf13a5195b60d57d92e0af016e --- /dev/null +++ b/ACT_DP_multitask/detr/models/mask_former/__init__.py @@ -0,0 +1,19 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +from . import data # register all new datasets +from . import modeling + +# config +from .config import add_mask_former_config + +# dataset loading +from .data.dataset_mappers.detr_panoptic_dataset_mapper import DETRPanopticDatasetMapper +from .data.dataset_mappers.mask_former_panoptic_dataset_mapper import ( + MaskFormerPanopticDatasetMapper, +) +from .data.dataset_mappers.mask_former_semantic_dataset_mapper import ( + MaskFormerSemanticDatasetMapper, +) + +# models +from .mask_former_model import MaskFormer +from .test_time_augmentation import SemanticSegmentorWithTTA diff --git a/ACT_DP_multitask/detr/models/mask_former/__pycache__/__init__.cpython-38.pyc b/ACT_DP_multitask/detr/models/mask_former/__pycache__/__init__.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..dc4c5140e564d8b83ff0b8cd8605afbdf30fa5ca Binary files /dev/null and b/ACT_DP_multitask/detr/models/mask_former/__pycache__/__init__.cpython-38.pyc differ diff --git a/ACT_DP_multitask/detr/models/mask_former/config.py b/ACT_DP_multitask/detr/models/mask_former/config.py new file mode 100644 index 0000000000000000000000000000000000000000..90ba988efd4bad35106b44f2c17de4bf71382a04 --- /dev/null +++ b/ACT_DP_multitask/detr/models/mask_former/config.py @@ -0,0 +1,85 @@ +# -*- coding: utf-8 -*- +# Copyright (c) Facebook, Inc. and its affiliates. +from detectron2.config import CfgNode as CN + + +def add_mask_former_config(cfg): + """ + Add config for MASK_FORMER. + """ + # data config + # select the dataset mapper + cfg.INPUT.DATASET_MAPPER_NAME = "mask_former_semantic" + # Color augmentation + cfg.INPUT.COLOR_AUG_SSD = False + # We retry random cropping until no single category in semantic segmentation GT occupies more + # than `SINGLE_CATEGORY_MAX_AREA` part of the crop. + cfg.INPUT.CROP.SINGLE_CATEGORY_MAX_AREA = 1.0 + # Pad image and segmentation GT in dataset mapper. + cfg.INPUT.SIZE_DIVISIBILITY = -1 + + # solver config + # weight decay on embedding + cfg.SOLVER.WEIGHT_DECAY_EMBED = 0.0 + # optimizer + cfg.SOLVER.OPTIMIZER = "ADAMW" + cfg.SOLVER.BACKBONE_MULTIPLIER = 0.1 + + # mask_former model config + cfg.MODEL.MASK_FORMER = CN() + + # loss + cfg.MODEL.MASK_FORMER.DEEP_SUPERVISION = True + cfg.MODEL.MASK_FORMER.NO_OBJECT_WEIGHT = 0.1 + cfg.MODEL.MASK_FORMER.DICE_WEIGHT = 1.0 + cfg.MODEL.MASK_FORMER.MASK_WEIGHT = 20.0 + + # transformer config + cfg.MODEL.MASK_FORMER.NHEADS = 8 + cfg.MODEL.MASK_FORMER.DROPOUT = 0.1 + cfg.MODEL.MASK_FORMER.DIM_FEEDFORWARD = 2048 + cfg.MODEL.MASK_FORMER.ENC_LAYERS = 0 + cfg.MODEL.MASK_FORMER.DEC_LAYERS = 6 + cfg.MODEL.MASK_FORMER.PRE_NORM = False + + cfg.MODEL.MASK_FORMER.HIDDEN_DIM = 256 + cfg.MODEL.MASK_FORMER.NUM_OBJECT_QUERIES = 100 + + cfg.MODEL.MASK_FORMER.TRANSFORMER_IN_FEATURE = "res5" + cfg.MODEL.MASK_FORMER.ENFORCE_INPUT_PROJ = False + + # mask_former inference config + cfg.MODEL.MASK_FORMER.TEST = CN() + cfg.MODEL.MASK_FORMER.TEST.PANOPTIC_ON = False + cfg.MODEL.MASK_FORMER.TEST.OBJECT_MASK_THRESHOLD = 0.0 + cfg.MODEL.MASK_FORMER.TEST.OVERLAP_THRESHOLD = 0.0 + cfg.MODEL.MASK_FORMER.TEST.SEM_SEG_POSTPROCESSING_BEFORE_INFERENCE = False + + # Sometimes `backbone.size_divisibility` is set to 0 for some backbone (e.g. ResNet) + # you can use this config to override + cfg.MODEL.MASK_FORMER.SIZE_DIVISIBILITY = 32 + + # pixel decoder config + cfg.MODEL.SEM_SEG_HEAD.MASK_DIM = 256 + # adding transformer in pixel decoder + cfg.MODEL.SEM_SEG_HEAD.TRANSFORMER_ENC_LAYERS = 0 + # pixel decoder + cfg.MODEL.SEM_SEG_HEAD.PIXEL_DECODER_NAME = "BasePixelDecoder" + + # swin transformer backbone + cfg.MODEL.SWIN = CN() + cfg.MODEL.SWIN.PRETRAIN_IMG_SIZE = 224 + cfg.MODEL.SWIN.PATCH_SIZE = 4 + cfg.MODEL.SWIN.EMBED_DIM = 96 + cfg.MODEL.SWIN.DEPTHS = [2, 2, 6, 2] + cfg.MODEL.SWIN.NUM_HEADS = [3, 6, 12, 24] + cfg.MODEL.SWIN.WINDOW_SIZE = 7 + cfg.MODEL.SWIN.MLP_RATIO = 4.0 + cfg.MODEL.SWIN.QKV_BIAS = True + cfg.MODEL.SWIN.QK_SCALE = None + cfg.MODEL.SWIN.DROP_RATE = 0.0 + cfg.MODEL.SWIN.ATTN_DROP_RATE = 0.0 + cfg.MODEL.SWIN.DROP_PATH_RATE = 0.3 + cfg.MODEL.SWIN.APE = False + cfg.MODEL.SWIN.PATCH_NORM = True + cfg.MODEL.SWIN.OUT_FEATURES = ["res2", "res3", "res4", "res5"] diff --git a/ACT_DP_multitask/detr/models/mask_former/mask_former_model.py b/ACT_DP_multitask/detr/models/mask_former/mask_former_model.py new file mode 100644 index 0000000000000000000000000000000000000000..02a46b97ef9066eb7cf433821c0366cffc9bb57e --- /dev/null +++ b/ACT_DP_multitask/detr/models/mask_former/mask_former_model.py @@ -0,0 +1,304 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +from typing import Tuple + +import torch +from torch import nn +from torch.nn import functional as F + +from detectron2.config import configurable +from detectron2.data import MetadataCatalog +from detectron2.modeling import META_ARCH_REGISTRY, build_backbone, build_sem_seg_head +from detectron2.modeling.backbone import Backbone +from detectron2.modeling.postprocessing import sem_seg_postprocess +from detectron2.structures import ImageList + +from .modeling.criterion import SetCriterion +from .modeling.matcher import HungarianMatcher + + +@META_ARCH_REGISTRY.register() +class MaskFormer(nn.Module): + """ + Main class for mask classification semantic segmentation architectures. + """ + + @configurable + def __init__( + self, + *, + backbone: Backbone, + sem_seg_head: nn.Module, + criterion: nn.Module, + num_queries: int, + panoptic_on: bool, + object_mask_threshold: float, + overlap_threshold: float, + metadata, + size_divisibility: int, + sem_seg_postprocess_before_inference: bool, + pixel_mean: Tuple[float], + pixel_std: Tuple[float], + ): + """ + Args: + backbone: a backbone module, must follow detectron2's backbone interface + sem_seg_head: a module that predicts semantic segmentation from backbone features + criterion: a module that defines the loss + num_queries: int, number of queries + panoptic_on: bool, whether to output panoptic segmentation prediction + object_mask_threshold: float, threshold to filter query based on classification score + for panoptic segmentation inference + overlap_threshold: overlap threshold used in general inference for panoptic segmentation + metadata: dataset meta, get `thing` and `stuff` category names for panoptic + segmentation inference + size_divisibility: Some backbones require the input height and width to be divisible by a + specific integer. We can use this to override such requirement. + sem_seg_postprocess_before_inference: whether to resize the prediction back + to original input size before semantic segmentation inference or after. + For high-resolution dataset like Mapillary, resizing predictions before + inference will cause OOM error. + pixel_mean, pixel_std: list or tuple with #channels element, representing + the per-channel mean and std to be used to normalize the input image + """ + super().__init__() + self.backbone = backbone + self.sem_seg_head = sem_seg_head + self.criterion = criterion + self.num_queries = num_queries + self.overlap_threshold = overlap_threshold + self.panoptic_on = panoptic_on + self.object_mask_threshold = object_mask_threshold + self.metadata = metadata + if size_divisibility < 0: + # use backbone size_divisibility if not set + size_divisibility = self.backbone.size_divisibility + self.size_divisibility = size_divisibility + self.sem_seg_postprocess_before_inference = sem_seg_postprocess_before_inference + self.register_buffer("pixel_mean", torch.Tensor(pixel_mean).view(-1, 1, 1), False) + self.register_buffer("pixel_std", torch.Tensor(pixel_std).view(-1, 1, 1), False) + + @classmethod + def from_config(cls, cfg): + backbone = build_backbone(cfg) + sem_seg_head = build_sem_seg_head(cfg, backbone.output_shape()) + + # Loss parameters: + deep_supervision = cfg.MODEL.MASK_FORMER.DEEP_SUPERVISION + no_object_weight = cfg.MODEL.MASK_FORMER.NO_OBJECT_WEIGHT + dice_weight = cfg.MODEL.MASK_FORMER.DICE_WEIGHT + mask_weight = cfg.MODEL.MASK_FORMER.MASK_WEIGHT + + # building criterion + matcher = HungarianMatcher( + cost_class=1, + cost_mask=mask_weight, + cost_dice=dice_weight, + ) + + weight_dict = {"loss_ce": 1, "loss_mask": mask_weight, "loss_dice": dice_weight} + if deep_supervision: + dec_layers = cfg.MODEL.MASK_FORMER.DEC_LAYERS + aux_weight_dict = {} + for i in range(dec_layers - 1): + aux_weight_dict.update({k + f"_{i}": v for k, v in weight_dict.items()}) + weight_dict.update(aux_weight_dict) + + losses = ["labels", "masks"] + + criterion = SetCriterion( + sem_seg_head.num_classes, + matcher=matcher, + weight_dict=weight_dict, + eos_coef=no_object_weight, + losses=losses, + ) + + return { + "backbone": backbone, + "sem_seg_head": sem_seg_head, + "criterion": criterion, + "num_queries": cfg.MODEL.MASK_FORMER.NUM_OBJECT_QUERIES, + "panoptic_on": cfg.MODEL.MASK_FORMER.TEST.PANOPTIC_ON, + "object_mask_threshold": cfg.MODEL.MASK_FORMER.TEST.OBJECT_MASK_THRESHOLD, + "overlap_threshold": cfg.MODEL.MASK_FORMER.TEST.OVERLAP_THRESHOLD, + "metadata": MetadataCatalog.get(cfg.DATASETS.TRAIN[0]), + "size_divisibility": cfg.MODEL.MASK_FORMER.SIZE_DIVISIBILITY, + "sem_seg_postprocess_before_inference": ( + cfg.MODEL.MASK_FORMER.TEST.SEM_SEG_POSTPROCESSING_BEFORE_INFERENCE + or cfg.MODEL.MASK_FORMER.TEST.PANOPTIC_ON + ), + "pixel_mean": cfg.MODEL.PIXEL_MEAN, + "pixel_std": cfg.MODEL.PIXEL_STD, + } + + @property + def device(self): + return self.pixel_mean.device + + def forward(self, batched_inputs): + """ + Args: + batched_inputs: a list, batched outputs of :class:`DatasetMapper`. + Each item in the list contains the inputs for one image. + For now, each item in the list is a dict that contains: + * "image": Tensor, image in (C, H, W) format. + * "instances": per-region ground truth + * Other information that's included in the original dicts, such as: + "height", "width" (int): the output resolution of the model (may be different + from input resolution), used in inference. + Returns: + list[dict]: + each dict has the results for one image. The dict contains the following keys: + + * "sem_seg": + A Tensor that represents the + per-pixel segmentation prediced by the head. + The prediction has shape KxHxW that represents the logits of + each class for each pixel. + * "panoptic_seg": + A tuple that represent panoptic output + panoptic_seg (Tensor): of shape (height, width) where the values are ids for each segment. + segments_info (list[dict]): Describe each segment in `panoptic_seg`. + Each dict contains keys "id", "category_id", "isthing". + """ + images = [x["image"].to(self.device) for x in batched_inputs] + images = [(x - self.pixel_mean) / self.pixel_std for x in images] + images = ImageList.from_tensors(images, self.size_divisibility) + + features = self.backbone(images.tensor) + outputs = self.sem_seg_head(features) + + if self.training: + # mask classification target + if "instances" in batched_inputs[0]: + gt_instances = [x["instances"].to(self.device) for x in batched_inputs] + targets = self.prepare_targets(gt_instances, images) + else: + targets = None + + # bipartite matching-based loss + losses = self.criterion(outputs, targets) + + for k in list(losses.keys()): + if k in self.criterion.weight_dict: + losses[k] *= self.criterion.weight_dict[k] + else: + # remove this loss if not specified in `weight_dict` + losses.pop(k) + + return losses + else: + mask_cls_results = outputs["pred_logits"] + mask_pred_results = outputs["pred_masks"] + # upsample masks + mask_pred_results = F.interpolate( + mask_pred_results, + size=(images.tensor.shape[-2], images.tensor.shape[-1]), + mode="bilinear", + align_corners=False, + ) + + processed_results = [] + for mask_cls_result, mask_pred_result, input_per_image, image_size in zip( + mask_cls_results, mask_pred_results, batched_inputs, images.image_sizes + ): + height = input_per_image.get("height", image_size[0]) + width = input_per_image.get("width", image_size[1]) + + if self.sem_seg_postprocess_before_inference: + mask_pred_result = sem_seg_postprocess( + mask_pred_result, image_size, height, width + ) + + # semantic segmentation inference + r = self.semantic_inference(mask_cls_result, mask_pred_result) + if not self.sem_seg_postprocess_before_inference: + r = sem_seg_postprocess(r, image_size, height, width) + processed_results.append({"sem_seg": r}) + + # panoptic segmentation inference + if self.panoptic_on: + panoptic_r = self.panoptic_inference(mask_cls_result, mask_pred_result) + processed_results[-1]["panoptic_seg"] = panoptic_r + + return processed_results + + def prepare_targets(self, targets, images): + h, w = images.tensor.shape[-2:] + new_targets = [] + for targets_per_image in targets: + # pad gt + gt_masks = targets_per_image.gt_masks + padded_masks = torch.zeros((gt_masks.shape[0], h, w), dtype=gt_masks.dtype, device=gt_masks.device) + padded_masks[:, : gt_masks.shape[1], : gt_masks.shape[2]] = gt_masks + new_targets.append( + { + "labels": targets_per_image.gt_classes, + "masks": padded_masks, + } + ) + return new_targets + + def semantic_inference(self, mask_cls, mask_pred): + mask_cls = F.softmax(mask_cls, dim=-1)[..., :-1] + mask_pred = mask_pred.sigmoid() + semseg = torch.einsum("qc,qhw->chw", mask_cls, mask_pred) + return semseg + + def panoptic_inference(self, mask_cls, mask_pred): + scores, labels = F.softmax(mask_cls, dim=-1).max(-1) + mask_pred = mask_pred.sigmoid() + + keep = labels.ne(self.sem_seg_head.num_classes) & (scores > self.object_mask_threshold) + cur_scores = scores[keep] + cur_classes = labels[keep] + cur_masks = mask_pred[keep] + cur_mask_cls = mask_cls[keep] + cur_mask_cls = cur_mask_cls[:, :-1] + + cur_prob_masks = cur_scores.view(-1, 1, 1) * cur_masks + + h, w = cur_masks.shape[-2:] + panoptic_seg = torch.zeros((h, w), dtype=torch.int32, device=cur_masks.device) + segments_info = [] + + current_segment_id = 0 + + if cur_masks.shape[0] == 0: + # We didn't detect any mask :( + return panoptic_seg, segments_info + else: + # take argmax + cur_mask_ids = cur_prob_masks.argmax(0) + stuff_memory_list = {} + for k in range(cur_classes.shape[0]): + pred_class = cur_classes[k].item() + isthing = pred_class in self.metadata.thing_dataset_id_to_contiguous_id.values() + mask = cur_mask_ids == k + mask_area = mask.sum().item() + original_area = (cur_masks[k] >= 0.5).sum().item() + + if mask_area > 0 and original_area > 0: + if mask_area / original_area < self.overlap_threshold: + continue + + # merge stuff regions + if not isthing: + if int(pred_class) in stuff_memory_list.keys(): + panoptic_seg[mask] = stuff_memory_list[int(pred_class)] + continue + else: + stuff_memory_list[int(pred_class)] = current_segment_id + 1 + + current_segment_id += 1 + panoptic_seg[mask] = current_segment_id + + segments_info.append( + { + "id": current_segment_id, + "isthing": bool(isthing), + "category_id": int(pred_class), + } + ) + + return panoptic_seg, segments_info diff --git a/ACT_DP_multitask/detr/models/mask_former/modeling/__init__.py b/ACT_DP_multitask/detr/models/mask_former/modeling/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..4e7eb4843fdcb71329842de6750fcb721a4a4e1f --- /dev/null +++ b/ACT_DP_multitask/detr/models/mask_former/modeling/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +from .backbone.swin import D2SwinTransformer +from .heads.mask_former_head import MaskFormerHead +from .heads.per_pixel_baseline import PerPixelBaselineHead, PerPixelBaselinePlusHead +from .heads.pixel_decoder import BasePixelDecoder diff --git a/ACT_DP_multitask/detr/models/mask_former/modeling/backbone/__init__.py b/ACT_DP_multitask/detr/models/mask_former/modeling/backbone/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..9020c2df23e2af280b7bb168b996ae9eaf312eb8 --- /dev/null +++ b/ACT_DP_multitask/detr/models/mask_former/modeling/backbone/__init__.py @@ -0,0 +1 @@ +# Copyright (c) Facebook, Inc. and its affiliates. diff --git a/ACT_DP_multitask/detr/models/mask_former/modeling/backbone/swin.py b/ACT_DP_multitask/detr/models/mask_former/modeling/backbone/swin.py new file mode 100644 index 0000000000000000000000000000000000000000..86a7762f025e7e702fc13d03d6a219a2e6c9acb8 --- /dev/null +++ b/ACT_DP_multitask/detr/models/mask_former/modeling/backbone/swin.py @@ -0,0 +1,768 @@ +# -------------------------------------------------------- +# Swin Transformer +# Copyright (c) 2021 Microsoft +# Licensed under The MIT License [see LICENSE for details] +# Written by Ze Liu, Yutong Lin, Yixuan Wei +# -------------------------------------------------------- + +# Copyright (c) Facebook, Inc. and its affiliates. +# Modified by Bowen Cheng from https://github.com/SwinTransformer/Swin-Transformer-Semantic-Segmentation/blob/main/mmseg/models/backbones/swin_transformer.py + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.utils.checkpoint as checkpoint +from timm.models.layers import DropPath, to_2tuple, trunc_normal_ + +from detectron2.modeling import BACKBONE_REGISTRY, Backbone, ShapeSpec + + +class Mlp(nn.Module): + """Multilayer perceptron.""" + + def __init__( + self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.0 + ): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.fc1 = nn.Linear(in_features, hidden_features) + self.act = act_layer() + self.fc2 = nn.Linear(hidden_features, out_features) + self.drop = nn.Dropout(drop) + + def forward(self, x): + x = self.fc1(x) + x = self.act(x) + x = self.drop(x) + x = self.fc2(x) + x = self.drop(x) + return x + + +def window_partition(x, window_size): + """ + Args: + x: (B, H, W, C) + window_size (int): window size + Returns: + windows: (num_windows*B, window_size, window_size, C) + """ + B, H, W, C = x.shape + x = x.view(B, H // window_size, window_size, W // window_size, window_size, C) + windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C) + return windows + + +def window_reverse(windows, window_size, H, W): + """ + Args: + windows: (num_windows*B, window_size, window_size, C) + window_size (int): Window size + H (int): Height of image + W (int): Width of image + Returns: + x: (B, H, W, C) + """ + B = int(windows.shape[0] / (H * W / window_size / window_size)) + x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1) + x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1) + return x + + +class WindowAttention(nn.Module): + """Window based multi-head self attention (W-MSA) module with relative position bias. + It supports both of shifted and non-shifted window. + Args: + dim (int): Number of input channels. + window_size (tuple[int]): The height and width of the window. + num_heads (int): Number of attention heads. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set + attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0 + proj_drop (float, optional): Dropout ratio of output. Default: 0.0 + """ + + def __init__( + self, + dim, + window_size, + num_heads, + qkv_bias=True, + qk_scale=None, + attn_drop=0.0, + proj_drop=0.0, + ): + + super().__init__() + self.dim = dim + self.window_size = window_size # Wh, Ww + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = qk_scale or head_dim ** -0.5 + + # define a parameter table of relative position bias + self.relative_position_bias_table = nn.Parameter( + torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads) + ) # 2*Wh-1 * 2*Ww-1, nH + + # get pair-wise relative position index for each token inside the window + coords_h = torch.arange(self.window_size[0]) + coords_w = torch.arange(self.window_size[1]) + coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww + coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww + relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww + relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2 + relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0 + relative_coords[:, :, 1] += self.window_size[1] - 1 + relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1 + relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww + self.register_buffer("relative_position_index", relative_position_index) + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + + trunc_normal_(self.relative_position_bias_table, std=0.02) + self.softmax = nn.Softmax(dim=-1) + + def forward(self, x, mask=None): + """Forward function. + Args: + x: input features with shape of (num_windows*B, N, C) + mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None + """ + B_, N, C = x.shape + qkv = ( + self.qkv(x) + .reshape(B_, N, 3, self.num_heads, C // self.num_heads) + .permute(2, 0, 3, 1, 4) + ) + q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple) + + q = q * self.scale + attn = q @ k.transpose(-2, -1) + + relative_position_bias = self.relative_position_bias_table[ + self.relative_position_index.view(-1) + ].view( + self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1 + ) # Wh*Ww,Wh*Ww,nH + relative_position_bias = relative_position_bias.permute( + 2, 0, 1 + ).contiguous() # nH, Wh*Ww, Wh*Ww + attn = attn + relative_position_bias.unsqueeze(0) + + if mask is not None: + nW = mask.shape[0] + attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0) + attn = attn.view(-1, self.num_heads, N, N) + attn = self.softmax(attn) + else: + attn = self.softmax(attn) + + attn = self.attn_drop(attn) + + x = (attn @ v).transpose(1, 2).reshape(B_, N, C) + x = self.proj(x) + x = self.proj_drop(x) + return x + + +class SwinTransformerBlock(nn.Module): + """Swin Transformer Block. + Args: + dim (int): Number of input channels. + num_heads (int): Number of attention heads. + window_size (int): Window size. + shift_size (int): Shift size for SW-MSA. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. + drop (float, optional): Dropout rate. Default: 0.0 + attn_drop (float, optional): Attention dropout rate. Default: 0.0 + drop_path (float, optional): Stochastic depth rate. Default: 0.0 + act_layer (nn.Module, optional): Activation layer. Default: nn.GELU + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + """ + + def __init__( + self, + dim, + num_heads, + window_size=7, + shift_size=0, + mlp_ratio=4.0, + qkv_bias=True, + qk_scale=None, + drop=0.0, + attn_drop=0.0, + drop_path=0.0, + act_layer=nn.GELU, + norm_layer=nn.LayerNorm, + ): + super().__init__() + self.dim = dim + self.num_heads = num_heads + self.window_size = window_size + self.shift_size = shift_size + self.mlp_ratio = mlp_ratio + assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size" + + self.norm1 = norm_layer(dim) + self.attn = WindowAttention( + dim, + window_size=to_2tuple(self.window_size), + num_heads=num_heads, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + attn_drop=attn_drop, + proj_drop=drop, + ) + + self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = Mlp( + in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop + ) + + self.H = None + self.W = None + + def forward(self, x, mask_matrix): + """Forward function. + Args: + x: Input feature, tensor size (B, H*W, C). + H, W: Spatial resolution of the input feature. + mask_matrix: Attention mask for cyclic shift. + """ + B, L, C = x.shape + H, W = self.H, self.W + assert L == H * W, "input feature has wrong size" + + shortcut = x + x = self.norm1(x) + x = x.view(B, H, W, C) + + # pad feature maps to multiples of window size + pad_l = pad_t = 0 + pad_r = (self.window_size - W % self.window_size) % self.window_size + pad_b = (self.window_size - H % self.window_size) % self.window_size + x = F.pad(x, (0, 0, pad_l, pad_r, pad_t, pad_b)) + _, Hp, Wp, _ = x.shape + + # cyclic shift + if self.shift_size > 0: + shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2)) + attn_mask = mask_matrix + else: + shifted_x = x + attn_mask = None + + # partition windows + x_windows = window_partition( + shifted_x, self.window_size + ) # nW*B, window_size, window_size, C + x_windows = x_windows.view( + -1, self.window_size * self.window_size, C + ) # nW*B, window_size*window_size, C + + # W-MSA/SW-MSA + attn_windows = self.attn(x_windows, mask=attn_mask) # nW*B, window_size*window_size, C + + # merge windows + attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C) + shifted_x = window_reverse(attn_windows, self.window_size, Hp, Wp) # B H' W' C + + # reverse cyclic shift + if self.shift_size > 0: + x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2)) + else: + x = shifted_x + + if pad_r > 0 or pad_b > 0: + x = x[:, :H, :W, :].contiguous() + + x = x.view(B, H * W, C) + + # FFN + x = shortcut + self.drop_path(x) + x = x + self.drop_path(self.mlp(self.norm2(x))) + + return x + + +class PatchMerging(nn.Module): + """Patch Merging Layer + Args: + dim (int): Number of input channels. + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + """ + + def __init__(self, dim, norm_layer=nn.LayerNorm): + super().__init__() + self.dim = dim + self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False) + self.norm = norm_layer(4 * dim) + + def forward(self, x, H, W): + """Forward function. + Args: + x: Input feature, tensor size (B, H*W, C). + H, W: Spatial resolution of the input feature. + """ + B, L, C = x.shape + assert L == H * W, "input feature has wrong size" + + x = x.view(B, H, W, C) + + # padding + pad_input = (H % 2 == 1) or (W % 2 == 1) + if pad_input: + x = F.pad(x, (0, 0, 0, W % 2, 0, H % 2)) + + x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C + x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C + x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C + x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C + x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C + x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C + + x = self.norm(x) + x = self.reduction(x) + + return x + + +class BasicLayer(nn.Module): + """A basic Swin Transformer layer for one stage. + Args: + dim (int): Number of feature channels + depth (int): Depths of this stage. + num_heads (int): Number of attention head. + window_size (int): Local window size. Default: 7. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. + drop (float, optional): Dropout rate. Default: 0.0 + attn_drop (float, optional): Attention dropout rate. Default: 0.0 + drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0 + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None + use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False. + """ + + def __init__( + self, + dim, + depth, + num_heads, + window_size=7, + mlp_ratio=4.0, + qkv_bias=True, + qk_scale=None, + drop=0.0, + attn_drop=0.0, + drop_path=0.0, + norm_layer=nn.LayerNorm, + downsample=None, + use_checkpoint=False, + ): + super().__init__() + self.window_size = window_size + self.shift_size = window_size // 2 + self.depth = depth + self.use_checkpoint = use_checkpoint + + # build blocks + self.blocks = nn.ModuleList( + [ + SwinTransformerBlock( + dim=dim, + num_heads=num_heads, + window_size=window_size, + shift_size=0 if (i % 2 == 0) else window_size // 2, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + drop=drop, + attn_drop=attn_drop, + drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path, + norm_layer=norm_layer, + ) + for i in range(depth) + ] + ) + + # patch merging layer + if downsample is not None: + self.downsample = downsample(dim=dim, norm_layer=norm_layer) + else: + self.downsample = None + + def forward(self, x, H, W): + """Forward function. + Args: + x: Input feature, tensor size (B, H*W, C). + H, W: Spatial resolution of the input feature. + """ + + # calculate attention mask for SW-MSA + Hp = int(np.ceil(H / self.window_size)) * self.window_size + Wp = int(np.ceil(W / self.window_size)) * self.window_size + img_mask = torch.zeros((1, Hp, Wp, 1), device=x.device) # 1 Hp Wp 1 + h_slices = ( + slice(0, -self.window_size), + slice(-self.window_size, -self.shift_size), + slice(-self.shift_size, None), + ) + w_slices = ( + slice(0, -self.window_size), + slice(-self.window_size, -self.shift_size), + slice(-self.shift_size, None), + ) + cnt = 0 + for h in h_slices: + for w in w_slices: + img_mask[:, h, w, :] = cnt + cnt += 1 + + mask_windows = window_partition( + img_mask, self.window_size + ) # nW, window_size, window_size, 1 + mask_windows = mask_windows.view(-1, self.window_size * self.window_size) + attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) + attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill( + attn_mask == 0, float(0.0) + ) + + for blk in self.blocks: + blk.H, blk.W = H, W + if self.use_checkpoint: + x = checkpoint.checkpoint(blk, x, attn_mask) + else: + x = blk(x, attn_mask) + if self.downsample is not None: + x_down = self.downsample(x, H, W) + Wh, Ww = (H + 1) // 2, (W + 1) // 2 + return x, H, W, x_down, Wh, Ww + else: + return x, H, W, x, H, W + + +class PatchEmbed(nn.Module): + """Image to Patch Embedding + Args: + patch_size (int): Patch token size. Default: 4. + in_chans (int): Number of input image channels. Default: 3. + embed_dim (int): Number of linear projection output channels. Default: 96. + norm_layer (nn.Module, optional): Normalization layer. Default: None + """ + + def __init__(self, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None): + super().__init__() + patch_size = to_2tuple(patch_size) + self.patch_size = patch_size + + self.in_chans = in_chans + self.embed_dim = embed_dim + + self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) + if norm_layer is not None: + self.norm = norm_layer(embed_dim) + else: + self.norm = None + + def forward(self, x): + """Forward function.""" + # padding + _, _, H, W = x.size() + if W % self.patch_size[1] != 0: + x = F.pad(x, (0, self.patch_size[1] - W % self.patch_size[1])) + if H % self.patch_size[0] != 0: + x = F.pad(x, (0, 0, 0, self.patch_size[0] - H % self.patch_size[0])) + + x = self.proj(x) # B C Wh Ww + if self.norm is not None: + Wh, Ww = x.size(2), x.size(3) + x = x.flatten(2).transpose(1, 2) + x = self.norm(x) + x = x.transpose(1, 2).view(-1, self.embed_dim, Wh, Ww) + + return x + + +class SwinTransformer(nn.Module): + """Swin Transformer backbone. + A PyTorch impl of : `Swin Transformer: Hierarchical Vision Transformer using Shifted Windows` - + https://arxiv.org/pdf/2103.14030 + Args: + pretrain_img_size (int): Input image size for training the pretrained model, + used in absolute postion embedding. Default 224. + patch_size (int | tuple(int)): Patch size. Default: 4. + in_chans (int): Number of input image channels. Default: 3. + embed_dim (int): Number of linear projection output channels. Default: 96. + depths (tuple[int]): Depths of each Swin Transformer stage. + num_heads (tuple[int]): Number of attention head of each stage. + window_size (int): Window size. Default: 7. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4. + qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float): Override default qk scale of head_dim ** -0.5 if set. + drop_rate (float): Dropout rate. + attn_drop_rate (float): Attention dropout rate. Default: 0. + drop_path_rate (float): Stochastic depth rate. Default: 0.2. + norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm. + ape (bool): If True, add absolute position embedding to the patch embedding. Default: False. + patch_norm (bool): If True, add normalization after patch embedding. Default: True. + out_indices (Sequence[int]): Output from which stages. + frozen_stages (int): Stages to be frozen (stop grad and set eval mode). + -1 means not freezing any parameters. + use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False. + """ + + def __init__( + self, + pretrain_img_size=224, + patch_size=4, + in_chans=3, + embed_dim=96, + depths=[2, 2, 6, 2], + num_heads=[3, 6, 12, 24], + window_size=7, + mlp_ratio=4.0, + qkv_bias=True, + qk_scale=None, + drop_rate=0.0, + attn_drop_rate=0.0, + drop_path_rate=0.2, + norm_layer=nn.LayerNorm, + ape=False, + patch_norm=True, + out_indices=(0, 1, 2, 3), + frozen_stages=-1, + use_checkpoint=False, + ): + super().__init__() + + self.pretrain_img_size = pretrain_img_size + self.num_layers = len(depths) + self.embed_dim = embed_dim + self.ape = ape + self.patch_norm = patch_norm + self.out_indices = out_indices + self.frozen_stages = frozen_stages + + # split image into non-overlapping patches + self.patch_embed = PatchEmbed( + patch_size=patch_size, + in_chans=in_chans, + embed_dim=embed_dim, + norm_layer=norm_layer if self.patch_norm else None, + ) + + # absolute position embedding + if self.ape: + pretrain_img_size = to_2tuple(pretrain_img_size) + patch_size = to_2tuple(patch_size) + patches_resolution = [ + pretrain_img_size[0] // patch_size[0], + pretrain_img_size[1] // patch_size[1], + ] + + self.absolute_pos_embed = nn.Parameter( + torch.zeros(1, embed_dim, patches_resolution[0], patches_resolution[1]) + ) + trunc_normal_(self.absolute_pos_embed, std=0.02) + + self.pos_drop = nn.Dropout(p=drop_rate) + + # stochastic depth + dpr = [ + x.item() for x in torch.linspace(0, drop_path_rate, sum(depths)) + ] # stochastic depth decay rule + + # build layers + self.layers = nn.ModuleList() + for i_layer in range(self.num_layers): + layer = BasicLayer( + dim=int(embed_dim * 2 ** i_layer), + depth=depths[i_layer], + num_heads=num_heads[i_layer], + window_size=window_size, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + drop=drop_rate, + attn_drop=attn_drop_rate, + drop_path=dpr[sum(depths[:i_layer]) : sum(depths[: i_layer + 1])], + norm_layer=norm_layer, + downsample=PatchMerging if (i_layer < self.num_layers - 1) else None, + use_checkpoint=use_checkpoint, + ) + self.layers.append(layer) + + num_features = [int(embed_dim * 2 ** i) for i in range(self.num_layers)] + self.num_features = num_features + + # add a norm layer for each output + for i_layer in out_indices: + layer = norm_layer(num_features[i_layer]) + layer_name = f"norm{i_layer}" + self.add_module(layer_name, layer) + + self._freeze_stages() + + def _freeze_stages(self): + if self.frozen_stages >= 0: + self.patch_embed.eval() + for param in self.patch_embed.parameters(): + param.requires_grad = False + + if self.frozen_stages >= 1 and self.ape: + self.absolute_pos_embed.requires_grad = False + + if self.frozen_stages >= 2: + self.pos_drop.eval() + for i in range(0, self.frozen_stages - 1): + m = self.layers[i] + m.eval() + for param in m.parameters(): + param.requires_grad = False + + def init_weights(self, pretrained=None): + """Initialize the weights in backbone. + Args: + pretrained (str, optional): Path to pre-trained weights. + Defaults to None. + """ + + def _init_weights(m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=0.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + + def forward(self, x): + """Forward function.""" + x = self.patch_embed(x) + + Wh, Ww = x.size(2), x.size(3) + if self.ape: + # interpolate the position embedding to the corresponding size + absolute_pos_embed = F.interpolate( + self.absolute_pos_embed, size=(Wh, Ww), mode="bicubic" + ) + x = (x + absolute_pos_embed).flatten(2).transpose(1, 2) # B Wh*Ww C + else: + x = x.flatten(2).transpose(1, 2) + x = self.pos_drop(x) + + outs = {} + for i in range(self.num_layers): + layer = self.layers[i] + x_out, H, W, x, Wh, Ww = layer(x, Wh, Ww) + + if i in self.out_indices: + norm_layer = getattr(self, f"norm{i}") + x_out = norm_layer(x_out) + + out = x_out.view(-1, H, W, self.num_features[i]).permute(0, 3, 1, 2).contiguous() + outs["res{}".format(i + 2)] = out + + return outs + + def train(self, mode=True): + """Convert the model into training mode while keep layers freezed.""" + super(SwinTransformer, self).train(mode) + self._freeze_stages() + + +@BACKBONE_REGISTRY.register() +class D2SwinTransformer(SwinTransformer, Backbone): + def __init__(self, cfg, input_shape): + + pretrain_img_size = cfg.MODEL.SWIN.PRETRAIN_IMG_SIZE + patch_size = cfg.MODEL.SWIN.PATCH_SIZE + in_chans = 3 + embed_dim = cfg.MODEL.SWIN.EMBED_DIM + depths = cfg.MODEL.SWIN.DEPTHS + num_heads = cfg.MODEL.SWIN.NUM_HEADS + window_size = cfg.MODEL.SWIN.WINDOW_SIZE + mlp_ratio = cfg.MODEL.SWIN.MLP_RATIO + qkv_bias = cfg.MODEL.SWIN.QKV_BIAS + qk_scale = cfg.MODEL.SWIN.QK_SCALE + drop_rate = cfg.MODEL.SWIN.DROP_RATE + attn_drop_rate = cfg.MODEL.SWIN.ATTN_DROP_RATE + drop_path_rate = cfg.MODEL.SWIN.DROP_PATH_RATE + norm_layer = nn.LayerNorm + ape = cfg.MODEL.SWIN.APE + patch_norm = cfg.MODEL.SWIN.PATCH_NORM + + super().__init__( + pretrain_img_size, + patch_size, + in_chans, + embed_dim, + depths, + num_heads, + window_size, + mlp_ratio, + qkv_bias, + qk_scale, + drop_rate, + attn_drop_rate, + drop_path_rate, + norm_layer, + ape, + patch_norm, + ) + + self._out_features = cfg.MODEL.SWIN.OUT_FEATURES + + self._out_feature_strides = { + "res2": 4, + "res3": 8, + "res4": 16, + "res5": 32, + } + self._out_feature_channels = { + "res2": self.num_features[0], + "res3": self.num_features[1], + "res4": self.num_features[2], + "res5": self.num_features[3], + } + + def forward(self, x): + """ + Args: + x: Tensor of shape (N,C,H,W). H, W must be a multiple of ``self.size_divisibility``. + Returns: + dict[str->Tensor]: names and the corresponding features + """ + assert ( + x.dim() == 4 + ), f"SwinTransformer takes an input of shape (N, C, H, W). Got {x.shape} instead!" + outputs = {} + y = super().forward(x) + for k in y.keys(): + if k in self._out_features: + outputs[k] = y[k] + return outputs + + def output_shape(self): + return { + name: ShapeSpec( + channels=self._out_feature_channels[name], stride=self._out_feature_strides[name] + ) + for name in self._out_features + } + + @property + def size_divisibility(self): + return 32 diff --git a/ACT_DP_multitask/detr/models/mask_former/modeling/criterion.py b/ACT_DP_multitask/detr/models/mask_former/modeling/criterion.py new file mode 100644 index 0000000000000000000000000000000000000000..7631ee3766bc0920f3a0e8ca648782e6c5c2e0bf --- /dev/null +++ b/ACT_DP_multitask/detr/models/mask_former/modeling/criterion.py @@ -0,0 +1,187 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# Modified by Bowen Cheng from https://github.com/facebookresearch/detr/blob/master/models/detr.py +""" +MaskFormer criterion. +""" +import torch +import torch.nn.functional as F +from torch import nn + +from detectron2.utils.comm import get_world_size + +from ..utils.misc import is_dist_avail_and_initialized, nested_tensor_from_tensor_list + + +def dice_loss(inputs, targets, num_masks): + """ + Compute the DICE loss, similar to generalized IOU for masks + Args: + inputs: A float tensor of arbitrary shape. + The predictions for each example. + targets: A float tensor with the same shape as inputs. Stores the binary + classification label for each element in inputs + (0 for the negative class and 1 for the positive class). + """ + inputs = inputs.sigmoid() + inputs = inputs.flatten(1) + numerator = 2 * (inputs * targets).sum(-1) + denominator = inputs.sum(-1) + targets.sum(-1) + loss = 1 - (numerator + 1) / (denominator + 1) + return loss.sum() / num_masks + + +def sigmoid_focal_loss(inputs, targets, num_masks, alpha: float = 0.25, gamma: float = 2): + """ + Loss used in RetinaNet for dense detection: https://arxiv.org/abs/1708.02002. + Args: + inputs: A float tensor of arbitrary shape. + The predictions for each example. + targets: A float tensor with the same shape as inputs. Stores the binary + classification label for each element in inputs + (0 for the negative class and 1 for the positive class). + alpha: (optional) Weighting factor in range (0,1) to balance + positive vs negative examples. Default = -1 (no weighting). + gamma: Exponent of the modulating factor (1 - p_t) to + balance easy vs hard examples. + Returns: + Loss tensor + """ + prob = inputs.sigmoid() + ce_loss = F.binary_cross_entropy_with_logits(inputs, targets, reduction="none") + p_t = prob * targets + (1 - prob) * (1 - targets) + loss = ce_loss * ((1 - p_t) ** gamma) + + if alpha >= 0: + alpha_t = alpha * targets + (1 - alpha) * (1 - targets) + loss = alpha_t * loss + + return loss.mean(1).sum() / num_masks + + +class SetCriterion(nn.Module): + """This class computes the loss for DETR. + The process happens in two steps: + 1) we compute hungarian assignment between ground truth boxes and the outputs of the model + 2) we supervise each pair of matched ground-truth / prediction (supervise class and box) + """ + + def __init__(self, num_classes, matcher, weight_dict, eos_coef, losses): + """Create the criterion. + Parameters: + num_classes: number of object categories, omitting the special no-object category + matcher: module able to compute a matching between targets and proposals + weight_dict: dict containing as key the names of the losses and as values their relative weight. + eos_coef: relative classification weight applied to the no-object category + losses: list of all the losses to be applied. See get_loss for list of available losses. + """ + super().__init__() + self.num_classes = num_classes + self.matcher = matcher + self.weight_dict = weight_dict + self.eos_coef = eos_coef + self.losses = losses + empty_weight = torch.ones(self.num_classes + 1) + empty_weight[-1] = self.eos_coef + self.register_buffer("empty_weight", empty_weight) + + def loss_labels(self, outputs, targets, indices, num_masks): + """Classification loss (NLL) + targets dicts must contain the key "labels" containing a tensor of dim [nb_target_boxes] + """ + assert "pred_logits" in outputs + src_logits = outputs["pred_logits"] + + idx = self._get_src_permutation_idx(indices) + target_classes_o = torch.cat([t["labels"][J] for t, (_, J) in zip(targets, indices)]) + target_classes = torch.full( + src_logits.shape[:2], self.num_classes, dtype=torch.int64, device=src_logits.device + ) + target_classes[idx] = target_classes_o + + loss_ce = F.cross_entropy(src_logits.transpose(1, 2), target_classes, self.empty_weight) + losses = {"loss_ce": loss_ce} + return losses + + def loss_masks(self, outputs, targets, indices, num_masks): + """Compute the losses related to the masks: the focal loss and the dice loss. + targets dicts must contain the key "masks" containing a tensor of dim [nb_target_boxes, h, w] + """ + assert "pred_masks" in outputs + + src_idx = self._get_src_permutation_idx(indices) + tgt_idx = self._get_tgt_permutation_idx(indices) + src_masks = outputs["pred_masks"] + src_masks = src_masks[src_idx] + masks = [t["masks"] for t in targets] + # TODO use valid to mask invalid areas due to padding in loss + target_masks, valid = nested_tensor_from_tensor_list(masks).decompose() + target_masks = target_masks.to(src_masks) + target_masks = target_masks[tgt_idx] + + # upsample predictions to the target size + src_masks = F.interpolate( + src_masks[:, None], size=target_masks.shape[-2:], mode="bilinear", align_corners=False + ) + src_masks = src_masks[:, 0].flatten(1) + + target_masks = target_masks.flatten(1) + target_masks = target_masks.view(src_masks.shape) + losses = { + "loss_mask": sigmoid_focal_loss(src_masks, target_masks, num_masks), + "loss_dice": dice_loss(src_masks, target_masks, num_masks), + } + return losses + + def _get_src_permutation_idx(self, indices): + # permute predictions following indices + batch_idx = torch.cat([torch.full_like(src, i) for i, (src, _) in enumerate(indices)]) + src_idx = torch.cat([src for (src, _) in indices]) + return batch_idx, src_idx + + def _get_tgt_permutation_idx(self, indices): + # permute targets following indices + batch_idx = torch.cat([torch.full_like(tgt, i) for i, (_, tgt) in enumerate(indices)]) + tgt_idx = torch.cat([tgt for (_, tgt) in indices]) + return batch_idx, tgt_idx + + def get_loss(self, loss, outputs, targets, indices, num_masks): + loss_map = {"labels": self.loss_labels, "masks": self.loss_masks} + assert loss in loss_map, f"do you really want to compute {loss} loss?" + return loss_map[loss](outputs, targets, indices, num_masks) + + def forward(self, outputs, targets): + """This performs the loss computation. + Parameters: + outputs: dict of tensors, see the output specification of the model for the format + targets: list of dicts, such that len(targets) == batch_size. + The expected keys in each dict depends on the losses applied, see each loss' doc + """ + outputs_without_aux = {k: v for k, v in outputs.items() if k != "aux_outputs"} + + # Retrieve the matching between the outputs of the last layer and the targets + indices = self.matcher(outputs_without_aux, targets) + + # Compute the average number of target boxes accross all nodes, for normalization purposes + num_masks = sum(len(t["labels"]) for t in targets) + num_masks = torch.as_tensor( + [num_masks], dtype=torch.float, device=next(iter(outputs.values())).device + ) + if is_dist_avail_and_initialized(): + torch.distributed.all_reduce(num_masks) + num_masks = torch.clamp(num_masks / get_world_size(), min=1).item() + + # Compute all the requested losses + losses = {} + for loss in self.losses: + losses.update(self.get_loss(loss, outputs, targets, indices, num_masks)) + + # In case of auxiliary losses, we repeat this process with the output of each intermediate layer. + if "aux_outputs" in outputs: + for i, aux_outputs in enumerate(outputs["aux_outputs"]): + indices = self.matcher(aux_outputs, targets) + for loss in self.losses: + l_dict = self.get_loss(loss, aux_outputs, targets, indices, num_masks) + l_dict = {k + f"_{i}": v for k, v in l_dict.items()} + losses.update(l_dict) + + return losses diff --git a/ACT_DP_multitask/detr/models/mask_former/modeling/heads/__init__.py b/ACT_DP_multitask/detr/models/mask_former/modeling/heads/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..9020c2df23e2af280b7bb168b996ae9eaf312eb8 --- /dev/null +++ b/ACT_DP_multitask/detr/models/mask_former/modeling/heads/__init__.py @@ -0,0 +1 @@ +# Copyright (c) Facebook, Inc. and its affiliates. diff --git a/ACT_DP_multitask/detr/models/mask_former/modeling/heads/mask_former_head.py b/ACT_DP_multitask/detr/models/mask_former/modeling/heads/mask_former_head.py new file mode 100644 index 0000000000000000000000000000000000000000..16f9d990953b33cecf6fd1e9a25131e2c9127ffe --- /dev/null +++ b/ACT_DP_multitask/detr/models/mask_former/modeling/heads/mask_former_head.py @@ -0,0 +1,119 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +import logging +from copy import deepcopy +from typing import Callable, Dict, List, Optional, Tuple, Union + +import fvcore.nn.weight_init as weight_init +from torch import nn +from torch.nn import functional as F + +from detectron2.config import configurable +from detectron2.layers import Conv2d, ShapeSpec, get_norm +from detectron2.modeling import SEM_SEG_HEADS_REGISTRY + +from ..transformer.transformer_predictor import TransformerPredictor +from .pixel_decoder import build_pixel_decoder + + +@SEM_SEG_HEADS_REGISTRY.register() +class MaskFormerHead(nn.Module): + + _version = 2 + + def _load_from_state_dict( + self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs + ): + version = local_metadata.get("version", None) + if version is None or version < 2: + # Do not warn if train from scratch + scratch = True + logger = logging.getLogger(__name__) + for k in list(state_dict.keys()): + newk = k + if "sem_seg_head" in k and not k.startswith(prefix + "predictor"): + newk = k.replace(prefix, prefix + "pixel_decoder.") + # logger.debug(f"{k} ==> {newk}") + if newk != k: + state_dict[newk] = state_dict[k] + del state_dict[k] + scratch = False + + if not scratch: + logger.warning( + f"Weight format of {self.__class__.__name__} have changed! " + "Please upgrade your models. Applying automatic conversion now ..." + ) + + @configurable + def __init__( + self, + input_shape: Dict[str, ShapeSpec], + *, + num_classes: int, + pixel_decoder: nn.Module, + loss_weight: float = 1.0, + ignore_value: int = -1, + # extra parameters + transformer_predictor: nn.Module, + transformer_in_feature: str, + ): + """ + NOTE: this interface is experimental. + Args: + input_shape: shapes (channels and stride) of the input features + num_classes: number of classes to predict + pixel_decoder: the pixel decoder module + loss_weight: loss weight + ignore_value: category id to be ignored during training. + transformer_predictor: the transformer decoder that makes prediction + transformer_in_feature: input feature name to the transformer_predictor + """ + super().__init__() + input_shape = sorted(input_shape.items(), key=lambda x: x[1].stride) + self.in_features = [k for k, v in input_shape] + feature_strides = [v.stride for k, v in input_shape] + feature_channels = [v.channels for k, v in input_shape] + + self.ignore_value = ignore_value + self.common_stride = 4 + self.loss_weight = loss_weight + + self.pixel_decoder = pixel_decoder + self.predictor = transformer_predictor + self.transformer_in_feature = transformer_in_feature + + self.num_classes = num_classes + + @classmethod + def from_config(cls, cfg, input_shape: Dict[str, ShapeSpec]): + return { + "input_shape": { + k: v for k, v in input_shape.items() if k in cfg.MODEL.SEM_SEG_HEAD.IN_FEATURES + }, + "ignore_value": cfg.MODEL.SEM_SEG_HEAD.IGNORE_VALUE, + "num_classes": cfg.MODEL.SEM_SEG_HEAD.NUM_CLASSES, + "pixel_decoder": build_pixel_decoder(cfg, input_shape), + "loss_weight": cfg.MODEL.SEM_SEG_HEAD.LOSS_WEIGHT, + "transformer_in_feature": cfg.MODEL.MASK_FORMER.TRANSFORMER_IN_FEATURE, + "transformer_predictor": TransformerPredictor( + cfg, + cfg.MODEL.SEM_SEG_HEAD.CONVS_DIM + if cfg.MODEL.MASK_FORMER.TRANSFORMER_IN_FEATURE == "transformer_encoder" + else input_shape[cfg.MODEL.MASK_FORMER.TRANSFORMER_IN_FEATURE].channels, + mask_classification=True, + ), + } + + def forward(self, features): + return self.layers(features) + + def layers(self, features): + mask_features, transformer_encoder_features = self.pixel_decoder.forward_features(features) + if self.transformer_in_feature == "transformer_encoder": + assert ( + transformer_encoder_features is not None + ), "Please use the TransformerEncoderPixelDecoder." + predictions = self.predictor(transformer_encoder_features, mask_features) + else: + predictions = self.predictor(features[self.transformer_in_feature], mask_features) + return predictions diff --git a/ACT_DP_multitask/detr/models/mask_former/modeling/heads/per_pixel_baseline.py b/ACT_DP_multitask/detr/models/mask_former/modeling/heads/per_pixel_baseline.py new file mode 100644 index 0000000000000000000000000000000000000000..a99f508e7b4a87ada0af6f10209f10edefa7e412 --- /dev/null +++ b/ACT_DP_multitask/detr/models/mask_former/modeling/heads/per_pixel_baseline.py @@ -0,0 +1,243 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +import logging +from typing import Callable, Dict, List, Optional, Tuple, Union + +import fvcore.nn.weight_init as weight_init +from torch import nn +from torch.nn import functional as F + +from detectron2.config import configurable +from detectron2.layers import Conv2d, ShapeSpec, get_norm +from detectron2.modeling import SEM_SEG_HEADS_REGISTRY + +from ..transformer.transformer_predictor import TransformerPredictor +from .pixel_decoder import build_pixel_decoder + + +@SEM_SEG_HEADS_REGISTRY.register() +class PerPixelBaselineHead(nn.Module): + + _version = 2 + + def _load_from_state_dict( + self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs + ): + version = local_metadata.get("version", None) + if version is None or version < 2: + logger = logging.getLogger(__name__) + # Do not warn if train from scratch + scratch = True + logger = logging.getLogger(__name__) + for k in list(state_dict.keys()): + newk = k + if "sem_seg_head" in k and not k.startswith(prefix + "predictor"): + newk = k.replace(prefix, prefix + "pixel_decoder.") + # logger.warning(f"{k} ==> {newk}") + if newk != k: + state_dict[newk] = state_dict[k] + del state_dict[k] + scratch = False + + if not scratch: + logger.warning( + f"Weight format of {self.__class__.__name__} have changed! " + "Please upgrade your models. Applying automatic conversion now ..." + ) + + @configurable + def __init__( + self, + input_shape: Dict[str, ShapeSpec], + *, + num_classes: int, + pixel_decoder: nn.Module, + loss_weight: float = 1.0, + ignore_value: int = -1, + ): + """ + NOTE: this interface is experimental. + Args: + input_shape: shapes (channels and stride) of the input features + num_classes: number of classes to predict + pixel_decoder: the pixel decoder module + loss_weight: loss weight + ignore_value: category id to be ignored during training. + """ + super().__init__() + input_shape = sorted(input_shape.items(), key=lambda x: x[1].stride) + self.in_features = [k for k, v in input_shape] + feature_strides = [v.stride for k, v in input_shape] + feature_channels = [v.channels for k, v in input_shape] + + self.ignore_value = ignore_value + self.common_stride = 4 + self.loss_weight = loss_weight + + self.pixel_decoder = pixel_decoder + self.predictor = Conv2d( + self.pixel_decoder.mask_dim, num_classes, kernel_size=1, stride=1, padding=0 + ) + weight_init.c2_msra_fill(self.predictor) + + @classmethod + def from_config(cls, cfg, input_shape: Dict[str, ShapeSpec]): + return { + "input_shape": { + k: v for k, v in input_shape.items() if k in cfg.MODEL.SEM_SEG_HEAD.IN_FEATURES + }, + "ignore_value": cfg.MODEL.SEM_SEG_HEAD.IGNORE_VALUE, + "num_classes": cfg.MODEL.SEM_SEG_HEAD.NUM_CLASSES, + "pixel_decoder": build_pixel_decoder(cfg, input_shape), + "loss_weight": cfg.MODEL.SEM_SEG_HEAD.LOSS_WEIGHT, + } + + def forward(self, features, targets=None): + """ + Returns: + In training, returns (None, dict of losses) + In inference, returns (CxHxW logits, {}) + """ + x = self.layers(features) + if self.training: + return None, self.losses(x, targets) + else: + x = F.interpolate( + x, scale_factor=self.common_stride, mode="bilinear", align_corners=False + ) + return x, {} + + def layers(self, features): + x, _ = self.pixel_decoder.forward_features(features) + x = self.predictor(x) + return x + + def losses(self, predictions, targets): + predictions = predictions.float() # https://github.com/pytorch/pytorch/issues/48163 + predictions = F.interpolate( + predictions, scale_factor=self.common_stride, mode="bilinear", align_corners=False + ) + loss = F.cross_entropy( + predictions, targets, reduction="mean", ignore_index=self.ignore_value + ) + losses = {"loss_sem_seg": loss * self.loss_weight} + return losses + + +@SEM_SEG_HEADS_REGISTRY.register() +class PerPixelBaselinePlusHead(PerPixelBaselineHead): + def _load_from_state_dict( + self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs + ): + version = local_metadata.get("version", None) + if version is None or version < 2: + # Do not warn if train from scratch + scratch = True + logger = logging.getLogger(__name__) + for k in list(state_dict.keys()): + newk = k + if "sem_seg_head" in k and not k.startswith(prefix + "predictor"): + newk = k.replace(prefix, prefix + "pixel_decoder.") + logger.debug(f"{k} ==> {newk}") + if newk != k: + state_dict[newk] = state_dict[k] + del state_dict[k] + scratch = False + + if not scratch: + logger.warning( + f"Weight format of {self.__class__.__name__} have changed! " + "Please upgrade your models. Applying automatic conversion now ..." + ) + + @configurable + def __init__( + self, + input_shape: Dict[str, ShapeSpec], + *, + # extra parameters + transformer_predictor: nn.Module, + transformer_in_feature: str, + deep_supervision: bool, + # inherit parameters + num_classes: int, + pixel_decoder: nn.Module, + loss_weight: float = 1.0, + ignore_value: int = -1, + ): + """ + NOTE: this interface is experimental. + Args: + input_shape: shapes (channels and stride) of the input features + transformer_predictor: the transformer decoder that makes prediction + transformer_in_feature: input feature name to the transformer_predictor + deep_supervision: whether or not to add supervision to the output of + every transformer decoder layer + num_classes: number of classes to predict + pixel_decoder: the pixel decoder module + loss_weight: loss weight + ignore_value: category id to be ignored during training. + """ + super().__init__( + input_shape, + num_classes=num_classes, + pixel_decoder=pixel_decoder, + loss_weight=loss_weight, + ignore_value=ignore_value, + ) + + del self.predictor + + self.predictor = transformer_predictor + self.transformer_in_feature = transformer_in_feature + self.deep_supervision = deep_supervision + + @classmethod + def from_config(cls, cfg, input_shape: Dict[str, ShapeSpec]): + ret = super().from_config(cfg, input_shape) + ret["transformer_in_feature"] = cfg.MODEL.MASK_FORMER.TRANSFORMER_IN_FEATURE + if cfg.MODEL.MASK_FORMER.TRANSFORMER_IN_FEATURE == "transformer_encoder": + in_channels = cfg.MODEL.SEM_SEG_HEAD.CONVS_DIM + else: + in_channels = input_shape[ret["transformer_in_feature"]].channels + ret["transformer_predictor"] = TransformerPredictor( + cfg, in_channels, mask_classification=False + ) + ret["deep_supervision"] = cfg.MODEL.MASK_FORMER.DEEP_SUPERVISION + return ret + + def forward(self, features, targets=None): + """ + Returns: + In training, returns (None, dict of losses) + In inference, returns (CxHxW logits, {}) + """ + x, aux_outputs = self.layers(features) + if self.training: + if self.deep_supervision: + losses = self.losses(x, targets) + for i, aux_output in enumerate(aux_outputs): + losses["loss_sem_seg" + f"_{i}"] = self.losses( + aux_output["pred_masks"], targets + )["loss_sem_seg"] + return None, losses + else: + return None, self.losses(x, targets) + else: + x = F.interpolate( + x, scale_factor=self.common_stride, mode="bilinear", align_corners=False + ) + return x, {} + + def layers(self, features): + mask_features, transformer_encoder_features = self.pixel_decoder.forward_features(features) + if self.transformer_in_feature == "transformer_encoder": + assert ( + transformer_encoder_features is not None + ), "Please use the TransformerEncoderPixelDecoder." + predictions = self.predictor(transformer_encoder_features, mask_features) + else: + predictions = self.predictor(features[self.transformer_in_feature], mask_features) + if self.deep_supervision: + return predictions["pred_masks"], predictions["aux_outputs"] + else: + return predictions["pred_masks"], None diff --git a/ACT_DP_multitask/detr/models/mask_former/modeling/heads/pixel_decoder.py b/ACT_DP_multitask/detr/models/mask_former/modeling/heads/pixel_decoder.py new file mode 100644 index 0000000000000000000000000000000000000000..8ae74c8ce6000aa976d19c7f9223519f46ce7e6a --- /dev/null +++ b/ACT_DP_multitask/detr/models/mask_former/modeling/heads/pixel_decoder.py @@ -0,0 +1,294 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +import logging +from typing import Callable, Dict, List, Optional, Tuple, Union + +import fvcore.nn.weight_init as weight_init +from torch import nn +from torch.nn import functional as F + +from detectron2.config import configurable +from detectron2.layers import Conv2d, ShapeSpec, get_norm +from detectron2.modeling import SEM_SEG_HEADS_REGISTRY + +from ..transformer.position_encoding import PositionEmbeddingSine +from ..transformer.transformer import TransformerEncoder, TransformerEncoderLayer + + +def build_pixel_decoder(cfg, input_shape): + """ + Build a pixel decoder from `cfg.MODEL.MASK_FORMER.PIXEL_DECODER_NAME`. + """ + name = cfg.MODEL.SEM_SEG_HEAD.PIXEL_DECODER_NAME + model = SEM_SEG_HEADS_REGISTRY.get(name)(cfg, input_shape) + forward_features = getattr(model, "forward_features", None) + if not callable(forward_features): + raise ValueError( + "Only SEM_SEG_HEADS with forward_features method can be used as pixel decoder. " + f"Please implement forward_features for {name} to only return mask features." + ) + return model + + +@SEM_SEG_HEADS_REGISTRY.register() +class BasePixelDecoder(nn.Module): + # @configurable + def __init__( + self, + input_shape: Dict[str, ShapeSpec], + # *, + conv_dim: int, + mask_dim: int, + norm: Optional[Union[str, Callable]] = None, + ): + """ + NOTE: this interface is experimental. + Args: + input_shape: shapes (channels and stride) of the input features + conv_dims: number of output channels for the intermediate conv layers. + mask_dim: number of output channels for the final conv layer. + norm (str or callable): normalization for all conv layers + """ + super().__init__() + + input_shape = sorted(input_shape.items(), key=lambda x: x[1].stride) + self.in_features = [k for k, v in input_shape] # starting from "res2" to "res5" + feature_channels = [v.channels for k, v in input_shape] + + lateral_convs = [] + output_convs = [] + + use_bias = norm == "" + for idx, in_channels in enumerate(feature_channels): + if idx == len(self.in_features) - 1: + output_norm = get_norm(norm, conv_dim) + output_conv = Conv2d( + in_channels, + conv_dim, + kernel_size=3, + stride=1, + padding=1, + bias=use_bias, + norm=output_norm, + activation=F.relu, + ) + weight_init.c2_xavier_fill(output_conv) + self.add_module("layer_{}".format(idx + 1), output_conv) + + lateral_convs.append(None) + output_convs.append(output_conv) + else: + lateral_norm = get_norm(norm, conv_dim) + output_norm = get_norm(norm, conv_dim) + + lateral_conv = Conv2d( + in_channels, conv_dim, kernel_size=1, bias=use_bias, norm=lateral_norm + ) + output_conv = Conv2d( + conv_dim, + conv_dim, + kernel_size=3, + stride=1, + padding=1, + bias=use_bias, + norm=output_norm, + activation=F.relu, + ) + weight_init.c2_xavier_fill(lateral_conv) + weight_init.c2_xavier_fill(output_conv) + self.add_module("adapter_{}".format(idx + 1), lateral_conv) + self.add_module("layer_{}".format(idx + 1), output_conv) + + lateral_convs.append(lateral_conv) + output_convs.append(output_conv) + # Place convs into top-down order (from low to high resolution) + # to make the top-down computation in forward clearer. + self.lateral_convs = lateral_convs[::-1] + self.output_convs = output_convs[::-1] + + self.mask_dim = mask_dim + self.mask_features = Conv2d( + conv_dim, + mask_dim, + kernel_size=3, + stride=1, + padding=1, + ) + weight_init.c2_xavier_fill(self.mask_features) + + # @classmethod + # def from_config(cls, cfg, input_shape: Dict[str, ShapeSpec]): + # ret = {} + # ret["input_shape"] = { + # k: v for k, v in input_shape.items() if k in cfg.MODEL.SEM_SEG_HEAD.IN_FEATURES + # } + # ret["conv_dim"] = cfg.MODEL.SEM_SEG_HEAD.CONVS_DIM + # ret["mask_dim"] = cfg.MODEL.SEM_SEG_HEAD.MASK_DIM + # ret["norm"] = cfg.MODEL.SEM_SEG_HEAD.NORM + # return ret + + def forward_features(self, features): + # Reverse feature maps into top-down order (from low to high resolution) + for idx, f in enumerate(self.in_features[::-1]): + x = features[f] + lateral_conv = self.lateral_convs[idx] + output_conv = self.output_convs[idx] + if lateral_conv is None: + y = output_conv(x) + else: + cur_fpn = lateral_conv(x) + # Following FPN implementation, we use nearest upsampling here + y = cur_fpn + F.interpolate(y, size=cur_fpn.shape[-2:], mode="nearest") + y = output_conv(y) + return self.mask_features(y), None + + def forward(self, features, targets=None): + logger = logging.getLogger(__name__) + logger.warning("Calling forward() may cause unpredicted behavior of PixelDecoder module.") + return self.forward_features(features) + + +class TransformerEncoderOnly(nn.Module): + def __init__( + self, + d_model=512, + nhead=8, + num_encoder_layers=6, + dim_feedforward=2048, + dropout=0.1, + activation="relu", + normalize_before=False, + ): + super().__init__() + + encoder_layer = TransformerEncoderLayer( + d_model, nhead, dim_feedforward, dropout, activation, normalize_before + ) + encoder_norm = nn.LayerNorm(d_model) if normalize_before else None + self.encoder = TransformerEncoder(encoder_layer, num_encoder_layers, encoder_norm) + + self._reset_parameters() + + self.d_model = d_model + self.nhead = nhead + + def _reset_parameters(self): + for p in self.parameters(): + if p.dim() > 1: + nn.init.xavier_uniform_(p) + + def forward(self, src, mask, pos_embed): + # flatten NxCxHxW to HWxNxC + bs, c, h, w = src.shape + src = src.flatten(2).permute(2, 0, 1) + pos_embed = pos_embed.flatten(2).permute(2, 0, 1) + if mask is not None: + mask = mask.flatten(1) + + memory = self.encoder(src, src_key_padding_mask=mask, pos=pos_embed) + return memory.permute(1, 2, 0).view(bs, c, h, w) + + +@SEM_SEG_HEADS_REGISTRY.register() +class TransformerEncoderPixelDecoder(BasePixelDecoder): + @configurable + def __init__( + self, + input_shape: Dict[str, ShapeSpec], + *, + transformer_dropout: float, + transformer_nheads: int, + transformer_dim_feedforward: int, + transformer_enc_layers: int, + transformer_pre_norm: bool, + conv_dim: int, + mask_dim: int, + norm: Optional[Union[str, Callable]] = None, + ): + """ + NOTE: this interface is experimental. + Args: + input_shape: shapes (channels and stride) of the input features + transformer_dropout: dropout probability in transformer + transformer_nheads: number of heads in transformer + transformer_dim_feedforward: dimension of feedforward network + transformer_enc_layers: number of transformer encoder layers + transformer_pre_norm: whether to use pre-layernorm or not + conv_dims: number of output channels for the intermediate conv layers. + mask_dim: number of output channels for the final conv layer. + norm (str or callable): normalization for all conv layers + """ + super().__init__(input_shape, conv_dim=conv_dim, mask_dim=mask_dim, norm=norm) + + input_shape = sorted(input_shape.items(), key=lambda x: x[1].stride) + self.in_features = [k for k, v in input_shape] # starting from "res2" to "res5" + feature_strides = [v.stride for k, v in input_shape] + feature_channels = [v.channels for k, v in input_shape] + + in_channels = feature_channels[len(self.in_features) - 1] + self.input_proj = Conv2d(in_channels, conv_dim, kernel_size=1) + weight_init.c2_xavier_fill(self.input_proj) + self.transformer = TransformerEncoderOnly( + d_model=conv_dim, + dropout=transformer_dropout, + nhead=transformer_nheads, + dim_feedforward=transformer_dim_feedforward, + num_encoder_layers=transformer_enc_layers, + normalize_before=transformer_pre_norm, + ) + N_steps = conv_dim // 2 + self.pe_layer = PositionEmbeddingSine(N_steps, normalize=True) + + # update layer + use_bias = norm == "" + output_norm = get_norm(norm, conv_dim) + output_conv = Conv2d( + conv_dim, + conv_dim, + kernel_size=3, + stride=1, + padding=1, + bias=use_bias, + norm=output_norm, + activation=F.relu, + ) + weight_init.c2_xavier_fill(output_conv) + delattr(self, "layer_{}".format(len(self.in_features))) + self.add_module("layer_{}".format(len(self.in_features)), output_conv) + self.output_convs[0] = output_conv + + @classmethod + def from_config(cls, cfg, input_shape: Dict[str, ShapeSpec]): + ret = super().from_config(cfg, input_shape) + ret["transformer_dropout"] = cfg.MODEL.MASK_FORMER.DROPOUT + ret["transformer_nheads"] = cfg.MODEL.MASK_FORMER.NHEADS + ret["transformer_dim_feedforward"] = cfg.MODEL.MASK_FORMER.DIM_FEEDFORWARD + ret[ + "transformer_enc_layers" + ] = cfg.MODEL.SEM_SEG_HEAD.TRANSFORMER_ENC_LAYERS # a separate config + ret["transformer_pre_norm"] = cfg.MODEL.MASK_FORMER.PRE_NORM + return ret + + def forward_features(self, features): + # Reverse feature maps into top-down order (from low to high resolution) + for idx, f in enumerate(self.in_features[::-1]): + x = features[f] + lateral_conv = self.lateral_convs[idx] + output_conv = self.output_convs[idx] + if lateral_conv is None: + transformer = self.input_proj(x) + pos = self.pe_layer(x) + transformer = self.transformer(transformer, None, pos) + y = output_conv(transformer) + # save intermediate feature as input to Transformer decoder + transformer_encoder_features = transformer + else: + cur_fpn = lateral_conv(x) + # Following FPN implementation, we use nearest upsampling here + y = cur_fpn + F.interpolate(y, size=cur_fpn.shape[-2:], mode="nearest") + y = output_conv(y) + return self.mask_features(y), transformer_encoder_features + + def forward(self, features, targets=None): + logger = logging.getLogger(__name__) + logger.warning("Calling forward() may cause unpredicted behavior of PixelDecoder module.") + return self.forward_features(features) diff --git a/ACT_DP_multitask/detr/models/mask_former/modeling/matcher.py b/ACT_DP_multitask/detr/models/mask_former/modeling/matcher.py new file mode 100644 index 0000000000000000000000000000000000000000..a4d11706577bf353bb0df13fe57032da3c66e8f5 --- /dev/null +++ b/ACT_DP_multitask/detr/models/mask_former/modeling/matcher.py @@ -0,0 +1,174 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# Modified by Bowen Cheng from https://github.com/facebookresearch/detr/blob/master/models/matcher.py +""" +Modules to compute the matching cost and solve the corresponding LSAP. +""" +import torch +import torch.nn.functional as F +from scipy.optimize import linear_sum_assignment +from torch import nn + + +def batch_dice_loss(inputs, targets): + """ + Compute the DICE loss, similar to generalized IOU for masks + Args: + inputs: A float tensor of arbitrary shape. + The predictions for each example. + targets: A float tensor with the same shape as inputs. Stores the binary + classification label for each element in inputs + (0 for the negative class and 1 for the positive class). + """ + inputs = inputs.sigmoid() + inputs = inputs.flatten(1) + numerator = 2 * torch.einsum("nc,mc->nm", inputs, targets) + denominator = inputs.sum(-1)[:, None] + targets.sum(-1)[None, :] + loss = 1 - (numerator + 1) / (denominator + 1) + return loss + + +def batch_sigmoid_focal_loss(inputs, targets, alpha: float = 0.25, gamma: float = 2): + """ + Loss used in RetinaNet for dense detection: https://arxiv.org/abs/1708.02002. + Args: + inputs: A float tensor of arbitrary shape. + The predictions for each example. + targets: A float tensor with the same shape as inputs. Stores the binary + classification label for each element in inputs + (0 for the negative class and 1 for the positive class). + alpha: (optional) Weighting factor in range (0,1) to balance + positive vs negative examples. Default = -1 (no weighting). + gamma: Exponent of the modulating factor (1 - p_t) to + balance easy vs hard examples. + Returns: + Loss tensor + """ + hw = inputs.shape[1] + + prob = inputs.sigmoid() + focal_pos = ((1 - prob) ** gamma) * F.binary_cross_entropy_with_logits( + inputs, torch.ones_like(inputs), reduction="none" + ) + focal_neg = (prob ** gamma) * F.binary_cross_entropy_with_logits( + inputs, torch.zeros_like(inputs), reduction="none" + ) + if alpha >= 0: + focal_pos = focal_pos * alpha + focal_neg = focal_neg * (1 - alpha) + + loss = torch.einsum("nc,mc->nm", focal_pos, targets) + torch.einsum( + "nc,mc->nm", focal_neg, (1 - targets) + ) + + return loss / hw + + +class HungarianMatcher(nn.Module): + """This class computes an assignment between the targets and the predictions of the network + + For efficiency reasons, the targets don't include the no_object. Because of this, in general, + there are more predictions than targets. In this case, we do a 1-to-1 matching of the best predictions, + while the others are un-matched (and thus treated as non-objects). + """ + + def __init__(self, cost_class: float = 1, cost_mask: float = 1, cost_dice: float = 1): + """Creates the matcher + + Params: + cost_class: This is the relative weight of the classification error in the matching cost + cost_mask: This is the relative weight of the focal loss of the binary mask in the matching cost + cost_dice: This is the relative weight of the dice loss of the binary mask in the matching cost + """ + super().__init__() + self.cost_class = cost_class + self.cost_mask = cost_mask + self.cost_dice = cost_dice + assert cost_class != 0 or cost_mask != 0 or cost_dice != 0, "all costs cant be 0" + + @torch.no_grad() + def memory_efficient_forward(self, outputs, targets): + """More memory-friendly matching""" + bs, num_queries = outputs["pred_logits"].shape[:2] + + # Work out the mask padding size + masks = [v["masks"] for v in targets] + h_max = max([m.shape[1] for m in masks]) + w_max = max([m.shape[2] for m in masks]) + + indices = [] + + # Iterate through batch size + for b in range(bs): + + out_prob = outputs["pred_logits"][b].softmax(-1) # [num_queries, num_classes] + out_mask = outputs["pred_masks"][b] # [num_queries, H_pred, W_pred] + + tgt_ids = targets[b]["labels"] + # gt masks are already padded when preparing target + tgt_mask = targets[b]["masks"].to(out_mask) + + # Compute the classification cost. Contrary to the loss, we don't use the NLL, + # but approximate it in 1 - proba[target class]. + # The 1 is a constant that doesn't change the matching, it can be ommitted. + cost_class = -out_prob[:, tgt_ids] + + # Downsample gt masks to save memory + tgt_mask = F.interpolate(tgt_mask[:, None], size=out_mask.shape[-2:], mode="nearest") + + # Flatten spatial dimension + out_mask = out_mask.flatten(1) # [batch_size * num_queries, H*W] + tgt_mask = tgt_mask[:, 0].flatten(1) # [num_total_targets, H*W] + + # Compute the focal loss between masks + cost_mask = batch_sigmoid_focal_loss(out_mask, tgt_mask) + + # Compute the dice loss betwen masks + cost_dice = batch_dice_loss(out_mask, tgt_mask) + + # Final cost matrix + C = ( + self.cost_mask * cost_mask + + self.cost_class * cost_class + + self.cost_dice * cost_dice + ) + C = C.reshape(num_queries, -1).cpu() + + indices.append(linear_sum_assignment(C)) + return [ + (torch.as_tensor(i, dtype=torch.int64), torch.as_tensor(j, dtype=torch.int64)) + for i, j in indices + ] + + @torch.no_grad() + def forward(self, outputs, targets): + """Performs the matching + + Params: + outputs: This is a dict that contains at least these entries: + "pred_logits": Tensor of dim [batch_size, num_queries, num_classes] with the classification logits + "pred_masks": Tensor of dim [batch_size, num_queries, H_pred, W_pred] with the predicted masks + + targets: This is a list of targets (len(targets) = batch_size), where each target is a dict containing: + "labels": Tensor of dim [num_target_boxes] (where num_target_boxes is the number of ground-truth + objects in the target) containing the class labels + "masks": Tensor of dim [num_target_boxes, H_gt, W_gt] containing the target masks + + Returns: + A list of size batch_size, containing tuples of (index_i, index_j) where: + - index_i is the indices of the selected predictions (in order) + - index_j is the indices of the corresponding selected targets (in order) + For each batch element, it holds: + len(index_i) = len(index_j) = min(num_queries, num_target_boxes) + """ + return self.memory_efficient_forward(outputs, targets) + + def __repr__(self): + head = "Matcher " + self.__class__.__name__ + body = [ + "cost_class: {}".format(self.cost_class), + "cost_mask: {}".format(self.cost_mask), + "cost_dice: {}".format(self.cost_dice), + ] + _repr_indent = 4 + lines = [head] + [" " * _repr_indent + line for line in body] + return "\n".join(lines) diff --git a/ACT_DP_multitask/detr/models/mask_former/modeling/transformer/__init__.py b/ACT_DP_multitask/detr/models/mask_former/modeling/transformer/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..9020c2df23e2af280b7bb168b996ae9eaf312eb8 --- /dev/null +++ b/ACT_DP_multitask/detr/models/mask_former/modeling/transformer/__init__.py @@ -0,0 +1 @@ +# Copyright (c) Facebook, Inc. and its affiliates. diff --git a/ACT_DP_multitask/detr/models/mask_former/modeling/transformer/position_encoding.py b/ACT_DP_multitask/detr/models/mask_former/modeling/transformer/position_encoding.py new file mode 100644 index 0000000000000000000000000000000000000000..a189ce20d14f44b6b61b209acd2f62fc1d70814a --- /dev/null +++ b/ACT_DP_multitask/detr/models/mask_former/modeling/transformer/position_encoding.py @@ -0,0 +1,52 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# # Modified by Bowen Cheng from: https://github.com/facebookresearch/detr/blob/master/models/position_encoding.py +""" +Various positional encodings for the transformer. +""" +import math + +import torch +from torch import nn + + +class PositionEmbeddingSine(nn.Module): + """ + This is a more standard version of the position embedding, very similar to the one + used by the Attention is all you need paper, generalized to work on images. + """ + + def __init__(self, num_pos_feats=64, temperature=10000, normalize=False, scale=None): + super().__init__() + self.num_pos_feats = num_pos_feats + self.temperature = temperature + self.normalize = normalize + if scale is not None and normalize is False: + raise ValueError("normalize should be True if scale is passed") + if scale is None: + scale = 2 * math.pi + self.scale = scale + + def forward(self, x, mask=None): + if mask is None: + mask = torch.zeros((x.size(0), x.size(2), x.size(3)), device=x.device, dtype=torch.bool) + not_mask = ~mask + y_embed = not_mask.cumsum(1, dtype=torch.float32) + x_embed = not_mask.cumsum(2, dtype=torch.float32) + if self.normalize: + eps = 1e-6 + y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale + x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale + + dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device) + dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats) + + pos_x = x_embed[:, :, :, None] / dim_t + pos_y = y_embed[:, :, :, None] / dim_t + pos_x = torch.stack( + (pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4 + ).flatten(3) + pos_y = torch.stack( + (pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4 + ).flatten(3) + pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2) + return pos diff --git a/ACT_DP_multitask/detr/models/mask_former/modeling/transformer/transformer.py b/ACT_DP_multitask/detr/models/mask_former/modeling/transformer/transformer.py new file mode 100644 index 0000000000000000000000000000000000000000..ea8caa0108f5e136a9739320ab69a3e1b6f40298 --- /dev/null +++ b/ACT_DP_multitask/detr/models/mask_former/modeling/transformer/transformer.py @@ -0,0 +1,369 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# Modified by Bowen Cheng from: https://github.com/facebookresearch/detr/blob/master/models/transformer.py +""" +Transformer class. + +Copy-paste from torch.nn.Transformer with modifications: + * positional encodings are passed in MHattention + * extra LN at the end of encoder is removed + * decoder returns a stack of activations from all decoding layers +""" +import copy +from typing import List, Optional + +import torch +import torch.nn.functional as F +from torch import Tensor, nn + + +class Transformer(nn.Module): + def __init__( + self, + d_model=512, + nhead=8, + num_encoder_layers=6, + num_decoder_layers=6, + dim_feedforward=2048, + dropout=0.1, + activation="relu", + normalize_before=False, + return_intermediate_dec=False, + ): + super().__init__() + + encoder_layer = TransformerEncoderLayer( + d_model, nhead, dim_feedforward, dropout, activation, normalize_before + ) + encoder_norm = nn.LayerNorm(d_model) if normalize_before else None + self.encoder = TransformerEncoder(encoder_layer, num_encoder_layers, encoder_norm) + + decoder_layer = TransformerDecoderLayer( + d_model, nhead, dim_feedforward, dropout, activation, normalize_before + ) + decoder_norm = nn.LayerNorm(d_model) + self.decoder = TransformerDecoder( + decoder_layer, + num_decoder_layers, + decoder_norm, + return_intermediate=return_intermediate_dec, + ) + + self._reset_parameters() + + self.d_model = d_model + self.nhead = nhead + + def _reset_parameters(self): + for p in self.parameters(): + if p.dim() > 1: + nn.init.xavier_uniform_(p) + + def forward(self, src, mask, query_embed, pos_embed): + # flatten NxCxHxW to HWxNxC + bs, c, h, w = src.shape + src = src.flatten(2).permute(2, 0, 1) + pos_embed = pos_embed.flatten(2).permute(2, 0, 1) + query_embed = query_embed.unsqueeze(1).repeat(1, bs, 1) + if mask is not None: + mask = mask.flatten(1) + + tgt = torch.zeros_like(query_embed) + memory = self.encoder(src, src_key_padding_mask=mask, pos=pos_embed) + hs = self.decoder( + tgt, memory, memory_key_padding_mask=mask, pos=pos_embed, query_pos=query_embed + ) + return hs.transpose(1, 2), memory.permute(1, 2, 0).view(bs, c, h, w) + + +class TransformerEncoder(nn.Module): + def __init__(self, encoder_layer, num_layers, norm=None): + super().__init__() + self.layers = _get_clones(encoder_layer, num_layers) + self.num_layers = num_layers + self.norm = norm + + def forward( + self, + src, + mask: Optional[Tensor] = None, + src_key_padding_mask: Optional[Tensor] = None, + pos: Optional[Tensor] = None, + ): + output = src + + for layer in self.layers: + output = layer( + output, src_mask=mask, src_key_padding_mask=src_key_padding_mask, pos=pos + ) + + if self.norm is not None: + output = self.norm(output) + + return output + + +class TransformerDecoder(nn.Module): + def __init__(self, decoder_layer, num_layers, norm=None, return_intermediate=False): + super().__init__() + self.layers = _get_clones(decoder_layer, num_layers) + self.num_layers = num_layers + self.norm = norm + self.return_intermediate = return_intermediate + + def forward( + self, + tgt, + memory, + tgt_mask: Optional[Tensor] = None, + memory_mask: Optional[Tensor] = None, + tgt_key_padding_mask: Optional[Tensor] = None, + memory_key_padding_mask: Optional[Tensor] = None, + pos: Optional[Tensor] = None, + query_pos: Optional[Tensor] = None, + ): + output = tgt + + intermediate = [] + + for layer in self.layers: + output = layer( + output, + memory, + tgt_mask=tgt_mask, + memory_mask=memory_mask, + tgt_key_padding_mask=tgt_key_padding_mask, + memory_key_padding_mask=memory_key_padding_mask, + pos=pos, + query_pos=query_pos, + ) + if self.return_intermediate: + intermediate.append(self.norm(output)) + + if self.norm is not None: + output = self.norm(output) + if self.return_intermediate: + intermediate.pop() + intermediate.append(output) + + if self.return_intermediate: + return torch.stack(intermediate) + + return output.unsqueeze(0) + + +class TransformerEncoderLayer(nn.Module): + def __init__( + self, + d_model, + nhead, + dim_feedforward=2048, + dropout=0.1, + activation="relu", + normalize_before=False, + ): + super().__init__() + self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout) + # Implementation of Feedforward model + self.linear1 = nn.Linear(d_model, dim_feedforward) + self.dropout = nn.Dropout(dropout) + self.linear2 = nn.Linear(dim_feedforward, d_model) + + self.norm1 = nn.LayerNorm(d_model) + self.norm2 = nn.LayerNorm(d_model) + self.dropout1 = nn.Dropout(dropout) + self.dropout2 = nn.Dropout(dropout) + + self.activation = _get_activation_fn(activation) + self.normalize_before = normalize_before + + def with_pos_embed(self, tensor, pos: Optional[Tensor]): + return tensor if pos is None else tensor + pos + + def forward_post( + self, + src, + src_mask: Optional[Tensor] = None, + src_key_padding_mask: Optional[Tensor] = None, + pos: Optional[Tensor] = None, + ): + q = k = self.with_pos_embed(src, pos) + src2 = self.self_attn( + q, k, value=src, attn_mask=src_mask, key_padding_mask=src_key_padding_mask + )[0] + src = src + self.dropout1(src2) + src = self.norm1(src) + src2 = self.linear2(self.dropout(self.activation(self.linear1(src)))) + src = src + self.dropout2(src2) + src = self.norm2(src) + return src + + def forward_pre( + self, + src, + src_mask: Optional[Tensor] = None, + src_key_padding_mask: Optional[Tensor] = None, + pos: Optional[Tensor] = None, + ): + src2 = self.norm1(src) + q = k = self.with_pos_embed(src2, pos) + src2 = self.self_attn( + q, k, value=src2, attn_mask=src_mask, key_padding_mask=src_key_padding_mask + )[0] + src = src + self.dropout1(src2) + src2 = self.norm2(src) + src2 = self.linear2(self.dropout(self.activation(self.linear1(src2)))) + src = src + self.dropout2(src2) + return src + + def forward( + self, + src, + src_mask: Optional[Tensor] = None, + src_key_padding_mask: Optional[Tensor] = None, + pos: Optional[Tensor] = None, + ): + if self.normalize_before: + return self.forward_pre(src, src_mask, src_key_padding_mask, pos) + return self.forward_post(src, src_mask, src_key_padding_mask, pos) + + +class TransformerDecoderLayer(nn.Module): + def __init__( + self, + d_model, + nhead, + dim_feedforward=2048, + dropout=0.1, + activation="relu", + normalize_before=False, + ): + super().__init__() + self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout) + self.multihead_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout) + # Implementation of Feedforward model + self.linear1 = nn.Linear(d_model, dim_feedforward) + self.dropout = nn.Dropout(dropout) + self.linear2 = nn.Linear(dim_feedforward, d_model) + + self.norm1 = nn.LayerNorm(d_model) + self.norm2 = nn.LayerNorm(d_model) + self.norm3 = nn.LayerNorm(d_model) + self.dropout1 = nn.Dropout(dropout) + self.dropout2 = nn.Dropout(dropout) + self.dropout3 = nn.Dropout(dropout) + + self.activation = _get_activation_fn(activation) + self.normalize_before = normalize_before + + def with_pos_embed(self, tensor, pos: Optional[Tensor]): + return tensor if pos is None else tensor + pos + + def forward_post( + self, + tgt, + memory, + tgt_mask: Optional[Tensor] = None, + memory_mask: Optional[Tensor] = None, + tgt_key_padding_mask: Optional[Tensor] = None, + memory_key_padding_mask: Optional[Tensor] = None, + pos: Optional[Tensor] = None, + query_pos: Optional[Tensor] = None, + ): + q = k = self.with_pos_embed(tgt, query_pos) + tgt2 = self.self_attn( + q, k, value=tgt, attn_mask=tgt_mask, key_padding_mask=tgt_key_padding_mask + )[0] + tgt = tgt + self.dropout1(tgt2) + tgt = self.norm1(tgt) + tgt2 = self.multihead_attn( + query=self.with_pos_embed(tgt, query_pos), + key=self.with_pos_embed(memory, pos), + value=memory, + attn_mask=memory_mask, + key_padding_mask=memory_key_padding_mask, + )[0] + tgt = tgt + self.dropout2(tgt2) + tgt = self.norm2(tgt) + tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt)))) + tgt = tgt + self.dropout3(tgt2) + tgt = self.norm3(tgt) + return tgt + + def forward_pre( + self, + tgt, + memory, + tgt_mask: Optional[Tensor] = None, + memory_mask: Optional[Tensor] = None, + tgt_key_padding_mask: Optional[Tensor] = None, + memory_key_padding_mask: Optional[Tensor] = None, + pos: Optional[Tensor] = None, + query_pos: Optional[Tensor] = None, + ): + tgt2 = self.norm1(tgt) + q = k = self.with_pos_embed(tgt2, query_pos) + tgt2 = self.self_attn( + q, k, value=tgt2, attn_mask=tgt_mask, key_padding_mask=tgt_key_padding_mask + )[0] + tgt = tgt + self.dropout1(tgt2) + tgt2 = self.norm2(tgt) + tgt2 = self.multihead_attn( + query=self.with_pos_embed(tgt2, query_pos), + key=self.with_pos_embed(memory, pos), + value=memory, + attn_mask=memory_mask, + key_padding_mask=memory_key_padding_mask, + )[0] + tgt = tgt + self.dropout2(tgt2) + tgt2 = self.norm3(tgt) + tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt2)))) + tgt = tgt + self.dropout3(tgt2) + return tgt + + def forward( + self, + tgt, + memory, + tgt_mask: Optional[Tensor] = None, + memory_mask: Optional[Tensor] = None, + tgt_key_padding_mask: Optional[Tensor] = None, + memory_key_padding_mask: Optional[Tensor] = None, + pos: Optional[Tensor] = None, + query_pos: Optional[Tensor] = None, + ): + if self.normalize_before: + return self.forward_pre( + tgt, + memory, + tgt_mask, + memory_mask, + tgt_key_padding_mask, + memory_key_padding_mask, + pos, + query_pos, + ) + return self.forward_post( + tgt, + memory, + tgt_mask, + memory_mask, + tgt_key_padding_mask, + memory_key_padding_mask, + pos, + query_pos, + ) + + +def _get_clones(module, N): + return nn.ModuleList([copy.deepcopy(module) for i in range(N)]) + + +def _get_activation_fn(activation): + """Return an activation function given a string""" + if activation == "relu": + return F.relu + if activation == "gelu": + return F.gelu + if activation == "glu": + return F.glu + raise RuntimeError(f"activation should be relu/gelu, not {activation}.") diff --git a/ACT_DP_multitask/detr/models/mask_former/modeling/transformer/transformer_predictor.py b/ACT_DP_multitask/detr/models/mask_former/modeling/transformer/transformer_predictor.py new file mode 100644 index 0000000000000000000000000000000000000000..45d04451444dc5532d60dd7a6b6072ba5b87357f --- /dev/null +++ b/ACT_DP_multitask/detr/models/mask_former/modeling/transformer/transformer_predictor.py @@ -0,0 +1,171 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# Modified by Bowen Cheng from: https://github.com/facebookresearch/detr/blob/master/models/detr.py +import fvcore.nn.weight_init as weight_init +import torch +from torch import nn +from torch.nn import functional as F + +from detectron2.config import configurable +from detectron2.layers import Conv2d + +from .position_encoding import PositionEmbeddingSine +from .transformer import Transformer + + +class TransformerPredictor(nn.Module): + @configurable + def __init__( + self, + in_channels, + mask_classification=True, + *, + num_classes: int, + hidden_dim: int, + num_queries: int, + nheads: int, + dropout: float, + dim_feedforward: int, + enc_layers: int, + dec_layers: int, + pre_norm: bool, + deep_supervision: bool, + mask_dim: int, + enforce_input_project: bool, + ): + """ + NOTE: this interface is experimental. + Args: + in_channels: channels of the input features + mask_classification: whether to add mask classifier or not + num_classes: number of classes + hidden_dim: Transformer feature dimension + num_queries: number of queries + nheads: number of heads + dropout: dropout in Transformer + dim_feedforward: feature dimension in feedforward network + enc_layers: number of Transformer encoder layers + dec_layers: number of Transformer decoder layers + pre_norm: whether to use pre-LayerNorm or not + deep_supervision: whether to add supervision to every decoder layers + mask_dim: mask feature dimension + enforce_input_project: add input project 1x1 conv even if input + channels and hidden dim is identical + """ + super().__init__() + + self.mask_classification = mask_classification + + # positional encoding + N_steps = hidden_dim // 2 + self.pe_layer = PositionEmbeddingSine(N_steps, normalize=True) + + transformer = Transformer( + d_model=hidden_dim, + dropout=dropout, + nhead=nheads, + dim_feedforward=dim_feedforward, + num_encoder_layers=enc_layers, + num_decoder_layers=dec_layers, + normalize_before=pre_norm, + return_intermediate_dec=deep_supervision, + ) + + self.num_queries = num_queries + self.transformer = transformer + hidden_dim = transformer.d_model + + self.query_embed = nn.Embedding(num_queries, hidden_dim) + + if in_channels != hidden_dim or enforce_input_project: + self.input_proj = Conv2d(in_channels, hidden_dim, kernel_size=1) + weight_init.c2_xavier_fill(self.input_proj) + else: + self.input_proj = nn.Sequential() + self.aux_loss = deep_supervision + + # output FFNs + if self.mask_classification: + self.class_embed = nn.Linear(hidden_dim, num_classes + 1) + self.mask_embed = MLP(hidden_dim, hidden_dim, mask_dim, 3) + + @classmethod + def from_config(cls, cfg, in_channels, mask_classification): + ret = {} + ret["in_channels"] = in_channels + ret["mask_classification"] = mask_classification + + ret["num_classes"] = cfg.MODEL.SEM_SEG_HEAD.NUM_CLASSES + ret["hidden_dim"] = cfg.MODEL.MASK_FORMER.HIDDEN_DIM + ret["num_queries"] = cfg.MODEL.MASK_FORMER.NUM_OBJECT_QUERIES + # Transformer parameters: + ret["nheads"] = cfg.MODEL.MASK_FORMER.NHEADS + ret["dropout"] = cfg.MODEL.MASK_FORMER.DROPOUT + ret["dim_feedforward"] = cfg.MODEL.MASK_FORMER.DIM_FEEDFORWARD + ret["enc_layers"] = cfg.MODEL.MASK_FORMER.ENC_LAYERS + ret["dec_layers"] = cfg.MODEL.MASK_FORMER.DEC_LAYERS + ret["pre_norm"] = cfg.MODEL.MASK_FORMER.PRE_NORM + ret["deep_supervision"] = cfg.MODEL.MASK_FORMER.DEEP_SUPERVISION + ret["enforce_input_project"] = cfg.MODEL.MASK_FORMER.ENFORCE_INPUT_PROJ + + ret["mask_dim"] = cfg.MODEL.SEM_SEG_HEAD.MASK_DIM + + return ret + + def forward(self, x, mask_features): + pos = self.pe_layer(x) + + src = x + mask = None + hs, memory = self.transformer(self.input_proj(src), mask, self.query_embed.weight, pos) + + if self.mask_classification: + outputs_class = self.class_embed(hs) + out = {"pred_logits": outputs_class[-1]} + else: + out = {} + + if self.aux_loss: + # [l, bs, queries, embed] + mask_embed = self.mask_embed(hs) + outputs_seg_masks = torch.einsum("lbqc,bchw->lbqhw", mask_embed, mask_features) + out["pred_masks"] = outputs_seg_masks[-1] + out["aux_outputs"] = self._set_aux_loss( + outputs_class if self.mask_classification else None, outputs_seg_masks + ) + else: + # FIXME h_boxes takes the last one computed, keep this in mind + # [bs, queries, embed] + mask_embed = self.mask_embed(hs[-1]) + outputs_seg_masks = torch.einsum("bqc,bchw->bqhw", mask_embed, mask_features) + out["pred_masks"] = outputs_seg_masks + return out + + @torch.jit.unused + def _set_aux_loss(self, outputs_class, outputs_seg_masks): + # this is a workaround to make torchscript happy, as torchscript + # doesn't support dictionary with non-homogeneous values, such + # as a dict having both a Tensor and a list. + if self.mask_classification: + return [ + {"pred_logits": a, "pred_masks": b} + for a, b in zip(outputs_class[:-1], outputs_seg_masks[:-1]) + ] + else: + return [{"pred_masks": b} for b in outputs_seg_masks[:-1]] + + +class MLP(nn.Module): + """Very simple multi-layer perceptron (also called FFN)""" + + def __init__(self, input_dim, hidden_dim, output_dim, num_layers): + super().__init__() + self.num_layers = num_layers + h = [hidden_dim] * (num_layers - 1) + self.layers = nn.ModuleList( + nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim]) + ) + + def forward(self, x): + for i, layer in enumerate(self.layers): + x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x) + return x diff --git a/ACT_DP_multitask/detr/models/mask_former/test_time_augmentation.py b/ACT_DP_multitask/detr/models/mask_former/test_time_augmentation.py new file mode 100644 index 0000000000000000000000000000000000000000..8d250b6bb7792b54ddeaaab62cc6c170d74d3bb9 --- /dev/null +++ b/ACT_DP_multitask/detr/models/mask_former/test_time_augmentation.py @@ -0,0 +1,113 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +import copy +from itertools import count + +import numpy as np +import torch +from fvcore.transforms import HFlipTransform +from torch import nn +from torch.nn.parallel import DistributedDataParallel + +from detectron2.data.detection_utils import read_image +from detectron2.modeling import DatasetMapperTTA + +__all__ = [ + "SemanticSegmentorWithTTA", +] + + +class SemanticSegmentorWithTTA(nn.Module): + """ + A SemanticSegmentor with test-time augmentation enabled. + Its :meth:`__call__` method has the same interface as :meth:`SemanticSegmentor.forward`. + """ + + def __init__(self, cfg, model, tta_mapper=None, batch_size=1): + """ + Args: + cfg (CfgNode): + model (SemanticSegmentor): a SemanticSegmentor to apply TTA on. + tta_mapper (callable): takes a dataset dict and returns a list of + augmented versions of the dataset dict. Defaults to + `DatasetMapperTTA(cfg)`. + batch_size (int): batch the augmented images into this batch size for inference. + """ + super().__init__() + if isinstance(model, DistributedDataParallel): + model = model.module + self.cfg = cfg.clone() + + self.model = model + + if tta_mapper is None: + tta_mapper = DatasetMapperTTA(cfg) + self.tta_mapper = tta_mapper + self.batch_size = batch_size + + def _batch_inference(self, batched_inputs): + """ + Execute inference on a list of inputs, + using batch size = self.batch_size, instead of the length of the list. + Inputs & outputs have the same format as :meth:`SemanticSegmentor.forward` + """ + outputs = [] + inputs = [] + for idx, input in zip(count(), batched_inputs): + inputs.append(input) + if len(inputs) == self.batch_size or idx == len(batched_inputs) - 1: + with torch.no_grad(): + outputs.extend(self.model(inputs)) + inputs = [] + return outputs + + def __call__(self, batched_inputs): + """ + Same input/output format as :meth:`SemanticSegmentor.forward` + """ + + def _maybe_read_image(dataset_dict): + ret = copy.copy(dataset_dict) + if "image" not in ret: + image = read_image(ret.pop("file_name"), self.model.input_format) + image = torch.from_numpy(np.ascontiguousarray(image.transpose(2, 0, 1))) # CHW + ret["image"] = image + if "height" not in ret and "width" not in ret: + ret["height"] = image.shape[1] + ret["width"] = image.shape[2] + return ret + + return [self._inference_one_image(_maybe_read_image(x)) for x in batched_inputs] + + def _inference_one_image(self, input): + """ + Args: + input (dict): one dataset dict with "image" field being a CHW tensor + Returns: + dict: one output dict + """ + augmented_inputs, tfms = self._get_augmented_inputs(input) + # 1: forward with all augmented images + outputs = self._batch_inference(augmented_inputs) + # Delete now useless variables to avoid being out of memory + del augmented_inputs + # 2: merge the results + # handle flip specially + new_outputs = [] + for output, tfm in zip(outputs, tfms): + if any(isinstance(t, HFlipTransform) for t in tfm.transforms): + new_outputs.append(output.pop("sem_seg").flip(dims=[2])) + else: + new_outputs.append(output.pop("sem_seg")) + del outputs + # to avoid OOM with torch.stack + final_predictions = new_outputs[0] + for i in range(1, len(new_outputs)): + final_predictions += new_outputs[i] + final_predictions = final_predictions / len(new_outputs) + del new_outputs + return {"sem_seg": final_predictions} + + def _get_augmented_inputs(self, input): + augmented_inputs = self.tta_mapper(input) + tfms = [x.pop("transforms") for x in augmented_inputs] + return augmented_inputs, tfms diff --git a/ACT_DP_multitask/detr/models/mask_former/utils/__init__.py b/ACT_DP_multitask/detr/models/mask_former/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..9020c2df23e2af280b7bb168b996ae9eaf312eb8 --- /dev/null +++ b/ACT_DP_multitask/detr/models/mask_former/utils/__init__.py @@ -0,0 +1 @@ +# Copyright (c) Facebook, Inc. and its affiliates. diff --git a/ACT_DP_multitask/detr/models/mask_former/utils/misc.py b/ACT_DP_multitask/detr/models/mask_former/utils/misc.py new file mode 100644 index 0000000000000000000000000000000000000000..874d9805b482f52bbffc1be620e36e0cffc07c46 --- /dev/null +++ b/ACT_DP_multitask/detr/models/mask_former/utils/misc.py @@ -0,0 +1,111 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# Modified by Bowen Cheng from https://github.com/facebookresearch/detr/blob/master/util/misc.py +""" +Misc functions, including distributed helpers. + +Mostly copy-paste from torchvision references. +""" +from typing import List, Optional + +import torch +import torch.distributed as dist +import torchvision +from torch import Tensor + + +def _max_by_axis(the_list): + # type: (List[List[int]]) -> List[int] + maxes = the_list[0] + for sublist in the_list[1:]: + for index, item in enumerate(sublist): + maxes[index] = max(maxes[index], item) + return maxes + + +class NestedTensor(object): + def __init__(self, tensors, mask: Optional[Tensor]): + self.tensors = tensors + self.mask = mask + + def to(self, device): + # type: (Device) -> NestedTensor # noqa + cast_tensor = self.tensors.to(device) + mask = self.mask + if mask is not None: + assert mask is not None + cast_mask = mask.to(device) + else: + cast_mask = None + return NestedTensor(cast_tensor, cast_mask) + + def decompose(self): + return self.tensors, self.mask + + def __repr__(self): + return str(self.tensors) + + +def nested_tensor_from_tensor_list(tensor_list: List[Tensor]): + # TODO make this more general + if tensor_list[0].ndim == 3: + if torchvision._is_tracing(): + # nested_tensor_from_tensor_list() does not export well to ONNX + # call _onnx_nested_tensor_from_tensor_list() instead + return _onnx_nested_tensor_from_tensor_list(tensor_list) + + # TODO make it support different-sized images + max_size = _max_by_axis([list(img.shape) for img in tensor_list]) + # min_size = tuple(min(s) for s in zip(*[img.shape for img in tensor_list])) + batch_shape = [len(tensor_list)] + max_size + b, c, h, w = batch_shape + dtype = tensor_list[0].dtype + device = tensor_list[0].device + tensor = torch.zeros(batch_shape, dtype=dtype, device=device) + mask = torch.ones((b, h, w), dtype=torch.bool, device=device) + for img, pad_img, m in zip(tensor_list, tensor, mask): + pad_img[: img.shape[0], : img.shape[1], : img.shape[2]].copy_(img) + m[: img.shape[1], : img.shape[2]] = False + else: + raise ValueError("not supported") + return NestedTensor(tensor, mask) + + +# _onnx_nested_tensor_from_tensor_list() is an implementation of +# nested_tensor_from_tensor_list() that is supported by ONNX tracing. +@torch.jit.unused +def _onnx_nested_tensor_from_tensor_list(tensor_list: List[Tensor]) -> NestedTensor: + max_size = [] + for i in range(tensor_list[0].dim()): + max_size_i = torch.max( + torch.stack([img.shape[i] for img in tensor_list]).to(torch.float32) + ).to(torch.int64) + max_size.append(max_size_i) + max_size = tuple(max_size) + + # work around for + # pad_img[: img.shape[0], : img.shape[1], : img.shape[2]].copy_(img) + # m[: img.shape[1], :img.shape[2]] = False + # which is not yet supported in onnx + padded_imgs = [] + padded_masks = [] + for img in tensor_list: + padding = [(s1 - s2) for s1, s2 in zip(max_size, tuple(img.shape))] + padded_img = torch.nn.functional.pad(img, (0, padding[2], 0, padding[1], 0, padding[0])) + padded_imgs.append(padded_img) + + m = torch.zeros_like(img[0], dtype=torch.int, device=img.device) + padded_mask = torch.nn.functional.pad(m, (0, padding[2], 0, padding[1]), "constant", 1) + padded_masks.append(padded_mask.to(torch.bool)) + + tensor = torch.stack(padded_imgs) + mask = torch.stack(padded_masks) + + return NestedTensor(tensor, mask=mask) + + +def is_dist_avail_and_initialized(): + if not dist.is_available(): + return False + if not dist.is_initialized(): + return False + return True diff --git a/ACT_DP_multitask/detr/models/mr_mg/LICENSE b/ACT_DP_multitask/detr/models/mr_mg/LICENSE new file mode 100644 index 0000000000000000000000000000000000000000..8c4d49b8628acece7ddcf25a6cc8ea24745a827c --- /dev/null +++ b/ACT_DP_multitask/detr/models/mr_mg/LICENSE @@ -0,0 +1,226 @@ + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "{}" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright {yyyy} {name of copyright owner} + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + +======================================================================= +Apache SkyWalking Subcomponents: + +The Apache SkyWalking project contains subcomponents with separate copyright +notices and license terms. Your use of the source code for the these +subcomponents is subject to the terms and conditions of the following +licenses. + +======================================================================== +Apache 2.0 licenses +======================================================================== + +The following components are provided under the Apache License. See project link for details. +The text of each license is the standard Apache 2.0 license. + + proto files from cncf/udpa: https://github.com/cncf/udpa Apache 2.0 + proto files from envoyproxy/data-plane-api: https://github.com/envoyproxy/data-plane-api Apache 2.0 + proto files from prometheus/client_model: https://github.com/prometheus/client_model Apache 2.0 + proto files from opentelemetry: https://github.com/open-telemetry/opentelemetry-proto/tree/main/opentelemetry/proto Apache 2.0 + proto files from opentelemetry: https://github.com/open-telemetry/opentelemetry-proto/tree/v0.7.0 Apache 2.0 + flatbuffers files from istio/proxy: https://github.com/istio/proxy Apache 2.0 + mvnw files from https://github.com/apache/maven-wrapper Apache 2.0 + svg files from skywalking-ui/src/assets/icons: https://github.com/google/material-design-icons Apache 2.0 + ZipkinQueryHandler.java reference from zipkin2.server.internal.ZipkinQueryApiV2: https://github.com/openzipkin/zipkin Apache 2.0 \ No newline at end of file diff --git a/ACT_DP_multitask/detr/models/mr_mg/README.md b/ACT_DP_multitask/detr/models/mr_mg/README.md new file mode 100644 index 0000000000000000000000000000000000000000..f5109c25686204f03e545d0655aeecee62f228c1 --- /dev/null +++ b/ACT_DP_multitask/detr/models/mr_mg/README.md @@ -0,0 +1,98 @@ +

GR-MG

+ +This repo contains code for the paper: +### Leveraging Partially Annotated Data via Multi-Modal Goal Conditioned Policy + +[Peiyan Li](https://github.com/LPY1219), [Hongtao Wu\*‡](https://scholar.google.com/citations?hl=zh-CN&user=7u0TYgIAAAAJ&view_op=list_works&sortby=pubdate), [Yan Huang\*](https://yanrockhuang.github.io/), [Chilam Cheang](https://github.com/bytedance/GR-MG/tree/main), [Liang Wang](https://scholar.google.com/citations?hl=zh-CN&user=8kzzUboAAAAJ&view_op=list_works&sortby=pubdate), [Tao Kong](https://www.taokong.org/) + +*Corresponding author ‡ Project lead + +### [🌠Project Website](https://gr-mg.github.io/) | [📄 Paper](https://arxiv.org/abs/2408.14368) + + +

+ Model Gif +

+ + +## News +- (🔥 New) **(2024.08.27)** We have released the code and checkpoints of GR-MG ! +## Preparation +**Note:** We only test GR-MG with CUDA 12.1 and python 3.9 + +```bash +# clone this repository +git clone https://github.com/bytedance/GR-MG.git +cd GR_MG +# install dependencies for goal image generation model +bash ./goal_gen/install.sh +# install dependencies for multi-modal goal conditioned policy +bash ./policy/install.sh +``` +Download the pretrained [InstructPix2Pix](https://huggingface.co/timbrooks/instruct-pix2pix) weights from Huggingface and save them in `resources/IP2P/`. +Download the pretrained MAE encoder [mae_pretrain_vit_base.pth ](https://dl.fbaipublicfiles.com/mae/pretrain/mae_pretrain_vit_base.pth) and save it in `resources/MAE/`. +Download and unzip the [CALVIN](https://github.com/mees/calvin) dataset. + + +## Checkpoints +- [Multi-modal Goal Conditioned Policy](https://lf-robot-opensource.bytetos.com/obj/lab-robot-public/gr_mg_release/epoch=47-step=83712.ckpt) +- [Goal Image Generation Model](https://lf-robot-opensource.bytetos.com/obj/lab-robot-public/gr_mg_release/goal_gen.ckpt) + + +## Training + +### 1. Train Goal Image Generation Model +```bash +# modify the variables in the script before you execute the following instruction +bash ./goal_gen/train_ip2p.sh ./goal_gen/config/train.json +``` +### 2. Pretrain Multi-modal Goal Conditioned Policy +We use the method described in [GR-1](https://arxiv.org/abs/2312.13139) and pretrain our policy with Ego4D videos. You can download the pretrained model checkpoint [here](https://lf-robot-opensource.bytetos.com/obj/lab-robot-public/gr_mg_release/pretrained.pt). You can also pretrain the policy yourself using the scripts we provide. Before doing this, you'll need to download the [Ego4D](https://ego4d-data.org/) dataset. + +```bash +# pretrain multi-modal goal conditioned policy +bash ./policy/main.sh ./policy/config/pretrain.json +``` +### 3. Train Multi-modal Goal Conditioned Policy +After pretraining, modify the pretrained_model_path in `/policy/config/train.json` and execute the following instruction to train the policy. +```bash +# train multi-modal goal conditioned policy +bash ./policy/main.sh ./policy/config/train.json +``` + + +## Evaluation +To evaluate our model on CALVIN, you can execute the following instruction: +```bash +# Evaluate GR-MG on CALVIN +bash ./evaluate/eval.sh ./policy/config/train.json +``` +In the `eval.sh` script, you can specify which goal image generation model and policy to use. Additionally, we provide multi-GPU evaluation code, allowing you to evaluate different training epochs of the policy simultaneously. + + +## Contact +If you have any questions about the project, please contact peiyan.li@cripac.ia.ac.cn. + + +## Acknowledgements + +We thank the authors of the following projects for making their code and dataset open source: + +- [CALVIN](https://github.com/mees/calvin) +- [InstructPix2Pix](https://github.com/timothybrooks/instruct-pix2pix) +- [T5](https://github.com/google-research/text-to-text-transfer-transformer) +- [GR-1](https://github.com/bytedance/GR-1) +- [CLIP](https://github.com/openai/CLIP) +- [MAE](https://github.com/facebookresearch/mae) + +## Citation + +If you find this project useful, please star the repository and cite our paper: +``` +@article{li2024gr, + title={GR-MG: Leveraging Partially Annotated Data via Multi-Modal Goal Conditioned Policy}, + author={Li, Peiyan and Wu, Hongtao and Huang, Yan and Cheang, Chilam and Wang, Liang and Kong, Tao}, + journal={arXiv preprint arXiv:2408.14368}, + year={2024} +} +``` diff --git a/ACT_DP_multitask/detr/models/mr_mg/evaluate/eval.py b/ACT_DP_multitask/detr/models/mr_mg/evaluate/eval.py new file mode 100644 index 0000000000000000000000000000000000000000..8f2c458dfd3f817f03b60e4e5ce8ce8391fc3570 --- /dev/null +++ b/ACT_DP_multitask/detr/models/mr_mg/evaluate/eval.py @@ -0,0 +1,220 @@ +# MIT License + +# Copyright (c) 2021 Oier Mees +# Copyright (c) 2024 Bytedance Ltd. and/or its affiliates + +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: + +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. + +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. +import argparse +import json +import logging +import os +from pathlib import Path +import sys +import time +import re +import copy +from copy import deepcopy +import os +# This is for using the locally installed repo clone when using slurm +import matplotlib.pyplot as plt +sys.path.insert(0, Path(__file__).absolute().parents[2].as_posix()) +from calvin_agent.evaluation.multistep_sequences import get_sequences +from calvin_agent.evaluation.utils import ( + count_success, + get_env_state_for_initial_condition +) +import hydra +import numpy as np +from omegaconf import OmegaConf +from pytorch_lightning import seed_everything +from termcolor import colored +import torch +from tqdm.auto import tqdm +from utils.utils import print_and_save +from wrapper.model_wrapper import CustomModel +from goal_gen.evaluate import IP2PEvaluation + +logger = logging.getLogger(__name__) +EP_LEN = 360 +NUM_SEQUENCES = 1000 +SAVE_DIR = None +FAIL_COUNTER=0 + +def make_env(dataset_path, observation_space, device_id): + val_folder = Path(dataset_path) / "validation" + # insert your own env wrapper + from wrapper.calvin_env_wrapper_raw import CalvinEnvWrapperRaw + device = torch.device('cuda', device_id) + env = CalvinEnvWrapperRaw(val_folder, observation_space, device) + return env + + +def evaluate_policy(model, env, eval_sr_path, eval_result_path, ip2p_model): + """Run this function to evaluate a model on the CALVIN challenge.""" + conf_dir = Path("./calvin/calvin_models/conf") + task_cfg = OmegaConf.load(conf_dir / "callbacks/rollout/tasks/new_playtable_tasks.yaml") + task_oracle = hydra.utils.instantiate(task_cfg) + val_annotations = OmegaConf.load(conf_dir / "annotations/new_playtable_validation.yaml") + eval_sequences = get_sequences(NUM_SEQUENCES) + results = [] + sequence_i = 0 + for index,(initial_state, eval_sequence) in enumerate(eval_sequences): + result= evaluate_sequence(env, model, task_oracle, initial_state, eval_sequence, val_annotations, sequence_i,ip2p_model) + results.append(result) + success_list = count_success(results) + with open(eval_sr_path, 'a') as f: + line =f"{sequence_i}/{NUM_SEQUENCES}: " + for sr in success_list: + line += f"{sr:.3f} | " + sequence_i += 1 + line += "\n" + f.write(line) + + if index%100==0 and index!=0: #save every 100 sequences + print_and_save(results, eval_sequences[:index+1], eval_result_path[:-5]+f"_{index+1}"+".json", None) + print_and_save(results, eval_sequences, eval_result_path, None) + return results + + +def evaluate_sequence(env, model, task_checker, initial_state, eval_sequence, val_annotations, sequence_i,ip2p_model): + """Evaluates a sequence of language instructions.""" + robot_obs, scene_obs = get_env_state_for_initial_condition(initial_state) + env.reset(robot_obs=robot_obs, scene_obs=scene_obs) + success_counter = 0 + for subtask_i, subtask in enumerate(eval_sequence): + success = rollout(env, model, task_checker, subtask, val_annotations, subtask_i, sequence_i,ip2p_model) + if success: + success_counter += 1 + else: + return success_counter + return success_counter + +def rollout(env, model, task_oracle, subtask, val_annotations, subtask_i, sequence_i,ip2p_model): + """Run the actual rollout on one subtask.""" + obs = env.get_obs() + # get lang annotation for subtask + lang_annotation = val_annotations[subtask][0] + model.reset() + start_info = env.get_info() + debug_image=[] + progress=0 + for i in range(EP_LEN): + if i % 20 == 0: # hardcode + static_rgb = obs['rgb_obs']['rgb_static'] # (200, 200, 3) + hand_rgb = obs['rgb_obs']['rgb_gripper'] + image_patch=[static_rgb] + text_patch=[lang_annotation + f".And {progress}% of the instruction has been finished."] + print(text_patch) + goal_image=ip2p_model.inference(image_patch,text_patch) + temp_image=[static_rgb,goal_image[0],hand_rgb] + debug_image.append(temp_image) + + action,progress = model.step(obs,deepcopy(goal_image),[lang_annotation]) + obs, _, _, current_info = env.step(action) + + # check if current step solves a task + current_task_info = task_oracle.get_task_info_for_set(start_info, current_info, {subtask}) + if len(current_task_info) > 0: + print("success!") + return True + print("fail!") + + global FAIL_COUNTER + FAIL_COUNTER+=1 + if FAIL_COUNTER % 30 ==0: # save every 30 failure cases + length=len(debug_image) + fig, ax = plt.subplots(length, 2,figsize=(5.5, 46.58)) + for ax_ in ax.flat: + # ax_.plot([1, 2, 3], [4, 5, 6]) + ax_.axis('off') # éšè—æ¯ä¸ªå­å›¾çš„刻度和边框 + for i in range(length): + ax[i][0].imshow(debug_image[i][0]) + ax[i][1].imshow(debug_image[i][1]) + # ax[i][2].imshow(debug_image[i][2]) + plt.tight_layout() + plt.axis('off') + plt.savefig(os.path.join(SAVE_DIR, f"{sequence_i}-{subtask_i}-{subtask}.png"),dpi=100) + plt.close() + return False + + +def main(): + seed_everything(0, workers=True) # type:ignore + parser = argparse.ArgumentParser(description="Evaluate a trained model on multistep sequences with language goals.") + parser.add_argument("--dataset_path", default='/mnt/bn/robotics/manipulation_data/calvin_data/task_ABC_D', + type=str, help="Path to the dataset root directory.") # modify it before opensource + # evaluation + parser.add_argument('--config_path', type=str, default="", help='path to the policy config file') + parser.add_argument('--ckpt_dir', type=str, default="",help="path to the policy ckpt file") + parser.add_argument('--epoch', type=int,default=41, help="epoch index for evaluating") + parser.add_argument('--device_id', default=0, type=int, help="CUDA device") + parser.add_argument('--ip2p_ckpt_path', default="", type=str, help="ip2p ckpt path") + args = parser.parse_args() + config_path = args.config_path + ckpt_dir = args.ckpt_dir + epoch = args.epoch + device_id = args.device_id + ip2p_ckpt_path=args.ip2p_ckpt_path + assert config_path != None + # Load config file + with open(config_path, 'r') as f: + configs = json.load(f) + + # Get checkpoint path + ckpt_path = None + ckpt_files = os.listdir(ckpt_dir) + for ckpt_file in ckpt_files: + match = re.search(r'epoch=(\d+)', ckpt_file) + if match: + temp_epoch = int(match.group(1)) + if temp_epoch == epoch: + ckpt_path = os.path.join(ckpt_dir, ckpt_file) + break + + device = torch.device('cuda', device_id) + model = CustomModel( + ckpt_path=ckpt_path, + configs=configs, + device=device) + observation_space = { + 'rgb_obs': ['rgb_static', 'rgb_gripper'], + 'depth_obs': [], + 'state_obs': ['robot_obs'], + 'actions': ['rel_actions'], + 'language': ['language']} + env = make_env(args.dataset_path, observation_space, device_id) + # Success rate and result files + flag="opensourcesd" + sub_dir=f"{flag}_{epoch}_epoch" + # set a global variable + global SAVE_DIR + SAVE_DIR=os.path.join(ckpt_dir,sub_dir) + if not os.path.exists(SAVE_DIR): + os.makedirs(SAVE_DIR) + sr_path = os.path.join(SAVE_DIR, f"success_rate.txt") + result_path = os.path.join(SAVE_DIR, f"results.json") + ip2p_model=IP2PEvaluation(ip2p_ckpt_path) + evaluate_policy( + model, + env, + eval_sr_path=sr_path, + eval_result_path=result_path, + ip2p_model=ip2p_model) +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/ACT_DP_multitask/detr/models/mr_mg/evaluate/eval.sh b/ACT_DP_multitask/detr/models/mr_mg/evaluate/eval.sh new file mode 100644 index 0000000000000000000000000000000000000000..2249b0810d6b69defbe96199be44e15f9892ccd4 --- /dev/null +++ b/ACT_DP_multitask/detr/models/mr_mg/evaluate/eval.sh @@ -0,0 +1,21 @@ +cd /opt/tiger/GR_MG +export EPOCHS=(47) +export CKPT_DIR="PATH_TO_POLICY_DIR" +export SD_CKPT="/PATH_TO_GOAL_GEN_MODEL_CKPT/epoch=49-step=51450.ckpt" +export MESA_GL_VERSION_OVERRIDE=3.3 +echo $EPOCHS +echo $CKPT_DIR +sudo chmod 777 -R ${CKPT_DIR} + +export COUNTER=-1 +# Use a for loop to iterate through a list +for epoch in "${EPOCHS[@]}"; do + export COUNTER=$((${COUNTER} + 1)) + export CUDA_VISIBLE_DEVICES=${COUNTER} + python3 evaluate/eval.py \ + --ckpt_dir ${CKPT_DIR} \ + --epoch ${epoch} \ + --ip2p_ckpt_path ${SD_CKPT} \ + --config_path ${@:1} & +done +wait \ No newline at end of file diff --git a/ACT_DP_multitask/detr/models/mr_mg/evaluate/install.sh b/ACT_DP_multitask/detr/models/mr_mg/evaluate/install.sh new file mode 100644 index 0000000000000000000000000000000000000000..1d61fe29916418dda3fb66332bd229a278dee79d --- /dev/null +++ b/ACT_DP_multitask/detr/models/mr_mg/evaluate/install.sh @@ -0,0 +1,22 @@ +#!/bin/bash +# You should first install the packages needed by policy ang goal image generation model +# install calvin +cd /opt/tiger/GR_MG +git clone --recurse-submodules https://github.com/mees/calvin.git +export CALVIN_ROOT=$(pwd)/calvin +cd calvin +cd calvin_env; git checkout main +cd .. + +cd $CALVIN_ROOT +pip3 install setuptools==57.5.0 +sh install.sh +cd /opt/tiger/GR_MG +export EVALUTION_ROOT=$(pwd) + +# Install dependency for calvin +sudo apt-get -y install libegl1-mesa libegl1 +sudo apt-get -y install libgl1 +sudo apt-get -y install libosmesa6-dev +sudo apt install ffmpeg +sudo apt-get -y install patchelf \ No newline at end of file diff --git a/ACT_DP_multitask/detr/models/mr_mg/evaluate/utils/utils.py b/ACT_DP_multitask/detr/models/mr_mg/evaluate/utils/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..36d6e3e5518eadddb39d96ddf611091446521028 --- /dev/null +++ b/ACT_DP_multitask/detr/models/mr_mg/evaluate/utils/utils.py @@ -0,0 +1,158 @@ +# MIT License + +# Copyright (c) 2021 Oier Mees +# Copyright (c) 2024 Bytedance Ltd. and/or its affiliates + +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: + +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. + +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +from collections import Counter +import json +import math + +import numpy as np + +def count_success(results): + count = Counter(results) + step_success = [] + for i in range(1, 6): + n_success = sum(count[j] for j in reversed(range(i, 6))) + sr = n_success / len(results) + step_success.append(sr) + return step_success + +def print_and_save(results, sequences, eval_result_path, epoch=None): + current_data = {} + print(f"Results for Epoch {epoch}:") + avg_seq_len = np.mean(results) + chain_sr = {i + 1: sr for i, sr in enumerate(count_success(results))} + print(f"Average successful sequence length: {avg_seq_len}") + print("Success rates for i instructions in a row:") + for i, sr in chain_sr.items(): + print(f"{i}: {sr * 100:.1f}%") + + cnt_success = Counter() + cnt_fail = Counter() + + for result, (_, sequence) in zip(results, sequences): + for successful_tasks in sequence[:result]: + cnt_success[successful_tasks] += 1 + if result < len(sequence): + failed_task = sequence[result] + cnt_fail[failed_task] += 1 + + total = cnt_success + cnt_fail + task_info = {} + for task in total: + task_info[task] = {"success": cnt_success[task], "total": total[task]} + print(f"{task}: {cnt_success[task]} / {total[task]} | SR: {cnt_success[task] / total[task] * 100:.1f}%") + + data = {"number of seq": len(results),"avg_seq_len": avg_seq_len, "chain_sr": chain_sr, "task_info": task_info} + + current_data[epoch] = data + + print() + previous_data = {} + json_data = {**previous_data, **current_data} + with open(eval_result_path, "w") as file: + json.dump(json_data, file) + print( + f"Best model: epoch {max(json_data, key=lambda x: json_data[x]['avg_seq_len'])} " + f"with average sequences length of {max(map(lambda x: x['avg_seq_len'], json_data.values()))}" + ) + +def alpha2rotm(a): + """Alpha euler angle to rotation matrix.""" + rotm = np.array([ + [1, 0, 0], + [0, np.cos(a), -np.sin(a)], + [0, np.sin(a), np.cos(a)] + ]) + return rotm + +def beta2rotm(b): + """Beta euler angle to rotation matrix.""" + rotm = np.array([ + [np.cos(b), 0, np.sin(b)], + [0, 1, 0], + [-np.sin(b), 0, np.cos(b)] + ]) + return rotm + +def gamma2rotm(c): + """Gamma euler angle to rotation matrix.""" + rotm = np.array([ + [np.cos(c), -np.sin(c), 0], + [np.sin(c), np.cos(c), 0], + [0, 0, 1] + ]) + return rotm + +def euler2rotm(euler_angles): + """Euler angle (ZYX) to rotation matrix.""" + alpha = euler_angles[0] + beta = euler_angles[1] + gamma = euler_angles[2] + + rotm_a = alpha2rotm(alpha) + rotm_b = beta2rotm(beta) + rotm_c = gamma2rotm(gamma) + + rotm = rotm_c @ rotm_b @ rotm_a + + return rotm + +def isRotm(R): + # Checks if a matrix is a valid rotation matrix. + # Forked from Andy Zeng + Rt = np.transpose(R) + shouldBeIdentity = np.dot(Rt, R) + I = np.identity(3, dtype=R.dtype) + n = np.linalg.norm(I - shouldBeIdentity) + return n < 1e-6 + +def rotm2euler(R): + # Forked from: https://learnopencv.com/rotation-matrix-to-euler-angles/ + # R = Rz * Ry * Rx + assert(isRotm(R)) + sy = math.sqrt(R[0,0] * R[0,0] + R[1,0] * R[1,0]) + singular = sy < 1e-6 + + if not singular : + x = math.atan2(R[2,1] , R[2,2]) + y = math.atan2(-R[2,0], sy) + z = math.atan2(R[1,0], R[0,0]) + else : + x = math.atan2(-R[1,2], R[1,1]) + y = math.atan2(-R[2,0], sy) + z = 0 + + # (-pi , pi] + while x > np.pi: + x -= (2 * np.pi) + while x <= -np.pi: + x += (2 * np.pi) + while y > np.pi: + y -= (2 * np.pi) + while y <= -np.pi: + y += (2 * np.pi) + while z > np.pi: + z -= (2 * np.pi) + while z <= -np.pi: + z += (2 * np.pi) + return np.array([x, y, z]) diff --git a/ACT_DP_multitask/detr/models/mr_mg/evaluate/wrapper/calvin_env_wrapper_raw.py b/ACT_DP_multitask/detr/models/mr_mg/evaluate/wrapper/calvin_env_wrapper_raw.py new file mode 100644 index 0000000000000000000000000000000000000000..31d7093f874d18a8da774aed9033688fd90e1d3e --- /dev/null +++ b/ACT_DP_multitask/detr/models/mr_mg/evaluate/wrapper/calvin_env_wrapper_raw.py @@ -0,0 +1,117 @@ +# MIT License + +# Copyright (c) 2021 Oier Mees +# Copyright (c) 2024 Bytedance Ltd. and/or its affiliates + +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: + +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. + +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. +import logging +import os +from typing import Any, Dict, Tuple, Union +from calvin_agent.datasets.utils.episode_utils import process_depth, process_rgb, process_state +import gym +import numpy as np +import torch +from calvin_env.envs.play_table_env import get_env +from calvin_env.utils.utils import EglDeviceNotFoundError, get_egl_device_id +logger = logging.getLogger(__name__) +class CalvinEnvWrapperRaw(gym.Wrapper): + def __init__(self, abs_datasets_dir, observation_space, device, show_gui=False, **kwargs): + """Environment wrapper which returns raw observations. + + Args: + abs_datasets_dir: absolute datset directory + observation_sapce: {'rgb_obs': ['rgb_static', 'rgb_gripper'], 'depth_obs': [], 'state_obs': ['robot_obs'], 'actions': ['rel_actions'], 'language': ['language']} + """ + # self.set_egl_device(device) + env = get_env( + abs_datasets_dir, show_gui=show_gui, obs_space=observation_space, **kwargs + ) + super(CalvinEnvWrapperRaw, self).__init__(env) + self.observation_space_keys = observation_space + self.device = device + self.relative_actions = "rel_actions" in self.observation_space_keys["actions"] + logger.info(f"Initialized PlayTableEnv for device {self.device}") + + @staticmethod + def set_egl_device(device): + if "EGL_VISIBLE_DEVICES" in os.environ: + logger.warning("Environment variable EGL_VISIBLE_DEVICES is already set. Is this intended?") + cuda_id = device.index if device.type == "cuda" else 0 + try: + egl_id = get_egl_device_id(cuda_id) + except EglDeviceNotFoundError: + logger.warning( + "Couldn't find correct EGL device. Setting EGL_VISIBLE_DEVICE=0. " + "When using DDP with many GPUs this can lead to OOM errors. " + "Did you install PyBullet correctly? Please refer to calvin env README" + ) + egl_id = 0 + os.environ["EGL_VISIBLE_DEVICES"] = str(egl_id) + logger.info(f"EGL_DEVICE_ID {egl_id} <==> CUDA_DEVICE_ID {cuda_id}") + + def step( + self, action_tensor: torch.Tensor + ) -> Tuple[Dict[str, Union[torch.Tensor, Dict[str, torch.Tensor]]], int, bool, Dict]: + if self.relative_actions: + action = action_tensor.squeeze().cpu().detach().numpy() + assert len(action) == 7 + else: + if action_tensor.shape[-1] == 7: + slice_ids = [3, 6] + elif action_tensor.shape[-1] == 8: + slice_ids = [3, 7] + else: + logger.error("actions are required to have length 8 (for euler angles) or 9 (for quaternions)") + raise NotImplementedError + action = np.split(action_tensor.squeeze().cpu().detach().numpy(), slice_ids) + action[-1] = 1 if action[-1] > 0 else -1 + o, r, d, i = self.env.step(action) + + # obs = self.transform_observation(o) + obs = o # use raw observation + return obs, r, d, i + + def reset( + self, + reset_info: Dict[str, Any] = None, + batch_idx: int = 0, + seq_idx: int = 0, + scene_obs: Any = None, + robot_obs: Any = None, + ) -> Dict[str, Union[torch.Tensor, Dict[str, torch.Tensor]]]: + if reset_info is not None: + obs = self.env.reset( + robot_obs=reset_info["robot_obs"][batch_idx, seq_idx], + scene_obs=reset_info["scene_obs"][batch_idx, seq_idx], + ) + elif scene_obs is not None or robot_obs is not None: + obs = self.env.reset(scene_obs=scene_obs, robot_obs=robot_obs) + else: + obs = self.env.reset() + + # return self.transform_observation(obs) + return obs # use raw observation + + def get_info(self): + return self.env.get_info() + + def get_obs(self): + obs = self.env.get_obs() + # return self.transform_observation(obs) + return obs # use raw observation diff --git a/ACT_DP_multitask/detr/models/mr_mg/evaluate/wrapper/model_wrapper.py b/ACT_DP_multitask/detr/models/mr_mg/evaluate/wrapper/model_wrapper.py new file mode 100644 index 0000000000000000000000000000000000000000..6b90fac7ed464ebdc7787444fb6ed3a15eea450c --- /dev/null +++ b/ACT_DP_multitask/detr/models/mr_mg/evaluate/wrapper/model_wrapper.py @@ -0,0 +1,265 @@ +# Copyright (2024) Bytedance Ltd. and/or its affiliates + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from omegaconf import OmegaConf +import numpy as np +import torch +import torchvision.transforms as T +from PIL import Image +import torch.nn.functional as F +import policy.model.vision_transformer as vits +from utils.utils import euler2rotm, rotm2euler +from copy import deepcopy +from calvin_agent.models.calvin_base_model import CalvinBaseModel +from policy.model.model import GR_MG +import time +import clip +GRIPPER_OPEN = 1 +GRIPPER_CLOSE = 0 +class CustomModel(CalvinBaseModel): + def __init__(self, + ckpt_path, + configs, + device): + self.device = device + # model config + self.configs = configs + self.seq_len = configs["policy"]['seq_len'] + self.act_len = configs["policy"]["act_len"] + self.device = device + input_size = (224, 224) + clip_mean = (0.485, 0.456, 0.406) + clip_std = (0.229, 0.224, 0.225) + self.preprocess = T.Compose([ + T.Resize(input_size, interpolation=Image.BICUBIC), + T.Normalize(clip_mean, clip_std)]) + model_mae = vits.__dict__['vit_base'](patch_size=16, num_classes=0) + model_mae.to(self.device) + training_target = [] + if configs["trainer"]['act_pred']: + training_target.append('act_pred') + if configs["trainer"]['fwd_pred']: + training_target.append('fwd_pred') + if configs["trainer"]['fwd_pred_hand']: + training_target.append('fwd_pred_hand') + if configs["trainer"]["progress_pred"]: + training_target.append('progress_pred') + print(f"training target: {training_target}") + + #modify before release + #clip model + clip_name = "ViT-B/32" + clip_model, clip_preprocess = clip.load(clip_name) + # freeze clip + for _, param in clip_model.named_parameters(): + param.requires_grad = False + policy_config = self.configs['policy'] + input_config=self.configs["input"] + trainer_config=self.configs["trainer"] + self.policy = GR_MG( + state_dim=input_config['state_dim'], + act_dim=input_config['act_dim'], + act_len=policy_config['act_len'], + act_latent_dim=policy_config['act_latent_dim'], + act_encoder_dim=policy_config['act_encoder_dim'], + act_decoder_dim=policy_config['act_decoder_dim'], + progress_decoder_dim=self.configs["policy"]["progress_decoder_dim"], + hidden_size=policy_config['embed_dim'], + model_mae=model_mae, + clip_model=clip_model, + img_feat_dim=policy_config["img_feat_dim"], + lang_feat_dim = policy_config["lang_feat_dim"], + patch_feat_dim=policy_config["patch_feat_dim"], + resampler_params=policy_config['resampler_params'], + max_length=policy_config['seq_len'], + training_target=training_target, + without_norm_pix_loss=trainer_config['without_norm_pix_loss'], + use_hand_rgb=input_config['use_hand_rgb'], + use_state=input_config['use_state'], + use_resampler=policy_config['use_resampler'], + n_layer=policy_config['n_layer'], + n_head=policy_config['n_head'], + n_inner=4*policy_config['embed_dim'], + activation_function=policy_config['activation_function'], + n_positions=1024, + resid_pdrop=policy_config['dropout'], + attn_pdrop=policy_config['dropout'], + device=self.device) + + + # Set up the model + payload = torch.load(ckpt_path) + epoch = payload['epoch'] + state_dict = payload['state_dict'] + print(f"loading state dict from epoch {epoch}...") + + del payload + # Remove the prefix "model." from pl models + pure_state_dict = dict() + for k, v in state_dict.items(): + if "model." in k: + new_k = k[6:] + pure_state_dict[new_k] = v + msg = self.policy.load_state_dict(pure_state_dict, strict=True) + print(msg) + self.policy.to(self.device) + self.policy.eval() + + def reset(self): + """Reset function.""" + self.rgb_list = [] + self.hand_rgb_list = [] + self.state_list = [] + self.rollout_step_counter = 0 + + @staticmethod + def compute_rel_state(states): + first_xyz = states[0][0] + first_rotm = states[0][1] + first_gripper = states[0][2] + seq_len = len(states) + arm_states = np.zeros((seq_len, 6)) + gripper_states = np.zeros(seq_len) + gripper_states[0] = first_gripper + for i in range(1, seq_len): + curr_xyz = states[i][0] + curr_rotm = states[i][1] + curr_gripper = states[i][2] + rel_rotm = first_rotm.T @ curr_rotm + rel_xyz = np.dot(first_rotm.T, curr_xyz - first_xyz) + arm_states[i, 0:3] = rel_xyz + arm_states[i, 3:6] = rotm2euler(rel_rotm) + gripper_states[i] = curr_gripper + return arm_states, gripper_states + + def step(self, obs, goal, text): + """Step function.""" + goal_rgb = goal[0] + + rgb = obs['rgb_obs']['rgb_static'] # (200, 200, 3) + hand_rgb = obs['rgb_obs']['rgb_gripper'] + + goal_rgb = Image.fromarray(goal_rgb) + goal_rgb = T.ToTensor()(goal_rgb.convert("RGB")) + goal_rgb = self.preprocess(goal_rgb) # (3, 224, 224) + + rgb = Image.fromarray(rgb) + rgb = T.ToTensor()(rgb.convert("RGB")) + rgb = self.preprocess(rgb) # (3, 224, 224) + self.rgb_list.append(rgb) + + hand_rgb = Image.fromarray(hand_rgb) + hand_rgb = T.ToTensor()(hand_rgb.convert("RGB")) + hand_rgb = self.preprocess(hand_rgb) + self.hand_rgb_list.append(hand_rgb) + + state = obs['robot_obs'] # (15,) + xyz_state = state[:3] + rpy_state = state[3:6] + rotm_state = euler2rotm(rpy_state) + gripper_state = state[-1] + state = (xyz_state, rotm_state, gripper_state) + self.state_list.append(state) + + buffer_len = len(self.rgb_list) + if buffer_len > self.seq_len: + self.rgb_list.pop(0) + self.hand_rgb_list.pop(0) + self.state_list.pop(0) + assert len(self.rgb_list) == self.seq_len + assert len(self.hand_rgb_list) == self.seq_len + assert len(self.state_list) == self.seq_len + buffer_len = len(self.rgb_list) + + + + # Static RGB + c, h, w = rgb.shape + c2,h2,w2=goal_rgb.shape + assert c==c2 and h==h2 and w==w2 + rgb_data = torch.zeros((1, self.seq_len, c, h, w)) + rgb_tensor = torch.stack(self.rgb_list, dim=0) # (len, c, h, w) + rgb_data[0, :buffer_len] = rgb_tensor + goal_rgb_data=torch.zeros((1, c, h, w)) + goal_rgb_data[0]=goal_rgb + + # Hand RGB + c, h, w = hand_rgb.shape + hand_rgb_data = torch.zeros((1, self.seq_len, c, h, w)) + hand_rgb_tensor = torch.stack(self.hand_rgb_list, dim=0) # (len, c, h, w) + hand_rgb_data[0, :buffer_len] = hand_rgb_tensor + + # State + arm_state, gripper_state = CustomModel.compute_rel_state(self.state_list) + arm_state_data = torch.zeros((1, self.seq_len, 6)) + arm_state_tensor = torch.from_numpy(arm_state) + arm_state_data[0, :buffer_len] = arm_state_tensor + gripper_state_tensor = torch.from_numpy(gripper_state) + gripper_state_tensor = (gripper_state_tensor + 1.0) / 2.0 + gripper_state_tensor = gripper_state_tensor.long() + gripper_state_data = torch.zeros((1, self.seq_len)).long() + gripper_state_data[0, :buffer_len] = gripper_state_tensor + gripper_state_data = F.one_hot(gripper_state_data, num_classes=2).type_as(arm_state_data) + + # Attention mask + attention_mask = torch.zeros(1, self.seq_len).long() + attention_mask[0, :buffer_len] = 1 + + # Action placeholder + arm_action_data = torch.zeros((1, self.seq_len, self.configs["policy"]['act_len'], 6)) + gripper_action_data = torch.zeros(1, self.seq_len, self.configs["policy"]['act_len']) + + #progress_placeholder + progress_data=torch.zeros(1, self.seq_len) + + input_dict = dict() + input_dict['rgb'] = rgb_data.to(self.device) + input_dict['hand_rgb'] = hand_rgb_data.to(self.device) + input_dict["goal_rgb"]=goal_rgb_data.to(self.device) + input_dict['arm_state'] = arm_state_data.to(self.device) + input_dict['gripper_state'] = gripper_state_data.to(self.device) + input_dict['arm_action'] = arm_action_data.to(self.device) + input_dict['gripper_action'] = gripper_action_data.to(self.device) + input_dict['attention_mask'] = attention_mask.to(self.device) + input_dict["text"]=[text] + input_dict["progress"]=progress_data + # Forward pass + with torch.no_grad(): + # action,action_traj = self.policy.evaluate(input_dict) + action,progress = self.policy.evaluate(input_dict,return_progress=True) + progress=int(int(progress * 10) *10) + action=action.numpy() + + + + # Action mode: ee_rel_pose_local + state = obs['robot_obs'] # (15,) + xyz_state = state[:3] + rpy_state = state[3:6] + rotm_state = euler2rotm(rpy_state) + rel_action = action + xyz_action = rel_action[:3] / 50 # scale down by 50 + rpy_action = rel_action[3:6] / 20 # scale down by 20 + gripper_action = rel_action[6] + rotm_action = euler2rotm(rpy_action) + xyz_next_state = xyz_state + rotm_state @ xyz_action + rotm_next_state = rotm_state @ rotm_action + rpy_next_state = rotm2euler(rotm_next_state) + action = np.zeros(7) + action[:3] = (xyz_next_state - xyz_state) * 50 + action[3:6] = (rpy_next_state - rpy_state) * 20 + action[-1] = gripper_action + action = torch.from_numpy(action) + self.rollout_step_counter += 1 + + return action,progress diff --git a/ACT_DP_multitask/detr/models/mr_mg/goal_gen/config/train.json b/ACT_DP_multitask/detr/models/mr_mg/goal_gen/config/train.json new file mode 100644 index 0000000000000000000000000000000000000000..b4f0f4536f7dab49c851007eca2e934154febe99 --- /dev/null +++ b/ACT_DP_multitask/detr/models/mr_mg/goal_gen/config/train.json @@ -0,0 +1,37 @@ +{ + "exp_name": "train_goal_gen", + "seed": 123, + "batch_size": 64, + "learning_rate": 8e-5, + "min_lr_scale": 1.0, + "warmup_steps": 800, + "device": "cuda", + "num_workers": 10, + "save_epoch": 1, + "pretrained_model_dir": "PATH_TO/resources/IP2P/instruct-pix2pix", + "ckpt_root": "SAVE_PATH/goal_gen/checkpoints/", + "log_root": "LOG_PATH/goal_gen/logs/", + "resume": null, + "color_aug": false, + + + "conditioning_dropout_prob": 0.05, + "use_ema": true, + "gradient_checkpointing":false, + + "adam_beta1": 0.95, + "adam_beta2": 0.999, + "adam_weight_decay": 1e-2, + "adam_epsilon": 1e-08, + + "trainer": { + "accelerator": "gpu", + "strategy": "ddp", + "precision": "bf16", + "logger": ["tensorboard"], + "use_distributed_sampler": true, + "gradient_clip_val": 0.7, + "log_every_n_steps": 50, + "max_epochs": 50 + } +} diff --git a/ACT_DP_multitask/detr/models/mr_mg/goal_gen/evaluate.py b/ACT_DP_multitask/detr/models/mr_mg/goal_gen/evaluate.py new file mode 100644 index 0000000000000000000000000000000000000000..880c551169a2856c6218f5e09efd5949c8ade489 --- /dev/null +++ b/ACT_DP_multitask/detr/models/mr_mg/goal_gen/evaluate.py @@ -0,0 +1,142 @@ +# Copyright (2024) Bytedance Ltd. and/or its affiliates + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os +import yaml +import numpy as np +import matplotlib.pyplot as plt +from PIL import Image +import torch +import torchvision.transforms as transforms +from transformers import T5Tokenizer, T5EncoderModel +from diffusers import AutoencoderKL, UNet2DConditionModel +from goal_gen.utils.pipeline import Pipeline +from goal_gen.data.calvindataset import CalvinDataset_Goalgen +class IP2PEvaluation(object): + def __init__(self, + ckpt_path, + res=256): + # Init models + pretrained_model_dir = "/mnt/bn/lpy-lq/stable_diffusion/instruct-pix2pix" + + self.tokenizer = T5Tokenizer.from_pretrained("t5-base") + self.text_encoder = T5EncoderModel.from_pretrained("t5-base") + self.vae = AutoencoderKL.from_pretrained( + pretrained_model_dir, subfolder="vae") + self.unet = UNet2DConditionModel.from_pretrained( + pretrained_model_dir, subfolder="unet") + + # Load weight for unet + payload = torch.load(ckpt_path) + state_dict = payload['state_dict'] + del payload + msg = self.unet.load_state_dict(state_dict['unet_ema'], strict=True) + print(msg) + + self.pipe = Pipeline.from_pretrained( + pretrained_model_dir, + text_encoder=self.text_encoder, + tokenizer=self.tokenizer, + vae=self.vae, + unet=self.unet, + revision=None, + variant=None, + torch_dtype=torch.bfloat16 + ).to("cuda") + + self.pipe.safety_checker = None + self.pipe.requires_safety_checker = False + self.generator = torch.Generator("cuda").manual_seed(42) + + # Diffusion hyparams + self.num_inference_steps = 50 + self.image_guidance_scale = 2.5 + self.guidance_scale = 7.5 + + # Image transform + self.res = res + self.transform = transforms.Resize((res, res)) + + def evaluate(self, eval_result_dir, eval_data_dir,is_training): + os.makedirs(eval_result_dir,exist_ok=True) + save_dir=os.path.join(eval_result_dir,"debug.png") + dataset = CalvinDataset_Goalgen( + eval_data_dir, + resolution=256, + resolution_before_crop=288, + center_crop=True, + forward_n_min_max=(20, 22), + is_training=is_training, + use_full=True, + color_aug=False + ) + for i in range(0, len(dataset), 100): + example = dataset[i] + text=example['input_text'] + original_pixel_values = example['original_pixel_values'] + edited_pixel_values = example['edited_pixel_values'] + progress=example["progress"] + + progress=progress*10 + text[0]=text[0]+f".And {progress}% of the instruction has been finished." + print(text[0]) + input_image_batch=[original_pixel_values] + predict_image = self.inference(input_image_batch, text) + + fig, ax = plt.subplots(1,3) + for k in range(3): + original_image = original_pixel_values.permute(1, 2, 0).numpy() + original_image = (original_image + 1) / 2 * 255 + original_image = np.clip(original_image, 0, 255) + original_image = original_image.astype(np.uint8) + ax[0].imshow(original_image) + + + edited_image = edited_pixel_values.permute(1, 2, 0).numpy() + edited_image = (edited_image + 1) / 2 * 255 + edited_image = np.clip(edited_image, 0, 255) + edited_image = edited_image.astype(np.uint8) + ax[1].imshow(edited_image) + + ax[2].imshow(predict_image[0]) + + plt.savefig( save_dir,dpi=300) + plt.close() + + def inference(self, image_batch, text_batch): + """Inference function.""" + input_images = [] + for image in image_batch: + if isinstance(image, np.ndarray): + image=Image.fromarray(image) + input_image = self.transform(image) + input_images.append(input_image) + edited_images = self.pipe( + prompt=text_batch, + image=input_images, + num_inference_steps=self.num_inference_steps, + image_guidance_scale=self.image_guidance_scale, + guidance_scale=self.guidance_scale, + generator=self.generator, + safety_checker=None, + requires_safety_checker=False).images + edited_images=[ np.array(image) for image in edited_images] + + return edited_images + +if __name__ == "__main__": + ckpt_path="PATH_TO_IP2P_CKPT/epoch=49-step=102900.ckpt" + eval = IP2PEvaluation(ckpt_path) + eval_data_dir = "PATH_TO_CALVIN/calvin/task_ABC_D/" + eval_result_dir = "SAVE_DIR" + eval.evaluate(eval_result_dir, eval_data_dir,is_training=False) \ No newline at end of file diff --git a/ACT_DP_multitask/detr/models/mr_mg/goal_gen/install.sh b/ACT_DP_multitask/detr/models/mr_mg/goal_gen/install.sh new file mode 100644 index 0000000000000000000000000000000000000000..5026f369d5f9f7d3db5a50dfd4f6fbef8c43d344 --- /dev/null +++ b/ACT_DP_multitask/detr/models/mr_mg/goal_gen/install.sh @@ -0,0 +1,18 @@ +pip3 install moviepy +pip3 install matplotlib +pip3 install einops +pip3 install diffusers +pip3 install timm +pip3 install matplotlib +pip3 install numpy +pip3 install sentencepiece +pip3 install accelerate +pip3 install transformers +pip3 install datasets +pip3 install ftfy +pip3 install tensorboard +pip3 install flamingo_pytorch +pip install git+https://github.com/openai/CLIP.git +pip3 install torch==2.1.0 torchvision==0.16.0 torchaudio==2.1.0 --index-url https://download.pytorch.org/whl/cu118 +pip3 install lightning==2.1.0 +pip3 install pytorch-lightning==2.1.0 \ No newline at end of file diff --git a/ACT_DP_multitask/detr/models/mr_mg/goal_gen/model/model.py b/ACT_DP_multitask/detr/models/mr_mg/goal_gen/model/model.py new file mode 100644 index 0000000000000000000000000000000000000000..977ed5ef7678b88deb4531dd4feee921d628aeed --- /dev/null +++ b/ACT_DP_multitask/detr/models/mr_mg/goal_gen/model/model.py @@ -0,0 +1,128 @@ +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +#!/usr/bin/env python +# coding=utf-8 +import torch +import torch.nn as nn +from diffusers import AutoencoderKL, DDPMScheduler, UNet2DConditionModel +from diffusers.training_utils import EMAModel +from transformers import T5Tokenizer, T5EncoderModel +import os +# modified from https://github.com/huggingface/diffusers/blob/main/examples/instruct_pix2pix/train_instruct_pix2pix.py +class IP2P(nn.Module): + """InstructPix2Pix model.""" + def __init__(self, + pretrained_model_dir, + device, + seed=123, + conditioning_dropout_prob=None, + gradient_checkpointing=False): + super().__init__() + self.device = device + self.noise_scheduler = DDPMScheduler.from_pretrained( + pretrained_model_dir, subfolder="scheduler") + text_encoder_name = "t5-base" + self.tokenizer = T5Tokenizer.from_pretrained(text_encoder_name) + self.text_encoder = T5EncoderModel.from_pretrained(text_encoder_name) + self.vae = AutoencoderKL.from_pretrained( + pretrained_model_dir, subfolder="vae") + self.unet = UNet2DConditionModel.from_pretrained( + pretrained_model_dir, subfolder="unet") + + # InstructPix2Pix uses an additional image for conditioning. To accommodate that, + # it uses 8 channels (instead of 4) in the first (conv) layer of the UNet. This UNet is + # then fine-tuned on the custom InstructPix2Pix dataset. This modified UNet is initialized + # from the pre-trained checkpoints. For the extra channels added to the first layer, they are + # initialized to zero. + self.in_channels = 8 + self.unet.register_to_config(in_channels=self.in_channels) + # Freeze vae and text_encoder + self.vae.requires_grad_(False) + self.text_encoder.requires_grad_(False) + + + # Conditioning dropout probability used for classifier free guidance + self.conditioning_dropout_prob = conditioning_dropout_prob + + if gradient_checkpointing: + self.unet.enable_gradient_checkpointing() # it will reduce GPU memory but add computing burden + + self.generator = torch.Generator(device=self.device).manual_seed(seed) + + def tokenize_texts(self, texts): + inputs = self.tokenizer(texts, return_tensors="pt", padding='max_length', truncation=True, max_length=77) + return inputs + + def forward(self, input_dict): + original_pixel_values = input_dict['original_pixel_values'] + edited_pixel_values = input_dict['edited_pixel_values'] + input_text = input_dict['input_text'][0] + progress=input_dict["progress"]*10#(b,) + input_text=[ text+f".And {curr_progress}% of the instruction has been finished." for (text,curr_progress) in zip(input_text,progress)] + input_ids=self.tokenize_texts(input_text) + + # We want to learn the denoising process w.r.t the edited images which + # are conditioned on the original image (which was edited) and the edit instruction. + # So, first, convert images to latent space. + latents = self.vae.encode(edited_pixel_values).latent_dist.sample() + latents = latents * self.vae.config.scaling_factor + + # Sample noise that we'll add to the latents + noise = torch.randn_like(latents) + bsz = latents.shape[0] + # Sample a random timestep for each image + timesteps = torch.randint(0, self.noise_scheduler.config.num_train_timesteps, (bsz,)).to(latents.device) + timesteps = timesteps.long() + + # Add noise to the latents according to the noise magnitude at each timestep + # (this is the forward diffusion process) + noisy_latents = self.noise_scheduler.add_noise(latents, noise, timesteps) + + # Get the text embedding for conditioning + encoder_hidden_states = self.text_encoder(**(input_ids.to(self.device))).last_hidden_state + + # Get the additional image embedding for conditioning. + # Instead of getting a diagonal Gaussian here, we simply take the mode. + original_image_embeds = self.vae.encode(original_pixel_values).latent_dist.mode() + + # Conditioning dropout to support classifier-free guidance during inference. + if self.conditioning_dropout_prob is not None: + random_p = torch.rand(bsz, device=latents.device, generator=self.generator) + # Sample masks for the edit prompts. + prompt_mask = random_p < 2 * self.conditioning_dropout_prob + prompt_mask = prompt_mask.reshape(bsz, 1, 1) + # Final text conditioning. + null_conditioning = self.text_encoder(**(self.tokenize_texts([""]).to(self.device))).last_hidden_state + encoder_hidden_states = torch.where(prompt_mask, null_conditioning, encoder_hidden_states) + + # Sample masks for the original images. + image_mask_dtype = original_image_embeds.dtype + image_mask = 1 - ( + (random_p >= self.conditioning_dropout_prob).to(image_mask_dtype) + * (random_p < 3 * self.conditioning_dropout_prob).to(image_mask_dtype) + ) + image_mask = image_mask.reshape(bsz, 1, 1, 1) + # Final image conditioning. + original_image_embeds = image_mask * original_image_embeds + + # Concatenate the `original_image_embeds` with the `noisy_latents`. + concatenated_noisy_latents = torch.cat([noisy_latents, original_image_embeds], dim=1) + + # Get the target for loss depending on the prediction type + target = noise + + # Predict the noise residual and compute loss + prediction = self.unet(concatenated_noisy_latents, timesteps, encoder_hidden_states, return_dict=False)[0] + + return prediction, target diff --git a/ACT_DP_multitask/detr/models/mr_mg/goal_gen/train.py b/ACT_DP_multitask/detr/models/mr_mg/goal_gen/train.py new file mode 100644 index 0000000000000000000000000000000000000000..63ff713af47f7066cdfb7a528c444e505a55f702 --- /dev/null +++ b/ACT_DP_multitask/detr/models/mr_mg/goal_gen/train.py @@ -0,0 +1,212 @@ +# Copyright (2024) Bytedance Ltd. and/or its affiliates + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import argparse +import json +from pathlib import Path +import importlib +import copy +import numpy as np +from lightning.pytorch.callbacks import LearningRateMonitor, ModelCheckpoint +from lightning.pytorch.trainer import Trainer +from lightning.pytorch.loggers import TensorBoardLogger +from lightning.pytorch.strategies import DDPStrategy +from lightning import seed_everything +import torch +import random +from utils.utils import SetupCallback +import datetime +from data.calvindataset import CalvinDataset_Goalgen +from goal_gen.training.trainer import Goalgen_Trainer +from torch.utils.data import DataLoader +def get_date_str(): + return str(datetime.date.today()) + + +def init_setup_callback(config): + return SetupCallback( + now=str(datetime.datetime.now()).replace(' ', '_'), + logdir=config['log_dir'], + ckptdir=config['ckpt_dir'], + ) + +def init_trainer_config(configs): + trainer_config = copy.deepcopy(configs['trainer']) + trainer_config['devices'] = configs.get('gpus', 'auto') + trainer_config['num_nodes'] = configs.get('num_nodes', 1) + if 'strategy' not in trainer_config or trainer_config['strategy'] == 'ddp': + trainer_config['strategy'] = DDPStrategy(find_unused_parameters=False) + exp_name = configs['exp_name'] + + # init loggers + loggers = None + log_dir = os.path.join(get_date_str(), exp_name) + log_dir = os.path.join(configs['log_root'], log_dir) + configs['log_dir'] = log_dir + Path(configs['log_dir']).mkdir(parents=True, exist_ok=True) + loggers = [] + loggers.append(TensorBoardLogger(log_dir, name=exp_name)) # you can also add other loggers + trainer_config['logger'] = loggers + ckpt_dir = os.path.join(get_date_str(), exp_name) + ckpt_dir = os.path.join(configs['ckpt_root'], ckpt_dir) + configs['ckpt_dir'] = ckpt_dir + Path(configs['ckpt_dir']).mkdir(parents=True, exist_ok=True) + trainer_config['callbacks'] = [ + init_setup_callback(configs), + LearningRateMonitor(logging_interval='step'), + ModelCheckpoint(dirpath=ckpt_dir, save_top_k=-1,every_n_epochs=1) + ] + return trainer_config + + +def set_seed(seed=0): + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + seed_everything(seed) + + +def experiment(variant): + set_seed(variant['seed']) + trainer_config = init_trainer_config(variant) + + trainer = Trainer(**trainer_config) + variant['gpus'] = trainer.num_devices + + model = Goalgen_Trainer(variant) + + # dataset + train_data= CalvinDataset_Goalgen( + data_dir="/PATH_TO_CALVIN/calvin/task_ABC_D/", + resolution=256, + resolution_before_crop=288, + center_crop=False, + forward_n_min_max=[20, 22], + use_full=True, + is_training=True, + color_aug=True) + val_data= CalvinDataset_Goalgen( + data_dir="/PATH_TO_CALVIN/calvin/task_ABC_D/", + resolution=256, + resolution_before_crop=288, + center_crop=False, + forward_n_min_max=[20, 22], + use_full = True, + is_training=False, + color_aug=False) + train_dataloader= DataLoader(train_data, + batch_size=variant["batch_size"], + num_workers=variant["num_workers"]) + val_dataloader= DataLoader(val_data, + batch_size=variant["batch_size"], + num_workers=variant["num_workers"]) + + _kwargs = { + 'model': model, + 'train_dataloaders':train_dataloader, + 'val_dataloaders':val_dataloader, + 'ckpt_path': variant['resume'] + } + if _kwargs['ckpt_path'] is not None: + print(f"Resuming from {variant['resume']}...") + trainer.fit(**_kwargs) + +def deep_update(d1, d2): + # use d2 to update d1 + for k, v in d2.items(): + if isinstance(v, dict) and k in d1: + assert isinstance(d1[k], dict) + deep_update(d1[k], d2[k]) + else: + d1[k] = d2[k] + return d1 + +def load_config(config_file): + _config = json.load(open(config_file)) + config = {} + if _config.get('parent', None): + deep_update(config, load_config(_config['parent'])) + deep_update(config, _config) + return config + +def update_configs(configs, args): + for (k, v) in args.items(): + if k not in configs: + print(f"{k} not in config. The value is {v}.") + configs[k] = v + if isinstance(v, dict): + for (sub_k, sub_v) in v.items(): + assert sub_k in configs[k], f"{sub_k} - {configs[k]}" + if sub_v != None: + configs[k][sub_k] = sub_v + else: + if v != None: + configs[k] = v + return configs + +def parse_args(): + parser = argparse.ArgumentParser() + parser.add_argument('--config', type=str,default="") + parser.add_argument('--gpus', default=1, type=int) + parser.add_argument('--num_nodes', default=1, type=int) + parser.add_argument('--seed', default=None, type=int) + parser.add_argument('--log_root', default=None, type=str) + parser.add_argument('--ckpt_root', default=None, type=str) + parser.add_argument('--exp_name', default=None, type=str) + parser.add_argument('--resume', default=None, type=str) + + # Training + parser.add_argument('--batch_size', default=None, type=int) + parser.add_argument('--learning_rate', default=None, type=float) + parser.add_argument('--min_lr_scale', default=None, type=float) + parser.add_argument('--warmup_steps', default=None, type=int) + parser.add_argument('--adam_weight_decay', default=None, type=float) + parser.add_argument('--adam_beta1', default=None, type=float) + parser.add_argument('--adam_beta2', default=None, type=float) + parser.add_argument('--adam_epsilon', default=None, type=float) + + # Diffusion + parser.add_argument("--conditioning_dropout_prob", default=None, type=float) + global_names = set(vars(parser.parse_known_args()[0]).keys()) + + # Trainer + trainer_parser = parser.add_argument_group('trainer') + trainer_parser.add_argument('--strategy', default=None, type=str) + trainer_parser.add_argument('--precision', default=None, type=str) + trainer_parser.add_argument('--gradient_clip_val', default=None, type=float) + trainer_parser.add_argument('--max_epochs', default=None, type=int) + trainer_names = set(vars(parser.parse_known_args()[0]).keys()) - global_names + + args = {} + trainer_args = {} + temp_args = vars(parser.parse_args()) + for (k, v) in temp_args.items(): + if k in global_names: + args[k] = v + elif k in trainer_names: + trainer_args[k] = v + + args['trainer'] = trainer_args + + return args + +if __name__ == '__main__': + args=parse_args() + configs = load_config(args.pop('config')) + configs = update_configs(configs, args) + os.system(f"sudo chmod 777 -R {configs['ckpt_root']}") + os.system(f"sudo chmod 777 -R {configs['log_root']}") + experiment(variant=configs) diff --git a/ACT_DP_multitask/detr/models/mr_mg/goal_gen/train_ip2p.sh b/ACT_DP_multitask/detr/models/mr_mg/goal_gen/train_ip2p.sh new file mode 100644 index 0000000000000000000000000000000000000000..2eb484e7bb320ba2b71bfeb3815fd548fb7e4be5 --- /dev/null +++ b/ACT_DP_multitask/detr/models/mr_mg/goal_gen/train_ip2p.sh @@ -0,0 +1,17 @@ +#!/usr/bin/env bash +set -x +GPUS_PER_NODE=8 # number of gpus per machine +MASTER_ADDR={master_address}":"{port} # modify it with your own address and port +NNODES=1 # number of machines +JOB_ID=107 +torchrun \ + --nproc_per_node $GPUS_PER_NODE \ + --nnodes $NNODES \ + --node_rank 0 \ + --rdzv_endpoint $MASTER_ADDR \ + --rdzv_id $JOB_ID \ + --rdzv_backend c10d \ + goal_gen/train.py \ + --config ${@:1} \ + --gpus $GPUS_PER_NODE \ + --num_nodes $NNODES diff --git a/ACT_DP_multitask/detr/models/mr_mg/goal_gen/training/trainer.py b/ACT_DP_multitask/detr/models/mr_mg/goal_gen/training/trainer.py new file mode 100644 index 0000000000000000000000000000000000000000..50e86ea80dc234e447a96a84cd1a95630fbdc7a1 --- /dev/null +++ b/ACT_DP_multitask/detr/models/mr_mg/goal_gen/training/trainer.py @@ -0,0 +1,129 @@ +# Copyright (2024) Bytedance Ltd. and/or its affiliates + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import torch +import torch.nn.functional as F +import lightning.pytorch as pl +from utils.ema import requires_grad,update_ema +from utils.dist_train import get_rank +from model.model import IP2P +class Goalgen_Trainer(pl.LightningModule): + def __init__(self, configs): + super().__init__() + self._main_rank_print('--------------- model configs ---------------') + self._main_rank_print(configs) + self.configs = configs + self._initialize() + self.save_hyperparameters() + self.dataset_name = "calvin" + + @staticmethod + def _main_rank_print(*args, **kwargs): + if get_rank() == 0: + print(*args, **kwargs) + + @property + def num_gpus(self): + return self.trainer.num_devices * self.trainer.num_nodes + + def _initialize(self): + + self.model = IP2P( + pretrained_model_dir=self.configs['pretrained_model_dir'], # load the pretrained instructpix2pix weight + device=self.configs['device'], + seed=self.configs['seed'], + conditioning_dropout_prob=self.configs['conditioning_dropout_prob'], + gradient_checkpointing=self.configs['gradient_checkpointing'] + ) + + self.use_ema=self.configs["use_ema"] + if self.use_ema: + self.ema_model = IP2P( + pretrained_model_dir=self.configs['pretrained_model_dir'], + device=self.configs['device'], + seed=self.configs['seed'], + conditioning_dropout_prob=self.configs['conditioning_dropout_prob'], + gradient_checkpointing=self.configs['gradient_checkpointing'] + ) + requires_grad(self.ema_model, False) # ema model will not be trained. It will only be updated. + self.ema_model.eval() + + + @classmethod + def from_checkpoint(cls, ckpt_dir=None, configs=None): + if ckpt_dir is None: + assert configs is not None, "ckpt_dir and configs are both None for initialization." + return cls(configs) + + def configure_optimizers(self): + lr = self.configs['learning_rate'] + eff_bsz = self.configs['batch_size'] * self.num_gpus + self._main_rank_print('-' * 40) + self._main_rank_print(f"learning rate: {lr}, effective batch size: {eff_bsz}") + + optimizer_params = [ + {'params': self.model.unet.parameters(), 'lr': lr}, + ] # only unet will be trained + + optimizer = torch.optim.AdamW( + optimizer_params, + betas=(self.configs['adam_beta1'], self.configs['adam_beta2']), + weight_decay=self.configs['adam_weight_decay'], + eps=self.configs['adam_epsilon'] + ) + + return { + 'optimizer': optimizer + } + + def _log_output(self, output, phase, dataset=None, **kwargs): + for k, v in output.items(): + log_name = f"{phase}_{k}" + if dataset is not None: + log_name = f"{dataset}_{log_name}" + self.log(log_name, v, prog_bar=True, **kwargs) + + def validation_step(self, batch, batch_idx, dataloader_idx=0): + with torch.no_grad(): + if self.use_ema: + prediction, target = self.ema_model.forward(batch) + else: + prediction, target = self.model.forward(batch) + loss = F.mse_loss(prediction.float(), target.float(), reduction="mean") + output = {'loss': loss} + + self._log_output(output, phase="val", sync_dist=True, + on_epoch=True, on_step=False, dataset=self.dataset_name) + + def training_step(self, batch, batch_idx): + prediction, target = self.model.forward(batch) + loss = F.mse_loss(prediction.float(), target.float(), reduction="mean") + output = {'loss': loss} + self._log_output(output, phase="train", on_epoch=False, on_step=True,dataset=self.dataset_name) + if self.configs['use_ema']: + update_ema(self.ema_model, self.model, decay=0.999) + return output['loss'] + + + def on_save_checkpoint(self, checkpoint): + if not self.use_ema: + checkpoint['state_dict'] = {'unet': self.model.unet.state_dict()} + else: + checkpoint['state_dict'] = {'unet_ema': self.ema_model.unet.state_dict(),'unet': self.model.unet.state_dict()} + def on_load_checkpoint(self, checkpoint): + # 仅加载unet模å—çš„å‚æ•° + if not self.use_ema: + self.model.unet.load_state_dict(checkpoint['state_dict']['unet']) + else: + self.model.unet.load_state_dict(checkpoint['state_dict']['unet_ema']) + self.ema_model.unet.load_state_dict(checkpoint['state_dict']['unet_ema']) \ No newline at end of file diff --git a/ACT_DP_multitask/detr/models/mr_mg/goal_gen/utils/dist_train.py b/ACT_DP_multitask/detr/models/mr_mg/goal_gen/utils/dist_train.py new file mode 100644 index 0000000000000000000000000000000000000000..12842dccc023cde8ded10582700d62343111b87e --- /dev/null +++ b/ACT_DP_multitask/detr/models/mr_mg/goal_gen/utils/dist_train.py @@ -0,0 +1,77 @@ +# Copyright (2024) Bytedance Ltd. and/or its affiliates + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import io +import os + +import torch +import torch.distributed as dist + + +_print = print + +def get_world_size(): return int(os.getenv('WORLD_SIZE', 1)) +def get_rank(): return int(os.getenv('RANK', 0)) +def get_local_rank(): return int(os.getenv('LOCAL_RANK', 0)) + + +def is_dist(): + return dist.is_available() and dist.is_initialized() and get_world_size() > 1 + +def print(*argc, all=False, **kwargs): + if not is_dist(): + _print(*argc, **kwargs) + return + + if not all and get_local_rank() != 0: + return + + output = io.StringIO() + kwargs['end'] = '' + kwargs['file'] = output + kwargs['flush'] = True + _print(*argc, **kwargs) + + s = output.getvalue() + output.close() + + s = '[rank {}] {}'.format(dist.get_rank(), s) + _print(s) + +def reduce_mean(tensor, nprocs=None): + if not is_dist(): + return tensor + if not isinstance(tensor, torch.Tensor): + device = torch.cuda.current_device() + rt = torch.tensor(tensor, device=device) + else: + rt = tensor.clone() + dist.all_reduce(rt, op=dist.ReduceOp.SUM) + nprocs = nprocs if nprocs else dist.get_world_size() + rt = rt / nprocs + if not isinstance(tensor, torch.Tensor): + rt = rt.item() + return rt + +def reduce_sum(tensor): + if not is_dist(): + return tensor + if not isinstance(tensor, torch.Tensor): + device = torch.cuda.current_device() + rt = torch.tensor(tensor, device=device) + else: + rt = tensor.clone() + dist.all_reduce(rt, op=dist.ReduceOp.SUM) + if not isinstance(tensor, torch.Tensor): + rt = rt.item() + return rt diff --git a/ACT_DP_multitask/detr/models/mr_mg/goal_gen/utils/ema.py b/ACT_DP_multitask/detr/models/mr_mg/goal_gen/utils/ema.py new file mode 100644 index 0000000000000000000000000000000000000000..01c95b1c95c3a7d2d569aedd11fefd7d03789a13 --- /dev/null +++ b/ACT_DP_multitask/detr/models/mr_mg/goal_gen/utils/ema.py @@ -0,0 +1,35 @@ +# Copyright (2024) Bytedance Ltd. and/or its affiliates + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import torch +from torch import inf +from collections import OrderedDict +@torch.no_grad() +def update_ema(ema_model, model, decay=0.999): + """ + Step the EMA model towards the current model. + """ + ema_params = OrderedDict(ema_model.unet.named_parameters()) + model_params = OrderedDict(model.unet.named_parameters()) + + for name, param in model_params.items(): + # TODO: Consider applying only to params that require_grad to avoid small numerical changes of pos_embed + ema_params[name].mul_(decay).add_(param.data, alpha=1 - decay) + +def requires_grad(model, flag=True): + """ + Set requires_grad flag for all parameters in a model. + """ + for p in model.parameters(): + p.requires_grad = flag + diff --git a/ACT_DP_multitask/detr/models/mr_mg/goal_gen/utils/format_calvin_data.py b/ACT_DP_multitask/detr/models/mr_mg/goal_gen/utils/format_calvin_data.py new file mode 100644 index 0000000000000000000000000000000000000000..cb34d9dd6c9d48990f80c9c8f66ec36cda62192c --- /dev/null +++ b/ACT_DP_multitask/detr/models/mr_mg/goal_gen/utils/format_calvin_data.py @@ -0,0 +1,79 @@ +# Copyright (2024) Bytedance Ltd. and/or its affiliates + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Format calvin data for training ip2p +import os +import argparse +import json +from tqdm import tqdm + +import numpy as np +import cv2 + + +SPLITS = ['training'] +# SPLITS = ['validation'] +def main(data_dir, target_dir, num_trajs_per_task=10000): + for split in SPLITS: + meta = dict() + split_dir = os.path.join(target_dir, split) + os.mkdir(split_dir) + dataset_dir = os.path.join(data_dir, split) + anns = np.load( + os.path.join(dataset_dir, "lang_annotations", "auto_lang_ann.npy"), + allow_pickle=True).item() + n_trajs = len(anns['info']['indx']) + task_dict = {} + for traj_idx in tqdm(range(n_trajs)): + if split == 'training': + # sample trajectories based on num_trajs_per_task + traj_task = anns['language']['task'][traj_idx] + if traj_task not in task_dict: + task_dict[traj_task] = 1 + else: + task_dict[traj_task] = task_dict[traj_task] + 1 + if task_dict[traj_task] > num_trajs_per_task: + continue + + traj_dir = os.path.join(split_dir, f"{traj_idx}") + os.mkdir(traj_dir) + traj_st, traj_ed = anns['info']['indx'][traj_idx] + traj_text = anns['language']['ann'][traj_idx] + for i in range(traj_st, traj_ed + 1): + frame = np.load(os.path.join(dataset_dir, f"episode_{i:07d}.npz")) + static_rgb = frame['rgb_static'] + hand_rgb = frame['rgb_gripper'] + cv2.imwrite(os.path.join(traj_dir, f"{i - traj_st}_static.png"), cv2.cvtColor(static_rgb, cv2.COLOR_BGR2RGB)) + cv2.imwrite(os.path.join(traj_dir, f"{i - traj_st}_hand.png"), cv2.cvtColor(hand_rgb, cv2.COLOR_BGR2RGB)) + meta[traj_idx] = {"text": traj_text, "num_frames": int(traj_ed - traj_st + 1)} + with open(os.path.join(split_dir, "meta.json"), "w") as f: + json.dump(meta, f) + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--data_dir", + type=str, + default="", + help="data directory") + parser.add_argument("--target_dir", + type=str, + default="", + help="target data directory") + parser.add_argument("--num_trajs_per_task", + type=int, + default=10000, # when you want to do few-shot experiments, change this number + help="number of trajectories per task") + args = parser.parse_args() + + main(args.data_dir, args.target_dir, args.num_trajs_per_task) \ No newline at end of file diff --git a/ACT_DP_multitask/detr/models/mr_mg/goal_gen/utils/pipeline.py b/ACT_DP_multitask/detr/models/mr_mg/goal_gen/utils/pipeline.py new file mode 100644 index 0000000000000000000000000000000000000000..d098b8c5f13f75f8d746191f09d67067ca471c0c --- /dev/null +++ b/ACT_DP_multitask/detr/models/mr_mg/goal_gen/utils/pipeline.py @@ -0,0 +1,445 @@ +# Copyright 2024 The InstructPix2Pix Authors and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# modified from https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py +from diffusers import StableDiffusionInstructPix2PixPipeline +from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput +from typing import Callable, Dict, List, Optional, Union +import numpy as np +import torch + + +def retrieve_latents( + encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample" +): + if hasattr(encoder_output, "latent_dist") and sample_mode == "sample": + return encoder_output.latent_dist.sample(generator) + elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax": + return encoder_output.latent_dist.mode() + elif hasattr(encoder_output, "latents"): + return encoder_output.latents + else: + raise AttributeError("Could not access latents of provided encoder_output") + +class Pipeline(StableDiffusionInstructPix2PixPipeline): + model_cpu_offload_seq = "text_encoder->unet->vae" + _optional_components = ["safety_checker", "feature_extractor", "image_encoder"] + _exclude_from_cpu_offload = ["safety_checker"] + _callback_tensor_inputs = ["latents", "prompt_embeds", "image_latents"] + def __init__( + self, + vae, + text_encoder, + tokenizer, + unet, + scheduler, + safety_checker, + feature_extractor, + image_encoder=None, + requires_safety_checker=True, + ): + super().__init__( + vae, + text_encoder, + tokenizer, + unet, + scheduler, + safety_checker, + feature_extractor, + ) + + @torch.no_grad() + def _encode_prompt( + self, + prompt, + device, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt=None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `List[str]`, *optional*): + prompt to be encoded + device: (`torch.device`): + torch device + num_images_per_prompt (`int`): + number of images that should be generated per prompt + do_classifier_free_guidance (`bool`): + whether to use classifier free guidance or not + negative_ prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + """ + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + if prompt_embeds is None: + # textual inversion: procecss multi-vector tokens if necessary + # if isinstance(self, TextualInversionLoaderMixin): + # prompt = self.maybe_convert_prompt(prompt, self.tokenizer) + + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=77, + truncation=True, + return_tensors="pt", + ) + + + + text_input_ids = text_inputs.input_ids + untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( + text_input_ids, untruncated_ids + ): + removed_text = self.tokenizer.batch_decode( + untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1] + ) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {self.tokenizer.model_max_length} tokens: {removed_text}" + ) + + attention_mask = text_inputs.attention_mask.to(device) + + + prompt_embeds = self.text_encoder( + text_input_ids.to(device), + attention_mask=attention_mask, + ) + prompt_embeds = prompt_embeds[0] + + prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) + + bs_embed, seq_len, _ = prompt_embeds.shape + # duplicate text embeddings for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) + + # get unconditional embeddings for classifier free guidance + if do_classifier_free_guidance and negative_prompt_embeds is None: + uncond_tokens: List[str] + if negative_prompt is None: + uncond_tokens = [""] * batch_size + elif type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif isinstance(negative_prompt, str): + uncond_tokens = [negative_prompt] + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + else: + uncond_tokens = negative_prompt + + # textual inversion: procecss multi-vector tokens if necessary + # if isinstance(self, TextualInversionLoaderMixin): + # uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer) + + max_length = prompt_embeds.shape[1] + uncond_input = self.tokenizer( + uncond_tokens, + padding="max_length", + max_length=max_length, + truncation=True, + return_tensors="pt", + ) + + attention_mask = uncond_input.attention_mask.to(device) + + + negative_prompt_embeds = self.text_encoder( + uncond_input.input_ids.to(device), + attention_mask=attention_mask, + ) + negative_prompt_embeds = negative_prompt_embeds[0] + + if do_classifier_free_guidance: + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + seq_len = negative_prompt_embeds.shape[1] + + negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) + + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) + negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + # pix2pix has two negative embeddings, and unlike in other pipelines latents are ordered [prompt_embeds, negative_prompt_embeds, negative_prompt_embeds] + prompt_embeds = torch.cat([prompt_embeds, negative_prompt_embeds, negative_prompt_embeds]) + + return prompt_embeds + + + @torch.no_grad() + def __call__( + self, + prompt: Union[str, List[str]] = None, + image =None, + num_inference_steps: int = 100, + guidance_scale: float = 7.5, + image_guidance_scale: float = 1.5, + negative_prompt: Optional[Union[str, List[str]]] = None, + num_images_per_prompt: Optional[int] = 1, + eta: float = 0.0, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.FloatTensor] = None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + ip_adapter_image= None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, + callback_on_step_end_tensor_inputs: List[str] = ["latents"], + **kwargs, + ): + + callback = kwargs.pop("callback", None) + callback_steps = kwargs.pop("callback_steps", None) + + if callback is not None: + deprecate( + "callback", + "1.0.0", + "Passing `callback` as an input argument to `__call__` is deprecated, consider use `callback_on_step_end`", + ) + if callback_steps is not None: + deprecate( + "callback_steps", + "1.0.0", + "Passing `callback_steps` as an input argument to `__call__` is deprecated, consider use `callback_on_step_end`", + ) + + # 0. Check inputs + self.check_inputs( + prompt, + callback_steps, + negative_prompt, + prompt_embeds, + negative_prompt_embeds, + callback_on_step_end_tensor_inputs, + ) + self._guidance_scale = guidance_scale + self._image_guidance_scale = image_guidance_scale + + device = self._execution_device + + if ip_adapter_image is not None: + output_hidden_state = False if isinstance(self.unet.encoder_hid_proj, ImageProjection) else True + image_embeds, negative_image_embeds = self.encode_image( + ip_adapter_image, device, num_images_per_prompt, output_hidden_state + ) + if self.do_classifier_free_guidance: + image_embeds = torch.cat([image_embeds, negative_image_embeds, negative_image_embeds]) + + if image is None: + raise ValueError("`image` input cannot be undefined.") + + # 1. Define call parameters + + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + # check if scheduler is in sigmas space + scheduler_is_in_sigma_space = hasattr(self.scheduler, "sigmas") + + # 2. Encode input prompt + prompt_embeds = self._encode_prompt( + prompt, + device, + num_images_per_prompt, + self.do_classifier_free_guidance, + negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + ) + + + + length=prompt_embeds.shape[0] + + # 3. Preprocess image + image = self.image_processor.preprocess(image) + + # 4. set timesteps + self.scheduler.set_timesteps(num_inference_steps, device=device) + timesteps = self.scheduler.timesteps + + # 5. Prepare Image latents + image_latents = self.prepare_image_latents( + image, + batch_size, + num_images_per_prompt, + prompt_embeds.dtype, + device, + self.do_classifier_free_guidance, + ) + + height, width = image_latents.shape[-2:] + height = height * self.vae_scale_factor + width = width * self.vae_scale_factor + + # 6. Prepare latent variables + num_channels_latents = self.vae.config.latent_channels + latents = self.prepare_latents( + batch_size * num_images_per_prompt, + num_channels_latents, + height, + width, + prompt_embeds.dtype, + device, + generator, + latents, + ) + + # 7. Check that shapes of latents and image match the UNet channels + num_channels_image = image_latents.shape[1] + if num_channels_latents + num_channels_image != self.unet.config.in_channels: + raise ValueError( + f"Incorrect configuration settings! The config of `pipeline.unet`: {self.unet.config} expects" + f" {self.unet.config.in_channels} but received `num_channels_latents`: {num_channels_latents} +" + f" `num_channels_image`: {num_channels_image} " + f" = {num_channels_latents+num_channels_image}. Please verify the config of" + " `pipeline.unet` or your `image` input." + ) + + # 8. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + # 8.1 Add image embeds for IP-Adapter + added_cond_kwargs = {"image_embeds": image_embeds} if ip_adapter_image is not None else None + + # 9. Denoising loop + num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + self._num_timesteps = len(timesteps) + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + # Expand the latents if we are doing classifier free guidance. + # The latents are expanded 3 times because for pix2pix the guidance\ + # is applied for both the text and the input image. + latent_model_input = torch.cat([latents] * 3) if self.do_classifier_free_guidance else latents + + # concat latents, image_latents in the channel dimension + scaled_latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + scaled_latent_model_input = torch.cat([scaled_latent_model_input, image_latents], dim=1) + + # predict the noise residual + noise_pred = self.unet( + scaled_latent_model_input, + t, + encoder_hidden_states=prompt_embeds, + added_cond_kwargs=added_cond_kwargs, + return_dict=False, + )[0] + + # Hack: + # For karras style schedulers the model does classifer free guidance using the + # predicted_original_sample instead of the noise_pred. So we need to compute the + # predicted_original_sample here if we are using a karras style scheduler. + if scheduler_is_in_sigma_space: + step_index = (self.scheduler.timesteps == t).nonzero()[0].item() + sigma = self.scheduler.sigmas[step_index] + noise_pred = latent_model_input - sigma * noise_pred + + # perform guidance + if self.do_classifier_free_guidance: + noise_pred_text, noise_pred_image, noise_pred_uncond = noise_pred.chunk(3) + noise_pred = ( + noise_pred_uncond + + self.guidance_scale * (noise_pred_text - noise_pred_image) + + self.image_guidance_scale * (noise_pred_image - noise_pred_uncond) + ) + + # Hack: + # For karras style schedulers the model does classifer free guidance using the + # predicted_original_sample instead of the noise_pred. But the scheduler.step function + # expects the noise_pred and computes the predicted_original_sample internally. So we + # need to overwrite the noise_pred here such that the value of the computed + # predicted_original_sample is correct. + if scheduler_is_in_sigma_space: + noise_pred = (noise_pred - latents) / (-sigma) + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) + image_latents = callback_outputs.pop("image_latents", image_latents) + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + if callback is not None and i % callback_steps == 0: + step_idx = i // getattr(self.scheduler, "order", 1) + callback(step_idx, t, latents) + + if not output_type == "latent": + image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] + image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype) + else: + image = latents + has_nsfw_concept = None + + if has_nsfw_concept is None: + do_denormalize = [True] * image.shape[0] + else: + do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept] + + image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize) + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (image, has_nsfw_concept) + + return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept) + diff --git a/ACT_DP_multitask/detr/models/mr_mg/goal_gen/utils/utils.py b/ACT_DP_multitask/detr/models/mr_mg/goal_gen/utils/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..1989ba9e3859bf9e429fba7d131385f76f28172c --- /dev/null +++ b/ACT_DP_multitask/detr/models/mr_mg/goal_gen/utils/utils.py @@ -0,0 +1,32 @@ +# Copyright (2024) Bytedance Ltd. and/or its affiliates + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +from lightning.pytorch.callbacks import Callback +import json + + +class SetupCallback(Callback): + def __init__(self, now, logdir, ckptdir): + super().__init__() + self.now = now + self.logdir = logdir + self.ckptdir = ckptdir + + def on_train_start(self, trainer, model): + if trainer.global_rank == 0: + # Create logdirs and save configs + os.makedirs(self.logdir, exist_ok=True) + os.makedirs(self.ckptdir, exist_ok=True) + diff --git a/ACT_DP_multitask/detr/models/mr_mg/media/model.gif b/ACT_DP_multitask/detr/models/mr_mg/media/model.gif new file mode 100644 index 0000000000000000000000000000000000000000..9eae2e1691c3364f60b96023e68904a5e9c3b5b5 --- /dev/null +++ b/ACT_DP_multitask/detr/models/mr_mg/media/model.gif @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:3d96bce6dc9112fcfce234a15d48f8d3c1ee22354c3c62281a7bf10db25b2610 +size 1054873 diff --git a/ACT_DP_multitask/detr/models/mr_mg/policy/config/pretrain.json b/ACT_DP_multitask/detr/models/mr_mg/policy/config/pretrain.json new file mode 100644 index 0000000000000000000000000000000000000000..990ace12eae46a0053244890983bcadea7743734 --- /dev/null +++ b/ACT_DP_multitask/detr/models/mr_mg/policy/config/pretrain.json @@ -0,0 +1,77 @@ +{ + "seed": 123, + "ckpt_root": "SAVE_PATH/policy/checkpoints/", + "log_root": "LOG_PATH/policy/logs/", + "exp_name":"pretrain_policy", + "resume":null, + "policy": { + "resampler_params": { + "depth": 3, + "dim_head": 128, + "heads": 4, + "num_latents": 9, + "num_media_embeds": 1 + }, + "seq_len": 10, + "act_len": 5, + "act_latent_dim": 32, + "act_encoder_dim": 128, + "act_decoder_dim": 128, + "progress_decoder_dim":128, + "patch_feat_dim": 768, + "img_feat_dim": 768, + "lang_feat_dim":512, + "embed_dim": 384, + "use_resampler": true, + "n_layer": 12, + "n_head": 12, + "activation_function": "relu", + "dropout": 0.1, + "forward_n_max": 25 + }, + "trainer":{ + "pl_config":{ + "accelerator": "gpu", + "strategy": "ddp", + "precision": "bf16", + "logger": ["tensorboard"], + "gradient_clip_val": 1.0, + "accumulate_grad_batches": 1, + "use_distributed_sampler": true, + "log_every_n_steps": 50, + "max_epochs": 50 + }, + "lr_decay":true, + "batch_size": 64, + "without_norm_pix_loss": false, + "start_epoch": 0, + "learning_rate": 3.6e-4, + "min_learning_rate_scale": 1e-2, + "weight_decay": 0.0, + "betas_1": 0.9, + "betas_2": 0.999, + "warmup_epochs": 5, + "optimizer":"adam", + "num_workers": 8, + "save_epoch": 1, + "gripper_loss_ratio": 0.01, + "fwd_loss_ratio": 1, + "kl_loss_ratio": 1.0, + "progress_loss_ratio":1.0, + "act_pred": false, + "fwd_pred": true, + "progress_pred":false, + "fwd_pred_hand": false, + "fwd_pred_next_n": 3, + "finetune": false, + "use_pretrain": false, + "pretrained_model_path": "" + }, + + "input": { + "state_dim": 7, + "act_dim": 7, + "use_hand_rgb": false, + "use_state": false + } +} \ No newline at end of file diff --git a/ACT_DP_multitask/detr/models/mr_mg/policy/config/train.json b/ACT_DP_multitask/detr/models/mr_mg/policy/config/train.json new file mode 100644 index 0000000000000000000000000000000000000000..faf85374edc11cf550c4453e3385d3b75f8dcf21 --- /dev/null +++ b/ACT_DP_multitask/detr/models/mr_mg/policy/config/train.json @@ -0,0 +1,77 @@ +{ + "seed": 123, + "ckpt_root": "SAVE_PATH/policy/checkpoints/", + "log_root": "LOG_PATH/policy/logs/", + "exp_name":"train_policy", + "resume":null, + "policy": { + "resampler_params": { + "depth": 3, + "dim_head": 128, + "heads": 4, + "num_latents": 9, + "num_media_embeds": 1 + }, + "seq_len": 10, + "act_len": 5, + "act_latent_dim": 32, + "act_encoder_dim": 128, + "act_decoder_dim": 128, + "progress_decoder_dim":128, + "patch_feat_dim": 768, + "img_feat_dim": 768, + "lang_feat_dim":512, + "embed_dim": 384, + "use_resampler": true, + "n_layer": 12, + "n_head": 12, + "activation_function": "relu", + "dropout": 0.1, + "forward_n_max": 25 + }, + "trainer":{ + "pl_config":{ + "accelerator": "gpu", + "strategy": "ddp", + "precision": "bf16", + "logger": ["tensorboard"], + "gradient_clip_val": 1.0, + "accumulate_grad_batches": 1, + "use_distributed_sampler": true, + "log_every_n_steps": 50, + "max_epochs": 50 + }, + "lr_decay":true, + "batch_size": 32, + "without_norm_pix_loss": false, + "start_epoch": 0, + "learning_rate": 1e-3, + "min_learning_rate_scale": 1e-2, + "weight_decay": 0.0, + "betas_1": 0.9, + "betas_2": 0.999, + "warmup_epochs": 1, + "optimizer":"adam", + "num_workers": 13, + "save_epoch": 1, + "gripper_loss_ratio": 0.01, + "fwd_loss_ratio": 0.1, + "kl_loss_ratio": 1.0, + "progress_loss_ratio":1.0, + "act_pred": true, + "fwd_pred": true, + "progress_pred":true, + "fwd_pred_hand": true, + "fwd_pred_next_n": 3, + "finetune": true, + "use_pretrain": true, + "pretrained_model_path": "PATH_TO_PRETRAINED_CHECKPOINT" + }, + + "input": { + "state_dim": 7, + "act_dim": 7, + "use_hand_rgb": true, + "use_state": true + } +} \ No newline at end of file diff --git a/ACT_DP_multitask/detr/models/mr_mg/policy/install.sh b/ACT_DP_multitask/detr/models/mr_mg/policy/install.sh new file mode 100644 index 0000000000000000000000000000000000000000..f5569897c9998d664029298e24598642d9e1eea8 --- /dev/null +++ b/ACT_DP_multitask/detr/models/mr_mg/policy/install.sh @@ -0,0 +1,14 @@ +sudo apt-get -y install libosmesa6-dev +sudo apt-get -y install patchelf +pip install --upgrade transformers +pip3 install flamingo_pytorch +pip3 install tensorboard +pip install opencv-python-headless +pip3 install ftfy regex tqdm +pip3 install matplotlib decord +pip install git+https://github.com/openai/CLIP.git +pip install sentencepiece +pip3 install torch==2.1.0 torchvision==0.16.0 torchaudio==2.1.0 --index-url https://download.pytorch.org/whl/cu118 +pip3 install lightning==2.1.0 +pip3 install pytorch-lightning==2.1.0 +pip install "numpy<2.0.0" \ No newline at end of file diff --git a/ACT_DP_multitask/detr/models/mr_mg/policy/main.py b/ACT_DP_multitask/detr/models/mr_mg/policy/main.py new file mode 100644 index 0000000000000000000000000000000000000000..35dc18ff76a37eaa2fbdf55c14633f1f2d9bdb00 --- /dev/null +++ b/ACT_DP_multitask/detr/models/mr_mg/policy/main.py @@ -0,0 +1,181 @@ +# Copyright (2024) Bytedance Ltd. and/or its affiliates + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os +import argparse +import json +from pathlib import Path +import random +import numpy as np +import torch +from pathlib import Path +import copy +import datetime +from lightning.pytorch.callbacks import Callback +from lightning.pytorch.callbacks import LearningRateMonitor, ModelCheckpoint +from lightning.pytorch.trainer import Trainer +from lightning.pytorch.loggers import TensorBoardLogger +from lightning.pytorch.strategies import DDPStrategy +from lightning import seed_everything +from data.calvin_dataset import CalvinDataset_Policy +from data.ego4d_dataset import Ego4DDataset_Policy +from training.trainer import Policy_Trainer +from torch.utils.data import DataLoader +def set_seed(seed=0): + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + seed_everything(seed) + + +def get_date_str(): + return str(datetime.date.today()) + +class SetupCallback(Callback): + def __init__(self, now, logdir, ckptdir): + super().__init__() + self.now = now + self.logdir = logdir + self.ckptdir = ckptdir + + def on_train_start(self, trainer, model): + if trainer.global_rank == 0: + # Create logdirs and save configs + os.makedirs(self.logdir, exist_ok=True) + os.makedirs(self.ckptdir, exist_ok=True) + +def init_setup_callback(config): + return SetupCallback( + now=str(datetime.datetime.now()).replace(' ', '_'), + logdir=config['log_dir'], + ckptdir=config['ckpt_dir'] + ) + +def init_trainer_config(configs): + trainer_config = copy.deepcopy(configs['trainer']["pl_config"]) + trainer_config['devices'] = configs.get('devices', 'auto') + trainer_config['num_nodes'] = configs.get('num_nodes', 1) + + if 'strategy' not in trainer_config or trainer_config['strategy'] == 'ddp': + trainer_config['strategy'] = DDPStrategy(find_unused_parameters=False) + + exp_name = configs['exp_name'] + + # init loggers + log_dir = os.path.join(get_date_str(), exp_name) + log_dir = os.path.join(configs['log_root'], log_dir) + configs['log_dir'] = log_dir + Path(configs['log_dir']).mkdir(parents=True, exist_ok=True) + trainer_config['logger'] = [TensorBoardLogger(log_dir, name=exp_name)] + + + # TODO: make callbacks configurable + ckpt_dir = os.path.join(get_date_str(), exp_name) + ckpt_dir = os.path.join(configs['ckpt_root'], ckpt_dir) + configs['ckpt_dir'] = ckpt_dir + Path(configs['ckpt_dir']).mkdir(parents=True, exist_ok=True) + trainer_config['callbacks'] = [ + init_setup_callback(configs), + LearningRateMonitor(logging_interval='step'), + ModelCheckpoint(dirpath=ckpt_dir, save_top_k=-1, every_n_epochs=configs["trainer"]["save_epoch"]) # if you have only limited space, just set save_top_k=1 to save the best model + ] + + return trainer_config + +def experiment(variant): + set_seed(variant['seed']) + trainer_config = init_trainer_config(variant) + trainer = Trainer(**trainer_config) + model = Policy_Trainer(variant) + # dataset + if variant["trainer"]["finetune"]: + train_data= CalvinDataset_Policy( + data_dir="PATH_TO_CALVIN/calvin_data", + use_data_augmentation=True, + subfolder= "task_ABC_D", + mode= "train", + forward_n_max=25, + use_play=False, + use_labeled=True) + val_data= CalvinDataset_Policy( + data_dir="PATH_TO_CALVIN/calvin_data", + use_data_augmentation=False, + subfolder= "task_ABC_D", + mode= "validate", + forward_n_max=25, + use_play=False, + use_labeled=True) + else: + train_data= Ego4DDataset_Policy( + data_dir="PATH_TO_Ego4d_Videos", + preprocess=None, + video_sample_rate=2, + seq_len=10, + annotation_file= "PATH_TO_Ego4d_800k_annotations", + use_data_augmentation=True, + goal_interval=7) + val_data= Ego4DDataset_Policy( + data_dir="PATH_TO_Ego4d_Videos", + preprocess=None, + video_sample_rate=2, + seq_len=10, + annotation_file= "PATH_TO_Ego4d_800k_annotations", + use_data_augmentation=False, + goal_interval=7) + train_dataloader= DataLoader(train_data, + batch_size=variant["trainer"]["batch_size"], + num_workers=variant["trainer"]["num_workers"]) + val_dataloader= DataLoader(val_data, + batch_size=variant["trainer"]["batch_size"], + num_workers=variant["trainer"]["num_workers"]) + + _kwargs = { + 'model': model, + 'train_dataloaders':train_dataloader, + 'val_dataloaders':val_dataloader, + 'ckpt_path': variant['resume'] # when you want to restore your training, modify this variant + } + if _kwargs['ckpt_path'] is not None: + print(f"Resuming from {variant['resume']}...") + trainer.fit(**_kwargs) + + +def parse_args(): + parser = argparse.ArgumentParser() + + # Experiment + parser.add_argument('--config', type=str,default="") + parser.add_argument('--devices', default=1, type=int) + parser.add_argument('--num_nodes', default=1, type=int) + parser.add_argument('--seed', default=None, type=int) + parser.add_argument('--log_root', default=None, type=str) + parser.add_argument('--ckpt_root', default=None, type=str) + parser.add_argument('--resume', type=str) + temp_args = vars(parser.parse_args()) + config_path=temp_args.pop("config") + # load config files + configs = json.load(open(config_path)) + for (k, v) in temp_args.items(): + if k not in configs: + configs[k]=v + + return configs + + + +if __name__ == '__main__': + configs=parse_args() + os.system(f"sudo chmod 777 -R {configs['ckpt_root']}") + os.system(f"sudo chmod 777 -R {configs['log_root']}") + experiment(variant=configs) diff --git a/ACT_DP_multitask/detr/models/mr_mg/policy/main.sh b/ACT_DP_multitask/detr/models/mr_mg/policy/main.sh new file mode 100644 index 0000000000000000000000000000000000000000..e8117d6221b034d996a3bfd568b9ed6119ac9a05 --- /dev/null +++ b/ACT_DP_multitask/detr/models/mr_mg/policy/main.sh @@ -0,0 +1,17 @@ +#!/usr/bin/env bash +cd /opt/tiger/GR_MG +GPUS_PER_NODE=8 # number of gpus per machine +MASTER_ADDR={master_address}":"{port} # modify it with your own address and port +NNODES=1 # number of machines +JOB_ID=107 +torchrun \ + --nproc_per_node $GPUS_PER_NODE \ + --nnodes $NNODES \ + --node_rank 0 \ + --rdzv_endpoint $MASTER_ADDR \ + --rdzv_id $JOB_ID \ + --rdzv_backend c10d \ + policy/main.py \ + --config ${@:1} \ + --devices $GPUS_PER_NODE \ + --num_nodes $NNODES \ No newline at end of file diff --git a/ACT_DP_multitask/detr/models/mr_mg/policy/model/__pycache__/vision_transformer.cpython-310.pyc b/ACT_DP_multitask/detr/models/mr_mg/policy/model/__pycache__/vision_transformer.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4666a58d37efb7ef0cb52589be3293a73c20a24d Binary files /dev/null and b/ACT_DP_multitask/detr/models/mr_mg/policy/model/__pycache__/vision_transformer.cpython-310.pyc differ diff --git a/ACT_DP_multitask/detr/models/mr_mg/policy/model/__pycache__/vision_transformer.cpython-37.pyc b/ACT_DP_multitask/detr/models/mr_mg/policy/model/__pycache__/vision_transformer.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d10c8d0dee67119ccb59c61d73c6a9bc4911c528 Binary files /dev/null and b/ACT_DP_multitask/detr/models/mr_mg/policy/model/__pycache__/vision_transformer.cpython-37.pyc differ diff --git a/ACT_DP_multitask/detr/models/mr_mg/policy/model/__pycache__/vision_transformer.cpython-38.pyc b/ACT_DP_multitask/detr/models/mr_mg/policy/model/__pycache__/vision_transformer.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..15e0c6a84ac4140e4fe9e6b5d950f9e1eafaccc5 Binary files /dev/null and b/ACT_DP_multitask/detr/models/mr_mg/policy/model/__pycache__/vision_transformer.cpython-38.pyc differ diff --git a/ACT_DP_multitask/detr/models/mr_mg/policy/model/gpt2.py b/ACT_DP_multitask/detr/models/mr_mg/policy/model/gpt2.py new file mode 100644 index 0000000000000000000000000000000000000000..d124e3d49ccf7be6127962b20c5e7650b12560b6 --- /dev/null +++ b/ACT_DP_multitask/detr/models/mr_mg/policy/model/gpt2.py @@ -0,0 +1,944 @@ +# coding=utf-8 +# Copyright 2018 The OpenAI Team Authors and HuggingFace Inc. team. +# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""PyTorch OpenAI GPT-2 model.""" + +import math +import os +import warnings +from dataclasses import dataclass +from typing import Optional, Tuple, Union + +import torch +import torch.utils.checkpoint +from torch import nn +from torch.cuda.amp import autocast +from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss + +from transformers.activations import ACT2FN +from transformers.modeling_outputs import ( + BaseModelOutputWithPastAndCrossAttentions, + CausalLMOutputWithCrossAttentions, + QuestionAnsweringModelOutput, + SequenceClassifierOutputWithPast, + TokenClassifierOutput, +) +from transformers.modeling_utils import PreTrainedModel, SequenceSummary +from transformers.pytorch_utils import Conv1D, find_pruneable_heads_and_indices, prune_conv1d_layer +from transformers.utils import ( + ModelOutput, + add_code_sample_docstrings, + add_start_docstrings, + add_start_docstrings_to_model_forward, + logging, + replace_return_docstrings, +) +from transformers.utils.model_parallel_utils import assert_device_map, get_device_map +from transformers.models.gpt2.configuration_gpt2 import GPT2Config + + +logger = logging.get_logger(__name__) + +_CHECKPOINT_FOR_DOC = "gpt2" +_CONFIG_FOR_DOC = "GPT2Config" + +GPT2_PRETRAINED_MODEL_ARCHIVE_LIST = [ + "gpt2", + "gpt2-medium", + "gpt2-large", + "gpt2-xl", + "distilgpt2", + # See all GPT-2 models at https://huggingface.co/models?filter=gpt2 +] + + +def load_tf_weights_in_gpt2(model, config, gpt2_checkpoint_path): + """Load tf checkpoints in a pytorch model""" + try: + import re + + import tensorflow as tf + except ImportError: + logger.error( + "Loading a TensorFlow model in PyTorch, requires TensorFlow to be installed. Please see " + "https://www.tensorflow.org/install/ for installation instructions." + ) + raise + tf_path = os.path.abspath(gpt2_checkpoint_path) + logger.info(f"Converting TensorFlow checkpoint from {tf_path}") + # Load weights from TF model + init_vars = tf.train.list_variables(tf_path) + names = [] + arrays = [] + for name, shape in init_vars: + logger.info(f"Loading TF weight {name} with shape {shape}") + array = tf.train.load_variable(tf_path, name) + names.append(name) + arrays.append(array.squeeze()) + + for name, array in zip(names, arrays): + name = name[6:] # skip "model/" + name = name.split("/") + pointer = model + for m_name in name: + if re.fullmatch(r"[A-Za-z]+\d+", m_name): + scope_names = re.split(r"(\d+)", m_name) + else: + scope_names = [m_name] + if scope_names[0] == "w" or scope_names[0] == "g": + pointer = getattr(pointer, "weight") + elif scope_names[0] == "b": + pointer = getattr(pointer, "bias") + elif scope_names[0] == "wpe" or scope_names[0] == "wte": + pointer = getattr(pointer, scope_names[0]) + pointer = getattr(pointer, "weight") + else: + pointer = getattr(pointer, scope_names[0]) + if len(scope_names) >= 2: + num = int(scope_names[1]) + pointer = pointer[num] + try: + if pointer.shape != array.shape: + raise ValueError(f"Pointer shape {pointer.shape} and array shape {array.shape} mismatched") + except ValueError as e: + e.args += (pointer.shape, array.shape) + raise + logger.info(f"Initialize PyTorch weight {name}") + pointer.data = torch.from_numpy(array) + return model + + +class GPT2Attention(nn.Module): + def __init__(self, config, is_cross_attention=False, layer_idx=None): + super().__init__() + + max_positions = config.max_position_embeddings + self.register_buffer( + "bias", + torch.tril(torch.ones((max_positions, max_positions), dtype=torch.bool)).view( + 1, 1, max_positions, max_positions + ), + persistent=False, + ) + self.register_buffer("masked_bias", torch.tensor(-1e4), persistent=False) + + self.embed_dim = config.hidden_size + self.num_heads = config.num_attention_heads + self.head_dim = self.embed_dim // self.num_heads + self.split_size = self.embed_dim + if self.head_dim * self.num_heads != self.embed_dim: + raise ValueError( + f"`embed_dim` must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:" + f" {self.num_heads})." + ) + + self.scale_attn_weights = config.scale_attn_weights + self.is_cross_attention = is_cross_attention + + # Layer-wise attention scaling, reordering, and upcasting + self.scale_attn_by_inverse_layer_idx = config.scale_attn_by_inverse_layer_idx + self.layer_idx = layer_idx + self.reorder_and_upcast_attn = config.reorder_and_upcast_attn + + if self.is_cross_attention: + self.c_attn = Conv1D(2 * self.embed_dim, self.embed_dim) + self.q_attn = Conv1D(self.embed_dim, self.embed_dim) + else: + self.c_attn = Conv1D(3 * self.embed_dim, self.embed_dim) + self.c_proj = Conv1D(self.embed_dim, self.embed_dim) + + self.attn_dropout = nn.Dropout(config.attn_pdrop) + self.resid_dropout = nn.Dropout(config.resid_pdrop) + + self.pruned_heads = set() + + def prune_heads(self, heads): + if len(heads) == 0: + return + heads, index = find_pruneable_heads_and_indices(heads, self.num_heads, self.head_dim, self.pruned_heads) + index_attn = torch.cat([index, index + self.split_size, index + (2 * self.split_size)]) + + # Prune conv1d layers + self.c_attn = prune_conv1d_layer(self.c_attn, index_attn, dim=1) + self.c_proj = prune_conv1d_layer(self.c_proj, index, dim=0) + + # Update hyper params + self.split_size = (self.split_size // self.num_heads) * (self.num_heads - len(heads)) + self.num_heads = self.num_heads - len(heads) + self.pruned_heads = self.pruned_heads.union(heads) + + def _attn(self, query, key, value, attention_mask=None, head_mask=None): + attn_weights = torch.matmul(query, key.transpose(-1, -2)) + + if self.scale_attn_weights: + attn_weights = attn_weights / torch.full( + [], value.size(-1) ** 0.5, dtype=attn_weights.dtype, device=attn_weights.device + ) + + # Layer-wise attention scaling + if self.scale_attn_by_inverse_layer_idx: + attn_weights = attn_weights / float(self.layer_idx + 1) + + if not self.is_cross_attention: + # if only "normal" attention layer implements causal mask + query_length, key_length = query.size(-2), key.size(-2) + causal_mask = self.bias[:, :, key_length - query_length : key_length, :key_length] + mask_value = torch.finfo(attn_weights.dtype).min + # Need to be a tensor, otherwise we get error: `RuntimeError: expected scalar type float but found double`. + # Need to be on the same device, otherwise `RuntimeError: ..., x and y to be on the same device` + mask_value = torch.full([], mask_value, dtype=attn_weights.dtype).to(attn_weights.device) + attn_weights = torch.where(causal_mask, attn_weights.to(attn_weights.dtype), mask_value) + + if attention_mask is not None: + # Apply the attention mask + attn_weights = attn_weights + attention_mask + + attn_weights = nn.functional.softmax(attn_weights, dim=-1) + + # Downcast (if necessary) back to V's dtype (if in mixed-precision) -- No-Op otherwise + attn_weights = attn_weights.type(value.dtype) + attn_weights = self.attn_dropout(attn_weights) + + # Mask heads if we want to + if head_mask is not None: + attn_weights = attn_weights * head_mask + + attn_output = torch.matmul(attn_weights, value) + + return attn_output, attn_weights + + def _upcast_and_reordered_attn(self, query, key, value, attention_mask=None, head_mask=None): + # Use `torch.baddbmm` (a bit more efficient w/ alpha param for scaling -- from Megatron-LM) + bsz, num_heads, q_seq_len, dk = query.size() + _, _, k_seq_len, _ = key.size() + + # Preallocate attn_weights for `baddbmm` + attn_weights = torch.empty(bsz * num_heads, q_seq_len, k_seq_len, dtype=torch.float32, device=query.device) + + # Compute Scale Factor + scale_factor = 1.0 + if self.scale_attn_weights: + scale_factor /= float(value.size(-1)) ** 0.5 + + if self.scale_attn_by_inverse_layer_idx: + scale_factor /= float(self.layer_idx + 1) + + # Upcast (turn off autocast) and reorder (Scale K by 1 / root(dk)) + with autocast(enabled=False): + q, k = query.reshape(-1, q_seq_len, dk), key.transpose(-1, -2).reshape(-1, dk, k_seq_len) + attn_weights = torch.baddbmm(attn_weights, q.float(), k.float(), beta=0, alpha=scale_factor) + attn_weights = attn_weights.reshape(bsz, num_heads, q_seq_len, k_seq_len) + + if not self.is_cross_attention: + # if only "normal" attention layer implements causal mask + query_length, key_length = query.size(-2), key.size(-2) + causal_mask = self.bias[:, :, key_length - query_length : key_length, :key_length] + mask_value = torch.finfo(attn_weights.dtype).min + # Need to be a tensor, otherwise we get error: `RuntimeError: expected scalar type float but found double`. + # Need to be on the same device, otherwise `RuntimeError: ..., x and y to be on the same device` + mask_value = torch.tensor(mask_value, dtype=attn_weights.dtype).to(attn_weights.device) + attn_weights = torch.where(causal_mask, attn_weights, mask_value) + + if attention_mask is not None: + # Apply the attention mask + attn_weights = attn_weights + attention_mask + + attn_weights = nn.functional.softmax(attn_weights, dim=-1) + + # Downcast (if necessary) back to V's dtype (if in mixed-precision) -- No-Op if otherwise + if attn_weights.dtype != torch.float32: + raise RuntimeError("Error with upcasting, attn_weights does not have dtype torch.float32") + attn_weights = attn_weights.type(value.dtype) + attn_weights = self.attn_dropout(attn_weights) + + # Mask heads if we want to + if head_mask is not None: + attn_weights = attn_weights * head_mask + + attn_output = torch.matmul(attn_weights, value) + + return attn_output, attn_weights + + def _split_heads(self, tensor, num_heads, attn_head_size): + """ + Splits hidden_size dim into attn_head_size and num_heads + """ + new_shape = tensor.size()[:-1] + (num_heads, attn_head_size) + tensor = tensor.view(new_shape) + return tensor.permute(0, 2, 1, 3) # (batch, head, seq_length, head_features) + + def _merge_heads(self, tensor, num_heads, attn_head_size): + """ + Merges attn_head_size dim and num_attn_heads dim into hidden_size + """ + tensor = tensor.permute(0, 2, 1, 3).contiguous() + new_shape = tensor.size()[:-2] + (num_heads * attn_head_size,) + return tensor.view(new_shape) + + def forward( + self, + hidden_states: Optional[Tuple[torch.FloatTensor]], + layer_past: Optional[Tuple[torch.Tensor]] = None, + attention_mask: Optional[torch.FloatTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = False, + output_attentions: Optional[bool] = False, + ) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]], ...]: + if encoder_hidden_states is not None: + if not hasattr(self, "q_attn"): + raise ValueError( + "If class is used as cross attention, the weights `q_attn` have to be defined. " + "Please make sure to instantiate class with `GPT2Attention(..., is_cross_attention=True)`." + ) + + query = self.q_attn(hidden_states) + key, value = self.c_attn(encoder_hidden_states).split(self.split_size, dim=2) + attention_mask = encoder_attention_mask + else: + query, key, value = self.c_attn(hidden_states).split(self.split_size, dim=2) + + query = self._split_heads(query, self.num_heads, self.head_dim) + key = self._split_heads(key, self.num_heads, self.head_dim) + value = self._split_heads(value, self.num_heads, self.head_dim) + + if layer_past is not None: + past_key, past_value = layer_past + key = torch.cat((past_key, key), dim=-2) + value = torch.cat((past_value, value), dim=-2) + + if use_cache is True: + present = (key, value) + else: + present = None + + if self.reorder_and_upcast_attn: + attn_output, attn_weights = self._upcast_and_reordered_attn(query, key, value, attention_mask, head_mask) + else: + attn_output, attn_weights = self._attn(query, key, value, attention_mask, head_mask) + + attn_output = self._merge_heads(attn_output, self.num_heads, self.head_dim) + attn_output = self.c_proj(attn_output) + attn_output = self.resid_dropout(attn_output) + + outputs = (attn_output, present) + if output_attentions: + outputs += (attn_weights,) + + return outputs # a, present, (attentions) + + +class GPT2MLP(nn.Module): + def __init__(self, intermediate_size, config): + super().__init__() + embed_dim = config.hidden_size + self.c_fc = Conv1D(intermediate_size, embed_dim) + self.c_proj = Conv1D(embed_dim, intermediate_size) + self.act = ACT2FN[config.activation_function] + self.dropout = nn.Dropout(config.resid_pdrop) + + def forward(self, hidden_states: Optional[Tuple[torch.FloatTensor]]) -> torch.FloatTensor: + hidden_states = self.c_fc(hidden_states) + hidden_states = self.act(hidden_states) + hidden_states = self.c_proj(hidden_states) + hidden_states = self.dropout(hidden_states) + return hidden_states + + +class GPT2Block(nn.Module): + def __init__(self, config, layer_idx=None): + super().__init__() + hidden_size = config.hidden_size + inner_dim = config.n_inner if config.n_inner is not None else 4 * hidden_size + + self.ln_1 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon) + self.attn = GPT2Attention(config, layer_idx=layer_idx) + self.ln_2 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon) + + if config.add_cross_attention: + self.crossattention = GPT2Attention(config, is_cross_attention=True, layer_idx=layer_idx) + self.ln_cross_attn = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon) + + self.mlp = GPT2MLP(inner_dim, config) + + def forward( + self, + hidden_states: Optional[Tuple[torch.FloatTensor]], + layer_past: Optional[Tuple[torch.Tensor]] = None, + attention_mask: Optional[torch.FloatTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = False, + output_attentions: Optional[bool] = False, + ) -> Union[Tuple[torch.Tensor], Optional[Tuple[torch.Tensor, Tuple[torch.FloatTensor, ...]]]]: + residual = hidden_states + hidden_states = self.ln_1(hidden_states) + attn_outputs = self.attn( + hidden_states, + layer_past=layer_past, + attention_mask=attention_mask, + head_mask=head_mask, + use_cache=use_cache, + output_attentions=output_attentions, + ) + attn_output = attn_outputs[0] # output_attn: a, present, (attentions) + outputs = attn_outputs[1:] + # residual connection + hidden_states = attn_output + residual + + if encoder_hidden_states is not None: + # add one self-attention block for cross-attention + if not hasattr(self, "crossattention"): + raise ValueError( + f"If `encoder_hidden_states` are passed, {self} has to be instantiated with " + "cross-attention layers by setting `config.add_cross_attention=True`" + ) + residual = hidden_states + hidden_states = self.ln_cross_attn(hidden_states) + cross_attn_outputs = self.crossattention( + hidden_states, + attention_mask=attention_mask, + head_mask=head_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + output_attentions=output_attentions, + ) + attn_output = cross_attn_outputs[0] + # residual connection + hidden_states = residual + attn_output + outputs = outputs + cross_attn_outputs[2:] # add cross attentions if we output attention weights + + residual = hidden_states + hidden_states = self.ln_2(hidden_states) + feed_forward_hidden_states = self.mlp(hidden_states) + # residual connection + hidden_states = residual + feed_forward_hidden_states + + if use_cache: + outputs = (hidden_states,) + outputs + else: + outputs = (hidden_states,) + outputs[1:] + + return outputs # hidden_states, present, (attentions, cross_attentions) + + +class GPT2PreTrainedModel(PreTrainedModel): + """ + An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained + models. + """ + + config_class = GPT2Config + load_tf_weights = load_tf_weights_in_gpt2 + base_model_prefix = "transformer" + is_parallelizable = True + supports_gradient_checkpointing = True + _no_split_modules = ["GPT2Block"] + _skip_keys_device_placement = "past_key_values" + + def __init__(self, *inputs, **kwargs): + super().__init__(*inputs, **kwargs) + + def _init_weights(self, module): + """Initialize the weights.""" + if isinstance(module, (nn.Linear, Conv1D)): + # Slightly different from the TF version which uses truncated_normal for initialization + # cf https://github.com/pytorch/pytorch/pull/5617 + module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + elif isinstance(module, nn.LayerNorm): + module.bias.data.zero_() + module.weight.data.fill_(1.0) + + # Reinitialize selected weights subject to the OpenAI GPT-2 Paper Scheme: + # > A modified initialization which accounts for the accumulation on the residual path with model depth. Scale + # > the weights of residual layers at initialization by a factor of 1/√N where N is the # of residual layers. + # > -- GPT-2 :: https://openai.com/blog/better-language-models/ + # + # Reference (Megatron-LM): https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/model/gpt_model.py + for name, p in module.named_parameters(): + if name == "c_proj.weight": + # Special Scaled Initialization --> There are 2 Layer Norms per Transformer Block + p.data.normal_(mean=0.0, std=(self.config.initializer_range / math.sqrt(2 * self.config.n_layer))) + + def _set_gradient_checkpointing(self, module, value=False): + if isinstance(module, GPT2Model): + module.gradient_checkpointing = value + + +@dataclass +class GPT2DoubleHeadsModelOutput(ModelOutput): + """ + Base class for outputs of models predicting if two sentences are consecutive or not. + + Args: + loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided): + Language modeling loss. + mc_loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `mc_labels` is provided): + Multiple choice classification loss. + logits (`torch.FloatTensor` of shape `(batch_size, num_choices, sequence_length, config.vocab_size)`): + Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). + mc_logits (`torch.FloatTensor` of shape `(batch_size, num_choices)`): + Prediction scores of the multiple choice classification head (scores for each choice before SoftMax). + past_key_values (`Tuple[Tuple[torch.Tensor]]`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + Tuple of length `config.n_layers`, containing tuples of tensors of shape `(batch_size, num_heads, + sequence_length, embed_size_per_head)`). + + Contains pre-computed hidden-states (key and values in the attention blocks) that can be used (see + `past_key_values` input) to speed up sequential decoding. + hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of + shape `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the initial embedding outputs. + attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + GPT2Attentions weights after the attention softmax, used to compute the weighted average in the + self-attention heads. + """ + + loss: Optional[torch.FloatTensor] = None + mc_loss: Optional[torch.FloatTensor] = None + logits: torch.FloatTensor = None + mc_logits: torch.FloatTensor = None + past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None + hidden_states: Optional[Tuple[torch.FloatTensor]] = None + attentions: Optional[Tuple[torch.FloatTensor]] = None + + +GPT2_START_DOCSTRING = r""" + + This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the + library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads + etc.) + + This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. + Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage + and behavior. + + Parameters: + config ([`GPT2Config`]): Model configuration class with all the parameters of the model. + Initializing with a config file does not load the weights associated with the model, only the + configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights. +""" + +GPT2_INPUTS_DOCSTRING = r""" + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, input_ids_length)`): + `input_ids_length` = `sequence_length` if `past_key_values` is `None` else + `past_key_values[0][0].shape[-2]` (`sequence_length` of input past key value states). Indices of input + sequence tokens in the vocabulary. + + If `past_key_values` is used, only `input_ids` that do not have their past calculated should be passed as + `input_ids`. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + past_key_values (`Tuple[Tuple[torch.Tensor]]` of length `config.n_layers`): + Contains precomputed hidden-states (key and values in the attention blocks) as computed by the model (see + `past_key_values` output below). Can be used to speed up sequential decoding. The `input_ids` which have + their past given to this model should not be passed as `input_ids` as they have already been computed. + attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + If `past_key_values` is used, `attention_mask` needs to contain the masking strategy that was used for + `past_key_values`. In other words, the `attention_mask` always has to have the length: + `len(past_key_values) + len(input_ids)` + + [What are attention masks?](../glossary#attention-mask) + token_type_ids (`torch.LongTensor` of shape `(batch_size, input_ids_length)`, *optional*): + Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0, + 1]`: + + - 0 corresponds to a *sentence A* token, + - 1 corresponds to a *sentence B* token. + + [What are token type IDs?](../glossary#token-type-ids) + position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, + config.max_position_embeddings - 1]`. + + [What are position IDs?](../glossary#position-ids) + head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*): + Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + + inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This + is useful if you want more control over how to convert `input_ids` indices into associated vectors than the + model's internal embedding lookup matrix. + + If `past_key_values` is used, optionally only the last `inputs_embeds` have to be input (see + `past_key_values`). + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see + `past_key_values`). + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. +""" +PARALLELIZE_DOCSTRING = r""" + This is an experimental feature and is a subject to change at a moment's notice. + + Uses a device map to distribute attention modules of the model across several devices. If no device map is given, + it will evenly distribute blocks across all devices. + + Args: + device_map (`Dict[int, list]`, optional, defaults to None): + A dictionary that maps attention modules to devices. Note that the embedding module and LMHead are always + automatically mapped to the first device (for esoteric reasons). That means that the first device should + have fewer attention modules mapped to it than other devices. For reference, the gpt2 models have the + following number of attention modules: + + - gpt2: 12 + - gpt2-medium: 24 + - gpt2-large: 36 + - gpt2-xl: 48 + + Example: + + ```python + # Here is an example of a device map on a machine with 4 GPUs using gpt2-xl, which has a total of 48 attention modules: + model = GPT2LMHeadModel.from_pretrained("gpt2-xl") + device_map = { + 0: [0, 1, 2, 3, 4, 5, 6, 7, 8], + 1: [9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21], + 2: [22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34], + 3: [35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47], + } + model.parallelize(device_map) + ``` +""" +DEPARALLELIZE_DOCSTRING = r""" + Moves the model to cpu from a model parallel state. + + Example: + + ```python + # On a 4 GPU machine with gpt2-large: + model = GPT2LMHeadModel.from_pretrained("gpt2-large") + device_map = { + 0: [0, 1, 2, 3, 4, 5, 6, 7], + 1: [8, 9, 10, 11, 12, 13, 14, 15], + 2: [16, 17, 18, 19, 20, 21, 22, 23], + 3: [24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35], + } + model.parallelize(device_map) # Splits the model across several devices + model.deparallelize() # Put the model back on cpu and cleans memory by calling torch.cuda.empty_cache() + ``` +""" + + +@add_start_docstrings( + "The bare GPT2 Model transformer outputting raw hidden-states without any specific head on top.", + GPT2_START_DOCSTRING, +) +class GPT2Model(GPT2PreTrainedModel): + def __init__(self, config): + super().__init__(config) + + self.embed_dim = config.hidden_size + + # self.wte = nn.Embedding(config.vocab_size, self.embed_dim) # ä¸ä¼šè¢«ç”¨åˆ°ï¼Œæ‰€ä»¥åŽ»æŽ‰ + # self.wpe = nn.Embedding(config.max_position_embeddings, self.embed_dim) + + self.drop = nn.Dropout(config.embd_pdrop) + self.h = nn.ModuleList([GPT2Block(config, layer_idx=i) for i in range(config.num_hidden_layers)]) + self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon) + + # Model parallel + self.model_parallel = False + self.device_map = None + self.gradient_checkpointing = False + + # Initialize weights and apply final processing + self.post_init() + + @add_start_docstrings(PARALLELIZE_DOCSTRING) + def parallelize(self, device_map=None): + # Check validity of device_map + warnings.warn( + "`GPT2Model.parallelize` is deprecated and will be removed in v5 of Transformers, you should load your" + " model with `device_map='balanced'` in the call to `from_pretrained`. You can also provide your own" + " `device_map` but it needs to be a dictionary module_name to device, so for instance {'h.0': 0, 'h.1': 1," + " ...}", + FutureWarning, + ) + self.device_map = ( + get_device_map(len(self.h), range(torch.cuda.device_count())) if device_map is None else device_map + ) + assert_device_map(self.device_map, len(self.h)) + self.model_parallel = True + self.first_device = "cpu" if "cpu" in self.device_map.keys() else "cuda:" + str(min(self.device_map.keys())) + self.last_device = "cuda:" + str(max(self.device_map.keys())) + self.wte = self.wte.to(self.first_device) + self.wpe = self.wpe.to(self.first_device) + # Load onto devices + for k, v in self.device_map.items(): + for block in v: + cuda_device = "cuda:" + str(k) + self.h[block] = self.h[block].to(cuda_device) + # ln_f to last + self.ln_f = self.ln_f.to(self.last_device) + + @add_start_docstrings(DEPARALLELIZE_DOCSTRING) + def deparallelize(self): + warnings.warn( + "Like `parallelize`, `deparallelize` is deprecated and will be removed in v5 of Transformers.", + FutureWarning, + ) + self.model_parallel = False + self.device_map = None + self.first_device = "cpu" + self.last_device = "cpu" + self.wte = self.wte.to("cpu") + self.wpe = self.wpe.to("cpu") + for index in range(len(self.h)): + self.h[index] = self.h[index].to("cpu") + self.ln_f = self.ln_f.to("cpu") + torch.cuda.empty_cache() + + def get_input_embeddings(self): + return self.wte + + def set_input_embeddings(self, new_embeddings): + self.wte = new_embeddings + + def _prune_heads(self, heads_to_prune): + """ + Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} + """ + for layer, heads in heads_to_prune.items(): + self.h[layer].attn.prune_heads(heads) + + @add_start_docstrings_to_model_forward(GPT2_INPUTS_DOCSTRING) + @add_code_sample_docstrings( + checkpoint=_CHECKPOINT_FOR_DOC, + output_type=BaseModelOutputWithPastAndCrossAttentions, + config_class=_CONFIG_FOR_DOC, + ) + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None, + attention_mask: Optional[torch.FloatTensor] = None, + token_type_ids: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, BaseModelOutputWithPastAndCrossAttentions]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") + elif input_ids is not None: + self.warn_if_padding_and_no_attention_mask(input_ids, attention_mask) + input_shape = input_ids.size() + input_ids = input_ids.view(-1, input_shape[-1]) + batch_size = input_ids.shape[0] + elif inputs_embeds is not None: + input_shape = inputs_embeds.size()[:-1] + batch_size = inputs_embeds.shape[0] + else: + raise ValueError("You have to specify either input_ids or inputs_embeds") + + device = input_ids.device if input_ids is not None else inputs_embeds.device + + if token_type_ids is not None: + token_type_ids = token_type_ids.view(-1, input_shape[-1]) + + if past_key_values is None: + past_length = 0 + past_key_values = tuple([None] * len(self.h)) + else: + past_length = past_key_values[0][0].size(-2) + if position_ids is None: + position_ids = torch.arange(past_length, input_shape[-1] + past_length, dtype=torch.long, device=device) + position_ids = position_ids.unsqueeze(0) + + # GPT2Attention mask. + if attention_mask is not None: + if batch_size <= 0: + raise ValueError("batch_size has to be defined and > 0") + attention_mask = attention_mask.view(batch_size, -1) + # We create a 3D attention mask from a 2D tensor mask. + # Sizes are [batch_size, 1, 1, to_seq_length] + # So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length] + # this attention mask is more simple than the triangular masking of causal attention + # used in OpenAI GPT, we just need to prepare the broadcast dimension here. + attention_mask = attention_mask[:, None, None, :] + + # Since attention_mask is 1.0 for positions we want to attend and 0.0 for + # masked positions, this operation will create a tensor which is 0.0 for + # positions we want to attend and the dtype's smallest value for masked positions. + # Since we are adding it to the raw scores before the softmax, this is + # effectively the same as removing these entirely. + attention_mask = attention_mask.to(dtype=self.dtype) # fp16 compatibility + attention_mask = (1.0 - attention_mask) * torch.finfo(self.dtype).min + + # If a 2D or 3D attention mask is provided for the cross-attention + # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length] + if self.config.add_cross_attention and encoder_hidden_states is not None: + encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size() + encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length) + if encoder_attention_mask is None: + encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device) + encoder_attention_mask = self.invert_attention_mask(encoder_attention_mask) + else: + encoder_attention_mask = None + + # Prepare head mask if needed + # 1.0 in head_mask indicate we keep the head + # attention_probs has shape bsz x n_heads x N x N + # head_mask has shape n_layer x batch x n_heads x N x N + head_mask = self.get_head_mask(head_mask, self.config.n_layer) + + if inputs_embeds is None: + inputs_embeds = self.wte(input_ids) + # position_embeds = self.wpe(position_ids) + hidden_states = inputs_embeds # + position_embeds + + if token_type_ids is not None: + token_type_embeds = self.wte(token_type_ids) + hidden_states = hidden_states + token_type_embeds + + hidden_states = self.drop(hidden_states) + + output_shape = (-1,) + input_shape[1:] + (hidden_states.size(-1),) + + if self.gradient_checkpointing and self.training: + if use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." + ) + use_cache = False + + presents = () if use_cache else None + all_self_attentions = () if output_attentions else None + all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None + all_hidden_states = () if output_hidden_states else None + for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)): + # Model parallel + if self.model_parallel: + torch.cuda.set_device(hidden_states.device) + # Ensure layer_past is on same device as hidden_states (might not be correct) + if layer_past is not None: + layer_past = tuple(past_state.to(hidden_states.device) for past_state in layer_past) + # Ensure that attention_mask is always on the same device as hidden_states + if attention_mask is not None: + attention_mask = attention_mask.to(hidden_states.device) + if isinstance(head_mask, torch.Tensor): + head_mask = head_mask.to(hidden_states.device) + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + if self.gradient_checkpointing and self.training: + + def create_custom_forward(module): + def custom_forward(*inputs): + # None for past_key_value + return module(*inputs, use_cache, output_attentions) + + return custom_forward + + outputs = torch.utils.checkpoint.checkpoint( + create_custom_forward(block), + hidden_states, + None, + attention_mask, + head_mask[i], + encoder_hidden_states, + encoder_attention_mask, + ) + else: + outputs = block( + hidden_states, + layer_past=layer_past, + attention_mask=attention_mask, + head_mask=head_mask[i], + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + use_cache=use_cache, + output_attentions=output_attentions, + ) + + hidden_states = outputs[0] + if use_cache is True: + presents = presents + (outputs[1],) + + if output_attentions: + all_self_attentions = all_self_attentions + (outputs[2 if use_cache else 1],) + if self.config.add_cross_attention: + all_cross_attentions = all_cross_attentions + (outputs[3 if use_cache else 2],) + + # Model Parallel: If it's the last layer for that device, put things on the next device + if self.model_parallel: + for k, v in self.device_map.items(): + if i == v[-1] and "cuda:" + str(k) != self.last_device: + hidden_states = hidden_states.to("cuda:" + str(k + 1)) + + hidden_states = self.ln_f(hidden_states) + + hidden_states = hidden_states.view(output_shape) + # Add last hidden state + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + if not return_dict: + return tuple( + v + for v in [hidden_states, presents, all_hidden_states, all_self_attentions, all_cross_attentions] + if v is not None + ) + + return BaseModelOutputWithPastAndCrossAttentions( + last_hidden_state=hidden_states, + past_key_values=presents, + hidden_states=all_hidden_states, + attentions=all_self_attentions, + cross_attentions=all_cross_attentions, + ) diff --git a/ACT_DP_multitask/detr/models/mr_mg/policy/model/model.py b/ACT_DP_multitask/detr/models/mr_mg/policy/model/model.py new file mode 100644 index 0000000000000000000000000000000000000000..c567481df54273008197973adf0f478d765dbe35 --- /dev/null +++ b/ACT_DP_multitask/detr/models/mr_mg/policy/model/model.py @@ -0,0 +1,611 @@ +# Copyright (2024) Bytedance Ltd. and/or its affiliates + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import torch +import torch.nn as nn +from torch.autograd import Variable +import numpy as np + +import transformers +from flamingo_pytorch import PerceiverResampler +import clip +from policy.model.gpt2 import GPT2Model +from policy.utils.dist_train import print as dis_print +import os +from policy.model.vision_transformer import Block,get_2d_sincos_pos_embed + +def reparameterize(mu, logvar): + std = logvar.div(2).exp() + eps = Variable(std.data.new(std.size()).normal_()) + return mu + std * eps + + +class GR_MG(nn.Module): + def __init__( + self, + state_dim, + act_dim, + act_len, + act_latent_dim, + act_encoder_dim, + act_decoder_dim, + progress_decoder_dim, + hidden_size, + model_mae, + clip_model, + img_feat_dim, + lang_feat_dim, + patch_feat_dim, + resampler_params, + max_length=None, + training_target=['act_pred'], + without_norm_pix_loss=False, + use_hand_rgb=False, + use_state=False, + use_resampler=True, + **kwargs): + super(GR_MG, self).__init__() + self.state_dim = state_dim + self.act_dim = act_dim + self.max_length = max_length + self.hidden_size = hidden_size + config = transformers.GPT2Config( + vocab_size=1, # doesn't matter -- we don't use the vocab + n_embd=hidden_size, + **kwargs + ) + # note: the difference between this GPT2Model and the default Huggingface version + # is that the positional embeddings are removed (since we'll add those ourselves) + # and wte is removed ( to set find_unused_parameters=False) + self.transformer = GPT2Model(config) + transformer_params = sum(p.numel() for p in self.transformer.parameters() if p.requires_grad) + dis_print(f"Transformer Parameters: {transformer_params / 1000000:.2f}M") + + self.use_resampler = use_resampler + if self.use_resampler: + self.n_patch_latents = resampler_params['num_latents'] + self.perceiver_resampler = PerceiverResampler( + dim=patch_feat_dim, + depth=resampler_params['depth'], + dim_head=resampler_params['dim_head'], + heads=resampler_params['heads'], + num_latents=self.n_patch_latents, + num_media_embeds=resampler_params['num_media_embeds']) + + resampler_params = sum(p.numel() for p in self.perceiver_resampler.parameters() if p.requires_grad) + dis_print(f"Perceiver Resampler Parameters: {resampler_params / 1000000:.2f}M") + + self.model_mae = model_mae + + self.act_len = act_len + self.act_latent_dim = act_latent_dim + self.act_encoder_dim = act_encoder_dim + self.act_decoder_dim = act_decoder_dim + self.progress_decoder_dim=progress_decoder_dim + + self.use_hand_rgb = use_hand_rgb + self.use_state = use_state + + self.text_tokenizer=clip.tokenize + self.text_encoder=clip_model + + + self.lang_feat_dim=lang_feat_dim # hardcode + self.img_feat_dim = img_feat_dim + self.patch_feat_dim = patch_feat_dim + self.n_patches = 49 # TODO: hardcode + self.patch_size = 16 # TODO: hardcode + self.image_size = 224 # TODO: hardcode + + self.act_pred = False + self.fwd_pred = False + self.fwd_pred_hand = False + self.progress_pred=False + if 'act_pred' in training_target: + self.act_pred = True + if 'fwd_pred' in training_target: + self.fwd_pred = True + if 'fwd_pred_hand' in training_target: + self.fwd_pred_hand = True + if 'progress_pred' in training_target: + self.progress_pred = True + + self.without_norm_pix_loss = without_norm_pix_loss + if self.use_state: + # state embedding + self.embed_arm_state = torch.nn.Linear(self.state_dim-1, self.hidden_size) + self.embed_gripper_state = torch.nn.Linear(2, self.hidden_size) # one-hot gripper state + self.embed_state = torch.nn.Linear(2*self.hidden_size, self.hidden_size) + + + # Embedding function for languages + self.embed_lang = torch.nn.Linear(self.lang_feat_dim, self.hidden_size) + # relative timestep embedding + self.embed_timestep = nn.Embedding(self.max_length, self.hidden_size) + + # image token embedding + if self.use_hand_rgb: + self.embed_hand_img = torch.nn.Linear(self.img_feat_dim, self.hidden_size) + self.embed_img = torch.nn.Linear(self.img_feat_dim, self.hidden_size) + self.embed_goal_image = torch.nn.Linear(self.img_feat_dim, self.hidden_size) + + # patch token embedding + if self.use_hand_rgb: + self.embed_hand_patch = torch.nn.Linear(self.patch_feat_dim, self.hidden_size) + self.embed_patch = torch.nn.Linear(self.patch_feat_dim, self.hidden_size) + self.embed_goal_patch = torch.nn.Linear(self.patch_feat_dim, self.hidden_size) + + # layer norm + self.embed_ln = nn.LayerNorm(self.hidden_size) + if self.act_pred: + # action query [ACT] + self.action_queries = nn.Embedding(1, self.hidden_size) # arm + gripper + + # action encoder (embed action trajectory as style vector) + self.embed_arm_action = torch.nn.Linear(self.act_dim - 1, self.act_encoder_dim) + self.embed_gripper_action = torch.nn.Embedding(2, self.act_encoder_dim) + self.embed_action = nn.Linear(2 * self.act_encoder_dim, self.act_encoder_dim) + self.action_encoder_cls_token = torch.nn.Embedding(1, self.act_encoder_dim) + action_encoder_depth = 4 + self.encode_action = nn.ModuleList([ + Block(self.act_encoder_dim, 8, 4, qkv_bias=True, qk_scale=None, norm_layer=nn.LayerNorm) + for i in range(action_encoder_depth)]) + self.action_encoder_positional_embeddings = nn.Embedding(self.act_len + 1, self.act_encoder_dim) + self.pred_style_vector = nn.Linear(self.act_encoder_dim, 2 * self.act_latent_dim) + self.embed_style_vector = nn.Linear(self.act_latent_dim, self.act_decoder_dim) + + # action decoder + self.proj_action_output_embed = nn.Linear(self.hidden_size, self.act_decoder_dim) + action_decoder_depth = 4 + self.decode_action = nn.ModuleList([ + Block(self.act_decoder_dim, 8, 4, qkv_bias=True, qk_scale=None, norm_layer=nn.LayerNorm) + for i in range(action_decoder_depth)]) + self.action_mask_token_embedding = nn.Embedding(1, self.act_decoder_dim) + self.action_decoder_positional_embeddings = nn.Embedding(self.act_len, self.act_decoder_dim) + self.pred_arm_act = nn.Linear(self.act_decoder_dim, self.act_dim - 1) # arm action + self.pred_gripper_act = nn.Linear(self.act_decoder_dim, 1) # gripper action (binary) + + # predict future image + if self.fwd_pred: + # add observation query for fwd prediction + self.obs_queries = nn.Embedding(self.n_patch_latents+1, self.hidden_size) # cls+resampler + if self.use_hand_rgb: + self.obs_hand_queries = nn.Embedding(self.n_patch_latents+1, self.hidden_size) # cls+resampler + self.decoder_embed = nn.Linear(self.hidden_size, self.hidden_size, bias=True) + self.mask_token = nn.Parameter(torch.zeros(1, 1, 1, self.hidden_size)) + # fixed sin-cos embedding + self.decoder_pos_embed = nn.Parameter(torch.zeros(1, (self.image_size//self.patch_size)**2, + self.hidden_size), requires_grad=False) # (1, n_patch, h) + + decoder_depth = 2 # hardcode + self.decoder_blocks = nn.ModuleList([ + Block(self.hidden_size, 16, 4, qkv_bias=True, qk_scale=None, norm_layer=nn.LayerNorm) + for i in range(decoder_depth)]) + + self.decoder_norm = nn.LayerNorm(self.hidden_size) + self.decoder_pred = nn.Linear(self.hidden_size, self.patch_size**2 * 3, bias=True) # decoder to patch + + decoder_pos_embed = get_2d_sincos_pos_embed(self.decoder_pos_embed.shape[-1], (self.image_size//self.patch_size), cls_token=False) + self.decoder_pos_embed.data.copy_(torch.from_numpy(decoder_pos_embed).float().unsqueeze(0)) + + fwd_params = sum(p.numel() for p in self.decoder_blocks.parameters() if p.requires_grad) + dis_print(f"Fwd Decoder Parameters: {fwd_params / 1000000:.2f}M") + if self.progress_pred: + self.progress_queries=nn.Embedding(1, self.hidden_size) # [PROG] + # progress decoder + self.proj_progress_output_embed = nn.Linear(self.hidden_size, self.progress_decoder_dim) + progress_decoder_depth = 2 + self.decode_progress = nn.ModuleList([ + Block(self.progress_decoder_dim, 8, 4, qkv_bias=True, qk_scale=None, norm_layer=nn.LayerNorm) + for i in range(progress_decoder_depth)]) + self.progress_mask_token_embedding = nn.Embedding(1, self.progress_decoder_dim) + self.pred_progress = nn.Linear(self.progress_decoder_dim, 1) # pred progress + self.sigmoid_progress=nn.Sigmoid() + + def encode_texts(self, texts): + inputs = self.text_tokenizer(texts) + device = next(self.text_encoder.parameters()).device + with torch.no_grad(): + encoder_hidden_states = self.text_encoder.encode_text(inputs.to(device)) + return encoder_hidden_states + + def forward(self, input_dict, is_training=True): + goal_rgb = input_dict['goal_rgb'] # (b, c, h, w) + rgb = input_dict['rgb'] # (b, l, c, h, w) + hand_rgb = input_dict['hand_rgb'] # (b, l, c, h, w) + attention_mask = input_dict['attention_mask'] # (b, l) + text=input_dict["text"][0] + progress_targets=input_dict["progress"] #(b,l,) + + obs_preds = None + obs_hand_preds = None + obs_target = None + obs_hand_target = None + arm_action_preds = None + gripper_action_preds = None + action_mu_preds = None + action_logvar_preds = None + progress_preds=None + + batch_size, seq_length, c, h, w = rgb.shape + + + if self.use_state: + arm_state = input_dict['arm_state'] # (b, l, state_dim - 1) + gripper_state = input_dict['gripper_state'] # (b, l, 2) + arm_state_embeddings = self.embed_arm_state(arm_state.view(batch_size, seq_length, self.state_dim-1)) # (b, l, h) + gripper_state_embeddings = self.embed_gripper_state(gripper_state) # (b, l, h) + state_embeddings = torch.cat((arm_state_embeddings, gripper_state_embeddings), dim=2) # (b, l, 2h) + state_embeddings = self.embed_state(state_embeddings) # (b, l, h) + + # goal rgb mae feature + goal_obs_embeddings, goal_patch_embeddings = self.model_mae(goal_rgb) # (b, img_feat_dim), (b, 196, patch_feat_dim) + + # rgb mae feature + obs_embeddings, patch_embeddings = self.model_mae( + rgb.view(batch_size*seq_length, c, h, w)) # (b * l, img_feat_dim), (b * l, 196, patch_feat_dim) + obs_embeddings = obs_embeddings.view(batch_size, seq_length, -1) # (b, l, img_feat_dim) + + # hand rgb mae feature + if self.use_hand_rgb: + hand_obs_embeddings, hand_patch_embeddings = self.model_mae( + hand_rgb.view(batch_size*seq_length, c, h, w)) # (b * l, img_feat_dim), (b * l, 196, patch_feat_dim) + hand_obs_embeddings = hand_obs_embeddings.view(batch_size, seq_length, -1) # (b, l, img_feat_dim) + + # compute obs target + if self.fwd_pred: + p = self.patch_size + h_p = h // p + w_p = w // p + rgb = rgb.reshape(shape=(batch_size, seq_length, 3, h_p, p, w_p, p)) # b,len,3,14,p,14,p + obs_target = rgb.permute(0, 1, 3, 5, 4, 6, 2) # b,len,14,14,p,p,3 + obs_target = obs_target.reshape(shape=(batch_size, seq_length, h_p * w_p, (p**2) * 3)) # b,len,14x14,p*p*3 + if not self.without_norm_pix_loss: + # norm the target + obs_target = (obs_target - obs_target.mean(dim=-1, keepdim=True) + ) / (obs_target.var(dim=-1, unbiased=True, keepdim=True).sqrt() + 1e-6) + + if self.fwd_pred_hand: + hand_rgb = hand_rgb.reshape(shape=(batch_size, seq_length, 3, h_p, p, w_p, p)) # b,len,3,14,p,14,p + obs_hand_target = hand_rgb.permute(0, 1, 3, 5, 4, 6, 2) # b,len,14,14,p,p,3 + obs_hand_target = obs_hand_target.reshape(shape=(batch_size, seq_length, h_p * w_p, (p**2)*3)) + if not self.without_norm_pix_loss: + # norm the target + obs_hand_target = (obs_hand_target - obs_hand_target.mean(dim=-1, keepdim=True) + ) / (obs_hand_target.var(dim=-1, unbiased=True, keepdim=True).sqrt() + 1e-6) + + if self.use_resampler: + goal_patch_embeddings = goal_patch_embeddings.unsqueeze(1) # (b, 1, 196, patch_feat_dim) + goal_patch_embeddings = self.perceiver_resampler(goal_patch_embeddings) # (b, 1, 9, patch_feat_dim) + goal_patch_embeddings = goal_patch_embeddings.squeeze(1) # (b, 9, patch_feat_dim) + + patch_embeddings = patch_embeddings.unsqueeze(1) # (b * l, 1, 196, patch_feat_dim) + patch_embeddings = self.perceiver_resampler(patch_embeddings) # (b * l, 1, 9, patch_feat_dim) + patch_embeddings = patch_embeddings.squeeze(1) # (b * l, 9, patch_feat_dim) + patch_embeddings = patch_embeddings.view(batch_size, seq_length, self.n_patch_latents, self.patch_feat_dim) # (b, l, 9, patch_feat_dim) + + if self.use_hand_rgb: + hand_patch_embeddings = hand_patch_embeddings.unsqueeze(1) # (b * l, 1, 196, patch_feat_dim) + hand_patch_embeddings = self.perceiver_resampler(hand_patch_embeddings) # (b * l, 1, 9, patch_feat_dim) + hand_patch_embeddings = hand_patch_embeddings.squeeze(1) # (b * l, 9, patch_feat_dim) + hand_patch_embeddings = hand_patch_embeddings.view(batch_size, seq_length, self.n_patch_latents, self.patch_feat_dim) # (b, l, 9, patch_feat_dim) + else: + raise NotImplementedError + + + # Embed language + lang_embeddings = self.encode_texts(text) + lang_embeddings = lang_embeddings / (lang_embeddings.norm(dim=1, keepdim=True) + 1e-6) # normalization + lang_embeddings = self.embed_lang(lang_embeddings.float()) # (b, h) + + # embed images and patches + goal_obs_embeddings = self.embed_goal_image(goal_obs_embeddings) # (b, h) + goal_patch_embeddings = self.embed_goal_patch(goal_patch_embeddings) # (b, 9, h) + obs_embeddings = self.embed_img(obs_embeddings) # (b, l, h) + patch_embeddings = self.embed_patch(patch_embeddings) # (b, l, 9, h) + if self.use_hand_rgb: + hand_obs_embeddings = self.embed_hand_img(hand_obs_embeddings) # (b, l, h) + hand_patch_embeddings = self.embed_hand_patch(hand_patch_embeddings) # (b, l, 9, h) + + # add timestep embeddings + time_embeddings = self.embed_timestep.weight # (l, h) + + lang_embeddings = lang_embeddings.view(batch_size, 1, -1).repeat(1,seq_length,1) + time_embeddings# 注æ„debug + patch_embeddings = patch_embeddings + time_embeddings.view(seq_length, 1, self.hidden_size) + obs_embeddings = obs_embeddings + time_embeddings + + + if self.use_hand_rgb: + hand_obs_embeddings = hand_obs_embeddings + time_embeddings + hand_patch_embeddings = hand_patch_embeddings + time_embeddings.view(seq_length, 1, self.hidden_size) + if self.use_state: + state_embeddings = state_embeddings + time_embeddings + + # Format sequence: lang, state, patch, obs, hand_patch, hand_obs, [ACT], [OBS], [OBS_HAND],[PROG] + lang_embeddings = lang_embeddings.view(batch_size, seq_length, 1, self.hidden_size) + obs_embeddings = obs_embeddings.view(batch_size, seq_length, 1, self.hidden_size) + if self.use_state: + state_embeddings = state_embeddings.view(batch_size, seq_length, 1, self.hidden_size) + stacked_inputs = torch.cat((lang_embeddings,state_embeddings, patch_embeddings, obs_embeddings), dim=2) # (b, l, n_tokens, h) + else: + stacked_inputs = torch.cat((lang_embeddings,patch_embeddings, obs_embeddings), dim=2) # (b, l, n_tokens, h) + if self.use_hand_rgb: + hand_obs_embeddings = hand_obs_embeddings.view(batch_size, seq_length, 1, self.hidden_size) + stacked_inputs = torch.cat((stacked_inputs, hand_patch_embeddings, hand_obs_embeddings), dim=2) # (b, l, n_tokens, h) + if self.act_pred: + action_queries = self.action_queries.weight # (1, h) + action_queries = action_queries.view(1, 1, 1, self.hidden_size).repeat(batch_size, seq_length, 1, 1) # (b, l, 1, h) + stacked_inputs = torch.cat((stacked_inputs, action_queries), dim=2) # (b, l, n_tokens, h) + if self.fwd_pred: + obs_queries = self.obs_queries.weight # (10, h) + obs_queries = obs_queries.view(1, 1, self.n_patch_latents + 1, self.hidden_size).repeat(batch_size, seq_length, 1, 1) + stacked_inputs = torch.cat((stacked_inputs, obs_queries), dim=2) + if self.fwd_pred_hand: + obs_hand_queries = self.obs_hand_queries.weight # (10, h) + obs_hand_queries = obs_hand_queries.view(1, 1, self.n_patch_latents + 1, self.hidden_size).repeat(batch_size, seq_length, 1, 1) + stacked_inputs = torch.cat((stacked_inputs, obs_hand_queries), dim=2) # (b, l, n_tokens, h) + if self.progress_pred: + progress_queries = self.progress_queries.weight # (1, h) + progress_queries = progress_queries.view(1, 1, 1, self.hidden_size).repeat(batch_size, seq_length, 1, 1) # (b, l, 1, h) + stacked_inputs = torch.cat((stacked_inputs, progress_queries), dim=2) # (b, l, n_tokens, h) + # number of tokens for different modalities + n_lang_tokens = 1 + n_state_tokens = 1 + n_patch_tokens = self.n_patch_latents + n_obs_tokens = 1 + n_hand_patch_tokens = self.n_patch_latents + n_hand_obs_tokens = 1 + n_act_pred_tokens = 1 + n_fwd_pred_tokens = n_patch_tokens + n_obs_tokens + n_fwd_pred_hand_tokens = n_patch_tokens + n_obs_tokens + n_progress_pred_tokens=1 + # compute number of tokens (does not include the conditioned goal image tokens) + n_tokens = n_lang_tokens + if self.use_state: + n_tokens += n_state_tokens + n_tokens += n_patch_tokens + n_tokens += n_obs_tokens + if self.use_hand_rgb: + n_tokens += n_hand_obs_tokens + n_tokens += n_hand_patch_tokens + if self.act_pred: + act_pred_token_i = n_tokens + n_tokens += n_act_pred_tokens + if self.fwd_pred: + obs_pred_token_i = n_tokens + n_tokens += n_fwd_pred_tokens + if self.fwd_pred_hand: + obs_pred_hand_token_i = n_tokens + n_tokens += n_fwd_pred_hand_tokens + if self.progress_pred: + progress_pred_token_i = n_tokens + n_tokens += n_progress_pred_tokens + # number of condtioned tokens (goal image) + n_condtioned_tokens = 1 + self.n_patch_latents + + # add goal image conditions at the front + stacked_inputs = stacked_inputs.reshape(batch_size, n_tokens * seq_length, self.hidden_size) + goal_obs_embeddings = goal_obs_embeddings.view(batch_size, 1, self.hidden_size) + stacked_inputs = torch.cat((goal_patch_embeddings, goal_obs_embeddings, stacked_inputs), dim=1) # (b, l * n_tokens + n_patch_latents + 1, h) + assert stacked_inputs.shape == (batch_size, seq_length * n_tokens + n_condtioned_tokens, self.hidden_size) + + # layer norm + stacked_inputs = self.embed_ln(stacked_inputs) + + # generate attention mask + attn_mask = attention_mask.view(batch_size, 1, seq_length) + lang_attn_mask=attn_mask.repeat(1, n_lang_tokens, 1) + state_attn_mask = attn_mask.repeat(1, n_state_tokens, 1) + patch_attn_mask = attn_mask.repeat(1, n_patch_tokens, 1) + obs_attn_mask = attn_mask.repeat(1, n_obs_tokens, 1) + hand_patch_attn_mask = attn_mask.repeat(1, n_hand_patch_tokens, 1) + hand_obs_attn_mask = attn_mask.repeat(1, n_hand_obs_tokens, 1) + + if self.use_state: + stacked_attn_mask = torch.cat((lang_attn_mask,state_attn_mask, patch_attn_mask, obs_attn_mask), dim=1) + else: + stacked_attn_mask = torch.cat((lang_attn_mask,patch_attn_mask, obs_attn_mask), dim=1) + if self.use_hand_rgb: + stacked_attn_mask = torch.cat((stacked_attn_mask, hand_patch_attn_mask, hand_obs_attn_mask), dim=1) + if self.act_pred: + act_pred_attn_mask = torch.zeros((batch_size, n_act_pred_tokens, seq_length), dtype=torch.long).cuda() + stacked_attn_mask = torch.cat((stacked_attn_mask, act_pred_attn_mask), dim=1) + if self.fwd_pred: + fwd_pred_attn_mask = torch.zeros((batch_size, n_fwd_pred_tokens, seq_length), dtype=torch.long).cuda() + stacked_attn_mask = torch.cat((stacked_attn_mask, fwd_pred_attn_mask), dim=1) + if self.fwd_pred_hand: + fwd_pred_hand_attn_mask = torch.zeros((batch_size, n_fwd_pred_hand_tokens, seq_length), dtype=torch.long).cuda() + stacked_attn_mask = torch.cat((stacked_attn_mask, fwd_pred_hand_attn_mask), dim=1) + if self.progress_pred: + progress_pred_attn_mask = torch.zeros((batch_size, n_progress_pred_tokens, seq_length), dtype=torch.long).cuda() + stacked_attn_mask = torch.cat((stacked_attn_mask, progress_pred_attn_mask), dim=1) + + stacked_attn_mask = stacked_attn_mask.permute(0, 2, 1) # (b, l, n_tokens) + stacked_attn_mask = stacked_attn_mask.reshape(batch_size, n_tokens * seq_length) # (b, l * n_tokens) + goal_obs_attn_mask = torch.ones((batch_size, 1), dtype=torch.long).cuda() + goal_patch_attn_mask = torch.ones((batch_size, self.n_patch_latents), dtype=torch.long).cuda() + stacked_attn_mask = torch.cat((goal_patch_attn_mask, goal_obs_attn_mask, stacked_attn_mask), dim=1) # (b, l * n_tokens + n_patch_latens + 1) + assert stacked_attn_mask.shape == (batch_size, seq_length * n_tokens + n_condtioned_tokens) + + # we feed in the input embeddings (not word indices as in NLP) to the model + transformer_outputs = self.transformer( + inputs_embeds=stacked_inputs, + attention_mask=stacked_attn_mask, + ) + x = transformer_outputs['last_hidden_state'] + x = x[:, n_condtioned_tokens:] + x = x.reshape(batch_size, seq_length, n_tokens, self.hidden_size) # (b, l, n_tokens, h) + + # action prediction: predict next action given obs + # format sequence + if self.act_pred: + action_output_embedding = x[:, :, act_pred_token_i] # (b, l, h) + + # encode action + arm_action = input_dict['arm_action'] # b,len,act_len,act_dim-1 + gripper_action = input_dict['gripper_action'].long() # b,len,act_len + arm_action_embeddings = self.embed_arm_action(arm_action) # b,len,act_len,act_encoder_dim + gripper_action_embeddings = self.embed_gripper_action(gripper_action) # b,len,act_len,act_encoder_dim + action_embeddings = torch.cat((arm_action_embeddings, gripper_action_embeddings), dim=-1) # b,len,act_len,2*act_encoder_dim + action_embeddings = self.embed_action(action_embeddings) # b,len,act_len,act_encoder_dim + cls_token_embeddings = self.action_encoder_cls_token.weight # 1,act_encoder_dim + cls_token_embeddings = cls_token_embeddings.unsqueeze(0).unsqueeze(0).repeat(batch_size, seq_length, 1, 1) # b,len,1,act_encoder_dim + z = torch.cat((cls_token_embeddings, action_embeddings), dim=2) # b,len,1+act_len,act_encoder_dim + action_encoder_positional_embeddings = self.action_encoder_positional_embeddings.weight # 1+act_len,act_encoder_dim + z = z + action_encoder_positional_embeddings # b,len,1+act_len,act_encoder_dim + z = z.reshape(batch_size * seq_length, self.act_len + 1, self.act_encoder_dim) # b*len,1+1+act_len,act_encoder_dim + for blk in self.encode_action: + z = blk(z) + action_latent_embedding = z[:, 0] # b*len,act_encoder_dim + action_latent_embedding = action_latent_embedding.reshape(batch_size, seq_length, self.act_encoder_dim) # b,len,act_encoder_dim + action_latent_preds = self.pred_style_vector(action_latent_embedding) # b,len,2*act_latent_dim + action_mu_preds = action_latent_preds[:, :, :self.act_latent_dim] # b,len,act_latent_dim + action_logvar_preds = action_latent_preds[:, :, self.act_latent_dim:] # b,len,act_latent_dim + # sample style vector + action_mu_preds = action_mu_preds.view(-1, self.act_latent_dim) + action_logvar_preds = action_logvar_preds.view(-1, self.act_latent_dim) + action_style_vector = reparameterize(action_mu_preds, action_logvar_preds) # b*len,act_latent_dim + action_style_vector = action_style_vector.view(batch_size, seq_length, self.act_latent_dim) # b,len,act_latent_dim + if not is_training: # we set the mean=0 and var=1 during inference + action_style_vector = torch.zeros([batch_size, seq_length, self.act_latent_dim], dtype=torch.float32).to(rgb.device) + action_style_vector = action_style_vector.type_as(arm_action) + action_style_embeddings = self.embed_style_vector(action_style_vector) # b,len,act_decoder_dimv + + + action_output_embedding = self.proj_action_output_embed(action_output_embedding) # (b, l, act_decoder_dim) + action_mask_token = self.action_mask_token_embedding.weight # (1, act_decoder_dim) + action_mask_token = action_mask_token.unsqueeze(0).unsqueeze(0).repeat(batch_size, seq_length, self.act_len, 1) # (b, l, act_len, act_decoder_dim) + action_output_embedding = action_output_embedding.view(batch_size, seq_length, 1, self.act_decoder_dim) + action_decoder_positional_embeddings = self.action_decoder_positional_embeddings.weight + + action_style_embeddings = action_style_embeddings.view(batch_size, seq_length, 1, self.act_decoder_dim) + y = torch.cat((action_style_embeddings, action_output_embedding, action_mask_token), dim=2) # (b, l, 2 + act_len, act_decoder_dim) + y[:, :, 2:] = y[:, :, 2:] + action_decoder_positional_embeddings + y = y.reshape(batch_size * seq_length, 2 + self.act_len, self.act_decoder_dim) + + + # forward transformer + for blk in self.decode_action: + y = blk(y) + + + action_decoder_output_embeddings = y[:, 2:] # (b * l, act_len, act_decoder_dim) + + + + action_decoder_output_embeddings = action_decoder_output_embeddings.reshape(batch_size, seq_length, self.act_len, self.act_decoder_dim) + arm_action_preds = self.pred_arm_act(action_decoder_output_embeddings) # (b, l, act_len, act_dim - 1) + gripper_action_preds = self.pred_gripper_act(action_decoder_output_embeddings) # (b, l, act_len, 1) + + # forward prediction: predict next obs + if self.fwd_pred: + mask_token = self.mask_token # (1, 1, 1, h) + mask_tokens = mask_token.repeat(batch_size, seq_length, (self.image_size//self.patch_size)**2, 1) # (b, l, n_patches, h) + mask_tokens = mask_tokens + self.decoder_pos_embed.unsqueeze(0).repeat(batch_size, seq_length, 1, 1) # (b, l, n_patch, h) + + obs_pred = self.decoder_embed(x[:, :, obs_pred_token_i:(obs_pred_token_i + n_fwd_pred_tokens)]) # (b, l, n_patch_latents + 1, h) + obs_pred_ = torch.cat([obs_pred, mask_tokens], dim=2) # (b, l, n_patch + n_patch_latens + 1, h) + obs_pred_ = obs_pred_.reshape(-1, obs_pred_.shape[-2], obs_pred_.shape[-1]) # (b * l, n_patch + n_patch_latens + 1, h) + for blk in self.decoder_blocks: + obs_pred_ = blk(obs_pred_) + obs_pred_ = self.decoder_norm(obs_pred_) + obs_preds = self.decoder_pred(obs_pred_) # (b * l, n_patch + n_patch_latens + 1, h) + obs_preds = obs_preds.reshape(batch_size, seq_length, -1, obs_preds.shape[-1]) # (b, l, n_patch + n_patch_latens + 1, h) + obs_preds = obs_preds[:, :, n_fwd_pred_tokens:] # (b, len, n_patch, h) + + if self.fwd_pred_hand: + obs_pred_hand = self.decoder_embed(x[:, :, obs_pred_hand_token_i:(obs_pred_hand_token_i + n_fwd_pred_hand_tokens)]) # (b, l, n_patch_latents + 1, h) + obs_pred_hand_ = torch.cat([obs_pred_hand, mask_tokens], dim=2) # (b, l, n_patch + n_patch_latens + 1, h) + obs_pred_hand_ = obs_pred_hand_.reshape(-1, obs_pred_hand_.shape[-2], obs_pred_hand_.shape[-1]) # (b * l, n_patch + n_patch_latens + 1, h) + for blk in self.decoder_blocks: + obs_pred_hand_ = blk(obs_pred_hand_) + obs_pred_hand_ = self.decoder_norm(obs_pred_hand_) + obs_hand_preds = self.decoder_pred(obs_pred_hand_) # (b * l, n_patch + n_patch_latens + 1, h) + obs_hand_preds = obs_hand_preds.reshape(batch_size, seq_length, -1, obs_hand_preds.shape[-1]) # (b, l, n_patch + n_patch_latens + 1, h) + obs_hand_preds = obs_hand_preds[:, :, n_fwd_pred_hand_tokens:] # (b, l, n_patch, h) + + # progress prediction + # format sequence + if self.progress_pred: + progress_output_embedding = x[:, :, progress_pred_token_i] # (b, l, h) + progress_output_embedding = self.proj_progress_output_embed(progress_output_embedding) # (b, l, progress_decoder_dim) + progress_mask_token = self.progress_mask_token_embedding.weight # (1, progress_decoder_dim) + progress_mask_token = progress_mask_token.unsqueeze(0).unsqueeze(0).repeat(batch_size, seq_length, 1, 1) # (b, l, 1, progress_decoder_dim) + progress_output_embedding = progress_output_embedding.view(batch_size, seq_length, 1, self.progress_decoder_dim) + y = torch.cat((progress_output_embedding, progress_mask_token), dim=2) # (b, l, 1 + 1, progress_decoder_dim) + y = y.reshape(batch_size * seq_length, 1 + 1, self.act_decoder_dim) + + # forward transformer + for blk in self.decode_progress: + y = blk(y) + + # get output + progress_decoder_output_embeddings = y[:, 1:] # (b * l, 1, progress_decoder_dim) + progress_decoder_output_embeddings = progress_decoder_output_embeddings.reshape(batch_size, seq_length, 1, self.progress_decoder_dim) + progress_preds = self.pred_progress(progress_decoder_output_embeddings).squeeze() # (b, l) + progress_preds=self.sigmoid_progress(progress_preds) + + + + + prediction = { + 'obs_preds': obs_preds, + 'obs_target': obs_target, + 'obs_hand_preds': obs_hand_preds, + 'obs_hand_target': obs_hand_target, + 'arm_action_preds': arm_action_preds, + 'gripper_action_preds': gripper_action_preds, + 'action_mu_preds': action_mu_preds, + 'action_logvar_preds': action_logvar_preds, + 'progress_preds':progress_preds, + 'progress_targets':progress_targets, + } + return prediction + + + def evaluate(self, input_dict,original_gripper=False,return_progress=False): + + attention_mask = input_dict['attention_mask'] + prediction = self.forward(input_dict, is_training=False) + + arm_action_preds = prediction['arm_action_preds'] # (1, len, act_len, act_dim-1) + gripper_action_preds = prediction['gripper_action_preds'] # (1, len, act_len, 1) + + + + arm_action_preds = arm_action_preds.squeeze(0) # (len, act_len, act_dim-1) + gripper_action_preds = gripper_action_preds.squeeze() # (len, act_len) + arm_action_preds = arm_action_preds[attention_mask.flatten() > 0] + gripper_action_preds = gripper_action_preds[attention_mask.flatten() > 0] + + + # Take the last action + arm_action_pred = arm_action_preds[-1].cpu() # (act_len, act_dim-1) + arm_action_pred = arm_action_pred[0] # (act_dim-1, ) + gripper_action_pred = gripper_action_preds[-1:].cpu() # (1, act_len) + gripper_action_pred = gripper_action_pred[:, 0] # (1, 1) + + if original_gripper: + gripper_action_pred = torch.nn.Sigmoid()(gripper_action_pred) + else: + gripper_action_pred = torch.nn.Sigmoid()(gripper_action_pred) + gripper_action_pred = gripper_action_pred > 0.5 + gripper_action_pred = gripper_action_pred.int().float() + gripper_action_pred = gripper_action_pred * 2.0 - 1.0 + action_pred = torch.cat((arm_action_pred, gripper_action_pred), dim=0) # (act_dim,) + if return_progress: + progress=prediction["progress_preds"] + progress=progress[-1] + return action_pred,progress + else: + return action_pred \ No newline at end of file diff --git a/ACT_DP_multitask/detr/models/mr_mg/policy/model/vision_transformer.py b/ACT_DP_multitask/detr/models/mr_mg/policy/model/vision_transformer.py new file mode 100644 index 0000000000000000000000000000000000000000..7fdd05befd1ce515fc8d7b468be798189c2e46f7 --- /dev/null +++ b/ACT_DP_multitask/detr/models/mr_mg/policy/model/vision_transformer.py @@ -0,0 +1,381 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Mostly copy-paste from timm library. +https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py +""" +import math +from functools import partial + +import torch +import torch.nn as nn + +import numpy as np + +def get_1d_sincos_pos_embed_from_grid(embed_dim, pos): + """ + embed_dim: output dimension for each position + pos: a list of positions to be encoded: size (M,) + out: (M, D) + """ + assert embed_dim % 2 == 0 + omega = np.arange(embed_dim // 2, dtype=np.float32) + omega /= embed_dim / 2. + omega = 1. / 10000**omega # (D/2,) + + pos = pos.reshape(-1) # (M,) + out = np.einsum('m,d->md', pos, omega) # (M, D/2), outer product + + emb_sin = np.sin(out) # (M, D/2) + emb_cos = np.cos(out) # (M, D/2) + + emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D) + return emb + + +def get_2d_sincos_pos_embed_from_grid(embed_dim, grid): + assert embed_dim % 2 == 0 + + # use half of dimensions to encode grid_h + emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2) + emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2) + + emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D) + return emb + + +def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False): + """ + grid_size: int of the grid height and width + return: + pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token) + """ + grid_h = np.arange(grid_size, dtype=np.float32) + grid_w = np.arange(grid_size, dtype=np.float32) + grid = np.meshgrid(grid_w, grid_h) # here w goes first + grid = np.stack(grid, axis=0) + + grid = grid.reshape([2, 1, grid_size, grid_size]) + pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid) + if cls_token: + pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0) + return pos_embed + +# from utils import trunc_normal_ + +def _no_grad_trunc_normal_(tensor, mean, std, a, b): + # Cut & paste from PyTorch official master until it's in a few official releases - RW + # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf + def norm_cdf(x): + # Computes standard normal cumulative distribution function + return (1. + math.erf(x / math.sqrt(2.))) / 2. + + if (mean < a - 2 * std) or (mean > b + 2 * std): + warnings.warn("mean is more than 2 std from [a, b] in nn.init.trunc_normal_. " + "The distribution of values may be incorrect.", + stacklevel=2) + + with torch.no_grad(): + # Values are generated by using a truncated uniform distribution and + # then using the inverse CDF for the normal distribution. + # Get upper and lower cdf values + l = norm_cdf((a - mean) / std) + u = norm_cdf((b - mean) / std) + + # Uniformly fill tensor with values from [l, u], then translate to + # [2l-1, 2u-1]. + tensor.uniform_(2 * l - 1, 2 * u - 1) + + # Use inverse cdf transform for normal distribution to get truncated + # standard normal + tensor.erfinv_() + + # Transform to proper mean, std + tensor.mul_(std * math.sqrt(2.)) + tensor.add_(mean) + + # Clamp to ensure it's in the proper range + tensor.clamp_(min=a, max=b) + return tensor + +def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.): + # type: (Tensor, float, float, float, float) -> Tensor + return _no_grad_trunc_normal_(tensor, mean, std, a, b) + +def drop_path(x, drop_prob: float = 0., training: bool = False): + if drop_prob == 0. or not training: + return x + keep_prob = 1 - drop_prob + shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets + random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device) + random_tensor.floor_() # binarize + output = x.div(keep_prob) * random_tensor + return output + + +class DropPath(nn.Module): + """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). + """ + def __init__(self, drop_prob=None): + super(DropPath, self).__init__() + self.drop_prob = drop_prob + + def forward(self, x): + return drop_path(x, self.drop_prob, self.training) + + +class Mlp(nn.Module): + def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.fc1 = nn.Linear(in_features, hidden_features) + self.act = act_layer() + self.fc2 = nn.Linear(hidden_features, out_features) + self.drop = nn.Dropout(drop) + + def forward(self, x): + x = self.fc1(x) + x = self.act(x) + x = self.drop(x) + x = self.fc2(x) + x = self.drop(x) + return x + + +class Attention(nn.Module): + def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.): + super().__init__() + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = qk_scale or head_dim ** -0.5 + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + + def forward(self, x): + B, N, C = x.shape + qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) + q, k, v = qkv[0], qkv[1], qkv[2] + + attn = (q @ k.transpose(-2, -1)) * self.scale + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn) + + x = (attn @ v).transpose(1, 2).reshape(B, N, C) + x = self.proj(x) + x = self.proj_drop(x) + return x, attn + + +class Block(nn.Module): + def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0., + drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm): + super().__init__() + self.norm1 = norm_layer(dim) + self.attn = Attention( + dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop) + self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) + + def forward(self, x, return_attention=False): + y, attn = self.attn(self.norm1(x)) + if return_attention: + return attn + x = x + self.drop_path(y) + x = x + self.drop_path(self.mlp(self.norm2(x))) + return x + + +class PatchEmbed(nn.Module): + """ Image to Patch Embedding + """ + def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768): + super().__init__() + num_patches = (img_size // patch_size) * (img_size // patch_size) + self.img_size = img_size + self.patch_size = patch_size + self.num_patches = num_patches + + self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) + + def forward(self, x): + B, C, H, W = x.shape + x = self.proj(x).flatten(2).transpose(1, 2) + return x + + +class VisionTransformer(nn.Module): + """ Vision Transformer """ + def __init__(self, img_size=[224], patch_size=16, in_chans=3, num_classes=0, embed_dim=768, depth=12, + num_heads=12, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop_rate=0., attn_drop_rate=0., + drop_path_rate=0., norm_layer=nn.LayerNorm, **kwargs): + super().__init__() + self.num_features = self.embed_dim = embed_dim + + self.patch_embed = PatchEmbed( + img_size=img_size[0], patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim) + num_patches = self.patch_embed.num_patches + + self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) + self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim)) + self.pos_drop = nn.Dropout(p=drop_rate) + + dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule + self.blocks = nn.ModuleList([ + Block( + dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer) + for i in range(depth)]) + self.norm = norm_layer(embed_dim) + + # Classifier head + self.head = nn.Linear(embed_dim, num_classes) if num_classes > 0 else nn.Identity() + + trunc_normal_(self.pos_embed, std=.02) + trunc_normal_(self.cls_token, std=.02) + self.apply(self._init_weights) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + + def interpolate_pos_encoding(self, x, w, h): + npatch = x.shape[1] - 1 + N = self.pos_embed.shape[1] - 1 + if npatch == N and w == h: + return self.pos_embed + class_pos_embed = self.pos_embed[:, 0] + patch_pos_embed = self.pos_embed[:, 1:] + dim = x.shape[-1] + w0 = w // self.patch_embed.patch_size + h0 = h // self.patch_embed.patch_size + # we add a small number to avoid floating point error in the interpolation + # see discussion at https://github.com/facebookresearch/dino/issues/8 + w0, h0 = w0 + 0.1, h0 + 0.1 + patch_pos_embed = nn.functional.interpolate( + patch_pos_embed.reshape(1, int(math.sqrt(N)), int(math.sqrt(N)), dim).permute(0, 3, 1, 2), + scale_factor=(w0 / math.sqrt(N), h0 / math.sqrt(N)), + mode='bicubic', + ) + assert int(w0) == patch_pos_embed.shape[-2] and int(h0) == patch_pos_embed.shape[-1] + patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim) + return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1) + + def prepare_tokens(self, x): + B, nc, w, h = x.shape + x = self.patch_embed(x) # patch linear embedding + + # add the [CLS] token to the embed patch tokens + cls_tokens = self.cls_token.expand(B, -1, -1) + x = torch.cat((cls_tokens, x), dim=1) + + # add positional encoding to each token + pos = self.interpolate_pos_encoding(x, w, h) + x = x + pos + + return self.pos_drop(x), pos + + def forward(self, x): + x, pos = self.prepare_tokens(x) + for blk in self.blocks: + x = blk(x) + x = self.norm(x) + return x[:, 0], x[:, 1:], pos + + def get_last_selfattention(self, x): + x = self.prepare_tokens(x) + for i, blk in enumerate(self.blocks): + if i < len(self.blocks) - 1: + x = blk(x) + else: + # return attention of the last block + return blk(x, return_attention=True) + + def get_intermediate_layers(self, x, n=1): + x = self.prepare_tokens(x) + # we return the output tokens from the `n` last blocks + output = [] + for i, blk in enumerate(self.blocks): + x = blk(x) + if len(self.blocks) - i <= n: + output.append(self.norm(x)) + return output + + +def vit_tiny(patch_size=16, **kwargs): + model = VisionTransformer( + patch_size=patch_size, embed_dim=192, depth=12, num_heads=3, mlp_ratio=4, + qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) + return model + + +def vit_small(patch_size=16, **kwargs): + model = VisionTransformer( + patch_size=patch_size, embed_dim=384, depth=12, num_heads=6, mlp_ratio=4, + qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) + return model + + +def vit_base(patch_size=16, **kwargs): + model = VisionTransformer( + patch_size=patch_size, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, + qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) + return model + + +class DINOHead(nn.Module): + def __init__(self, in_dim, out_dim, use_bn=False, norm_last_layer=True, nlayers=3, hidden_dim=2048, bottleneck_dim=256): + super().__init__() + nlayers = max(nlayers, 1) + if nlayers == 1: + self.mlp = nn.Linear(in_dim, bottleneck_dim) + else: + layers = [nn.Linear(in_dim, hidden_dim)] + if use_bn: + layers.append(nn.BatchNorm1d(hidden_dim)) + layers.append(nn.GELU()) + for _ in range(nlayers - 2): + layers.append(nn.Linear(hidden_dim, hidden_dim)) + if use_bn: + layers.append(nn.BatchNorm1d(hidden_dim)) + layers.append(nn.GELU()) + layers.append(nn.Linear(hidden_dim, bottleneck_dim)) + self.mlp = nn.Sequential(*layers) + self.apply(self._init_weights) + self.last_layer = nn.utils.weight_norm(nn.Linear(bottleneck_dim, out_dim, bias=False)) + self.last_layer.weight_g.data.fill_(1) + if norm_last_layer: + self.last_layer.weight_g.requires_grad = False + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + + def forward(self, x): + x = self.mlp(x) + x = nn.functional.normalize(x, dim=-1, p=2) + x = self.last_layer(x) + return x \ No newline at end of file diff --git a/ACT_DP_multitask/detr/models/mr_mg/policy/training/trainer.py b/ACT_DP_multitask/detr/models/mr_mg/policy/training/trainer.py new file mode 100644 index 0000000000000000000000000000000000000000..369cdbec815e8d121e1cfbfb88a10087e27ddaea --- /dev/null +++ b/ACT_DP_multitask/detr/models/mr_mg/policy/training/trainer.py @@ -0,0 +1,506 @@ +# Copyright (2024) Bytedance Ltd. and/or its affiliates + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os +import torch +import torch.nn.functional as F +from functools import partial +import math +import lightning.pytorch as pl +from utils.dist_train import get_rank +import model.vision_transformer as vits +import json +from model.model import GR_MG +import clip +def adjust_learning_rate(iter, configs): + """Decay the learning rate with half-cycle cosine after warmup""" + warmup_iters = configs['warmup_iters'] + total_iters = configs['iters'] + min_lr_scale = configs['min_lr_scale'] + + if iter < configs['warmup_iters']: + lr_scaler = 1.0 * iter / warmup_iters + else: + lr_scaler = min_lr_scale + (1.0 - min_lr_scale) * 0.5 * \ + (1.0 + math.cos(math.pi * (iter - warmup_iters) / (total_iters - warmup_iters))) + return lr_scaler + +def compute_kl_divergence( mu, logvar): + if torch.isnan(mu).any() or torch.isnan(logvar).any(): + raise ValueError("Input contains NaN values") + if torch.isinf(mu).any() or torch.isinf(logvar).any(): + raise ValueError("Input contains infinite values") + type=mu.dtype + latent_dim = mu.shape[-1] + # å°† mu å’Œ logvar 转æ¢ä¸º float32 + mu = mu.float().view(-1, latent_dim) + logvar = logvar.float().view(-1, latent_dim) + + klds = -0.5 * (1 + logvar- mu.pow(2) - logvar.exp()) + klds = klds.sum(1).mean(0, True).squeeze() + + # transform back to bf16 + return klds.to(type) + + +class Policy_Trainer(pl.LightningModule): + def __init__(self, configs): + super().__init__() + self._main_rank_print('--------------- model configs ---------------') + self._main_rank_print(configs) + self.configs = configs + self._initialize() + self.save_hyperparameters() + self.val_set_names = "calvin" + + @staticmethod + def _main_rank_print(*args, **kwargs): + if get_rank() == 0: + print(*args, **kwargs) + + @property + def num_gpus(self): + return self.trainer.num_devices * self.trainer.num_nodes + + def _initialize(self): + training_target = [] + if self.configs["trainer"]['act_pred']: + training_target.append('act_pred') + if self.configs["trainer"]['fwd_pred']: # predict future static image + training_target.append('fwd_pred') + if self.configs["trainer"]['fwd_pred_hand']: # predict future hand image + training_target.append('fwd_pred_hand') + if self.configs["trainer"]['progress_pred']: # predict progress information + training_target.append('progress_pred') + + # mae model + model_mae = vits.__dict__['vit_base'](patch_size=16, num_classes=0) + model_mae.to(self.device) + mae_ckpt = '/PATH_TO/resources/MAE/mae_pretrain_vit_base.pth' + checkpoint = torch.load(mae_ckpt, map_location='cpu') + model_mae.load_state_dict(checkpoint['model'], strict=True) + # freeze mae + for name, p in model_mae.named_parameters(): + p.requires_grad = False + #clip model + clip_name = "ViT-B/32" + clip_model, clip_preprocess = clip.load(clip_name) + # freeze clip + for _, param in clip_model.named_parameters(): + param.requires_grad = False + + # resampler parameters + resampler_params = dict() + resampler_params['depth'] = self.configs["policy"]['resampler_params']["depth"] + resampler_params['dim_head'] = self.configs["policy"]['resampler_params']['dim_head'] + resampler_params['heads'] = self.configs["policy"]['resampler_params']['heads'] + resampler_params['num_latents'] = self.configs["policy"]['resampler_params']['num_latents'] + resampler_params['num_media_embeds'] = self.configs["policy"]['resampler_params']['num_media_embeds'] + + # main model + self.model = GR_MG( + state_dim=self.configs["input"]['state_dim'], + act_dim=self.configs["input"]['act_dim'], + act_len=self.configs["policy"]['act_len'], + act_latent_dim=self.configs["policy"]['act_latent_dim'], + act_encoder_dim=self.configs["policy"]['act_encoder_dim'], + act_decoder_dim=self.configs["policy"]['act_decoder_dim'], + progress_decoder_dim=self.configs["policy"]["progress_decoder_dim"], + hidden_size=self.configs["policy"]['embed_dim'], + model_mae=model_mae, + clip_model=clip_model, + img_feat_dim=self.configs["policy"]["img_feat_dim"], + lang_feat_dim = self.configs["policy"]["lang_feat_dim"], + patch_feat_dim=self.configs["policy"]["patch_feat_dim"], + resampler_params=resampler_params, + max_length=self.configs["policy"]['seq_len'], + training_target=training_target, + without_norm_pix_loss=self.configs["trainer"]['without_norm_pix_loss'], + use_hand_rgb=self.configs["input"]['use_hand_rgb'], + use_state=self.configs["input"]['use_state'], + use_resampler=self.configs["policy"]['use_resampler'], + n_layer=self.configs["policy"]['n_layer'], + n_head=self.configs["policy"]['n_head'], + n_inner=4*self.configs["policy"]['embed_dim'], + activation_function=self.configs["policy"]['activation_function'], + n_positions=1024, + resid_pdrop=self.configs["policy"]['dropout'], + attn_pdrop=self.configs["policy"]['dropout']) + + + # if finetune, we need to load the pretrained model + if self.configs["trainer"]["finetune"]: + if self.configs["trainer"]["use_pretrain"]: + trainer_config=self.configs["trainer"] + self._main_rank_print(f"Loading pretrained model from: {trainer_config['pretrained_model_path']}") + checkpoint = torch.load(self.configs["trainer"]['pretrained_model_path'], map_location='cpu') + state_dict = dict() + # Exclude action and state related weights + for key, value in checkpoint['state_dict'].items(): + if key[:6]=="model.": + key = key[6:] # remove "model." from pl checkpoint + state_dict[key] = value + del checkpoint + msg = self.model.load_state_dict(state_dict, strict=False) + self._main_rank_print(msg) + del state_dict + + # save config + if get_rank() == 0: + with open(os.path.join(self.configs['ckpt_dir'], 'hyperparameters.json'), 'w') as f: + json.dump(self.configs, f) + + + # these variables are used to indicate what information will be used or predicted + self.act_pred = self.model.act_pred + self.fwd_pred = self.model.fwd_pred + self.fwd_pred_hand = self.model.fwd_pred_hand + self.use_state = self.model.use_state + self.use_hand_rgb = self.model.use_hand_rgb + self.progress_pred=self.model.progress_pred + + # loss + self.kl_loss_ratio =self.configs["trainer"]["kl_loss_ratio"] + self.gripper_loss_ratio =self.configs["trainer"]["gripper_loss_ratio"] + self.fwd_loss_ratio =self.configs["trainer"]["fwd_loss_ratio"] + self.progress_loss_ratio=self.configs["trainer"]["progress_loss_ratio"] + + + + def configure_optimizers(self): + lr = self.configs["trainer"]['learning_rate'] + eff_bsz = self.configs["trainer"]['batch_size'] * self.num_gpus + self._main_rank_print('-' * 40) + self._main_rank_print("LR SCHEDULER CONFIGS:") + self._main_rank_print(f"learning rate: {lr}, effective batch size: {eff_bsz}") + + optimizer = torch.optim.AdamW( + self.model.parameters(), + lr=lr, + betas=(self.configs["trainer"]['betas_1'], self.configs["trainer"]['betas_2']), + weight_decay=self.configs["trainer"]['weight_decay'] + ) + assert self.trainer.max_epochs is not None + num_training_batches = self.trainer.estimated_stepping_batches + iter_per_epoch = num_training_batches / self.trainer.max_epochs + self.configs["trainer"]['warmup_steps']=self.configs["trainer"]['warmup_epochs'] * iter_per_epoch + lr_scheduler_configs = { + 'warmup_iters': self.configs["trainer"]['warmup_steps'], + 'iters': self.trainer.max_epochs * iter_per_epoch, + 'min_lr_scale': self.configs["trainer"]['min_learning_rate_scale'] + } + lr_lambda = partial(adjust_learning_rate, configs=lr_scheduler_configs) + self._main_rank_print(lr_scheduler_configs) + scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lr_lambda) + return { + 'optimizer': optimizer, + 'lr_scheduler': + { + 'scheduler': scheduler, + 'interval': 'step', + 'frequency': 1 + } + } + + + def _log_output(self, output, phase, dataset=None, **kwargs): + for k, v in output.items(): + log_name = f"{phase}_{k}" + if dataset is not None: + log_name = f"{dataset}_{log_name}" + self.log(log_name, v, prog_bar=True, **kwargs) + + def validation_step(self, batch, batch_idx): + with torch.no_grad(): + goal_rgb = batch['goal_rgb'] + rgb = batch['rgb'] + hand_rgb = batch['hand_rgb'] + state = batch['rel_state'] + action = batch['action'] + action_mask = batch['action_mask'] + attention_mask = batch['attention_mask'] + text=batch["text"] + progress=batch["progress"] + + # Split arm and gripper action + arm_action = action[:, :, :, :6] # (b, l, act_len, act_dim - 1) + gripper_action = action[:, :, :, 6] # (b, l, act_len) + arm_action_target = torch.clone(arm_action) # (b, l, act_len, act_dim - 1) + gripper_action_target = torch.clone(gripper_action) # (b, l, act_len) + + # Split arm and gripper state + arm_state = state[:, :, :6] # (b, l, state_dim - 1) + gripper_state = state[:, :, 6].long() # (b, l) + gripper_state = F.one_hot(gripper_state, num_classes=2).type_as(arm_state) # (b, l, 2) + + seq_len = arm_action.size(1) + act_len = arm_action.size(2) + + input_dict = { + 'goal_rgb': goal_rgb, + 'rgb': rgb, + 'hand_rgb': hand_rgb, + 'arm_action': arm_action, + 'gripper_action': gripper_action, + 'arm_state': arm_state, + 'gripper_state': gripper_state, + 'attention_mask': attention_mask, #(b,l) + 'text':text, + 'progress':progress + } + + + prediction = self.model(input_dict, is_training=False) + + + obs_preds = prediction['obs_preds'] + obs_target = prediction['obs_target'] + arm_action_preds = prediction['arm_action_preds'] # (b, l, act_len, act_dim - 1) + gripper_action_preds = prediction['gripper_action_preds'] # (b, l, act_len, 1) + obs_hand_preds = prediction['obs_hand_preds'] + obs_hand_target = prediction['obs_hand_target'] + action_mu_preds = prediction['action_mu_preds'] # (b * l, act_latent_dim) + action_logvar_preds = prediction['action_logvar_preds'] # (b * l, act_latent_dim) + progress_preds=prediction["progress_preds"] + progress_targets=prediction["progress_targets"] + + + loss_act = 0 + loss_arm_act = 0 + loss_gripper_act = 0 + loss_kl_act = 0 + acc_gripper_act = 0 + loss_obs = 0 + loss_hand_obs = 0 + gripper_cnt = 0 + loss_progress=0 + + # action prediction + act_dim = self.model.act_dim + if self.act_pred: + # kl loss + loss_kl_act = compute_kl_divergence(action_mu_preds, action_logvar_preds) + + + # action smooth l1 loss + arm_action_preds = arm_action_preds.view(-1, act_len, act_dim-1)[attention_mask.flatten() > 0] # b,len,act_len,6 -> b*len,act_len,6 + arm_action_target = arm_action_target.view(-1, act_len, act_dim-1)[attention_mask.flatten() > 0] # b,len,act_len,6 -> b*len,act_len,6 + action_mask = action_mask.view(-1, act_len)[attention_mask.flatten() > 0] # b,len,act_len -> b*len,act_len + arm_action_preds = arm_action_preds.view(-1, act_dim-1)[action_mask.flatten() > 0] # b*len*act_len, 6 + arm_action_target = arm_action_target.view(-1, act_dim-1)[action_mask.flatten() > 0] # b*len*act_len, 6 + loss_arm_act = torch.nn.SmoothL1Loss()(arm_action_preds, arm_action_target) + + # gripper bce loss + gripper_action_preds = gripper_action_preds.view(-1, act_len)[attention_mask.flatten() > 0] # b,len,act_len -> b*len,act_len + gripper_action_target = gripper_action_target.view(-1, act_len)[attention_mask.flatten() > 0] # b,len,act_len -> b*len,act_len + gripper_action_preds = gripper_action_preds.flatten()[action_mask.flatten() > 0] # b*len*act_len + gripper_action_target = gripper_action_target.flatten()[action_mask.flatten() > 0] # b*len*act_len + bce_with_logits_loss = torch.nn.BCEWithLogitsLoss() + loss_gripper_act = bce_with_logits_loss(gripper_action_preds, gripper_action_target) + loss_act = loss_arm_act + loss_gripper_act * self.gripper_loss_ratio + loss_kl_act * self.kl_loss_ratio + + # Compute gripper action acc + gripper_action_preds = torch.nn.Sigmoid()(gripper_action_preds) # Sigmoid function + gripper_action_preds = (gripper_action_preds > 0.5).float() + acc_gripper_act = torch.eq(gripper_action_preds, gripper_action_target).sum().float() + gripper_cnt = gripper_action_preds.shape[0] + acc_gripper_act /= gripper_cnt + + # forward prediction + if self.fwd_pred: + fwd_pred_next_n = self.configs["trainer"]['fwd_pred_next_n'] + obs_preds = obs_preds[:, :seq_len-fwd_pred_next_n, :, :] + obs_target = obs_target[:, fwd_pred_next_n:, :, :] + obs_attention_mask = attention_mask[:, fwd_pred_next_n:] + loss_obs = (obs_preds - obs_target) ** 2 + loss_obs = loss_obs.mean(dim=-1).mean(dim=-1) + loss_obs = (loss_obs * obs_attention_mask).sum() / obs_attention_mask.sum() + if self.fwd_pred_hand: + obs_hand_preds = obs_hand_preds[:, :seq_len-fwd_pred_next_n, :, :] + obs_hand_target = obs_hand_target[:, fwd_pred_next_n:, :, :] + loss_hand_obs = (obs_hand_preds - obs_hand_target) ** 2 + loss_hand_obs = loss_hand_obs.mean(dim=-1).mean(dim=-1) + loss_hand_obs = (loss_hand_obs * obs_attention_mask).sum() / obs_attention_mask.sum() + if self.progress_pred: + diff = progress_preds - progress_targets + masked_diff = diff * attention_mask + squared_error = masked_diff ** 2 + loss_progress = squared_error.sum() / attention_mask.sum() + + # compute loss + loss = torch.tensor(0.0).to(self.device) + if self.act_pred: + loss += loss_act + if self.fwd_pred: + loss += self.fwd_loss_ratio * loss_obs + if self.fwd_pred_hand: + loss += self.fwd_loss_ratio * loss_hand_obs + if self.progress_pred: + loss+=loss_progress*self.progress_loss_ratio + output = { + 'loss': loss, + 'loss_act': loss_act, + 'loss_arm_act': loss_arm_act, + 'loss_gripper_act': loss_gripper_act, + 'loss_kl_act': loss_kl_act, + 'acc_gripper_act': acc_gripper_act, + 'loss_obs': loss_obs, + 'loss_hand_obs': loss_hand_obs, + 'loss_progress':loss_progress + } + self._log_output(output, phase="val", on_epoch=True, on_step=False) + return output['loss'] + + def training_step(self, batch, batch_idx): + goal_rgb = batch['goal_rgb'] + rgb = batch['rgb'] + hand_rgb = batch['hand_rgb'] + state = batch['rel_state'] + action = batch['action'] + action_mask = batch['action_mask'] + attention_mask = batch['attention_mask'] + text=batch["text"] + progress=batch["progress"] + # Split arm and gripper action + arm_action = action[:, :, :, :6] # (b, l, act_len, act_dim - 1) + gripper_action = action[:, :, :, 6] # (b, l, act_len) + arm_action_target = torch.clone(arm_action) # (b, l, act_len, act_dim - 1) + gripper_action_target = torch.clone(gripper_action) # (b, l, act_len) + + # Split arm and gripper state + arm_state = state[:, :, :6] # (b, l, state_dim - 1) + gripper_state = state[:, :, 6].long() # (b, l) + gripper_state = F.one_hot(gripper_state, num_classes=2).type_as(arm_state) # (b, l, 2) + + seq_len = arm_action.size(1) + act_len = arm_action.size(2) + + input_dict = { + 'goal_rgb': goal_rgb, + 'rgb': rgb, + 'hand_rgb': hand_rgb, + 'arm_action': arm_action, + 'gripper_action': gripper_action, + 'arm_state': arm_state, + 'gripper_state': gripper_state, + 'attention_mask': attention_mask, + 'text':text, + 'progress':progress + } + + + prediction = self.model(input_dict, is_training=True) + + + obs_preds = prediction['obs_preds'] + obs_target = prediction['obs_target'] + arm_action_preds = prediction['arm_action_preds'] # (b, l, act_len, act_dim - 1) + gripper_action_preds = prediction['gripper_action_preds'] # (b, l, act_len, 1) + obs_hand_preds = prediction['obs_hand_preds'] + obs_hand_target = prediction['obs_hand_target'] + action_mu_preds = prediction['action_mu_preds'] # (b * l, act_latent_dim) + action_logvar_preds = prediction['action_logvar_preds'] # (b * l, act_latent_dim) + progress_preds=prediction["progress_preds"] + progress_targets=prediction["progress_targets"] + + + + loss_act = 0 + loss_arm_act = 0 + loss_gripper_act = 0 + loss_kl_act = 0 + acc_gripper_act = 0 + loss_obs = 0 + loss_hand_obs = 0 + gripper_cnt = 0 + loss_progress= 0 + # action prediction + act_dim = self.model.act_dim + if self.act_pred: + # kl loss + + loss_kl_act = compute_kl_divergence(action_mu_preds, action_logvar_preds) + + + # action smooth l1 loss + arm_action_preds = arm_action_preds.view(-1, act_len, act_dim-1)[attention_mask.flatten() > 0] # b,len,act_len,6 -> b*len,act_len,6 + arm_action_target = arm_action_target.view(-1, act_len, act_dim-1)[attention_mask.flatten() > 0] # b,len,act_len,6 -> b*len,act_len,6 + action_mask = action_mask.view(-1, act_len)[attention_mask.flatten() > 0] # b,len,act_len -> b*len,act_len + arm_action_preds = arm_action_preds.view(-1, act_dim-1)[action_mask.flatten() > 0] # b*len*act_len, 6 + arm_action_target = arm_action_target.view(-1, act_dim-1)[action_mask.flatten() > 0] # b*len*act_len, 6 + loss_arm_act = torch.nn.SmoothL1Loss()(arm_action_preds, arm_action_target) + + # gripper bce loss + gripper_action_preds = gripper_action_preds.view(-1, act_len)[attention_mask.flatten() > 0] # b,len,act_len -> b*len,act_len + gripper_action_target = gripper_action_target.view(-1, act_len)[attention_mask.flatten() > 0] # b,len,act_len -> b*len,act_len + gripper_action_preds = gripper_action_preds.flatten()[action_mask.flatten() > 0] # b*len*act_len + gripper_action_target = gripper_action_target.flatten()[action_mask.flatten() > 0] # b*len*act_len + bce_with_logits_loss = torch.nn.BCEWithLogitsLoss() + loss_gripper_act = bce_with_logits_loss(gripper_action_preds, gripper_action_target) + loss_act = loss_arm_act + loss_gripper_act * self.gripper_loss_ratio + loss_kl_act * self.kl_loss_ratio + + # Compute gripper action acc + gripper_action_preds = torch.nn.Sigmoid()(gripper_action_preds) + gripper_action_preds = (gripper_action_preds > 0.5).float() + acc_gripper_act = torch.eq(gripper_action_preds, gripper_action_target).sum().float() + gripper_cnt = gripper_action_preds.shape[0] + acc_gripper_act /= gripper_cnt + + # predict future image + if self.fwd_pred: + fwd_pred_next_n = self.configs["trainer"]['fwd_pred_next_n'] + obs_preds = obs_preds[:, :seq_len-fwd_pred_next_n, :, :] + obs_target = obs_target[:, fwd_pred_next_n:, :, :] + obs_attention_mask = attention_mask[:, fwd_pred_next_n:] + loss_obs = (obs_preds - obs_target) ** 2 + loss_obs = loss_obs.mean(dim=-1).mean(dim=-1) + loss_obs = (loss_obs * obs_attention_mask).sum() / obs_attention_mask.sum() + if self.fwd_pred_hand: + obs_hand_preds = obs_hand_preds[:, :seq_len-fwd_pred_next_n, :, :] + obs_hand_target = obs_hand_target[:, fwd_pred_next_n:, :, :] + loss_hand_obs = (obs_hand_preds - obs_hand_target) ** 2 + loss_hand_obs = loss_hand_obs.mean(dim=-1).mean(dim=-1) + loss_hand_obs = (loss_hand_obs * obs_attention_mask).sum() / obs_attention_mask.sum() + + if self.progress_pred: + diff = progress_preds - progress_targets + masked_diff = diff * attention_mask + squared_error = masked_diff ** 2 + loss_progress = squared_error.sum() / attention_mask.sum() + + # compute loss + loss = torch.tensor(0.0).to(self.device) + if self.act_pred: + loss += loss_act + if self.fwd_pred: + loss += self.fwd_loss_ratio * loss_obs + if self.fwd_pred_hand: + loss += self.fwd_loss_ratio * loss_hand_obs + if self.progress_pred: + loss+=loss_progress*self.progress_loss_ratio + output = { + 'loss': loss, + 'loss_act': loss_act, + 'loss_arm_act': loss_arm_act, + 'loss_gripper_act': loss_gripper_act, + 'loss_kl_act': loss_kl_act, + 'acc_gripper_act': acc_gripper_act, + 'loss_obs': loss_obs, + 'loss_hand_obs': loss_hand_obs, + 'loss_progress':loss_progress + } + self._log_output(output, phase="train", on_epoch=False, on_step=True) + return output['loss'] + + + diff --git a/ACT_DP_multitask/detr/models/mr_mg/policy/utils/dist_train.py b/ACT_DP_multitask/detr/models/mr_mg/policy/utils/dist_train.py new file mode 100644 index 0000000000000000000000000000000000000000..1dbcba812ca1941ce53e1e6e8ebd71451159bc4e --- /dev/null +++ b/ACT_DP_multitask/detr/models/mr_mg/policy/utils/dist_train.py @@ -0,0 +1,77 @@ +# Copyright (2024) Bytedance Ltd. and/or its affiliates + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import io +import os + +import torch +import torch.distributed as dist + + +_print = print + +def get_world_size(): return int(os.getenv('WORLD_SIZE', 1)) +def get_rank(): return int(os.getenv('RANK', 0)) +def get_local_rank(): return int(os.getenv('LOCAL_RANK', 0)) + + +def is_dist(): + return dist.is_available() and dist.is_initialized() and get_world_size() > 1 + +def print(*argc, all=False, **kwargs): + if not is_dist(): + _print(*argc, **kwargs) + return + + if not all and get_local_rank() != 0: + return + + output = io.StringIO() + kwargs['end'] = '' + kwargs['file'] = output + kwargs['flush'] = True + _print(*argc, **kwargs) + + s = output.getvalue() + output.close() + + s = '[rank {}] {}'.format(dist.get_rank(), s) + _print(s) + +def reduce_mean(tensor, nprocs=None): + if not is_dist(): + return tensor + if not isinstance(tensor, torch.Tensor): + device = torch.cuda.current_device() + rt = torch.tensor(tensor, device=device) + else: + rt = tensor.clone() + dist.all_reduce(rt, op=dist.ReduceOp.SUM) + nprocs = nprocs if nprocs else dist.get_world_size() + rt = rt / nprocs + if not isinstance(tensor, torch.Tensor): + rt = rt.item() + return rt + +def reduce_sum(tensor): + if not is_dist(): + return tensor + if not isinstance(tensor, torch.Tensor): + device = torch.cuda.current_device() + rt = torch.tensor(tensor, device=device) + else: + rt = tensor.clone() + dist.all_reduce(rt, op=dist.ReduceOp.SUM) + if not isinstance(tensor, torch.Tensor): + rt = rt.item() + return rt \ No newline at end of file diff --git a/ACT_DP_multitask/detr/models/position_encoding.py b/ACT_DP_multitask/detr/models/position_encoding.py new file mode 100644 index 0000000000000000000000000000000000000000..f62b15b802cb19f06b1c7988f9fde5ced357a0ec --- /dev/null +++ b/ACT_DP_multitask/detr/models/position_encoding.py @@ -0,0 +1,111 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved +""" +Various positional encodings for the transformer. +""" +import math +import torch +from torch import nn + +from ..util.misc import NestedTensor + +import IPython + +e = IPython.embed + + +class PositionEmbeddingSine(nn.Module): + """ + This is a more standard version of the position embedding, very similar to the one + used by the Attention is all you need paper, generalized to work on images. + """ + + def __init__( + self, num_pos_feats=64, temperature=10000, normalize=False, scale=None + ): + super().__init__() + self.num_pos_feats = num_pos_feats + self.temperature = temperature + self.normalize = normalize + if scale is not None and normalize is False: + raise ValueError("normalize should be True if scale is passed") + if scale is None: + scale = 2 * math.pi + self.scale = scale + + def forward(self, tensor): + x = tensor + # mask = tensor_list.mask + # assert mask is not None + # not_mask = ~mask + + not_mask = torch.ones_like(x[0, [0]]) + y_embed = not_mask.cumsum(1, dtype=torch.float32) + x_embed = not_mask.cumsum(2, dtype=torch.float32) + if self.normalize: + eps = 1e-6 + y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale + x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale + + dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device) + dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats) + + pos_x = x_embed[:, :, :, None] / dim_t + pos_y = y_embed[:, :, :, None] / dim_t + pos_x = torch.stack( + (pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4 + ).flatten(3) + pos_y = torch.stack( + (pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4 + ).flatten(3) + pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2) + return pos # 1 D H W + + +class PositionEmbeddingLearned(nn.Module): + """ + Absolute pos embedding, learned. + """ + + def __init__(self, num_pos_feats=256): + super().__init__() + self.row_embed = nn.Embedding(50, num_pos_feats) + self.col_embed = nn.Embedding(50, num_pos_feats) + self.reset_parameters() + + def reset_parameters(self): + nn.init.uniform_(self.row_embed.weight) + nn.init.uniform_(self.col_embed.weight) + + def forward(self, tensor_list: NestedTensor): + x = tensor_list.tensors + h, w = x.shape[-2:] + i = torch.arange(w, device=x.device) + j = torch.arange(h, device=x.device) + x_emb = self.col_embed(i) + y_emb = self.row_embed(j) + pos = ( + torch.cat( + [ + x_emb.unsqueeze(0).repeat(h, 1, 1), + y_emb.unsqueeze(1).repeat(1, w, 1), + ], + dim=-1, + ) + .permute(2, 0, 1) + .unsqueeze(0) + .repeat(x.shape[0], 1, 1, 1) + ) + return pos + + +def build_position_encoding(args): + N_steps = args.hidden_dim // 2 + if args.position_embedding in ("v2", "sine"): + # TODO find a better way of exposing other arguments + position_embedding = PositionEmbeddingSine(N_steps, normalize=True) + elif args.position_embedding in ("v3", "learned"): + position_embedding = PositionEmbeddingLearned(N_steps) + else: + raise ValueError(f"not supported {args.position_embedding}") + + return position_embedding diff --git a/ACT_DP_multitask/detr/models/resnet_film.py b/ACT_DP_multitask/detr/models/resnet_film.py new file mode 100644 index 0000000000000000000000000000000000000000..bc58f38f17e94c6cd7eb6afba19e0e1e2d3b151f --- /dev/null +++ b/ACT_DP_multitask/detr/models/resnet_film.py @@ -0,0 +1,450 @@ +from typing import Type, Any, Callable, Union, List, Mapping, Optional + +import copy +import torch +import torch.nn as nn +from torch import Tensor + + +def is_torch_version_lower_than_17(): + major_version = float(torch.__version__.split('.')[0]) + minor_version = float(torch.__version__.split('.')[1]) + return major_version == 1 and minor_version < 7 + + +if not is_torch_version_lower_than_17(): + # TODO: Make sure the torchvision version is similarly updated. + from torchvision.models import ResNet18_Weights, ResNet34_Weights, ResNet101_Weights + + +def conv3x3(in_planes: int, out_planes: int, stride: int = 1, groups: int = 1, dilation: int = 1) -> nn.Conv2d: + """3x3 convolution with padding""" + return nn.Conv2d( + in_planes, + out_planes, + kernel_size=3, + stride=stride, + padding=dilation, + groups=groups, + bias=False, + dilation=dilation, + ) + + +def conv1x1(in_planes: int, out_planes: int, stride: int = 1) -> nn.Conv2d: + """1x1 convolution""" + return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False) + + +class BasicBlock(nn.Module): + expansion: int = 1 + + def __init__( + self, + inplanes: int, + planes: int, + stride: int = 1, + downsample: Optional[nn.Module] = None, + groups: int = 1, + base_width: int = 64, + dilation: int = 1, + norm_layer: Optional[Callable[..., nn.Module]] = None, + ) -> None: + super().__init__() + if norm_layer is None: + norm_layer = nn.BatchNorm2d + if groups != 1 or base_width != 64: + raise ValueError("BasicBlock only supports groups=1 and base_width=64") + if dilation > 1: + raise NotImplementedError("Dilation > 1 not supported in BasicBlock") + # Both self.conv1 and self.downsample layers downsample the input when stride != 1 + self.conv1 = conv3x3(inplanes, planes, stride) + self.bn1 = norm_layer(planes) + self.relu = nn.ReLU(inplace=True) + self.conv2 = conv3x3(planes, planes) + self.bn2 = norm_layer(planes) + self.downsample = downsample + self.stride = stride + + def forward(self, x: Tensor, film_features: Optional[Tensor] = None) -> Tensor: + identity = x + + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + + out = self.conv2(out) + out = self.bn2(out) + # Apply FiLM here + if film_features is not None: + # gamma, beta will be (B, 1, 1, planes) + gamma, beta = torch.split(film_features, 1, dim=1) + gamma = gamma.squeeze().view(x.size(0), -1, 1, 1) + beta = beta.squeeze().view(x.size(0), -1, 1, 1) + out = (1 + gamma) * out + beta + + if self.downsample is not None: + identity = self.downsample(x) + + out += identity + out = self.relu(out) + + return out + + +class Bottleneck(nn.Module): + # Bottleneck in torchvision places the stride for downsampling at 3x3 convolution(self.conv2) + # while original implementation places the stride at the first 1x1 convolution(self.conv1) + # according to "Deep residual learning for image recognition"https://arxiv.org/abs/1512.03385. + # This variant is also known as ResNet V1.5 and improves accuracy according to + # https://ngc.nvidia.com/catalog/model-scripts/nvidia:resnet_50_v1_5_for_pytorch. + + expansion: int = 4 + + def __init__( + self, + inplanes: int, + planes: int, + stride: int = 1, + downsample: Optional[nn.Module] = None, + groups: int = 1, + base_width: int = 64, + dilation: int = 1, + norm_layer: Optional[Callable[..., nn.Module]] = None,) -> None: + super().__init__() + if norm_layer is None: + norm_layer = nn.BatchNorm2d + width = int(planes * (base_width / 64.0)) * groups + # Both self.conv2 and self.downsample layers downsample the input when stride != 1 + self.conv1 = conv1x1(inplanes, width) + self.bn1 = norm_layer(width) + self.conv2 = conv3x3(width, width, stride, groups, dilation) + self.bn2 = norm_layer(width) + self.conv3 = conv1x1(width, planes * self.expansion) + self.bn3 = norm_layer(planes * self.expansion) + self.relu = nn.ReLU(inplace=True) + self.downsample = downsample + self.stride = stride + + def forward(self, x: Tensor) -> Tensor: + identity = x + + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + + out = self.conv2(out) + out = self.bn2(out) + out = self.relu(out) + + out = self.conv3(out) + out = self.bn3(out) + + if self.downsample is not None: + identity = self.downsample(x) + + out += identity + + out = self.relu(out) + + return out + + +class ResNetWithExtraModules(nn.Module): + """Update standard ResNet image classification models with FiLM.""" + def __init__( + self, + block: Type[Union[BasicBlock, Bottleneck]], + layers: List[int], + num_classes: int = 1000, + zero_init_residual: bool = False, + groups: int = 1, + width_per_group: int = 64, + replace_stride_with_dilation: Optional[List[bool]] = None, + norm_layer: Optional[Callable[..., nn.Module]] = None, + film_config: Optional[Mapping[str, Any]] = None,) -> None: + super().__init__() + if norm_layer is None: + norm_layer = nn.BatchNorm2d + self._norm_layer = norm_layer + + # Save how many blocks in each layer + self.layers = layers + + # FiLM only implemented for BasicBlock for now + self.use_film = film_config is not None and film_config['use'] + if self.use_film: + self.film_config = film_config + self.film_planes = film_config['film_planes'] + + self.inplanes = 64 + self.dilation = 1 + if replace_stride_with_dilation is None: + # each element in the tuple indicates if we should replace + # the 2x2 stride with a dilated convolution instead + replace_stride_with_dilation = [False, False, False] + if len(replace_stride_with_dilation) != 3: + raise ValueError( + "replace_stride_with_dilation should be None " + f"or a 3-element tuple, got {replace_stride_with_dilation}" + ) + + in_channels_conv1 = 4 if ( + film_config is not None and + film_config.get('append_object_mask', None) is not None) else 3 + + self.groups = groups + self.base_width = width_per_group + self.conv1 = nn.Conv2d(in_channels_conv1, self.inplanes, kernel_size=7, stride=2, padding=3, bias=False) + self.bn1 = norm_layer(self.inplanes) + self.relu = nn.ReLU(inplace=True) + self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) + self.layer1 = self._make_layer(block, 64, layers[0]) + self.layer2 = self._make_layer(block, 128, layers[1], stride=2, dilate=replace_stride_with_dilation[0]) + self.layer3 = self._make_layer(block, 256, layers[2], stride=2, dilate=replace_stride_with_dilation[1]) + self.layer4 = self._make_layer(block, 512, layers[3], stride=2, dilate=replace_stride_with_dilation[2]) + self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) + self.fc = nn.Linear(512 * block.expansion, num_classes) + + for m_name, m in self.named_modules(): + if isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu") + elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): + nn.init.constant_(m.weight, 1) + nn.init.constant_(m.bias, 0) + + # Zero-initialize the last BN in each residual branch, + # so that the residual branch starts with zeros, and each residual block behaves like an identity. + # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677 + if zero_init_residual: + for m in self.modules(): + if isinstance(m, Bottleneck) and m.bn3.weight is not None: + nn.init.constant_(m.bn3.weight, 0) # type: ignore[arg-type] + elif isinstance(m, BasicBlock) and m.bn2.weight is not None: + nn.init.constant_(m.bn2.weight, 0) # type: ignore[arg-type] + + def _make_layer( + self, + block: Type[Union[BasicBlock, Bottleneck]], + planes: int, + blocks: int, + stride: int = 1, + dilate: bool = False,) -> nn.Sequential: + norm_layer = self._norm_layer + downsample = None + previous_dilation = self.dilation + if dilate: + self.dilation *= stride + stride = 1 + if stride != 1 or self.inplanes != planes * block.expansion: + downsample = nn.Sequential( + conv1x1(self.inplanes, planes * block.expansion, stride), + norm_layer(planes * block.expansion), + ) + + layers = [ + block(self.inplanes, planes, stride, downsample, self.groups, self.base_width, previous_dilation, norm_layer,) + ] + self.inplanes = planes * block.expansion + for _ in range(1, blocks): + layers.append( + block( + self.inplanes, + planes, + groups=self.groups, + base_width=self.base_width, + dilation=self.dilation, + norm_layer=norm_layer, + ) + ) + + if self.use_film: + return nn.ModuleList(layers) + else: + return nn.Sequential(*layers) + + def _forward_impl_film(self, x: Tensor, film_features: List[Optional[Tensor]], flatten: bool = True): + assert self.use_film and film_features is not None + + def _extract_film_features_for_layer(film_feat: Optional[Tensor], layer_idx: int): + if film_features[layer_idx] is None: + return [None] * self.layers[layer_idx] + + num_planes = self.film_planes[layer_idx] + num_blocks = self.layers[layer_idx] + film_feat = film_feat.view(-1, 2, num_blocks, num_planes) + film_feat_per_block = torch.split(film_feat, 1, dim=2) + return film_feat_per_block + + for layer_idx, layer in enumerate([self.layer1, self.layer2, self.layer3, self.layer4]): + film_feat_per_block = _extract_film_features_for_layer( + film_features[layer_idx], layer_idx) + for block_idx, block in enumerate(layer): + if film_feat_per_block[block_idx] is not None: + assert x.shape[0] == film_feat_per_block[block_idx].shape[0], ('FiLM batch size does not match') + x = block(x, film_features=film_feat_per_block[block_idx]) + + x = self.avgpool(x) + if flatten: + x = torch.flatten(x, 1) + x = self.fc(x) + return x + + def _forward_impl(self, + x: Tensor, + film_features: List[Optional[Tensor]], + flatten: bool = True) -> Tensor: + # See note [TorchScript super()] + x = self.conv1(x) + x = self.bn1(x) + x = self.relu(x) + x = self.maxpool(x) + + if self.use_film: + return self._forward_impl_film(x, film_features, flatten=flatten) + else: + x = self.layer1(x) + x = self.layer2(x) + x = self.layer3(x) + x = self.layer4(x) + + x = self.avgpool(x) + if flatten: + x = torch.flatten(x, 1) + x = self.fc(x) + + return x + + def forward(self, + x: Tensor, + film_features: List[Optional[Tensor]], **kwargs) -> Tensor: + return self._forward_impl(x, film_features, **kwargs) + + +def _resnet( + block: Type[Union[BasicBlock, Bottleneck]], + layers: List[int], + weights, + progress: bool, + **kwargs: Any, +) -> ResNetWithExtraModules: + model_kwargs = copy.deepcopy(kwargs) + if 'pretrained' in model_kwargs: + del model_kwargs['pretrained'] + if 'arch' in model_kwargs: + del model_kwargs['arch'] + model = ResNetWithExtraModules(block, layers, **model_kwargs) + + if weights is not None: + model.load_state_dict(weights.get_state_dict(progress=progress)) + elif kwargs.get('pretrained', False) and kwargs.get('arch') is not None: + if float(torch.__version__.split('.')[1]) < 7: + # Copied from https://pytorch.org/vision/0.11/_modules/torchvision/models/resnet.html#resnet18 + model_urls = { + 'resnet18': 'https://download.pytorch.org/models/resnet18-f37072fd.pth', + 'resnet34': 'https://download.pytorch.org/models/resnet34-b627a593.pth', + 'resnet50': 'https://download.pytorch.org/models/resnet50-0676ba61.pth', + 'resnet101': 'https://download.pytorch.org/models/resnet101-63fe2227.pth', + 'resnet152': 'https://download.pytorch.org/models/resnet152-394f9c45.pth', + 'resnext50_32x4d': 'https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth', + 'resnext101_32x8d': 'https://download.pytorch.org/models/resnext101_32x8d-8ba56ff5.pth', + 'wide_resnet50_2': 'https://download.pytorch.org/models/wide_resnet50_2-95faca4d.pth', + 'wide_resnet101_2': 'https://download.pytorch.org/models/wide_resnet101_2-32ee1156.pth', + } + + # state_dict = load_state_dict_from_url(model_urls[arch], + # progress=progress) + state_dict = torch.hub.load_state_dict_from_url(model_urls[kwargs.get('arch')], + progress=progress) + model.load_state_dict(state_dict) + + return model + + +def resnet18(*, weights = None, progress: bool = True, **kwargs: Any) -> ResNetWithExtraModules: + """ResNet-18 from `Deep Residual Learning for Image Recognition `__. + + Args: + weights (:class:`~torchvision.models.ResNet18_Weights`, optional): The + pretrained weights to use. See + :class:`~torchvision.models.ResNet18_Weights` below for + more details, and possible values. By default, no pre-trained + weights are used. + progress (bool, optional): If True, displays a progress bar of the + download to stderr. Default is True. + **kwargs: parameters passed to the ``torchvision.models.resnet.ResNet`` + base class. Please refer to the `source code + `_ + for more details about this class. + + .. autoclass:: torchvision.models.ResNet18_Weights + :members: + """ + if is_torch_version_lower_than_17(): + kwargs["arch"] = "resnet18" + weights = None + else: + weights = ResNet18_Weights.verify(weights) + + return _resnet(BasicBlock, [2, 2, 2, 2], weights, progress, **kwargs) + + +def resnet34(*, weights = None, progress: bool = True, **kwargs: Any) -> ResNetWithExtraModules: + """ResNet-34 from `Deep Residual Learning for Image Recognition `__. + + Args: + weights (:class:`~torchvision.models.ResNet34_Weights`, optional): The + pretrained weights to use. See + :class:`~torchvision.models.ResNet34_Weights` below for + more details, and possible values. By default, no pre-trained + weights are used. + progress (bool, optional): If True, displays a progress bar of the + download to stderr. Default is True. + **kwargs: parameters passed to the ``torchvision.models.resnet.ResNet`` + base class. Please refer to the `source code + `_ + for more details about this class. + + .. autoclass:: torchvision.models.ResNet34_Weights + :members: + """ + if is_torch_version_lower_than_17(): + kwargs["arch"] = "resnet34" + weights = None + else: + weights = ResNet34_Weights.verify(weights) + + return _resnet(BasicBlock, [3, 4, 6, 3], weights, progress, **kwargs) + + +def resnet101(*, weights = None, progress: bool = True, **kwargs: Any) -> ResNetWithExtraModules: + """ResNet-101 from `Deep Residual Learning for Image Recognition `__. + + .. note:: + The bottleneck of TorchVision places the stride for downsampling to the second 3x3 + convolution while the original paper places it to the first 1x1 convolution. + This variant improves the accuracy and is known as `ResNet V1.5 + `_. + + Args: + weights (:class:`~torchvision.models.ResNet101_Weights`, optional): The + pretrained weights to use. See + :class:`~torchvision.models.ResNet101_Weights` below for + more details, and possible values. By default, no pre-trained + weights are used. + progress (bool, optional): If True, displays a progress bar of the + download to stderr. Default is True. + **kwargs: parameters passed to the ``torchvision.models.resnet.ResNet`` + base class. Please refer to the `source code + `_ + for more details about this class. + + .. autoclass:: torchvision.models.ResNet101_Weights + :members: + """ + if is_torch_version_lower_than_17(): + kwargs["arch"] = "resnet101" + weights = None + else: + weights = ResNet101_Weights.verify(weights) + return _resnet(Bottleneck, [3, 4, 23, 3], weights, progress, **kwargs) diff --git a/ACT_DP_multitask/detr/models/transformer.py b/ACT_DP_multitask/detr/models/transformer.py new file mode 100644 index 0000000000000000000000000000000000000000..3632f6fdf84bd1feb6952189536824f4796ae66f --- /dev/null +++ b/ACT_DP_multitask/detr/models/transformer.py @@ -0,0 +1,2358 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved +""" +DETR Transformer class. + +Copy-paste from torch.nn.Transformer with modifications: + * positional encodings are passed in MHattention + * extra LN at the end of encoder is removed + * decoder returns a stack of activations from all decoding layers +""" +import copy +from typing import Optional, List +import math +import torch +import torch.nn.functional as F +from torch import nn, Tensor + +import IPython + +e = IPython.embed + + +class Transformer(nn.Module): + + def __init__( + self, + d_model=512, + nhead=8, + num_encoder_layers=6, + num_decoder_layers=6, + dim_feedforward=2048, + dropout=0.1, + activation="relu", + normalize_before=False, + return_intermediate_dec=False, + ): + super().__init__() + + encoder_layer = TransformerEncoderLayer( + d_model, nhead, dim_feedforward, dropout, activation, normalize_before + ) + encoder_norm = nn.LayerNorm(d_model) if normalize_before else None + self.encoder = TransformerEncoder( + encoder_layer, num_encoder_layers, encoder_norm + ) + + decoder_layer = TransformerDecoderLayer( + d_model, nhead, dim_feedforward, dropout, activation, normalize_before + ) + decoder_norm = nn.LayerNorm(d_model) + self.decoder = TransformerDecoder( + decoder_layer, + num_decoder_layers, + decoder_norm, + return_intermediate=return_intermediate_dec, + ) + + self._reset_parameters() + + self.d_model = d_model + self.nhead = nhead + + def _reset_parameters(self): + for p in self.parameters(): + if p.dim() > 1: + nn.init.xavier_uniform_(p) + + def forward( + self, + src, + mask, + query_embed, + pos_embed, + latent_input=None, + proprio_input=None, + additional_pos_embed=None, + ): + # TODO flatten only when input has H and W + if len(src.shape) == 4: # has H and W + # flatten NxCxHxW to HWxNxC + bs, c, h, w = src.shape + src = src.flatten(2).permute(2, 0, 1) + pos_embed = pos_embed.flatten(2).permute(2, 0, 1).repeat(1, bs, 1) + query_embed = query_embed.unsqueeze(1).repeat(1, bs, 1) + # mask = mask.flatten(1) + + additional_pos_embed = additional_pos_embed.unsqueeze(1).repeat( + 1, bs, 1 + ) # seq, bs, dim + pos_embed = torch.cat([additional_pos_embed, pos_embed], axis=0) + + addition_input = torch.stack([latent_input, proprio_input], axis=0) + src = torch.cat([addition_input, src], axis=0) + else: + assert len(src.shape) == 3 + # flatten NxHWxC to HWxNxC + bs, hw, c = src.shape + src = src.permute(1, 0, 2) + pos_embed = pos_embed.unsqueeze(1).repeat(1, bs, 1) + query_embed = query_embed.unsqueeze(1).repeat(1, bs, 1) + + tgt = torch.zeros_like(query_embed) + memory = self.encoder(src, src_key_padding_mask=mask, pos=pos_embed) + shape_len = tgt.shape[0] + tgt_mask = torch.zeros(shape_len, shape_len).to(tgt.device) + tgt_mask[:100, 100:] = float( + "-inf" + ) + tgt_mask[100:, :100] = float("-inf") # + hs = self.decoder( + tgt, + memory, + tgt_mask, + memory_key_padding_mask=mask, + pos=pos_embed, + query_pos=query_embed, + ) + hs = hs.transpose(1, 2) + return hs + + +class Transformer_Denoise(nn.Module): + + def __init__( + self, + d_model=512, + nhead=8, + num_encoder_layers=6, + num_decoder_layers=6, + dim_feedforward=2048, + dropout=0.1, + activation="relu", + normalize_before=False, + return_intermediate_dec=False, + causal_mask=False, + ): + super().__init__() + + encoder_layer = TransformerEncoderLayer( + d_model, nhead, dim_feedforward, dropout, activation, normalize_before + ) + encoder_norm = nn.LayerNorm(d_model) if normalize_before else None + self.encoder = TransformerEncoder( + encoder_layer, num_encoder_layers, encoder_norm + ) + + decoder_layer = TransformerDecoderLayer( + d_model, nhead, dim_feedforward, dropout, activation, normalize_before + ) + decoder_norm = nn.LayerNorm(d_model) + self.decoder = TransformerDecoder( + decoder_layer, + num_decoder_layers, + decoder_norm, + return_intermediate=return_intermediate_dec, + ) + + self.time_embed = nn.Sequential( + nn.Linear(d_model, d_model), + nn.SiLU(), + nn.Linear(d_model, d_model), + ) + self._reset_parameters() + + self.action_embed = nn.Sequential( + nn.Linear(14, d_model), + nn.SiLU(), + nn.Linear(d_model, d_model), + ) + + self.d_model = d_model + self.nhead = nhead + self.denoise_step_pos_embed = nn.Embedding(1, d_model) + self.causal_mask = causal_mask + print("apply causal_mask:", causal_mask) + + def _reset_parameters(self): + for p in self.parameters(): + if p.dim() > 1: + nn.init.xavier_uniform_(p) + # self.transformer(src, None, self.query_embed.weight, pos, latent_input, proprio_input, self.additional_pos_embed.weight,actions, denoise_steps) + + def forward( + self, + src, + mask, + query_embed, + pos_embed, + latent_input=None, + proprio_input=None, + additional_pos_embed=None, + noisy_actions=None, + denoise_steps=None, + task_emb=None, # B D + ): + # TODO flatten only when input has H and W + # encoder don't need change + # src: image embedding mask: mask + # query_embed: decoder PE + # pos_embed: encoder PE + # latent_input: vae latent or tacile latent + # proprio_input: proprio + # additional_pos_embed: proprio + proprio + # denoise_embed: denoise timestep embedding + # noisy_actions: noisy actions + + if len(src.shape) == 4: # has H and W b d h (w n_view) + # flatten NxCxHxW to HWxNxC + bs, c, h, w = src.shape + src = src.flatten(2).permute(2, 0, 1) # h*w, bs, c + pos_embed = pos_embed.flatten(2).permute(2, 0, 1).repeat(1, bs, 1) + query_embed = query_embed.unsqueeze(1).repeat(1, bs, 1) # N bs dim + # mask = mask.flatten(1) + # N_add Dim + additional_pos_embed = additional_pos_embed.unsqueeze(1).repeat( + 1, bs, 1 + ) # seq, bs, dim + pos_embed = torch.cat([additional_pos_embed, pos_embed], axis=0) + + latent_input = latent_input.unsqueeze(1) if len(latent_input.shape) ==2 else latent_input # B 1 D + # print('latent_input', latent_input.shape, 'proprio_input', proprio_input.shape) + addition_input = torch.cat([latent_input, proprio_input], axis=1).permute( + 1, 0, 2 + ) # B T+1 D -> T+1 B D + if task_emb is not None: + addition_input = torch.cat([addition_input, task_emb.unsqueeze(0)], axis=0) ## concat task embedding to encoder T+2 B D + src = torch.cat([addition_input, src], axis=0) + else: + assert len(src.shape) == 3 + # flatten NxHWxC to HWxNxC + bs, hw, c = src.shape + src = src.permute(1, 0, 2) + pos_embed = pos_embed.unsqueeze(1).repeat(1, bs, 1) + query_embed = query_embed.unsqueeze(1).repeat(1, bs, 1) + + tgt = self.action_embed(noisy_actions).permute( + 1, 0, 2 + ) # TODO Change to noise tgt B T D -> T B D + denoise_embed = get_timestep_embedding(denoise_steps, self.d_model) # B -> B D + # print(denoise_embed.shape) + denoise_embed = self.time_embed(denoise_embed).unsqueeze(0) # B D -> 1 B D + memory = self.encoder( + src, src_key_padding_mask=mask, pos=pos_embed + ) # cross attention + denoise_step_pos_embed = self.denoise_step_pos_embed.weight.unsqueeze(1).repeat( + 1, bs, 1 + ) # 1 D -> 1 B D + memory = torch.cat([memory, denoise_embed], axis=0) + pos_embed = torch.cat([pos_embed, denoise_step_pos_embed], axis=0) + seq_len = tgt.shape[0] + if self.causal_mask: + tgt_mask = torch.triu( + torch.full((seq_len, seq_len), float("-inf")), diagonal=1 + ).to(tgt.device) + else: + tgt_mask = torch.zeros(seq_len, seq_len).to(tgt.device) + hs = self.decoder( + tgt, + memory, + tgt_mask, + memory_key_padding_mask=mask, + pos=pos_embed, + query_pos=query_embed, + ) # TODO + hs = hs.transpose(1, 2) # 1 T B D -> 1 B T D + return hs + + +class Transformer_Denoise_AdLN(nn.Module): + + def __init__( + self, + d_model=512, + nhead=8, + num_encoder_layers=6, + num_decoder_layers=6, + dim_feedforward=2048, + dropout=0.1, + activation="relu", + normalize_before=False, + return_intermediate_dec=False, + causal_mask=False, + ): + super().__init__() + + encoder_layer = TransformerEncoderLayer( + d_model, nhead, dim_feedforward, dropout, activation, normalize_before + ) + encoder_norm = nn.LayerNorm(d_model) if normalize_before else None + self.encoder = TransformerEncoder( + encoder_layer, num_encoder_layers, encoder_norm + ) + + decoder_layer = TransformerEncoderLayer_AdLN( + d_model, nhead, dim_feedforward, dropout, activation, normalize_before + ) + decoder_norm = nn.LayerNorm(d_model) + self.decoder = TransformerEncoder_AdLN( + decoder_layer, + num_decoder_layers, + decoder_norm, + ) + + self.time_embed = nn.Sequential( + nn.Linear(d_model, d_model), + nn.SiLU(), + nn.Linear(d_model, d_model), + ) + self._reset_parameters() + + self.action_embed = nn.Sequential( + nn.Linear(14, d_model), + nn.SiLU(), + nn.Linear(d_model, d_model), + ) + + self.d_model = d_model + self.nhead = nhead + self.denoise_step_pos_embed = nn.Embedding(1, d_model) + + self.global_1d_pool = nn.AdaptiveAvgPool1d(1) + self.norm_after_pool = nn.LayerNorm(d_model) + + def _reset_parameters(self): + for p in self.parameters(): + if p.dim() > 1: + nn.init.xavier_uniform_(p) + + def forward( + self, + src, # B D H W*num_view + mask, # None + query_embed, # H D + pos_embed, # 1 D H W*num_view + latent_input=None, # B 1 D + proprio_input=None, # B 1 D + additional_pos_embed=None, # 1+1 D + noisy_actions=None, # B H D + denoise_steps=None, # B + ): + + if len(src.shape) == 4: # has H and W b d h (w n_view) + # flatten NxCxHxW to HWxNxC + bs, c, h, w = src.shape + src = src.flatten(2).permute(2, 0, 1) # h*w, bs, c + pos_embed = pos_embed.flatten(2).permute(2, 0, 1).repeat(1, bs, 1) + query_embed = query_embed.unsqueeze(1).repeat(1, bs, 1) # N bs dim + # N_add Dim + additional_pos_embed = additional_pos_embed.unsqueeze(1).repeat( + 1, bs, 1 + ) # seq, bs, dim + pos_embed = torch.cat([additional_pos_embed, pos_embed], axis=0) # pro. + visual token PE + + latent_input = latent_input.unsqueeze(1) if len(latent_input.shape) ==2 else latent_input # B 1 D + addition_input = torch.cat([latent_input, proprio_input], axis=1).permute( + 1, 0, 2 + ) # B T+1 D -> T+1 B D + src = torch.cat([addition_input, src], axis=0) + else: + assert len(src.shape) == 3 + # flatten NxHWxC to HWxNxC + bs, hw, c = src.shape + src = src.permute(1, 0, 2) + pos_embed = pos_embed.unsqueeze(1).repeat(1, bs, 1) + query_embed = query_embed.unsqueeze(1).repeat(1, bs, 1) + + memory = self.encoder( + src, src_key_padding_mask=mask, pos=pos_embed + ) # cross attention N B D + # N B D => B D N => B D 1 => B D + memory = self.global_1d_pool(memory.permute(1, 2, 0)).squeeze(-1) # B D + memory = self.norm_after_pool(memory) + denoise_embed = get_timestep_embedding(denoise_steps, self.d_model) # B -> B D + denoise_embed = self.time_embed(denoise_embed) # B -> B D + + condition = memory + denoise_embed # B D as condition for modulation + + + tgt = self.action_embed(noisy_actions).permute( + 1, 0, 2 + ) + + hs = self.decoder( + tgt, + condition, + pos=query_embed + ) # Actually need is_pad? + + hs = hs.transpose(1, 2) # 1 T B D -> 1 B T D + return hs + +class Transformer_Denoise_Tactile(nn.Module): + def __init__( + self, + d_model=512, + nhead=8, + num_encoder_layers=6, + num_decoder_layers=6, + dim_feedforward=2048, + dropout=0.1, + activation="relu", + normalize_before=False, + return_intermediate_dec=False, + causal_mask=False, + ): + super().__init__() + encoder_layer = TransformerEncoderLayer( + d_model, nhead, dim_feedforward, dropout, activation, normalize_before + ) + encoder_norm = nn.LayerNorm(d_model) if normalize_before else None + self.encoder = TransformerEncoder( + encoder_layer, num_encoder_layers, encoder_norm + ) + + decoder_layer = TransformerDecoderLayer( + d_model, nhead, dim_feedforward, dropout, activation, normalize_before + ) + decoder_norm = nn.LayerNorm(d_model) + self.decoder = TransformerDecoder_alter( + decoder_layer, + num_decoder_layers, + decoder_norm, + return_intermediate=return_intermediate_dec, + ) + + self.time_embed = nn.Sequential( + nn.Linear(d_model, d_model), + nn.SiLU(), + nn.Linear(d_model, d_model), + ) + self._reset_parameters() + + self.action_embed = nn.Sequential( + nn.Linear(14, d_model), + nn.SiLU(), + nn.Linear(d_model, d_model), + ) + + self.d_model = d_model + self.nhead = nhead + self.denoise_step_pos_embed = nn.Embedding(1, d_model) + self.causal_mask = causal_mask + print("apply causal_mask:", causal_mask) + + def _reset_parameters(self): + for p in self.parameters(): + if p.dim() > 1: + nn.init.xavier_uniform_(p) + def forward( + self, + src, + mask, + query_embed, + pos_embed, + tactile_input, + tactile_pos, + proprio_input=None, + additional_pos_embed=None, + noisy_actions=None, + denoise_steps=None,): + + if len(src.shape) == 4: # has H and W b d h (w n_view) + # flatten NxCxHxW to HWxNxC + bs, c, h, w = src.shape + src = src.flatten(2).permute(2, 0, 1) # h*w, bs, c + pos_embed = pos_embed.flatten(2).permute(2, 0, 1).repeat(1, bs, 1) + query_embed = query_embed.unsqueeze(1).repeat(1, bs, 1) # N bs dim + + tactile_input = tactile_input.flatten(2).permute(2, 0, 1) # h*w*4, bs, c + tactile_pos = tactile_pos.flatten(2).permute(2, 0, 1).repeat(1, bs, 1) # h*w, bs, c + # N_add Dim + additional_pos_embed = additional_pos_embed.unsqueeze(1).repeat( + 1, bs, 1 + ) # seq, bs, dim + proprio_input = proprio_input.permute(1, 0, 2) # B 1 D -> 1 B D + # vision-state input + pos_embed = torch.cat([additional_pos_embed, pos_embed], axis=0) + src = torch.cat([proprio_input, src], axis=0) + # tactile-state input + # print('proprio_input', proprio_input.shape, 'tactile_input', tactile_input.shape) + src_tactile = torch.cat([proprio_input,tactile_input], axis=0) + src_tactile_pos = torch.cat([additional_pos_embed,tactile_pos], axis=0) + + tgt = self.action_embed(noisy_actions).permute( + 1, 0, 2 + ) # TODO Change to noise tgt B T D -> T B D + denoise_embed = get_timestep_embedding(denoise_steps, self.d_model) # B -> B D + denoise_embed = self.time_embed(denoise_embed).unsqueeze(0) # B D -> 1 B D + denoise_step_pos_embed = self.denoise_step_pos_embed.weight.unsqueeze(1).repeat( + 1, bs, 1 + ) # 1 D -> 1 B D + + # encoder vision-state information + memory = self.encoder( + src, src_key_padding_mask=mask, pos=pos_embed + ) # cross attention + memory = torch.cat([memory, denoise_embed], axis=0) + pos_embed = torch.cat([pos_embed, denoise_step_pos_embed], axis=0) + seq_len = tgt.shape[0] + + # tactile-state information + memory_tactile = torch.cat([src_tactile, denoise_embed], axis=0) + tactile_pos_embed = torch.cat([src_tactile_pos, denoise_step_pos_embed], axis=0) + tgt_mask = torch.zeros(seq_len, seq_len).to(tgt.device) + hs = self.decoder( + tgt, + memory, + memory_tactile, + tgt_mask, + memory_key_padding_mask=mask, + pos = pos_embed, + pos_alter = tactile_pos_embed, + query_pos=query_embed, + ) + hs = hs.transpose(1, 2) + return hs + + +class Transformer_diffusion_prediction(nn.Module): + def __init__( + self, + d_model=512, + nhead=8, + num_encoder_layers=6, + num_decoder_layers=6, + dim_feedforward=2048, + dropout=0.1, + activation="relu", + normalize_before=False, + return_intermediate_dec=False, + num_queries=100, + share_decoder=True, + patch_size=5, + diffusion_timestep_type="cat", + attention_type="v1", + predict_frame=16, + predict_only_last=False, + token_dim=6, + ): + super().__init__() + + encoder_layer = TransformerEncoderLayer( + d_model, nhead, dim_feedforward, dropout, activation, normalize_before + ) + encoder_norm = nn.LayerNorm(d_model) if normalize_before else None + self.encoder = TransformerEncoder( + encoder_layer, num_encoder_layers, encoder_norm + ) + + decoder_layer = TransformerDecoderLayer( + d_model, nhead, dim_feedforward, dropout, activation, normalize_before + ) + decoder_norm = nn.LayerNorm(d_model) + self.decoder = TransformerDecoder( + decoder_layer, + num_decoder_layers, + decoder_norm, + return_intermediate=return_intermediate_dec, + ) + if share_decoder == False: + self.decoder_token = copy.deepcopy(self.decoder) + + self.share_decoder = share_decoder + self.diffusion_timestep_type = diffusion_timestep_type + self.time_embed = nn.Sequential( + nn.Linear(d_model, d_model), + nn.SiLU(), + nn.Linear(d_model, d_model), + ) + self._reset_parameters() + + self.action_embed = nn.Sequential( + nn.Linear(14, d_model), # TODO action dim + nn.SiLU(), + nn.Linear(d_model, d_model), + ) + + self.token_embed = nn.Sequential( + nn.Linear( + token_dim * patch_size * patch_size, d_model + ), # Hardcode patch size * path size * patch dim + nn.SiLU(), + nn.Linear(d_model, d_model), + ) + + self.d_model = d_model + self.nhead = nhead + self.chunksize = num_queries + self.attention_type = attention_type + self.predict_frame = predict_frame + self.predict_only_last = predict_only_last + self.token_dim = token_dim + + def _reset_parameters(self): + for p in self.parameters(): + if p.dim() > 1: + nn.init.xavier_uniform_(p) + + # hs_action, hs_token = self.transformer(src, None, self.query_embed.weight, self.query_embed_toekn.weight, pos, latent_input, proprio_input, self.additional_pos_embed.weight, noisy_actions, noise_tokens, denoise_steps, self.denoise_step_pos_embed.weight) + def forward( + self, + src, + mask, + query_embed, + query_embed_token, + pos_embed, + latent_input=None, + proprio_input=None, + additional_pos_embed=None, + noisy_actions=None, + noisy_tokens=None, + denoise_steps=None, + denoise_step_pos_embed=None, + is_pad=None, + is_pad_token=None, + ): + # src: B D H W*num_view + # mask:None + # query_embed: chunksize D for action query + # query_embed_token: Num_tokens D for token query + # pos_embed: 1 D H W*num_view for current frame token + # latent_input: B D + # proprio_input: B T' D + # additional_pos_embed: B T'+1 D, include proprio and latent + # noisy_actions: B chunksize D + # noisy_tokens: B T' N D H' W', T' = predict_frame / temporal_compression_rate + # denoise_steps: B + # denoise_step_pos_embed:1 D + # is_pad: B chunksize + # is_pad_token: B T' N H' W' + if len(src.shape) == 4: # has H and W + bs, c, h, w = src.shape # B D H W*num_view*T + src = src.flatten(2).permute( + 2, 0, 1 + ) # H*W*num_view*T, bs, c deal with visual features from resnet + pos_embed = ( + pos_embed.flatten(2).permute(2, 0, 1).repeat(1, bs, 1) + ) # H*W*num_view*T, bs, c + + query_embed = query_embed.unsqueeze(1).repeat(1, bs, 1) + query_embed_token = query_embed_token.unsqueeze(1).repeat( + 1, bs, 1 + ) # TODO temporal-spatial position embedding FOR predicted frame token + # print('transformer query_embed_token', query_embed_token.shape,'query_embed', query_embed.shape) + query_embed_all = torch.cat( + [query_embed, query_embed_token], axis=0 + ) # chunksize + num_tokens, bs, c + additional_pos_embed = additional_pos_embed.unsqueeze(1).repeat( + 1, bs, 1 + ) # seq, bs, dim + pos_embed = torch.cat([additional_pos_embed, pos_embed], axis=0) + + latent_input = latent_input.unsqueeze(1) # B 1 D + addition_input = torch.cat([latent_input, proprio_input], axis=1).permute( + 1, 0, 2 + ) # B T+1 D -> T+1 B D + src = torch.cat([addition_input, src], axis=0) + + denoise_step_pos_embed = denoise_step_pos_embed.unsqueeze(1).repeat( + 1, bs, 1 + ) # 1 D -> 1 B D + else: + assert len(src.shape) == 3 + + # encoder + memory = self.encoder(src, src_key_padding_mask=mask, pos=pos_embed) + + # add denoise timestep embedding for decoder denoise step + denoise_embed = self.time_embed( + get_timestep_embedding(denoise_steps, self.d_model) + ).unsqueeze( + 0 + ) # B D -> 1 B D + if ( + self.diffusion_timestep_type == "cat" + ): # TODO add tokenizer visual token & timestep embedding + memory = torch.cat([memory, denoise_embed], axis=0) + pos_embed = torch.cat([pos_embed, denoise_step_pos_embed], axis=0) + # elif self.diffusion_timestep_type == 'vis_cat': + # memory = torch.cat([memory, visual_token, denoise_embed], axis=0) # visual token map + # pos_embed = torch.cat([pos_embed, pos_embed_visual_token, denoise_step_pos_embed], axis=0) # pos_embed_visual_token + else: + memory = memory + denoise_embed + + tgt_action = self.action_embed(noisy_actions).permute(1, 0, 2) # H B D + + # if noisy_tokens is None: + + # tgt_key_padding_mask = torch.zeros_like(tgt_token).sum(-1).bool() # T'*N*H'*W' B + + if self.share_decoder and noisy_tokens is not None: + noisy_tokens = noisy_tokens.permute( + 0, 1, 2, 4, 5, 3 + ) # B T' N D H' W' -> B T' N H' W' D + tgt_token = ( + self.token_embed(noisy_tokens) + .reshape(bs, -1, self.d_model) + .permute(1, 0, 2) + ) # B T' N H' W' D -> T'*N*H'*W' B D + # print('transformer',tgt_action.shape,tgt_token.shape ) + tgt = torch.cat([tgt_action, tgt_token], axis=0) # H1+ H2 B D + seq_len = tgt.shape[0] + bs = tgt.shape[1] + # tgt_key_padding_mask = torch.cat([is_pad, is_pad_token], axis=1) # B N+M is + is_pad_token_zero = torch.zeros_like(is_pad_token).bool() # HARD CODE + if is_pad is None: + is_pad = torch.zeros(bs, self.chunksize).bool().to(src.device) + tgt_key_padding_mask = torch.cat( + [is_pad, is_pad_token_zero], axis=1 + ) # HARD CODE avoid nan when all token is pad + + seq_len = tgt.shape[0] + tgt_mask = torch.zeros(seq_len, seq_len).to( + tgt.device + ) # chunksize + num_pred_token_per_frame*predict_frame + # TODO design mask + if self.attention_type == "v1": + tgt_mask[0 : self.chunksize, self.chunksize :] = float( + "-inf" + ) # action prediction cannot attend to token prediction + tgt_mask[self.chunksize :, self.predict_frame : self.chunksize] = float( + "-inf" + ) # token prediction cannot attend to action token after token prediction + elif self.attention_type == "v2": + tgt_mask[0 : self.chunksize, self.chunksize :] = float( + "-inf" + ) # action prediction cannot attend to token prediction + tgt_mask[self.chunksize :, 0 : self.chunksize] = float( + "-inf" + ) # token prediction cannot attend to action prediction + elif self.attention_type == "v3": # v1= v3 sad + tgt_mask[0 : self.chunksize, self.chunksize :] = float( + "-inf" + ) # action prediction cannot attend to token prediction + if ( + self.predict_frame < self.chunksize + ): # don't need attent action after target frame + tgt_mask[self.chunksize :, self.predict_frame : self.chunksize] = ( + float("-inf") + ) + elif self.attention_type == "causal": + predict_frame = noisy_tokens.shape[1] + num_pred_token_per_frame = tgt_token.shape[0] // predict_frame + chunk_size = query_embed.shape[0] + temporal_compression_rate = chunk_size // predict_frame + tgt_mask = torch.full((seq_len, seq_len), float("-inf")).to( + tgt.device + ) # seq_len = chunksize + num_pred_token_per_frame*predict_frame + # tgt_mask[:chunk_size, :chunk_size] = torch.triu(torch.full((chunk_size, seq_len), float('-inf')), diagonal=1).to(tgt.device) + tgt_mask[:chunk_size, :chunk_size] = torch.zeros( + (chunk_size, chunk_size) + ).to(tgt.device) + for t in range(predict_frame): + tgt_mask[ + chunk_size + + t * num_pred_token_per_frame : chunk_size + + (t + 1) * num_pred_token_per_frame, + 0 : t * temporal_compression_rate, + ] = 0 + tgt_mask[ + chunk_size + + t * num_pred_token_per_frame : chunk_size + + (t + 1) * num_pred_token_per_frame, + chunk_size : chunk_size + (t + 1) * num_pred_token_per_frame, + ] = 0 + # print(tgt.shape, memory.shape, tgt_key_padding_mask.shape, pos_embed.shape, query_embed_all.shape) + hs = self.decoder( + tgt, + memory, + tgt_mask, + memory_key_padding_mask=mask, + tgt_key_padding_mask=tgt_key_padding_mask, + pos=pos_embed, + query_pos=query_embed_all, + ).transpose(1, 2)[ + 0 + ] # TODO 1 H B D + hs_action = hs[:, : self.chunksize] # B H D + hs_token = hs[:, self.chunksize :] # B H D -> None + # simplify version + # tgt_mask = torch.full((seq_len, seq_len), float('-inf')).to(tgt.device) + # # Action-to-Action mask + # tgt_mask[:chunk_size, :chunk_size] = torch.triu(torch.full((chunk_size, chunk_size), float('-inf')), diagonal=1) + + # # Frame-to-All mask + # frame_indices = torch.arange(chunk_size, seq_len).view(predict_frame, num_pred_token_per_frame).to(tgt.device) + # action_indices = torch.arange(chunk_size).to(tgt.device) + + # for t in range(predict_frame): + # tgt_mask[frame_indices[t], action_indices[:t * temporal_compression_rate]] = 0 + # tgt_mask[frame_indices[t], frame_indices[:t + 1].flatten()] = 0 + + elif self.share_decoder and noisy_tokens is None: + hs = self.decoder( + tgt_action, + memory, + memory_key_padding_mask=mask, + tgt_key_padding_mask=is_pad, + pos=pos_embed, + query_pos=query_embed, + ).transpose(1, 2)[0] + hs_action = hs + hs_token = None + + else: + hs_action = self.decoder( + tgt_action, + memory, + memory_key_padding_mask=mask, + tgt_key_padding_mask=is_pad, + pos=pos_embed, + query_pos=query_embed, + ).transpose(1, 2)[0] + hs_token = self.decoder_token( + tgt_token, + memory, + memory_key_padding_mask=mask, # tgt_key_padding_mask = is_pad_token, # is_pad_token is nonetype ! fix + pos=pos_embed, + query_pos=query_embed_token, + ).transpose(1, 2)[0] + # hs_action B chunksize D + # hs_token B T'*N*H'*W' D + # print("tgt_token has NaN:", torch.isnan(tgt_token).any()) + # print("memory has NaN:", torch.isnan(memory).any()) + # print('tgt_key_padding_mask has NaN:', torch.isnan(is_pad_token).any()) + # print("pos_embed has NaN:", torch.isnan(pos_embed).any()) + # print("query_embed_token has NaN:", torch.isnan(query_embed_token).any()) + # print("hs_token has NaN:", torch.isnan(hs_token).any()) # has nan + return hs_action, hs_token + + +class Transformer_diffusion_prediction_with_dual_visual_token(nn.Module): + def __init__( + self, + d_model=512, + nhead=8, + num_encoder_layers=6, + num_decoder_layers=6, + dim_feedforward=2048, + dropout=0.1, + activation="relu", + normalize_before=False, + return_intermediate_dec=False, + num_queries=100, + share_decoder=True, + patch_size=5, + diffusion_timestep_type="cat", + attention_type="v1", + predict_frame=16, + predict_only_last=False, + token_dim=6, + ): + super().__init__() + + encoder_layer = TransformerEncoderLayer( + d_model, nhead, dim_feedforward, dropout, activation, normalize_before + ) + encoder_norm = nn.LayerNorm(d_model) if normalize_before else None + self.encoder = TransformerEncoder( + encoder_layer, num_encoder_layers, encoder_norm + ) + + decoder_layer = TransformerDecoderLayer( + d_model, nhead, dim_feedforward, dropout, activation, normalize_before + ) + decoder_norm = nn.LayerNorm(d_model) + self.decoder = TransformerDecoder( + decoder_layer, + num_decoder_layers, + decoder_norm, + return_intermediate=return_intermediate_dec, + ) + if share_decoder == False: + self.decoder_token = copy.deepcopy(self.decoder) + + self.share_decoder = share_decoder + self.diffusion_timestep_type = diffusion_timestep_type + self.time_embed = nn.Sequential( + nn.Linear(d_model, d_model), + nn.SiLU(), + nn.Linear(d_model, d_model), + ) + self._reset_parameters() + # TODO change action dim + token_dim = token_dim + self.action_embed = nn.Sequential( + nn.Linear(14, d_model), + nn.SiLU(), + nn.Linear(d_model, d_model), + ) + + self.token_embed = nn.Sequential( + nn.Linear( + token_dim * patch_size * patch_size, d_model + ), # Hardcode patch size * path size * patch dim + nn.SiLU(), + nn.Linear(d_model, d_model), + ) + + self.d_model = d_model + self.nhead = nhead + self.chunksize = num_queries + self.attention_type = attention_type + self.predict_frame = predict_frame + self.predict_only_last = predict_only_last + + def _reset_parameters(self): + for p in self.parameters(): + if p.dim() > 1: + nn.init.xavier_uniform_(p) + + # hs_action, hs_token = self.transformer(src, None, self.query_embed.weight, self.query_embed_toekn.weight, pos, latent_input, proprio_input, self.additional_pos_embed.weight, noisy_actions, noise_tokens, denoise_steps, self.denoise_step_pos_embed.weight) + def forward( + self, + src, + mask, + query_embed, + query_embed_token, + pos_embed, + latent_input=None, + proprio_input=None, + additional_pos_embed=None, + noisy_actions=None, + noisy_tokens=None, + denoise_steps=None, + denoise_step_pos_embed=None, + is_pad=None, + is_pad_token=None, + addition_visual_token=None, + addition_visual_token_pos=None, + ): + # src: B D H W*num_view*T + # mask:None + # query_embed: chunksize D for action query + # query_embed_token: Num_tokens D for token query + # pos_embed: 1 D H W*num_view*T for current frame token + # latent_input: B D + # proprio_input: B T' D + # additional_pos_embed: T'+1 D, include proprio and latent + # noisy_actions: B chunksize D + # noisy_tokens: B T' N D H' W', T' = predict_frame / temporal_compression_rate + # denoise_steps: B + # denoise_step_pos_embed:1 D + # is_pad: B chunksize + # is_pad_token: B T' N H' W' + # addition_visual_token: B D H' W'*num_view + # addition_visual_token_pos: H'*W'*num_view D for visual token position embedding + + if len(src.shape) == 4: # has H and W + bs, c, h, w = src.shape # B D H W*num_view*T + # For encoder + src = src.flatten(2).permute( + 2, 0, 1 + ) # H*W*num_view*T, bs, c deal with visual features from resnet + addition_visual_token = addition_visual_token.flatten(2).permute( + 2, 0, 1 + ) # H*W*num_view, bs, c visual token from tokenzier + latent_input = latent_input.unsqueeze(1) # B 1 D + addition_input = torch.cat([latent_input, proprio_input], axis=1).permute( + 1, 0, 2 + ) # B T+1 D -> T+1 B D + + pos_embed = ( + pos_embed.flatten(2).permute(2, 0, 1).repeat(1, bs, 1) + ) # H*W*num_view*T, bs, c + addition_visual_token_pos = addition_visual_token_pos.unsqueeze(1).repeat( + 1, bs, 1 + ) # H*W*num_view, bs, c + additional_pos_embed = additional_pos_embed.unsqueeze(1).repeat( + 1, bs, 1 + ) # seq, bs, dim + + # only consider the visual token from resnet for encoder, robot state + visual token + pos_embed = torch.cat([additional_pos_embed, pos_embed], axis=0) + src = torch.cat([addition_input, src], axis=0) + + # For decoder + query_embed = query_embed.unsqueeze(1).repeat(1, bs, 1) + query_embed_token = query_embed_token.unsqueeze(1).repeat( + 1, bs, 1 + ) # TODO temporal-spatial position embedding FOR predicted frame token + query_embed_all = torch.cat( + [query_embed, query_embed_token], axis=0 + ) # chunksize + num_tokens, bs, c + + denoise_step_pos_embed = denoise_step_pos_embed.unsqueeze(1).repeat( + 1, bs, 1 + ) # 1 D -> 1 B D + else: + assert len(src.shape) == 3 + + # encoder TODO add visual token from tokenizer here? + memory = self.encoder( + src, src_key_padding_mask=mask, pos=pos_embed + ) # resnet visual feature + proprio + latent + + # add denoise timestep embedding for decoder denoise step or current visual token + denoise_embed = self.time_embed( + get_timestep_embedding(denoise_steps, self.d_model) + ).unsqueeze( + 0 + ) # B D -> 1 B D + if ( + self.diffusion_timestep_type == "cat" + ): # only resnet visual token + proprio + latent + timestep embedding + memory = torch.cat([memory, denoise_embed], axis=0) + pos_embed = torch.cat([pos_embed, denoise_step_pos_embed], axis=0) + elif ( + self.diffusion_timestep_type == "vis_cat" + ): # add visual token from tokenizer + memory = torch.cat( + [memory, addition_visual_token, denoise_embed], axis=0 + ) # visual token map + pos_embed = torch.cat( + [pos_embed, addition_visual_token_pos, denoise_step_pos_embed], axis=0 + ) # pos_embed_visual_token + elif self.diffusion_timestep_type == "add": + memory = memory + denoise_embed + + # Noisy action and token + tgt_action = self.action_embed(noisy_actions).permute(1, 0, 2) # H B D + noisy_tokens = noisy_tokens.permute( + 0, 1, 2, 4, 5, 3 + ) # B T' N D H' W' -> B T' N H' W' D + tgt_token = ( + self.token_embed(noisy_tokens) + .reshape(bs, -1, self.d_model) + .permute(1, 0, 2) + ) # B T' N H' W' D -> T'*N*H'*W' B D + + if self.share_decoder: + tgt = torch.cat([tgt_action, tgt_token], axis=0) # H1+ H2 B D + seq_len = tgt.shape[0] + bs = tgt.shape[1] + is_pad_token_zero = torch.zeros_like(is_pad_token).bool() # HARD CODE + if is_pad is None: # if action is pad + is_pad = torch.zeros(bs, self.chunksize).bool().to(src.device) + tgt_key_padding_mask = torch.cat( + [is_pad, is_pad_token_zero], axis=1 + ) # HARD CODE avoid nan when all token is pad add causal mechanism + seq_len = tgt.shape[0] + tgt_mask = torch.zeros(seq_len, seq_len).to(tgt.device) + # TODO design mask + if self.attention_type == "v1": + tgt_mask[0 : self.chunksize, self.chunksize :] = float( + "-inf" + ) # action prediction cannot attend to token prediction + tgt_mask[self.chunksize :, self.predict_frame : self.chunksize] = float( + "-inf" + ) # token prediction cannot attend to action token after token prediction + elif self.attention_type == "v2": + tgt_mask[0 : self.chunksize, self.chunksize :] = float( + "-inf" + ) # action prediction cannot attend to token prediction + tgt_mask[self.chunksize :, 0 : self.chunksize] = float( + "-inf" + ) # token prediction cannot attend to action prediction + elif self.attention_type == "v3": + tgt_mask[0 : self.chunksize, self.chunksize :] = float( + "-inf" + ) # action prediction cannot attend to token prediction + + hs = self.decoder( + tgt, + memory, + tgt_mask, + memory_key_padding_mask=mask, + tgt_key_padding_mask=tgt_key_padding_mask, + pos=pos_embed, + query_pos=query_embed_all, + ).transpose(1, 2)[ + 0 + ] # TODO 1 H B D + hs_action = hs[:, : self.chunksize] # B H D + hs_token = hs[:, self.chunksize :] # B H D + else: + hs_action = self.decoder( + tgt_action, + memory, + memory_key_padding_mask=mask, + tgt_key_padding_mask=is_pad, + pos=pos_embed, + query_pos=query_embed, + ).transpose(1, 2)[0] + hs_token = self.decoder_token( + tgt_token, + memory, + memory_key_padding_mask=mask, # tgt_key_padding_mask = is_pad_token, # is_pad_token is nonetype ! fix + pos=pos_embed, + query_pos=query_embed_token, + ).transpose(1, 2)[0] + return hs_action, hs_token + + +class Transformer_diffusion_prediction_pixel(nn.Module): + def __init__( + self, + d_model=512, + nhead=8, + num_encoder_layers=6, + num_decoder_layers=6, + dim_feedforward=2048, + dropout=0.1, + activation="relu", + normalize_before=False, + return_intermediate_dec=False, + num_queries=100, + share_decoder=True, + patch_size=5, + diffusion_timestep_type="cat", + attention_type="v1", + predict_frame=16, + predict_only_last=False, + ): + super().__init__() + + encoder_layer = TransformerEncoderLayer( + d_model, nhead, dim_feedforward, dropout, activation, normalize_before + ) + encoder_norm = nn.LayerNorm(d_model) if normalize_before else None + self.encoder = TransformerEncoder( + encoder_layer, num_encoder_layers, encoder_norm + ) + + decoder_layer = TransformerDecoderLayer( + d_model, nhead, dim_feedforward, dropout, activation, normalize_before + ) + decoder_norm = nn.LayerNorm(d_model) + self.decoder = TransformerDecoder( + decoder_layer, + num_decoder_layers, + decoder_norm, + return_intermediate=return_intermediate_dec, + ) + if share_decoder == False: + self.decoder_token = copy.deepcopy(self.decoder) + + self.share_decoder = share_decoder + self.diffusion_timestep_type = diffusion_timestep_type + self.time_embed = nn.Sequential( + nn.Linear(d_model, d_model), + nn.SiLU(), + nn.Linear(d_model, d_model), + ) + self._reset_parameters() + + self.action_embed = nn.Sequential( + nn.Linear(14, d_model), + nn.SiLU(), + nn.Linear(d_model, d_model), + ) + + self.token_embed = nn.Sequential( + nn.Linear( + 3 * patch_size * patch_size, d_model + ), # Hardcode patch size * path size * patch dim + nn.SiLU(), + nn.Linear(d_model, d_model), + ) + + self.d_model = d_model + self.nhead = nhead + self.chunksize = num_queries + self.attention_type = attention_type + self.predict_frame = predict_frame + self.predict_only_last = predict_only_last + + def _reset_parameters(self): + for p in self.parameters(): + if p.dim() > 1: + nn.init.xavier_uniform_(p) + + # hs_action, hs_token = self.transformer(src, None, self.query_embed.weight, self.query_embed_toekn.weight, pos, latent_input, proprio_input, self.additional_pos_embed.weight, noisy_actions, noise_tokens, denoise_steps, self.denoise_step_pos_embed.weight) + def forward( + self, + src, + mask, + query_embed, + query_embed_token, + pos_embed, + latent_input=None, + proprio_input=None, + additional_pos_embed=None, + noisy_actions=None, + noisy_tokens=None, + denoise_steps=None, + denoise_step_pos_embed=None, + is_pad=None, + is_pad_token=None, + ): + # src: B D H W*num_view*T + # mask:None + # query_embed: chunksize D for action query + # query_embed_token: Num_tokens D for token query + # pos_embed: 1 D H W*num_view*T for current frame token + # latent_input: B D + # proprio_input: B T' D + # additional_pos_embed: B T'+1 D, include proprio and latent + # noisy_actions: B chunksize D + # noisy_tokens: B T' N D H' W', T' = predict_frame / temporal_compression_rate + # denoise_steps: B + # denoise_step_pos_embed:1 D + # is_pad: B chunksize + # is_pad_token: B T' N H' W' + if len(src.shape) == 4: # has H and W + bs, c, h, w = src.shape # B D H W*num_view*T + src = src.flatten(2).permute( + 2, 0, 1 + ) # H*W*num_view*T, bs, c deal with visual features from resnet + # ?TODO check the shape of pos_embed + pos_embed = ( + pos_embed.flatten(2).permute(2, 0, 1).repeat(1, bs, 1) + ) # H*W*num_view*T, bs, c + + query_embed = query_embed.unsqueeze(1).repeat(1, bs, 1) + query_embed_token = query_embed_token.unsqueeze(1).repeat( + 1, bs, 1 + ) # TODO temporal-spatial position embedding FOR predicted frame token + query_embed_all = torch.cat( + [query_embed, query_embed_token], axis=0 + ) # chunksize + num_tokens, bs, c + additional_pos_embed = additional_pos_embed.unsqueeze(1).repeat( + 1, bs, 1 + ) # seq, bs, dim + pos_embed = torch.cat([additional_pos_embed, pos_embed], axis=0) + + latent_input = latent_input.unsqueeze(1) # B 1 D + addition_input = torch.cat([latent_input, proprio_input], axis=1).permute( + 1, 0, 2 + ) # B T+1 D -> T+1 B D + src = torch.cat([addition_input, src], axis=0) + + denoise_step_pos_embed = denoise_step_pos_embed.unsqueeze(1).repeat( + 1, bs, 1 + ) # 1 D -> 1 B D + else: + assert len(src.shape) == 3 + + # encoder + memory = self.encoder(src, src_key_padding_mask=mask, pos=pos_embed) + + # add denoise timestep embedding for decoder denoise step + denoise_embed = self.time_embed( + get_timestep_embedding(denoise_steps, self.d_model) + ).unsqueeze( + 0 + ) # B D -> 1 B D + if ( + self.diffusion_timestep_type == "cat" + ): # TODO add tokenizer visual token & timestep embedding + memory = torch.cat([memory, denoise_embed], axis=0) + pos_embed = torch.cat([pos_embed, denoise_step_pos_embed], axis=0) + # elif self.diffusion_timestep_type == 'vis_cat': + # memory = torch.cat([memory, visual_token, denoise_embed], axis=0) # visual token map + # pos_embed = torch.cat([pos_embed, pos_embed_visual_token, denoise_step_pos_embed], axis=0) # pos_embed_visual_token + else: + memory = memory + denoise_embed + + tgt_action = self.action_embed(noisy_actions).permute(1, 0, 2) # H B D + noisy_tokens = noisy_tokens.permute( + 0, 1, 2, 4, 5, 3 + ) # B T' N D H' W' -> B T' N H' W' D + tgt_token = ( + self.token_embed(noisy_tokens) + .reshape(bs, -1, self.d_model) + .permute(1, 0, 2) + ) # B T' N H' W' D -> T'*N*H'*W' B D + # tgt_key_padding_mask = torch.zeros_like(tgt_token).sum(-1).bool() # T'*N*H'*W' B + + if self.share_decoder: + tgt = torch.cat([tgt_action, tgt_token], axis=0) # H1+ H2 B D + seq_len = tgt.shape[0] + bs = tgt.shape[1] + # tgt_key_padding_mask = torch.cat([is_pad, is_pad_token], axis=1) # B N+M is + is_pad_token_zero = torch.zeros_like(is_pad_token).bool() # HARD CODE + if is_pad is None: + is_pad = torch.zeros(bs, self.chunksize).bool().to(src.device) + tgt_key_padding_mask = torch.cat( + [is_pad, is_pad_token_zero], axis=1 + ) # HARD CODE avoid nan when all token is pad + + seq_len = tgt.shape[0] + tgt_mask = torch.zeros(seq_len, seq_len).to(tgt.device) + # TODO design mask + if self.attention_type == "v1": + tgt_mask[0 : self.chunksize, self.chunksize :] = float( + "-inf" + ) # action prediction cannot attend to token prediction + tgt_mask[self.chunksize :, self.predict_frame : self.chunksize] = float( + "-inf" + ) # token prediction cannot attend to action token after token prediction + elif self.attention_type == "v2": + tgt_mask[0 : self.chunksize, self.chunksize :] = float( + "-inf" + ) # action prediction cannot attend to token prediction + tgt_mask[self.chunksize :, 0 : self.chunksize] = float( + "-inf" + ) # token prediction cannot attend to action prediction + elif self.attention_type == "v3": + tgt_mask[0 : self.chunksize, self.chunksize :] = float( + "-inf" + ) # action prediction cannot attend to token prediction + + # print(tgt.shape, memory.shape, tgt_key_padding_mask.shape, pos_embed.shape, query_embed_all.shape) + hs = self.decoder( + tgt, + memory, + tgt_mask, + memory_key_padding_mask=mask, + tgt_key_padding_mask=tgt_key_padding_mask, + pos=pos_embed, + query_pos=query_embed_all, + ).transpose(1, 2)[ + 0 + ] # TODO 1 H B D + hs_action = hs[:, : self.chunksize] # B H D + hs_token = hs[:, self.chunksize :] # B H D + else: + hs_action = self.decoder( + tgt_action, + memory, + memory_key_padding_mask=mask, + tgt_key_padding_mask=is_pad, + pos=pos_embed, + query_pos=query_embed, + ).transpose(1, 2)[0] + hs_token = self.decoder_token( + tgt_token, + memory, + memory_key_padding_mask=mask, # tgt_key_padding_mask = is_pad_token, # is_pad_token is nonetype ! fix + pos=pos_embed, + query_pos=query_embed_token, + ).transpose(1, 2)[0] + # hs_action B chunksize D + # hs_token B T'*N*H'*W' D + # print("tgt_token has NaN:", torch.isnan(tgt_token).any()) + # print("memory has NaN:", torch.isnan(memory).any()) + # print('tgt_key_padding_mask has NaN:', torch.isnan(is_pad_token).any()) + # print("pos_embed has NaN:", torch.isnan(pos_embed).any()) + # print("query_embed_token has NaN:", torch.isnan(query_embed_token).any()) + # print("hs_token has NaN:", torch.isnan(hs_token).any()) # has nan + return hs_action, hs_token + + +class Transformer_diffusion(nn.Module): + + def __init__( + self, + d_model=512, + nhead=8, + num_encoder_layers=6, + num_decoder_layers=6, + dim_feedforward=2048, + dropout=0.1, + activation="relu", + normalize_before=False, + return_intermediate_dec=False, + jpeg_dim=400, + num_jpeg=80, + ): + super().__init__() + + encoder_layer = TransformerEncoderLayer( + d_model, nhead, dim_feedforward, dropout, activation, normalize_before + ) + encoder_norm = nn.LayerNorm(d_model) if normalize_before else None + self.encoder = TransformerEncoder( + encoder_layer, num_encoder_layers, encoder_norm + ) + + decoder_layer = TransformerDecoderLayer( + d_model, nhead, dim_feedforward, dropout, activation, normalize_before + ) + decoder_norm = nn.LayerNorm(d_model) + self.decoder = TransformerDecoder( + decoder_layer, + num_decoder_layers, + decoder_norm, + return_intermediate=return_intermediate_dec, + ) + self.time_embed = nn.Sequential( + nn.Linear(d_model, d_model), + nn.SiLU(), + nn.Linear(d_model, d_model), + ) + self.action_embed = nn.Sequential( + nn.Linear(14, d_model), + nn.SiLU(), + nn.Linear(d_model, d_model), + ) + self.jpeg_embed = nn.Sequential( + nn.Linear(jpeg_dim, d_model), + nn.SiLU(), + nn.Linear(d_model, d_model), + ) + self._reset_parameters() + self.d_model = d_model + self.nhead = nhead + self.num_jpeg = num_jpeg + + def _reset_parameters(self): + for p in self.parameters(): + if p.dim() > 1: + nn.init.xavier_uniform_(p) + + def forward( + self, + src, + mask, + query_embed, + pos_embed, + latent_input=None, + proprio_input=None, + additional_pos_embed=None, + noisy_actions=None, + noisy_jpegs=None, + denoise_steps=None, + ): + # TODO flatten only when input has H and W + if len(src.shape) == 4: # has H and W + # flatten NxCxHxW to HWxNxC + bs, c, h, w = src.shape + src = src.flatten(2).permute(2, 0, 1) + pos_embed = pos_embed.flatten(2).permute(2, 0, 1).repeat(1, bs, 1) + query_embed = query_embed.unsqueeze(1).repeat(1, bs, 1) + # mask = mask.flatten(1) + + additional_pos_embed = additional_pos_embed.unsqueeze(1).repeat( + 1, bs, 1 + ) # seq, bs, dim + pos_embed = torch.cat([additional_pos_embed, pos_embed], axis=0) + + addition_input = torch.stack([latent_input, proprio_input], axis=0) + src = torch.cat([addition_input, src], axis=0) + else: + assert len(src.shape) == 3 + # flatten NxHWxC to HWxN + tgt_action = self.action_embed(noisy_actions).permute(1, 0, 2) # B H D + tgt_jpeg = self.jpeg_embed(noisy_jpegs).permute(1, 0, 2) # B H. D + tgt = torch.cat([tgt_action, tgt_jpeg], axis=0) # H1+ H2 B D + + denoise_embed = get_timestep_embedding(denoise_steps, self.d_model) # B -> B D + denoise_embed = self.time_embed(denoise_embed) + memory = self.encoder( + src, src_key_padding_mask=mask, pos=pos_embed + ) # cross attentionc# TODO mak mask to focus on self-time-step + memory = memory + denoise_embed.unsqueeze(0) + seq_len = tgt.shape[0] + tgt_mask = torch.zeros(seq_len, seq_len).to(tgt.device) + tgt_mask[: -self.num_jpeg, -self.num_jpeg :] = float( + "-inf" + ) + tgt_mask[-self.num_jpeg :, : -self.num_jpeg] = float("-inf") # + hs = self.decoder( + tgt, + memory, + tgt_mask, + memory_key_padding_mask=mask, + pos=pos_embed, + query_pos=query_embed, + ) # TODO 1 H B D + hs = hs.transpose(1, 2) # 1 B H D + return hs + + +class Transformer_diffusion_seperate(nn.Module): + + def __init__( + self, + d_model=512, + nhead=8, + num_encoder_layers=6, + num_decoder_layers=6, + dim_feedforward=2048, + dropout=0.1, + activation="relu", + normalize_before=False, + return_intermediate_dec=False, + share_decoder=False, + ): + super().__init__() + + encoder_layer = TransformerEncoderLayer( + d_model, nhead, dim_feedforward, dropout, activation, normalize_before + ) + encoder_norm = nn.LayerNorm(d_model) if normalize_before else None + self.encoder = TransformerEncoder( + encoder_layer, num_encoder_layers, encoder_norm + ) + + decoder_layer = TransformerDecoderLayer( + d_model, nhead, dim_feedforward, dropout, activation, normalize_before + ) + decoder_norm = nn.LayerNorm(d_model) + self.decoder_action = TransformerDecoder( + decoder_layer, + num_decoder_layers, + decoder_norm, + return_intermediate=return_intermediate_dec, + ) + # self.decoder_jpeg = TransformerDecoder(decoder_layer, num_decoder_layers, decoder_norm, + # return_intermediate=return_intermediate_dec) + if share_decoder: + self.decoder_jpeg = self.decoder_action + else: + self.decoder_jpeg = copy.deepcopy(self.decoder_action) + + self.time_embed = nn.Sequential( + nn.Linear(d_model, d_model), + nn.SiLU(), + nn.Linear(d_model, d_model), + ) + self.action_embed = nn.Sequential( + nn.Linear(14, d_model), + nn.SiLU(), + nn.Linear(d_model, d_model), + ) + + self._reset_parameters() + self.d_model = d_model + self.nhead = nhead + + def _reset_parameters(self): + for p in self.parameters(): + if p.dim() > 1: + nn.init.xavier_uniform_(p) + + def forward( + self, + src, + mask, + query_embed_action=None, + query_embed_jpeg=None, + pos_embed=None, + latent_input=None, + proprio_input=None, + additional_pos_embed=None, + noisy_actions=None, + denoise_steps=None, + ): + # TODO flatten only when input has H and W + if len(src.shape) == 4: # has H and W + # flatten NxCxHxW to HWxNxC + bs, c, h, w = src.shape + src = src.flatten(2).permute(2, 0, 1) + pos_embed = pos_embed.flatten(2).permute(2, 0, 1).repeat(1, bs, 1) + query_embed_action = query_embed_action.unsqueeze(1).repeat(1, bs, 1) + query_embed_jpeg = query_embed_jpeg.unsqueeze(1).repeat(1, bs, 1) + # mask = mask.flatten(1) + + additional_pos_embed = additional_pos_embed.unsqueeze(1).repeat( + 1, bs, 1 + ) # seq, bs, dim + pos_embed = torch.cat([additional_pos_embed, pos_embed], axis=0) + + addition_input = torch.stack([latent_input, proprio_input], axis=0) + src = torch.cat([addition_input, src], axis=0) + else: + assert len(src.shape) == 3 + # flatten NxHWxC to HWxN + + memory = self.encoder( + src, src_key_padding_mask=mask, pos=pos_embed + ) # cross attentionc# TODO mak mask to focus on self-time-step + denoise_embed = get_timestep_embedding(denoise_steps, self.d_model) # B -> B D + denoise_embed = self.time_embed(denoise_embed) + memory_action = memory + denoise_embed.unsqueeze(0) + + tgt_action = self.action_embed(noisy_actions).permute(1, 0, 2) # B H D + # print('tgt_action', tgt_action.shape) + hs_action = self.decoder_action( + tgt_action, + memory_action, + memory_key_padding_mask=mask, + pos=pos_embed, + query_pos=query_embed_action, + ) + # print('hs_action before', hs_action.shape) + hs_action = hs_action.transpose(1, 2) # 1 B H D + # print('hs_action after', hs_action.shape) + tgt_jpeg = torch.zeros_like(query_embed_jpeg).cuda() + # print('tgt_jpeg', tgt_jpeg.shape) + hs_jpeg = self.decoder_jpeg( + tgt_jpeg, + memory, + memory_key_padding_mask=mask, + pos=pos_embed, + query_pos=query_embed_jpeg, + ) + hs_jpeg = hs_jpeg.transpose(1, 2) # 1 B H D + + return hs_action[0], hs_jpeg[0] + + def inference_only_actin( + self, + src, + mask, + query_embed_action=None, + pos_embed=None, + latent_input=None, + proprio_input=None, + additional_pos_embed=None, + noisy_actions=None, + denoise_steps=None, + ): + if len(src.shape) == 4: # has H and W + # flatten NxCxHxW to HWxNxC + bs, c, h, w = src.shape + src = src.flatten(2).permute(2, 0, 1) + pos_embed = pos_embed.flatten(2).permute(2, 0, 1).repeat(1, bs, 1) + query_embed_action = query_embed_action.unsqueeze(1).repeat(1, bs, 1) + additional_pos_embed = additional_pos_embed.unsqueeze(1).repeat( + 1, bs, 1 + ) # seq, bs, dim + pos_embed = torch.cat([additional_pos_embed, pos_embed], axis=0) + + addition_input = torch.stack([latent_input, proprio_input], axis=0) + src = torch.cat([addition_input, src], axis=0) + else: + assert len(src.shape) == 3 + + memory = self.encoder( + src, src_key_padding_mask=mask, pos=pos_embed + ) # cross attentionc# TODO mak mask to focus on self-time-step + denoise_embed = get_timestep_embedding(denoise_steps, self.d_model) # B -> B D + denoise_embed = self.time_embed(denoise_embed) + memory_action = memory + denoise_embed.unsqueeze(0) + + tgt_action = self.action_embed(noisy_actions).permute(1, 0, 2) # B H D + + hs_action = self.decoder_action( + tgt_action, + memory_action, + memory_key_padding_mask=mask, + pos=pos_embed, + query_pos=query_embed_action, + ) + return hs_action, None + + +class TransformerEncoder(nn.Module): + + def __init__(self, encoder_layer, num_layers, norm=None): + super().__init__() + self.layers = _get_clones(encoder_layer, num_layers) + self.num_layers = num_layers + self.norm = norm + + def forward( + self, + src, + mask: Optional[Tensor] = None, + src_key_padding_mask: Optional[Tensor] = None, + pos: Optional[Tensor] = None, + ): + output = src + + for layer in self.layers: + output = layer( + output, + src_mask=mask, + src_key_padding_mask=src_key_padding_mask, + pos=pos, + ) + + if self.norm is not None: + output = self.norm(output) + + return output + +class TransformerEncoder_AdLN(nn.Module): + + def __init__(self, encoder_layer, num_layers, norm=None): + super().__init__() + self.layers = _get_clones(encoder_layer, num_layers) + self.num_layers = num_layers + self.norm = norm + + def forward( + self, + src, + condition, + mask: Optional[Tensor] = None, + src_key_padding_mask: Optional[Tensor] = None, + pos: Optional[Tensor] = None, + ): + output = src + + for layer in self.layers: + output = layer( + output, + condition, + src_mask=mask, + src_key_padding_mask=src_key_padding_mask, + pos=pos, + ) + + if self.norm is not None: + output = self.norm(output) + + return output.unsqueeze(0) #1 T B D + + +class TransformerDecoder(nn.Module): + + def __init__(self, decoder_layer, num_layers, norm=None, return_intermediate=False): + super().__init__() + self.layers = _get_clones(decoder_layer, num_layers) + self.num_layers = num_layers + self.norm = norm + self.return_intermediate = return_intermediate + + def forward( + self, + tgt, + memory, + tgt_mask: Optional[Tensor] = None, + memory_mask: Optional[Tensor] = None, + tgt_key_padding_mask: Optional[Tensor] = None, + memory_key_padding_mask: Optional[Tensor] = None, + pos: Optional[Tensor] = None, + query_pos: Optional[Tensor] = None, + ): + output = tgt + + intermediate = [] + + for layer in self.layers: + output = layer( + output, + memory, + tgt_mask=tgt_mask, + memory_mask=memory_mask, + tgt_key_padding_mask=tgt_key_padding_mask, + memory_key_padding_mask=memory_key_padding_mask, + pos=pos, + query_pos=query_pos, + ) + if self.return_intermediate: + intermediate.append(self.norm(output)) + + if self.norm is not None: + output = self.norm(output) + if self.return_intermediate: + intermediate.pop() + intermediate.append(output) + + if self.return_intermediate: + return torch.stack(intermediate) + + return output.unsqueeze(0) + +class TransformerDecoder_alter(nn.Module): + + def __init__(self, decoder_layer, num_layers, norm=None, return_intermediate=False): + super().__init__() + self.layers = _get_clones(decoder_layer, num_layers) + self.num_layers = num_layers + self.norm = norm + self.return_intermediate = return_intermediate + + def forward( + self, + tgt, + memory, + memory_alter, + tgt_mask: Optional[Tensor] = None, + memory_mask: Optional[Tensor] = None, + tgt_key_padding_mask: Optional[Tensor] = None, + memory_key_padding_mask: Optional[Tensor] = None, + pos: Optional[Tensor] = None, + pos_alter: Optional[Tensor] = None, + query_pos: Optional[Tensor] = None, + ): + output = tgt + memory_list = [memory, memory_alter] + pos_list = [pos, pos_alter] + intermediate = [] + num_layer = 0 + for layer in self.layers: + index = num_layer % 2 + output = layer( + output, + memory_list[index], + tgt_mask=tgt_mask, + memory_mask=memory_mask, + tgt_key_padding_mask=tgt_key_padding_mask, + memory_key_padding_mask=memory_key_padding_mask, + pos=pos_list[index], + query_pos=query_pos, + ) + if self.return_intermediate: + intermediate.append(self.norm(output)) + + if self.norm is not None: + output = self.norm(output) + if self.return_intermediate: + intermediate.pop() + intermediate.append(output) + + if self.return_intermediate: + return torch.stack(intermediate) + + return output.unsqueeze(0) + + +class TransformerEncoderLayer(nn.Module): + + def __init__( + self, + d_model, + nhead, + dim_feedforward=2048, + dropout=0.1, + activation="relu", + normalize_before=False, + ): + super().__init__() + self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout) + # Implementation of Feedforward model + self.linear1 = nn.Linear(d_model, dim_feedforward) + self.dropout = nn.Dropout(dropout) + self.linear2 = nn.Linear(dim_feedforward, d_model) + + self.norm1 = nn.LayerNorm(d_model) + self.norm2 = nn.LayerNorm(d_model) + self.dropout1 = nn.Dropout(dropout) + self.dropout2 = nn.Dropout(dropout) + + self.activation = _get_activation_fn(activation) + self.normalize_before = normalize_before + + def with_pos_embed(self, tensor, pos: Optional[Tensor]): + return tensor if pos is None else tensor + pos + + def forward_post( + self, + src, + src_mask: Optional[Tensor] = None, + src_key_padding_mask: Optional[Tensor] = None, + pos: Optional[Tensor] = None, + ): + q = k = self.with_pos_embed(src, pos) + src2 = self.self_attn( + q, k, value=src, attn_mask=src_mask, key_padding_mask=src_key_padding_mask + )[0] + src = src + self.dropout1(src2) + src = self.norm1(src) + src2 = self.linear2(self.dropout(self.activation(self.linear1(src)))) + src = src + self.dropout2(src2) + src = self.norm2(src) + return src + + def forward_pre( + self, + src, + src_mask: Optional[Tensor] = None, + src_key_padding_mask: Optional[Tensor] = None, + pos: Optional[Tensor] = None, + ): + src2 = self.norm1(src) + q = k = self.with_pos_embed(src2, pos) + src2 = self.self_attn( + q, k, value=src2, attn_mask=src_mask, key_padding_mask=src_key_padding_mask + )[0] + src = src + self.dropout1(src2) + src2 = self.norm2(src) + src2 = self.linear2(self.dropout(self.activation(self.linear1(src2)))) + src = src + self.dropout2(src2) + return src + + def forward( + self, + src, + src_mask: Optional[Tensor] = None, + src_key_padding_mask: Optional[Tensor] = None, + pos: Optional[Tensor] = None, + ): + if self.normalize_before: + return self.forward_pre(src, src_mask, src_key_padding_mask, pos) + return self.forward_post(src, src_mask, src_key_padding_mask, pos) + + +def modulate(x, shift, scale): + return x * (1 + scale.unsqueeze(0)) + shift.unsqueeze(0) + +class TransformerEncoderLayer_AdLN(nn.Module): + + def __init__( + self, + d_model, + nhead, + dim_feedforward=2048, + dropout=0.1, + activation="relu", + normalize_before=True, + ): + super().__init__() + self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout) + # Implementation of Feedforward model + self.linear1 = nn.Linear(d_model, dim_feedforward) + self.dropout = nn.Dropout(dropout) + self.linear2 = nn.Linear(dim_feedforward, d_model) + + self.norm1 = nn.LayerNorm(d_model) + self.norm2 = nn.LayerNorm(d_model) + self.dropout1 = nn.Dropout(dropout) + self.dropout2 = nn.Dropout(dropout) + + self.activation = _get_activation_fn(activation) + self.normalize_before = normalize_before + + self.adaLN_modulation = nn.Sequential( + nn.SiLU(), + nn.Linear(d_model, 6 * d_model, bias=True) + ) # neccesary for adaLN + + def with_pos_embed(self, tensor, pos: Optional[Tensor]): + return tensor if pos is None else tensor + pos + + def forward_pre( + self, + src, # B T D + condition, # B D + src_mask: Optional[Tensor] = None, + src_key_padding_mask: Optional[Tensor] = None, + pos: Optional[Tensor] = None, + ): + shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(condition).chunk(6, dim=1) # B D + + src2 = self.norm1(src) + src2 = modulate(src2, shift_msa, scale_msa) + q = k = self.with_pos_embed(src2, pos) + src2 = self.self_attn( + q, k, value=src2, attn_mask=src_mask, key_padding_mask=src_key_padding_mask + )[0] + src = src + self.dropout1(gate_msa.unsqueeze(0) *src2) + + src2 = self.norm2(src) + src2 = modulate(src2, shift_mlp, scale_mlp) + src2 = self.linear2(self.dropout(self.activation(self.linear1(src2)))) + src = src + self.dropout2(gate_mlp.unsqueeze(0) *src2) + return src + + def forward( + self, + src, + condition, + src_mask: Optional[Tensor] = None, + src_key_padding_mask: Optional[Tensor] = None, + pos: Optional[Tensor] = None, + ): + + return self.forward_pre(src, condition,src_mask, src_key_padding_mask, pos) + + + +class TransformerDecoderLayer(nn.Module): + + def __init__( + self, + d_model, + nhead, + dim_feedforward=2048, + dropout=0.1, + activation="relu", + normalize_before=False, + ): + super().__init__() + self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout) + self.multihead_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout) + # Implementation of Feedforward model + self.linear1 = nn.Linear(d_model, dim_feedforward) + self.dropout = nn.Dropout(dropout) + self.linear2 = nn.Linear(dim_feedforward, d_model) + + self.norm1 = nn.LayerNorm(d_model) + self.norm2 = nn.LayerNorm(d_model) + self.norm3 = nn.LayerNorm(d_model) + self.dropout1 = nn.Dropout(dropout) + self.dropout2 = nn.Dropout(dropout) + self.dropout3 = nn.Dropout(dropout) + + self.activation = _get_activation_fn(activation) + self.normalize_before = normalize_before + + def with_pos_embed(self, tensor, pos: Optional[Tensor]): + return tensor if pos is None else tensor + pos + + def forward_post( + self, + tgt, + memory, + tgt_mask: Optional[Tensor] = None, + memory_mask: Optional[Tensor] = None, + tgt_key_padding_mask: Optional[Tensor] = None, + memory_key_padding_mask: Optional[Tensor] = None, + pos: Optional[Tensor] = None, + query_pos: Optional[Tensor] = None, + ): + q = k = self.with_pos_embed(tgt, query_pos) + tgt2 = self.self_attn( + q, k, value=tgt, attn_mask=tgt_mask, key_padding_mask=tgt_key_padding_mask + )[0] + tgt = tgt + self.dropout1(tgt2) + tgt = self.norm1(tgt) + tgt2 = self.multihead_attn( + query=self.with_pos_embed(tgt, query_pos), + key=self.with_pos_embed(memory, pos), + value=memory, + attn_mask=memory_mask, + key_padding_mask=memory_key_padding_mask, + )[0] + tgt = tgt + self.dropout2(tgt2) + tgt = self.norm2(tgt) + tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt)))) + tgt = tgt + self.dropout3(tgt2) + tgt = self.norm3(tgt) + return tgt + + def forward_pre( + self, + tgt, + memory, + tgt_mask: Optional[Tensor] = None, + memory_mask: Optional[Tensor] = None, + tgt_key_padding_mask: Optional[Tensor] = None, + memory_key_padding_mask: Optional[Tensor] = None, + pos: Optional[Tensor] = None, + query_pos: Optional[Tensor] = None, + ): + tgt2 = self.norm1(tgt) + q = k = self.with_pos_embed(tgt2, query_pos) + tgt2 = self.self_attn( + q, k, value=tgt2, attn_mask=tgt_mask, key_padding_mask=tgt_key_padding_mask + )[0] + tgt = tgt + self.dropout1(tgt2) + tgt2 = self.norm2(tgt) + tgt2 = self.multihead_attn( + query=self.with_pos_embed(tgt2, query_pos), + key=self.with_pos_embed(memory, pos), + value=memory, + attn_mask=memory_mask, + key_padding_mask=memory_key_padding_mask, + )[0] + tgt = tgt + self.dropout2(tgt2) + tgt2 = self.norm3(tgt) + tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt2)))) + tgt = tgt + self.dropout3(tgt2) + return tgt + + def forward( + self, + tgt, + memory, + tgt_mask: Optional[Tensor] = None, + memory_mask: Optional[Tensor] = None, + tgt_key_padding_mask: Optional[Tensor] = None, + memory_key_padding_mask: Optional[Tensor] = None, + pos: Optional[Tensor] = None, + query_pos: Optional[Tensor] = None, + ): + if self.normalize_before: + return self.forward_pre( + tgt, + memory, + tgt_mask, + memory_mask, + tgt_key_padding_mask, + memory_key_padding_mask, + pos, + query_pos, + ) + return self.forward_post( + tgt, + memory, + tgt_mask, + memory_mask, + tgt_key_padding_mask, + memory_key_padding_mask, + pos, + query_pos, + ) + + +def _get_clones(module, N): + return nn.ModuleList([copy.deepcopy(module) for i in range(N)]) + + +def build_transformer(args): + return Transformer( + d_model=args.hidden_dim, + dropout=args.dropout, + nhead=args.nheads, + dim_feedforward=args.dim_feedforward, + num_encoder_layers=args.enc_layers, + num_decoder_layers=args.dec_layers, + normalize_before=args.pre_norm, + return_intermediate_dec=True, + ) + + +def build_transformer_denoise(args): + print(f"Using {args.condition_type} for condition") + print(f'enc_layers: {args.enc_layers}, dec_layers: {args.dec_layers}') + if args.condition_type == "adaLN": + return Transformer_Denoise_AdLN( + d_model=args.hidden_dim, + dropout=args.dropout, + nhead=args.nheads, + dim_feedforward=args.dim_feedforward, + num_encoder_layers=args.enc_layers, + num_decoder_layers=args.dec_layers, + normalize_before=args.pre_norm, + return_intermediate_dec=True, + ) + return Transformer_Denoise( + d_model=args.hidden_dim, + dropout=args.dropout, + nhead=args.nheads, + dim_feedforward=args.dim_feedforward, + num_encoder_layers=args.enc_layers, + num_decoder_layers=args.dec_layers, + normalize_before=args.pre_norm, + return_intermediate_dec=True, + causal_mask=args.causal_mask, + ) + +def build_transformer_denoise_tactile(args): + return Transformer_Denoise_Tactile( + d_model=args.hidden_dim, + dropout=args.dropout, + nhead=args.nheads, + dim_feedforward=args.dim_feedforward, + num_encoder_layers=args.enc_layers, + num_decoder_layers=args.dec_layers, + normalize_before=args.pre_norm, + return_intermediate_dec=True, + causal_mask=args.causal_mask, + ) + + +def build_transformer_diffusion_prediction(args): + return Transformer_diffusion_prediction( + d_model=args.hidden_dim, + dropout=args.dropout, + nhead=args.nheads, + dim_feedforward=args.dim_feedforward, + num_encoder_layers=args.enc_layers, + num_decoder_layers=args.dec_layers, + normalize_before=args.pre_norm, + return_intermediate_dec=True, + num_queries=args.num_queries, + share_decoder=args.share_decoder, + patch_size=args.patch_size, + diffusion_timestep_type=args.diffusion_timestep_type, + attention_type=args.attention_type, + predict_frame=args.predict_frame, + predict_only_last=args.predict_only_last, + token_dim=args.token_dim, + ) + + +def build_transformer_diffusion_prediction_with_dual_visual_token(args): + return Transformer_diffusion_prediction_with_dual_visual_token( + d_model=args.hidden_dim, + dropout=args.dropout, + nhead=args.nheads, + dim_feedforward=args.dim_feedforward, + num_encoder_layers=args.enc_layers, + num_decoder_layers=args.dec_layers, + normalize_before=args.pre_norm, + return_intermediate_dec=True, + num_queries=args.num_queries, + share_decoder=args.share_decoder, + patch_size=args.patch_size, + diffusion_timestep_type=args.diffusion_timestep_type, + attention_type=args.attention_type, + predict_frame=args.predict_frame, + predict_only_last=args.predict_only_last, + ) + + +def build_transformer_diffusion(args): + return Transformer_diffusion( + d_model=args.hidden_dim, + dropout=args.dropout, + nhead=args.nheads, + dim_feedforward=args.dim_feedforward, + num_encoder_layers=args.enc_layers, + num_decoder_layers=args.dec_layers, + normalize_before=args.pre_norm, + jpeg_dim=args.jpeg_dim, # Hard codes + num_jpeg=args.jpeg_token_num * args.predict_frame, # Hard codes + return_intermediate_dec=True, + ) + + +def build_transformer_diffusion_pixel_prediction(args): + return Transformer_diffusion_prediction_pixel( + d_model=args.hidden_dim, + dropout=args.dropout, + nhead=args.nheads, + dim_feedforward=args.dim_feedforward, + num_encoder_layers=args.enc_layers, + num_decoder_layers=args.dec_layers, + normalize_before=args.pre_norm, + return_intermediate_dec=True, + num_queries=args.num_queries, + share_decoder=args.share_decoder, + patch_size=args.patch_size, + diffusion_timestep_type=args.diffusion_timestep_type, + attention_type=args.attention_type, + predict_frame=args.predict_frame, + predict_only_last=args.predict_only_last, + ) + + +def build_transformer_diffusion_seperate(args): + return Transformer_diffusion_seperate( + d_model=args.hidden_dim, + dropout=args.dropout, + nhead=args.nheads, + dim_feedforward=args.dim_feedforward, + num_encoder_layers=args.enc_layers, + num_decoder_layers=args.dec_layers, + normalize_before=args.pre_norm, + share_decoder=args.share_decoder, + return_intermediate_dec=True, + ) + + +def _get_activation_fn(activation): + """Return an activation function given a string""" + if activation == "relu": + return F.relu + if activation == "gelu": + return F.gelu + if activation == "glu": + return F.glu + raise RuntimeError(f"activation should be relu/gelu, not {activation}.") + + +def get_timestep_embedding( + timesteps: torch.Tensor, + embedding_dim: int, + flip_sin_to_cos=False, + downscale_freq_shift=1, + scale=1, + max_period=10000, +): + """ + This matches the implementation in Denoising Diffusion Probabilistic Models: Create sinusoidal timestep embeddings. + + :param timesteps: a 1-D Tensor of N indices, one per batch element. + These may be fractional. + :param embedding_dim: the dimension of the output. :param max_period: controls the minimum frequency of the + embeddings. :return: an [N x dim] Tensor of positional embeddings. + """ + assert len(timesteps.shape) == 1, "Timesteps should be a 1d-array" + + half_dim = embedding_dim // 2 + exponent = -math.log(max_period) * torch.arange( + start=0, end=half_dim, dtype=torch.float32, device=timesteps.device + ) + exponent = exponent / (half_dim - downscale_freq_shift) + + emb = torch.exp(exponent) + emb = timesteps[:, None].float() * emb[None, :] + + # scale embeddings + emb = scale * emb + + # concat sine and cosine embeddings + emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=-1) + + # flip sine and cosine embeddings + if flip_sin_to_cos: + emb = torch.cat([emb[:, half_dim:], emb[:, :half_dim]], dim=-1) + + # zero pad + if embedding_dim % 2 == 1: + emb = torch.nn.functional.pad(emb, (0, 1, 0, 0)) + + return emb + +class Tactile_ConvBlock(nn.Module): + def __init__(self, in_channels, out_channels): + super(Tactile_ConvBlock, self).__init__() + self.conv1 = nn.Conv2d(in_channels, out_channels//2, kernel_size=3, stride=3) + self.pool = nn.MaxPool2d(kernel_size=2, stride=2) + self.conv2 = nn.Conv2d(out_channels//2, out_channels, kernel_size=5, stride=5) + self.pool2 = nn.AdaptiveAvgPool2d((4, 4)) + self.activation = nn.SiLU() + self.bn1 = nn.BatchNorm2d(out_channels//2) # LN for conv1 output + self.bn2 = nn.BatchNorm2d(out_channels) # LN for conv2 output + def forward(self, x): + x = self.conv1(x) + x = self.bn1(x) + x = self.activation(x) + x = self.pool(x) + x = self.conv2(x) + x = self.bn2(x) + x = self.pool2(x) + + return x + +class Tactile_Encoder(nn.Module): + def __init__(self,tactile_dim,dropout): + super().__init__() + self.conv_ll = Tactile_ConvBlock(3, tactile_dim) + self.conv_lr = Tactile_ConvBlock(3, tactile_dim) + self.conv_rl = Tactile_ConvBlock(3, tactile_dim) + self.conv_rr = Tactile_ConvBlock(3, tactile_dim) + self.feature_extractor = nn.ModuleList([ + self.conv_ll, + self.conv_lr, + self.conv_rl, + self.conv_rr + ]) + self.input_proj_tacile = nn.Conv2d( + tactile_dim, + tactile_dim, + kernel_size=1, + ) + + # self.query_tokens = nn.Parameter(torch.randn(16, tactile_dim)) + # self.attn = nn.MultiheadAttention(embed_dim=tactile_dim, num_heads=8, dropout=dropout,batch_first=True) + self.tactile_dim = tactile_dim + + def forward(self,tactile_data): + #B 4 C H W + tactile_data = tactile_data[:,0] + B = tactile_data.shape[0] + tactile_feature_list = [] + for i in range(tactile_data.shape[1]): + tactile_feature = self.feature_extractor[i](tactile_data[:,i]) + tactile_feature = self.input_proj_tacile(tactile_feature) + tactile_feature_list.append(tactile_feature) + tactile_features_raw = torch.stack(tactile_feature_list, dim=2) # B C 4 H W + # print('before attention tactile feature raw shape', tactile_features_raw.shape) + # tactile_features = tactile_features_raw.view(B, tactile_features_raw.size(1), -1) # B C 4*H*W + # print('before attention tactile feature shape', tactile_features.shape) + # tactile_features = tactile_features.permute(0, 2, 1) # B H*W*4 C + return tactile_features_raw # b d 4 h w + # print('before attention tactile feature shape', tactile_features.shape) + # query = self.query_tokens.unsqueeze(0).repeat(B, 1, 1) + # # print('query shape', query.shape) + # attn_output, _ = self.attn(query, tactile_features, tactile_features) # B 16 D + # return attn_output + +if __name__ == "__main__": + tactile_data = torch.randn(2, 1, 4, 3, 960, 960).cuda() + tactile_encoder = Tactile_Encoder(512, 0.1).cuda() + tactile_feature = tactile_encoder(tactile_data) + print(tactile_feature.shape) # B 16 D + total_params = sum(p.numel() for p in tactile_encoder.parameters() if p.requires_grad) + total_params_in_million = total_params / 1e6 + print(f"Total trainable parameters: {total_params_in_million:.2f} M") \ No newline at end of file diff --git a/ACT_DP_multitask/detr/models/vision_transformer.py b/ACT_DP_multitask/detr/models/vision_transformer.py new file mode 100644 index 0000000000000000000000000000000000000000..a1b718ea1bb1e616f29437db33d9140273c5f806 --- /dev/null +++ b/ACT_DP_multitask/detr/models/vision_transformer.py @@ -0,0 +1,397 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Mostly copy-paste from timm library. +https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py +""" +import math +from functools import partial + +import torch +import torch.nn as nn + +import numpy as np + +def get_1d_sincos_pos_embed_from_grid(embed_dim, pos): + """ + embed_dim: output dimension for each position + pos: a list of positions to be encoded: size (M,) + out: (M, D) + """ + assert embed_dim % 2 == 0 + omega = np.arange(embed_dim // 2, dtype=np.float32) + omega /= embed_dim / 2. + omega = 1. / 10000**omega # (D/2,) + + pos = pos.reshape(-1) # (M,) + out = np.einsum('m,d->md', pos, omega) # (M, D/2), outer product + + emb_sin = np.sin(out) # (M, D/2) + emb_cos = np.cos(out) # (M, D/2) + + emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D) + return emb + + +def get_2d_sincos_pos_embed_from_grid(embed_dim, grid): + assert embed_dim % 2 == 0 + + # use half of dimensions to encode grid_h + emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2) + emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2) + + emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D) + return emb + + +def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False): + """ + grid_size: int of the grid height and width + return: + pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token) + """ + grid_h = np.arange(grid_size, dtype=np.float32) + grid_w = np.arange(grid_size, dtype=np.float32) + grid = np.meshgrid(grid_w, grid_h) # here w goes first + grid = np.stack(grid, axis=0) + + grid = grid.reshape([2, 1, grid_size, grid_size]) + pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid) + if cls_token: + pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0) + return pos_embed + +def get_2d_sincos_pos_embed_v2(embed_dim, grid_size, cls_token=False): + """ + grid_size: int of the grid height and width + return: + pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token) + """ + grid_h = np.arange(grid_size[0], dtype=np.float32) + grid_w = np.arange(grid_size[1], dtype=np.float32) + grid = np.meshgrid(grid_w, grid_h) # here w goes first + grid = np.stack(grid, axis=0) + + grid = grid.reshape([2, 1, grid_size[0], grid_size[1]]) + pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid) + if cls_token: + pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0) + return pos_embed + +# from utils import trunc_normal_ + +def _no_grad_trunc_normal_(tensor, mean, std, a, b): + # Cut & paste from PyTorch official master until it's in a few official releases - RW + # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf + def norm_cdf(x): + # Computes standard normal cumulative distribution function + return (1. + math.erf(x / math.sqrt(2.))) / 2. + + if (mean < a - 2 * std) or (mean > b + 2 * std): + warnings.warn("mean is more than 2 std from [a, b] in nn.init.trunc_normal_. " + "The distribution of values may be incorrect.", + stacklevel=2) + + with torch.no_grad(): + # Values are generated by using a truncated uniform distribution and + # then using the inverse CDF for the normal distribution. + # Get upper and lower cdf values + l = norm_cdf((a - mean) / std) + u = norm_cdf((b - mean) / std) + + # Uniformly fill tensor with values from [l, u], then translate to + # [2l-1, 2u-1]. + tensor.uniform_(2 * l - 1, 2 * u - 1) + + # Use inverse cdf transform for normal distribution to get truncated + # standard normal + tensor.erfinv_() + + # Transform to proper mean, std + tensor.mul_(std * math.sqrt(2.)) + tensor.add_(mean) + + # Clamp to ensure it's in the proper range + tensor.clamp_(min=a, max=b) + return tensor + +def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.): + # type: (Tensor, float, float, float, float) -> Tensor + return _no_grad_trunc_normal_(tensor, mean, std, a, b) + +def drop_path(x, drop_prob: float = 0., training: bool = False): + if drop_prob == 0. or not training: + return x + keep_prob = 1 - drop_prob + shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets + random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device) + random_tensor.floor_() # binarize + output = x.div(keep_prob) * random_tensor + return output + + +class DropPath(nn.Module): + """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). + """ + def __init__(self, drop_prob=None): + super(DropPath, self).__init__() + self.drop_prob = drop_prob + + def forward(self, x): + return drop_path(x, self.drop_prob, self.training) + + +class Mlp(nn.Module): + def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.fc1 = nn.Linear(in_features, hidden_features) + self.act = act_layer() + self.fc2 = nn.Linear(hidden_features, out_features) + self.drop = nn.Dropout(drop) + + def forward(self, x): + x = self.fc1(x) + x = self.act(x) + x = self.drop(x) + x = self.fc2(x) + x = self.drop(x) + return x + + +class Attention(nn.Module): + def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.): + super().__init__() + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = qk_scale or head_dim ** -0.5 + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + + def forward(self, x): + B, N, C = x.shape + qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) + q, k, v = qkv[0], qkv[1], qkv[2] + + attn = (q @ k.transpose(-2, -1)) * self.scale + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn) + + x = (attn @ v).transpose(1, 2).reshape(B, N, C) + x = self.proj(x) + x = self.proj_drop(x) + return x, attn + + +class Block(nn.Module): + def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0., + drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm): + super().__init__() + self.norm1 = norm_layer(dim) + self.attn = Attention( + dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop) + self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) + + def forward(self, x, return_attention=False): + y, attn = self.attn(self.norm1(x)) + if return_attention: + return attn + x = x + self.drop_path(y) + x = x + self.drop_path(self.mlp(self.norm2(x))) + return x + + +class PatchEmbed(nn.Module): + """ Image to Patch Embedding + """ + def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768): + super().__init__() + num_patches = (img_size // patch_size) * (img_size // patch_size) + self.img_size = img_size + self.patch_size = patch_size + self.num_patches = num_patches + + self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) + + def forward(self, x): + B, C, H, W = x.shape + x = self.proj(x).flatten(2).transpose(1, 2) + return x + + +class VisionTransformer(nn.Module): + """ Vision Transformer """ + def __init__(self, img_size=[224], patch_size=16, in_chans=3, num_classes=0, embed_dim=768, depth=12, + num_heads=12, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop_rate=0., attn_drop_rate=0., + drop_path_rate=0., norm_layer=nn.LayerNorm, **kwargs): + super().__init__() + self.num_features = self.embed_dim = embed_dim + + self.patch_embed = PatchEmbed( + img_size=img_size[0], patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim) + num_patches = self.patch_embed.num_patches + + self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) + self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim)) + self.pos_drop = nn.Dropout(p=drop_rate) + + dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule + self.blocks = nn.ModuleList([ + Block( + dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer) + for i in range(depth)]) + self.norm = norm_layer(embed_dim) + + # Classifier head + self.head = nn.Linear(embed_dim, num_classes) if num_classes > 0 else nn.Identity() + + trunc_normal_(self.pos_embed, std=.02) + trunc_normal_(self.cls_token, std=.02) + self.apply(self._init_weights) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + + def interpolate_pos_encoding(self, x, w, h): + npatch = x.shape[1] - 1 + N = self.pos_embed.shape[1] - 1 + if npatch == N and w == h: + return self.pos_embed + class_pos_embed = self.pos_embed[:, 0] + patch_pos_embed = self.pos_embed[:, 1:] + dim = x.shape[-1] + w0 = w // self.patch_embed.patch_size + h0 = h // self.patch_embed.patch_size + # we add a small number to avoid floating point error in the interpolation + # see discussion at https://github.com/facebookresearch/dino/issues/8 + w0, h0 = w0 + 0.1, h0 + 0.1 + patch_pos_embed = nn.functional.interpolate( + patch_pos_embed.reshape(1, int(math.sqrt(N)), int(math.sqrt(N)), dim).permute(0, 3, 1, 2), + scale_factor=(w0 / math.sqrt(N), h0 / math.sqrt(N)), + mode='bicubic', + ) + assert int(w0) == patch_pos_embed.shape[-2] and int(h0) == patch_pos_embed.shape[-1] + patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim) + return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1) + + def prepare_tokens(self, x): + B, nc, w, h = x.shape + x = self.patch_embed(x) # patch linear embedding + + # add the [CLS] token to the embed patch tokens + cls_tokens = self.cls_token.expand(B, -1, -1) + x = torch.cat((cls_tokens, x), dim=1) + + # add positional encoding to each token + x = x + self.interpolate_pos_encoding(x, w, h) + + return self.pos_drop(x) + + def forward(self, x): + x = self.prepare_tokens(x) + for blk in self.blocks: + x = blk(x) + x = self.norm(x) + return x[:, 0], x[:, 1:] + + def get_last_selfattention(self, x): + x = self.prepare_tokens(x) + for i, blk in enumerate(self.blocks): + if i < len(self.blocks) - 1: + x = blk(x) + else: + # return attention of the last block + return blk(x, return_attention=True) + + def get_intermediate_layers(self, x, n=1): + x = self.prepare_tokens(x) + # we return the output tokens from the `n` last blocks + output = [] + for i, blk in enumerate(self.blocks): + x = blk(x) + if len(self.blocks) - i <= n: + output.append(self.norm(x)) + return output + + +def vit_tiny(patch_size=16, **kwargs): + model = VisionTransformer( + patch_size=patch_size, embed_dim=192, depth=12, num_heads=3, mlp_ratio=4, + qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) + return model + + +def vit_small(patch_size=16, **kwargs): + model = VisionTransformer( + patch_size=patch_size, embed_dim=384, depth=12, num_heads=6, mlp_ratio=4, + qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) + return model + + +def vit_base(patch_size=16, **kwargs): + model = VisionTransformer( + patch_size=patch_size, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, + qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) + return model + + +class DINOHead(nn.Module): + def __init__(self, in_dim, out_dim, use_bn=False, norm_last_layer=True, nlayers=3, hidden_dim=2048, bottleneck_dim=256): + super().__init__() + nlayers = max(nlayers, 1) + if nlayers == 1: + self.mlp = nn.Linear(in_dim, bottleneck_dim) + else: + layers = [nn.Linear(in_dim, hidden_dim)] + if use_bn: + layers.append(nn.BatchNorm1d(hidden_dim)) + layers.append(nn.GELU()) + for _ in range(nlayers - 2): + layers.append(nn.Linear(hidden_dim, hidden_dim)) + if use_bn: + layers.append(nn.BatchNorm1d(hidden_dim)) + layers.append(nn.GELU()) + layers.append(nn.Linear(hidden_dim, bottleneck_dim)) + self.mlp = nn.Sequential(*layers) + self.apply(self._init_weights) + self.last_layer = nn.utils.weight_norm(nn.Linear(bottleneck_dim, out_dim, bias=False)) + self.last_layer.weight_g.data.fill_(1) + if norm_last_layer: + self.last_layer.weight_g.requires_grad = False + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + + def forward(self, x): + x = self.mlp(x) + x = nn.functional.normalize(x, dim=-1, p=2) + x = self.last_layer(x) + return x \ No newline at end of file diff --git a/ACT_DP_multitask/detr/setup.py b/ACT_DP_multitask/detr/setup.py new file mode 100644 index 0000000000000000000000000000000000000000..55d18c0db74e27a687b6d6e0fb236e1cbb801f20 --- /dev/null +++ b/ACT_DP_multitask/detr/setup.py @@ -0,0 +1,10 @@ +from distutils.core import setup +from setuptools import find_packages + +setup( + name='detr', + version='0.0.0', + packages=find_packages(), + license='MIT License', + long_description=open('README.md').read(), +) \ No newline at end of file diff --git a/ACT_DP_multitask/detr/util/__init__.py b/ACT_DP_multitask/detr/util/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..168f9979a4623806934b0ff1102ac166704e7dec --- /dev/null +++ b/ACT_DP_multitask/detr/util/__init__.py @@ -0,0 +1 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved diff --git a/ACT_DP_multitask/detr/util/__pycache__/__init__.cpython-310.pyc b/ACT_DP_multitask/detr/util/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2986ac876cdf748c0006ae0c4fde5b5f75d5b06b Binary files /dev/null and b/ACT_DP_multitask/detr/util/__pycache__/__init__.cpython-310.pyc differ diff --git a/ACT_DP_multitask/detr/util/__pycache__/__init__.cpython-37.pyc b/ACT_DP_multitask/detr/util/__pycache__/__init__.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..26e2fbb6146ceee5b9f484f1cc99f7386957e270 Binary files /dev/null and b/ACT_DP_multitask/detr/util/__pycache__/__init__.cpython-37.pyc differ diff --git a/ACT_DP_multitask/detr/util/__pycache__/misc.cpython-310.pyc b/ACT_DP_multitask/detr/util/__pycache__/misc.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8cf42af5f3bf9b6fd1b97f8425bd254978913183 Binary files /dev/null and b/ACT_DP_multitask/detr/util/__pycache__/misc.cpython-310.pyc differ diff --git a/ACT_DP_multitask/detr/util/__pycache__/misc.cpython-37.pyc b/ACT_DP_multitask/detr/util/__pycache__/misc.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..68dad05c8123660ccb4dee85cd343ba35ae4199a Binary files /dev/null and b/ACT_DP_multitask/detr/util/__pycache__/misc.cpython-37.pyc differ diff --git a/ACT_DP_multitask/detr/util/box_ops.py b/ACT_DP_multitask/detr/util/box_ops.py new file mode 100644 index 0000000000000000000000000000000000000000..9c088e5bacc88ff7217fc971f5db889f5bb45b39 --- /dev/null +++ b/ACT_DP_multitask/detr/util/box_ops.py @@ -0,0 +1,88 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved +""" +Utilities for bounding box manipulation and GIoU. +""" +import torch +from torchvision.ops.boxes import box_area + + +def box_cxcywh_to_xyxy(x): + x_c, y_c, w, h = x.unbind(-1) + b = [(x_c - 0.5 * w), (y_c - 0.5 * h), + (x_c + 0.5 * w), (y_c + 0.5 * h)] + return torch.stack(b, dim=-1) + + +def box_xyxy_to_cxcywh(x): + x0, y0, x1, y1 = x.unbind(-1) + b = [(x0 + x1) / 2, (y0 + y1) / 2, + (x1 - x0), (y1 - y0)] + return torch.stack(b, dim=-1) + + +# modified from torchvision to also return the union +def box_iou(boxes1, boxes2): + area1 = box_area(boxes1) + area2 = box_area(boxes2) + + lt = torch.max(boxes1[:, None, :2], boxes2[:, :2]) # [N,M,2] + rb = torch.min(boxes1[:, None, 2:], boxes2[:, 2:]) # [N,M,2] + + wh = (rb - lt).clamp(min=0) # [N,M,2] + inter = wh[:, :, 0] * wh[:, :, 1] # [N,M] + + union = area1[:, None] + area2 - inter + + iou = inter / union + return iou, union + + +def generalized_box_iou(boxes1, boxes2): + """ + Generalized IoU from https://giou.stanford.edu/ + + The boxes should be in [x0, y0, x1, y1] format + + Returns a [N, M] pairwise matrix, where N = len(boxes1) + and M = len(boxes2) + """ + # degenerate boxes gives inf / nan results + # so do an early check + assert (boxes1[:, 2:] >= boxes1[:, :2]).all() + assert (boxes2[:, 2:] >= boxes2[:, :2]).all() + iou, union = box_iou(boxes1, boxes2) + + lt = torch.min(boxes1[:, None, :2], boxes2[:, :2]) + rb = torch.max(boxes1[:, None, 2:], boxes2[:, 2:]) + + wh = (rb - lt).clamp(min=0) # [N,M,2] + area = wh[:, :, 0] * wh[:, :, 1] + + return iou - (area - union) / area + + +def masks_to_boxes(masks): + """Compute the bounding boxes around the provided masks + + The masks should be in format [N, H, W] where N is the number of masks, (H, W) are the spatial dimensions. + + Returns a [N, 4] tensors, with the boxes in xyxy format + """ + if masks.numel() == 0: + return torch.zeros((0, 4), device=masks.device) + + h, w = masks.shape[-2:] + + y = torch.arange(0, h, dtype=torch.float) + x = torch.arange(0, w, dtype=torch.float) + y, x = torch.meshgrid(y, x) + + x_mask = (masks * x.unsqueeze(0)) + x_max = x_mask.flatten(1).max(-1)[0] + x_min = x_mask.masked_fill(~(masks.bool()), 1e8).flatten(1).min(-1)[0] + + y_mask = (masks * y.unsqueeze(0)) + y_max = y_mask.flatten(1).max(-1)[0] + y_min = y_mask.masked_fill(~(masks.bool()), 1e8).flatten(1).min(-1)[0] + + return torch.stack([x_min, y_min, x_max, y_max], 1) diff --git a/ACT_DP_multitask/detr/util/misc.py b/ACT_DP_multitask/detr/util/misc.py new file mode 100644 index 0000000000000000000000000000000000000000..dfa9fb5b8f9e44c98e42aa9bb7275f6fa151472d --- /dev/null +++ b/ACT_DP_multitask/detr/util/misc.py @@ -0,0 +1,468 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved +""" +Misc functions, including distributed helpers. + +Mostly copy-paste from torchvision references. +""" +import os +import subprocess +import time +from collections import defaultdict, deque +import datetime +import pickle +from packaging import version +from typing import Optional, List + +import torch +import torch.distributed as dist +from torch import Tensor + +# needed due to empty tensor bug in pytorch and torchvision 0.5 +import torchvision +if version.parse(torchvision.__version__) < version.parse('0.7'): + from torchvision.ops import _new_empty_tensor + from torchvision.ops.misc import _output_size + + +class SmoothedValue(object): + """Track a series of values and provide access to smoothed values over a + window or the global series average. + """ + + def __init__(self, window_size=20, fmt=None): + if fmt is None: + fmt = "{median:.4f} ({global_avg:.4f})" + self.deque = deque(maxlen=window_size) + self.total = 0.0 + self.count = 0 + self.fmt = fmt + + def update(self, value, n=1): + self.deque.append(value) + self.count += n + self.total += value * n + + def synchronize_between_processes(self): + """ + Warning: does not synchronize the deque! + """ + if not is_dist_avail_and_initialized(): + return + t = torch.tensor([self.count, self.total], dtype=torch.float64, device='cuda') + dist.barrier() + dist.all_reduce(t) + t = t.tolist() + self.count = int(t[0]) + self.total = t[1] + + @property + def median(self): + d = torch.tensor(list(self.deque)) + return d.median().item() + + @property + def avg(self): + d = torch.tensor(list(self.deque), dtype=torch.float32) + return d.mean().item() + + @property + def global_avg(self): + return self.total / self.count + + @property + def max(self): + return max(self.deque) + + @property + def value(self): + return self.deque[-1] + + def __str__(self): + return self.fmt.format( + median=self.median, + avg=self.avg, + global_avg=self.global_avg, + max=self.max, + value=self.value) + + +def all_gather(data): + """ + Run all_gather on arbitrary picklable data (not necessarily tensors) + Args: + data: any picklable object + Returns: + list[data]: list of data gathered from each rank + """ + world_size = get_world_size() + if world_size == 1: + return [data] + + # serialized to a Tensor + buffer = pickle.dumps(data) + storage = torch.ByteStorage.from_buffer(buffer) + tensor = torch.ByteTensor(storage).to("cuda") + + # obtain Tensor size of each rank + local_size = torch.tensor([tensor.numel()], device="cuda") + size_list = [torch.tensor([0], device="cuda") for _ in range(world_size)] + dist.all_gather(size_list, local_size) + size_list = [int(size.item()) for size in size_list] + max_size = max(size_list) + + # receiving Tensor from all ranks + # we pad the tensor because torch all_gather does not support + # gathering tensors of different shapes + tensor_list = [] + for _ in size_list: + tensor_list.append(torch.empty((max_size,), dtype=torch.uint8, device="cuda")) + if local_size != max_size: + padding = torch.empty(size=(max_size - local_size,), dtype=torch.uint8, device="cuda") + tensor = torch.cat((tensor, padding), dim=0) + dist.all_gather(tensor_list, tensor) + + data_list = [] + for size, tensor in zip(size_list, tensor_list): + buffer = tensor.cpu().numpy().tobytes()[:size] + data_list.append(pickle.loads(buffer)) + + return data_list + + +def reduce_dict(input_dict, average=True): + """ + Args: + input_dict (dict): all the values will be reduced + average (bool): whether to do average or sum + Reduce the values in the dictionary from all processes so that all processes + have the averaged results. Returns a dict with the same fields as + input_dict, after reduction. + """ + world_size = get_world_size() + if world_size < 2: + return input_dict + with torch.no_grad(): + names = [] + values = [] + # sort the keys so that they are consistent across processes + for k in sorted(input_dict.keys()): + names.append(k) + values.append(input_dict[k]) + values = torch.stack(values, dim=0) + dist.all_reduce(values) + if average: + values /= world_size + reduced_dict = {k: v for k, v in zip(names, values)} + return reduced_dict + + +class MetricLogger(object): + def __init__(self, delimiter="\t"): + self.meters = defaultdict(SmoothedValue) + self.delimiter = delimiter + + def update(self, **kwargs): + for k, v in kwargs.items(): + if isinstance(v, torch.Tensor): + v = v.item() + assert isinstance(v, (float, int)) + self.meters[k].update(v) + + def __getattr__(self, attr): + if attr in self.meters: + return self.meters[attr] + if attr in self.__dict__: + return self.__dict__[attr] + raise AttributeError("'{}' object has no attribute '{}'".format( + type(self).__name__, attr)) + + def __str__(self): + loss_str = [] + for name, meter in self.meters.items(): + loss_str.append( + "{}: {}".format(name, str(meter)) + ) + return self.delimiter.join(loss_str) + + def synchronize_between_processes(self): + for meter in self.meters.values(): + meter.synchronize_between_processes() + + def add_meter(self, name, meter): + self.meters[name] = meter + + def log_every(self, iterable, print_freq, header=None): + i = 0 + if not header: + header = '' + start_time = time.time() + end = time.time() + iter_time = SmoothedValue(fmt='{avg:.4f}') + data_time = SmoothedValue(fmt='{avg:.4f}') + space_fmt = ':' + str(len(str(len(iterable)))) + 'd' + if torch.cuda.is_available(): + log_msg = self.delimiter.join([ + header, + '[{0' + space_fmt + '}/{1}]', + 'eta: {eta}', + '{meters}', + 'time: {time}', + 'data: {data}', + 'max mem: {memory:.0f}' + ]) + else: + log_msg = self.delimiter.join([ + header, + '[{0' + space_fmt + '}/{1}]', + 'eta: {eta}', + '{meters}', + 'time: {time}', + 'data: {data}' + ]) + MB = 1024.0 * 1024.0 + for obj in iterable: + data_time.update(time.time() - end) + yield obj + iter_time.update(time.time() - end) + if i % print_freq == 0 or i == len(iterable) - 1: + eta_seconds = iter_time.global_avg * (len(iterable) - i) + eta_string = str(datetime.timedelta(seconds=int(eta_seconds))) + if torch.cuda.is_available(): + print(log_msg.format( + i, len(iterable), eta=eta_string, + meters=str(self), + time=str(iter_time), data=str(data_time), + memory=torch.cuda.max_memory_allocated() / MB)) + else: + print(log_msg.format( + i, len(iterable), eta=eta_string, + meters=str(self), + time=str(iter_time), data=str(data_time))) + i += 1 + end = time.time() + total_time = time.time() - start_time + total_time_str = str(datetime.timedelta(seconds=int(total_time))) + print('{} Total time: {} ({:.4f} s / it)'.format( + header, total_time_str, total_time / len(iterable))) + + +def get_sha(): + cwd = os.path.dirname(os.path.abspath(__file__)) + + def _run(command): + return subprocess.check_output(command, cwd=cwd).decode('ascii').strip() + sha = 'N/A' + diff = "clean" + branch = 'N/A' + try: + sha = _run(['git', 'rev-parse', 'HEAD']) + subprocess.check_output(['git', 'diff'], cwd=cwd) + diff = _run(['git', 'diff-index', 'HEAD']) + diff = "has uncommited changes" if diff else "clean" + branch = _run(['git', 'rev-parse', '--abbrev-ref', 'HEAD']) + except Exception: + pass + message = f"sha: {sha}, status: {diff}, branch: {branch}" + return message + + +def collate_fn(batch): + batch = list(zip(*batch)) + batch[0] = nested_tensor_from_tensor_list(batch[0]) + return tuple(batch) + + +def _max_by_axis(the_list): + # type: (List[List[int]]) -> List[int] + maxes = the_list[0] + for sublist in the_list[1:]: + for index, item in enumerate(sublist): + maxes[index] = max(maxes[index], item) + return maxes + + +class NestedTensor(object): + def __init__(self, tensors, mask: Optional[Tensor]): + self.tensors = tensors + self.mask = mask + + def to(self, device): + # type: (Device) -> NestedTensor # noqa + cast_tensor = self.tensors.to(device) + mask = self.mask + if mask is not None: + assert mask is not None + cast_mask = mask.to(device) + else: + cast_mask = None + return NestedTensor(cast_tensor, cast_mask) + + def decompose(self): + return self.tensors, self.mask + + def __repr__(self): + return str(self.tensors) + + +def nested_tensor_from_tensor_list(tensor_list: List[Tensor]): + # TODO make this more general + if tensor_list[0].ndim == 3: + if torchvision._is_tracing(): + # nested_tensor_from_tensor_list() does not export well to ONNX + # call _onnx_nested_tensor_from_tensor_list() instead + return _onnx_nested_tensor_from_tensor_list(tensor_list) + + # TODO make it support different-sized images + max_size = _max_by_axis([list(img.shape) for img in tensor_list]) + # min_size = tuple(min(s) for s in zip(*[img.shape for img in tensor_list])) + batch_shape = [len(tensor_list)] + max_size + b, c, h, w = batch_shape + dtype = tensor_list[0].dtype + device = tensor_list[0].device + tensor = torch.zeros(batch_shape, dtype=dtype, device=device) + mask = torch.ones((b, h, w), dtype=torch.bool, device=device) + for img, pad_img, m in zip(tensor_list, tensor, mask): + pad_img[: img.shape[0], : img.shape[1], : img.shape[2]].copy_(img) + m[: img.shape[1], :img.shape[2]] = False + else: + raise ValueError('not supported') + return NestedTensor(tensor, mask) + + +# _onnx_nested_tensor_from_tensor_list() is an implementation of +# nested_tensor_from_tensor_list() that is supported by ONNX tracing. +@torch.jit.unused +def _onnx_nested_tensor_from_tensor_list(tensor_list: List[Tensor]) -> NestedTensor: + max_size = [] + for i in range(tensor_list[0].dim()): + max_size_i = torch.max(torch.stack([img.shape[i] for img in tensor_list]).to(torch.float32)).to(torch.int64) + max_size.append(max_size_i) + max_size = tuple(max_size) + + # work around for + # pad_img[: img.shape[0], : img.shape[1], : img.shape[2]].copy_(img) + # m[: img.shape[1], :img.shape[2]] = False + # which is not yet supported in onnx + padded_imgs = [] + padded_masks = [] + for img in tensor_list: + padding = [(s1 - s2) for s1, s2 in zip(max_size, tuple(img.shape))] + padded_img = torch.nn.functional.pad(img, (0, padding[2], 0, padding[1], 0, padding[0])) + padded_imgs.append(padded_img) + + m = torch.zeros_like(img[0], dtype=torch.int, device=img.device) + padded_mask = torch.nn.functional.pad(m, (0, padding[2], 0, padding[1]), "constant", 1) + padded_masks.append(padded_mask.to(torch.bool)) + + tensor = torch.stack(padded_imgs) + mask = torch.stack(padded_masks) + + return NestedTensor(tensor, mask=mask) + + +def setup_for_distributed(is_master): + """ + This function disables printing when not in master process + """ + import builtins as __builtin__ + builtin_print = __builtin__.print + + def print(*args, **kwargs): + force = kwargs.pop('force', False) + if is_master or force: + builtin_print(*args, **kwargs) + + __builtin__.print = print + + +def is_dist_avail_and_initialized(): + if not dist.is_available(): + return False + if not dist.is_initialized(): + return False + return True + + +def get_world_size(): + if not is_dist_avail_and_initialized(): + return 1 + return dist.get_world_size() + + +def get_rank(): + if not is_dist_avail_and_initialized(): + return 0 + return dist.get_rank() + + +def is_main_process(): + return get_rank() == 0 + + +def save_on_master(*args, **kwargs): + if is_main_process(): + torch.save(*args, **kwargs) + + +def init_distributed_mode(args): + if 'RANK' in os.environ and 'WORLD_SIZE' in os.environ: + args.rank = int(os.environ["RANK"]) + args.world_size = int(os.environ['WORLD_SIZE']) + args.gpu = int(os.environ['LOCAL_RANK']) + elif 'SLURM_PROCID' in os.environ: + args.rank = int(os.environ['SLURM_PROCID']) + args.gpu = args.rank % torch.cuda.device_count() + else: + print('Not using distributed mode') + args.distributed = False + return + + args.distributed = True + + torch.cuda.set_device(args.gpu) + args.dist_backend = 'nccl' + print('| distributed init (rank {}): {}'.format( + args.rank, args.dist_url), flush=True) + torch.distributed.init_process_group(backend=args.dist_backend, init_method=args.dist_url, + world_size=args.world_size, rank=args.rank) + torch.distributed.barrier() + setup_for_distributed(args.rank == 0) + + +@torch.no_grad() +def accuracy(output, target, topk=(1,)): + """Computes the precision@k for the specified values of k""" + if target.numel() == 0: + return [torch.zeros([], device=output.device)] + maxk = max(topk) + batch_size = target.size(0) + + _, pred = output.topk(maxk, 1, True, True) + pred = pred.t() + correct = pred.eq(target.view(1, -1).expand_as(pred)) + + res = [] + for k in topk: + correct_k = correct[:k].view(-1).float().sum(0) + res.append(correct_k.mul_(100.0 / batch_size)) + return res + + +def interpolate(input, size=None, scale_factor=None, mode="nearest", align_corners=None): + # type: (Tensor, Optional[List[int]], Optional[float], str, Optional[bool]) -> Tensor + """ + Equivalent to nn.functional.interpolate, but with support for empty batch sizes. + This will eventually be supported natively by PyTorch, and this + class can go away. + """ + if version.parse(torchvision.__version__) < version.parse('0.7'): + if input.numel() > 0: + return torch.nn.functional.interpolate( + input, size, scale_factor, mode, align_corners + ) + + output_shape = _output_size(2, input, size, scale_factor) + output_shape = list(input.shape[:-2]) + list(output_shape) + return _new_empty_tensor(input, output_shape) + else: + return torchvision.ops.misc.interpolate(input, size, scale_factor, mode, align_corners) diff --git a/ACT_DP_multitask/detr/util/plot_utils.py b/ACT_DP_multitask/detr/util/plot_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..0f24bed0d3fe4624aeb231ddd02633f2e58e4bff --- /dev/null +++ b/ACT_DP_multitask/detr/util/plot_utils.py @@ -0,0 +1,107 @@ +""" +Plotting utilities to visualize training logs. +""" +import torch +import pandas as pd +import numpy as np +import seaborn as sns +import matplotlib.pyplot as plt + +from pathlib import Path, PurePath + + +def plot_logs(logs, fields=('class_error', 'loss_bbox_unscaled', 'mAP'), ewm_col=0, log_name='log.txt'): + ''' + Function to plot specific fields from training log(s). Plots both training and test results. + + :: Inputs - logs = list containing Path objects, each pointing to individual dir with a log file + - fields = which results to plot from each log file - plots both training and test for each field. + - ewm_col = optional, which column to use as the exponential weighted smoothing of the plots + - log_name = optional, name of log file if different than default 'log.txt'. + + :: Outputs - matplotlib plots of results in fields, color coded for each log file. + - solid lines are training results, dashed lines are test results. + + ''' + func_name = "plot_utils.py::plot_logs" + + # verify logs is a list of Paths (list[Paths]) or single Pathlib object Path, + # convert single Path to list to avoid 'not iterable' error + + if not isinstance(logs, list): + if isinstance(logs, PurePath): + logs = [logs] + print(f"{func_name} info: logs param expects a list argument, converted to list[Path].") + else: + raise ValueError(f"{func_name} - invalid argument for logs parameter.\n \ + Expect list[Path] or single Path obj, received {type(logs)}") + + # Quality checks - verify valid dir(s), that every item in list is Path object, and that log_name exists in each dir + for i, dir in enumerate(logs): + if not isinstance(dir, PurePath): + raise ValueError(f"{func_name} - non-Path object in logs argument of {type(dir)}: \n{dir}") + if not dir.exists(): + raise ValueError(f"{func_name} - invalid directory in logs argument:\n{dir}") + # verify log_name exists + fn = Path(dir / log_name) + if not fn.exists(): + print(f"-> missing {log_name}. Have you gotten to Epoch 1 in training?") + print(f"--> full path of missing log file: {fn}") + return + + # load log file(s) and plot + dfs = [pd.read_json(Path(p) / log_name, lines=True) for p in logs] + + fig, axs = plt.subplots(ncols=len(fields), figsize=(16, 5)) + + for df, color in zip(dfs, sns.color_palette(n_colors=len(logs))): + for j, field in enumerate(fields): + if field == 'mAP': + coco_eval = pd.DataFrame( + np.stack(df.test_coco_eval_bbox.dropna().values)[:, 1] + ).ewm(com=ewm_col).mean() + axs[j].plot(coco_eval, c=color) + else: + df.interpolate().ewm(com=ewm_col).mean().plot( + y=[f'train_{field}', f'test_{field}'], + ax=axs[j], + color=[color] * 2, + style=['-', '--'] + ) + for ax, field in zip(axs, fields): + ax.legend([Path(p).name for p in logs]) + ax.set_title(field) + + +def plot_precision_recall(files, naming_scheme='iter'): + if naming_scheme == 'exp_id': + # name becomes exp_id + names = [f.parts[-3] for f in files] + elif naming_scheme == 'iter': + names = [f.stem for f in files] + else: + raise ValueError(f'not supported {naming_scheme}') + fig, axs = plt.subplots(ncols=2, figsize=(16, 5)) + for f, color, name in zip(files, sns.color_palette("Blues", n_colors=len(files)), names): + data = torch.load(f) + # precision is n_iou, n_points, n_cat, n_area, max_det + precision = data['precision'] + recall = data['params'].recThrs + scores = data['scores'] + # take precision for all classes, all areas and 100 detections + precision = precision[0, :, :, 0, -1].mean(1) + scores = scores[0, :, :, 0, -1].mean(1) + prec = precision.mean() + rec = data['recall'][0, :, 0, -1].mean() + print(f'{naming_scheme} {name}: mAP@50={prec * 100: 05.1f}, ' + + f'score={scores.mean():0.3f}, ' + + f'f1={2 * prec * rec / (prec + rec + 1e-8):0.3f}' + ) + axs[0].plot(recall, precision, c=color) + axs[1].plot(recall, scores, c=color) + + axs[0].set_title('Precision / Recall') + axs[0].legend(names) + axs[1].set_title('Scores / Recall') + axs[1].legend(names) + return fig, axs diff --git a/ACT_DP_multitask/policy_anyrobot.py b/ACT_DP_multitask/policy_anyrobot.py new file mode 100644 index 0000000000000000000000000000000000000000..464f4e13d2bd3317c2b7068e44e89941ed73f46d --- /dev/null +++ b/ACT_DP_multitask/policy_anyrobot.py @@ -0,0 +1,1496 @@ +import math +import torch +import torch.nn as nn +from torch.nn import functional as F +import torchvision.transforms as transforms +from diffusers import DDIMScheduler, DDPMScheduler +from detr.main import * +import numpy as np +from collections import deque +from utils import normalize_data, tensor2numpy, kl_divergence, RandomShiftsAug +import os + +# from cosmos_tokenizer.video_lib import CausalVideoTokenizer +# from cosmos_tokenizer.image_lib import ImageTokenizer + + +def get_tokenizer(model_name): + print(f"Loading tokenizer {model_name}") + current_dir = os.path.dirname(__file__) + checkpoint_enc = ( + f"{current_dir}/Cosmos-Tokenizer/pretrained_ckpts/{model_name}/encoder.jit" + ) + checkpoint_dec = ( + f"{current_dir}/Cosmos-Tokenizer/pretrained_ckpts/{model_name}/decoder.jit" + ) + model_type = model_name[18] # I or V + if model_type == "I": + encoder = ImageTokenizer(checkpoint_enc=checkpoint_enc) + decoder = ImageTokenizer(checkpoint_dec=checkpoint_dec) + elif model_type == "V": + encoder = CausalVideoTokenizer(checkpoint_enc=checkpoint_enc) + decoder = CausalVideoTokenizer(checkpoint_dec=checkpoint_dec) + + for param in encoder.parameters(): # frozen + param.requires_grad = False + for param in decoder.parameters(): + param.requires_grad = False + return encoder, decoder + + +class ACTPolicy(nn.Module): + def __init__(self, args_override): + super().__init__() + model, optimizer = build_ACT_model_and_optimizer(args_override) + self.model = model # CVAE decoder + self.optimizer = optimizer + self.kl_weight = args_override["kl_weight"] + print(f"KL Weight {self.kl_weight}") + self.args = args_override + + def __call__( + self, qpos, image, actions=None, is_pad=None, is_training=False, instances=None + ): + env_state = None + normalize = transforms.Normalize( + mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] + ) + + if actions is not None: # training time + # image = self.aug(image) if is_training else image # disbale aug + image = normalize(image) + actions = actions[:, : self.model.num_queries] + is_pad = is_pad[:, : self.model.num_queries] + + if self.args["segmentation"]: + # ACT with segmentation + a_hat, is_pad_hat, (mu, logvar), (mask_classes, outputs_seg_masks) = ( + self.model(qpos, image, env_state, actions, is_pad) + ) + else: + # Vanilla ACT + a_hat, is_pad_hat, (mu, logvar) = self.model( + qpos, image, env_state, actions, is_pad + ) + + total_kld, dim_wise_kld, mean_kld = kl_divergence(mu, logvar) + loss_dict = dict() + all_l1 = F.l1_loss(actions, a_hat, reduction="none") + + l1 = (all_l1 * ~is_pad.unsqueeze(-1)).mean() + loss_dict["l1"] = l1 + loss_dict["kl"] = total_kld[0] + + if self.args["segmentation"]: + targets = self.prepare_targets(instances, image) + losses_seg = self.criterion( + {"pred_logits": mask_classes, "pred_masks": outputs_seg_masks}, + targets, + ) + + for k in list(losses_seg.keys()): + if k in self.criterion.weight_dict: + losses_seg[k] *= ( + self.criterion.weight_dict[k] * self.segloss_weight + ) + else: + # remove this loss if not specified in `weight_dict` + losses_seg.pop(k) + loss_dict.update(losses_seg) + loss_dict["loss"] = ( + loss_dict["l1"] + + loss_dict["kl"] * self.kl_weight + + loss_dict["loss_ce"] + + loss_dict["loss_mask"] + + loss_dict["loss_dice"] + ) + else: + loss_dict["loss"] = loss_dict["l1"] + loss_dict["kl"] * self.kl_weight + return loss_dict + else: # inference time + image = normalize(image) + if self.args["segmentation"]: + a_hat, is_pad_hat, (mu, logvar), (mask_classes, outputs_seg_masks) = ( + self.model(qpos, image, env_state, actions, is_pad) + ) + for mask_cls_result, mask_pred_result in zip( + mask_classes, outputs_seg_masks + ): + sem_seg = torch.zeros( + (mask_pred_result.shape[1], mask_pred_result.shape[2], 3), + device=mask_pred_result.device, + ) + keep = mask_cls_result.softmax(-1)[:, 0] > 0.5 + for ii, mask in enumerate(mask_pred_result): + if keep[ii]: + sem_seg[mask.sigmoid() > 0.5, :] = torch.tensor( + self.colors[ii], + device=mask_pred_result.device, + dtype=sem_seg.dtype, + ) + self.seg = sem_seg.cpu().numpy() + # sem_seg = self.semantic_inference(mask_cls_result, mask_pred_result) + # import matplotlib.pyplot as plt + # plt.subplot(122) + # plt.imshow(sem_seg.cpu().numpy()/255) + # plt.savefig('seg.png') + # import pdb; pdb.set_trace() + else: + a_hat, _, (_, _) = self.model( + qpos, image, env_state + ) # no action, sample from prior + + return a_hat + + def configure_optimizers(self): + return self.optimizer + + def prepare_targets(self, targets, image): + # h, w = images.tensor.shape[-2:] + new_targets = [] + for targets_per_image in targets: + # pad gt + targets_per_image = targets_per_image.to(image.device) + gt_masks = targets_per_image.gt_masks + # padded_masks = torch.zeros((gt_masks.shape[0], h, w), dtype=gt_masks.dtype, device=gt_masks.device) + # padded_masks[:, : gt_masks.shape[1], : gt_masks.shape[2]] = gt_masks + new_targets.append( + { + "labels": targets_per_image.gt_classes, + "masks": gt_masks, + } + ) + return new_targets + + def semantic_inference(self, mask_cls, mask_pred): + mask_cls = F.softmax(mask_cls, dim=-1) + mask_pred = mask_pred.sigmoid() + semseg = torch.einsum("qc,qhw->chw", mask_cls, mask_pred) + return semseg + + # for robotwin + def reset_obs(self, stats, norm_type): + self.stats = stats + self.norm_type = norm_type + + def update_obs(self, obs): + self.obs_image = ( + torch.from_numpy(obs["head_cam"]).unsqueeze(0).unsqueeze(0).float().cuda() + ) # 1 1 C H W 0~1 + obs_qpos = torch.from_numpy(obs["agent_pos"]).unsqueeze(0).float().cuda() + self.obs_qpos = normalize_data( + obs_qpos, self.stats, "gaussian", data_type="qpos" + ) # qpos mean std + + def get_action(self): + a_hat = self(self.obs_qpos, self.obs_image).detach().cpu().numpy() # B T K + # unnormalize + if self.norm_type == "minmax": + a_hat = (a_hat + 1) / 2 * ( + self.stats["action_max"] - self.stats["action_min"] + ) + self.stats["action_min"] + elif self.norm_type == "gaussian": + a_hat = a_hat * self.stats["action_std"] + self.stats["action_mean"] + return a_hat[0] # chunksize 14 + + +class ACTDiffusionPolicy(nn.Module): + def __init__(self, args_override): + super().__init__() + model, optimizer = build_ACTDiffusion_model_and_optimizer(args_override) + self.model = model # CVAE decoder + self.optimizer = optimizer + self.kl_weight = args_override["kl_weight"] + self.aug = RandomShiftsAug(15, 20) # TODO acording to the task + # for robotwin env + self.history_steps = args_override["history_step"] + self.obs_image = deque(maxlen=self.history_steps + 1) + self.obs_qpos = deque(maxlen=self.history_steps + 1) + # diffusion setup + self.num_inference_steps = args_override["num_inference_steps"] + self.num_queries = args_override["num_queries"] + num_train_timesteps = args_override["num_train_timesteps"] + prediction_type = args_override["prediction_type"] + beta_schedule = args_override["beta_schedule"] + noise_scheduler = ( + DDIMScheduler if args_override["schedule_type"] == "DDIM" else DDPMScheduler + ) + noise_scheduler = noise_scheduler( + num_train_timesteps=num_train_timesteps, + beta_schedule=beta_schedule, + prediction_type=prediction_type, + ) + + self.noise_scheduler = noise_scheduler + self.loss_type = args_override["loss_type"] + print("num_train_timesteps", {args_override["num_train_timesteps"]}) + print("schedule_type", {args_override["schedule_type"]}) + print("beta_schedule", {args_override["beta_schedule"]}) + print("prediction_type", {args_override["prediction_type"]}) + print(f"Loss Type {self.loss_type}") + + def train_model(self, qpos, image, actions, is_pad=None,is_training=True, task_emb=None): + """ + qpos: B his+1 14 + image: B his+1 N_view 3 H W + """ + env_state = None + noise = torch.randn_like(actions).to(actions.device) + bsz = actions.shape[0] + timesteps = torch.randint( + 0, + self.noise_scheduler.config.num_train_timesteps, + (bsz,), + device=actions.device, + ) + # print('action device', actions.device, 'noise device', noise.device, 'timesteps device', timesteps.device, ) + noisy_actions = self.noise_scheduler.add_noise(actions, noise, timesteps) + + pred, is_pad_hat, [mu, logvar] = self.model( + qpos, image, env_state, noisy_actions, is_pad, denoise_steps=timesteps, is_training=is_training, task_emb=task_emb + ) + + pred_type = self.noise_scheduler.config.prediction_type + + if pred_type == "epsilon": + target = noise + elif pred_type == "sample": + target = actions + elif pred_type == "v_prediction": + # https://github.com/huggingface/diffusers/blob/main/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py + # https://github.com/huggingface/diffusers/blob/v0.11.1-patch/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py + # sigma = self.noise_scheduler.sigmas[timesteps] + # alpha_t, sigma_t = self.noise_scheduler._sigma_to_alpha_sigma_t(sigma) + self.noise_scheduler.alpha_t = self.noise_scheduler.alpha_t.to(self.device) + self.noise_scheduler.sigma_t = self.noise_scheduler.sigma_t.to(self.device) + alpha_t, sigma_t = ( + self.noise_scheduler.alpha_t[timesteps], + self.noise_scheduler.sigma_t[timesteps], + ) + alpha_t = alpha_t.unsqueeze(-1).unsqueeze(-1) + sigma_t = sigma_t.unsqueeze(-1).unsqueeze(-1) + v_t = alpha_t * noise - sigma_t * actions + target = v_t + else: + raise ValueError(f"Unsupported prediction type {pred_type}") + + loss_dict = {} + if self.loss_type == "l2": + loss = F.mse_loss(pred, target, reduction="none") + elif self.loss_type == "l1": + loss = F.l1_loss(pred, target, reduction="none") + diffusion_loss = (loss * ~is_pad.unsqueeze(-1)).mean() + diffusion_loss_name = pred_type + "_diffusion_loss_" + self.loss_type + loss_dict[diffusion_loss_name] = diffusion_loss + + if mu is not None and logvar is not None: # for CVAE module + total_kld, dim_wise_kld, mean_kld = kl_divergence(mu, logvar) + loss_dict["kl"] = total_kld[0] + loss_dict["loss"] = ( + loss_dict[diffusion_loss_name] + loss_dict["kl"] * self.kl_weight + ) + else: + loss_dict["loss"] = loss_dict[diffusion_loss_name] + return loss_dict + + # ===================inferece =============== + def conditional_sample(self, qpos, image, task_emb=None): + """ + diffusion process to generate actions + """ + if len(image.shape) == 5: # B N C H W + qpos = qpos.unsqueeze(1) + image = image.unsqueeze(1) + env_state = None + model = self.model + scheduler = self.noise_scheduler + batch = image.shape[0] + action_shape = (batch, self.num_queries, 14) + actions = torch.randn(action_shape, device=qpos.device, dtype=qpos.dtype) + scheduler.set_timesteps(self.num_inference_steps) + for t in scheduler.timesteps: + # print('diffusion timestep', t) + timesteps = torch.full((batch,), t, device=qpos.device, dtype=torch.long) + model_output, is_pad_hat, [mu, logvar] = model( + qpos, + image, + env_state, + actions, + None, + denoise_steps=timesteps, + is_training=False, + task_emb=task_emb, + ) + actions = scheduler.step(model_output, t, actions).prev_sample + return actions + + def __call__(self, qpos, image, actions=None, is_pad=None, is_training=True, task_emb=None): + # qpos: B D + # image: B Num_view C H W + # actions: B T K + normalize = transforms.Normalize( + mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] + ) + + if actions is not None: # training time + # image = self.aug(image) if is_training else image + image = normalize(image) + actions = actions[:, : self.model.num_queries] + is_pad = is_pad[:, : self.model.num_queries] + loss_dict = self.train_model(qpos, image, actions, is_pad, is_training,task_emb) + return loss_dict + else: # inference time + image = normalize(image) + a_hat = self.conditional_sample(qpos, image,task_emb) + return a_hat + + def configure_optimizers(self): + return self.optimizer + + def reset_obs(self, stats, norm_type): + self.stats = stats + self.norm_type = norm_type + + + def get_action(self, qpos, image, task_emb): + # normalize qpos + qpos = qpos.unsqueeze(0) # B D + image = image.unsqueeze(0) # B N C H W + task_emb = task_emb.unsqueeze(0) # B D ? + qpos = normalize_data(qpos, self.stats, "gaussian", data_type="qpos") + with torch.no_grad(): + a_hat = self(qpos, image, is_pad=None, is_training=False, task_emb=task_emb).detach().cpu().numpy() + # unnormalize action + if self.norm_type == "minmax": + a_hat = (a_hat + 1) / 2 * ( + self.stats["action_max"] - self.stats["action_min"] + ) + self.stats["action_min"] + elif self.norm_type == "gaussian": + a_hat = a_hat * self.stats["action_std"] + self.stats["action_mean"] + # detach and convert to numpy + a_hat = a_hat[0] + return a_hat # 50 14 + + + +class ACT_Flow_Matching(nn.Module): + def __init__(self, args_override): + super().__init__() + model, optimizer = build_ACTDiffusion_model_and_optimizer(args_override) + self.model = model + self.optimizer = optimizer + if "sim" in args_override["task_name"]: # for aloha env + self.aug = RandomShiftsAug(15, 20) # TODO acording to the task + else: + self.aug = RandomShiftsAug(8, 10) # for robotwin env + self.history_steps = args_override["history_step"] + self.obs_image = deque(maxlen=self.history_steps + 1) + self.obs_qpos = deque(maxlen=self.history_steps + 1) + self.num_queries = args_override["num_queries"] + # flow matching steps + self.num_inference_steps = args_override["num_inference_steps"] + self.noise_scheduler = torch.distributions.Beta(1.5, 1) + self.loss_type = args_override["loss_type"] + def train_model(self, qpos, image, actions, is_pad=None): + """ + qpos: B his+1 14 + image: B his+1 N_view 3 H W + """ + env_state = None + noise = torch.randn_like(actions).to(actions.device) + bsz = actions.shape[0] + # refer to openpi0 + timesteps = self.noise_scheduler.sample((bsz,)).to(actions.device) * 0.999 + 0.001 + timesteps_expand = timesteps[...,None,None] # B 1 1 + noisy_actions = timesteps_expand * noise + (1 - timesteps_expand) * actions + + pred, is_pad_hat, [mu, logvar] = self.model( + qpos, image, env_state, noisy_actions, is_pad, denoise_steps=timesteps + ) + + target = noise - actions # ut + loss_dict = {} + if self.loss_type == "l2": + loss = F.mse_loss(pred, target, reduction="none") + elif self.loss_type == "l1": + loss = F.l1_loss(pred, target, reduction="none") + diffusion_loss = (loss * ~is_pad.unsqueeze(-1)).mean() + pred_type = 'v_pred' + diffusion_loss_name = pred_type + "_flow_loss_" + self.loss_type + loss_dict[diffusion_loss_name] = diffusion_loss + + if mu is not None and logvar is not None: # for CVAE module + total_kld, dim_wise_kld, mean_kld = kl_divergence(mu, logvar) + loss_dict["kl"] = total_kld[0] + loss_dict["loss"] = ( + loss_dict[diffusion_loss_name] + loss_dict["kl"] * self.kl_weight + ) + else: + loss_dict["loss"] = loss_dict[diffusion_loss_name] + return loss_dict + + # ===================inferece =============== + def conditional_sample(self, qpos, image, is_pad): + """ + diffusion process to generate actions + """ + if len(image.shape) == 5: # B N C H W + qpos = qpos.unsqueeze(1) + image = image.unsqueeze(1) + env_state = None + model = self.model + batch = image.shape[0] + action_shape = (batch, self.num_queries, 14) + actions = torch.randn(action_shape, device=qpos.device, dtype=qpos.dtype) + dt = -1.0 / self.num_inference_steps # -0.1, constant speed + timesteps = torch.ones(batch, device=qpos.device) + while timesteps[0].item() >= -dt/2: # inferen_step + model_output, is_pad_hat, [mu, logvar] = model( + qpos, + image, + env_state, + actions, + is_pad, + denoise_steps=timesteps, + is_training=False, + ) + actions += model_output * dt + timesteps += dt + return actions.clip(min=-1, max=1) + + def __call__(self, qpos, image, actions=None, is_pad=None, is_training=True): + # qpos: B D + # image: B Num_view C H W + # actions: B T K + normalize = transforms.Normalize( + mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] + ) + + if actions is not None: # training time + # image = self.aug(image) if is_training else image + image = normalize(image) + actions = actions[:, : self.model.num_queries] + is_pad = is_pad[:, : self.model.num_queries] + loss_dict = self.train_model(qpos, image, actions, is_pad) + return loss_dict + else: # inference time + image = normalize(image) + a_hat = self.conditional_sample(qpos, image, is_pad) + return a_hat + + def configure_optimizers(self): + return self.optimizer + + +class ACTDiffusionPolicy_Tactile(nn.Module): + def __init__(self, args_override): + super().__init__() + model, optimizer = build_ACTDiffusion_tactile_model_and_optimizer(args_override) + self.model = model # CVAE decoder + self.optimizer = optimizer + print(args_override.keys()) + self.aug = RandomShiftsAug(8, 10) # for robotwin env + self.history_steps = 0 + self.obs_image = deque(maxlen=self.history_steps + 1) + self.obs_qpos = deque(maxlen=self.history_steps + 1) + self.obs_tactile = deque(maxlen=self.history_steps + 1) + # diffusion setup + self.num_inference_steps = args_override["num_inference_steps"] + self.num_queries = args_override["num_queries"] + num_train_timesteps = args_override["num_train_timesteps"] + prediction_type = args_override["prediction_type"] + beta_schedule = args_override["beta_schedule"] + noise_scheduler = ( + DDIMScheduler if args_override["schedule_type"] == "DDIM" else DDPMScheduler + ) + noise_scheduler = noise_scheduler( + num_train_timesteps=num_train_timesteps, + beta_schedule=beta_schedule, + prediction_type=prediction_type, + ) + + self.noise_scheduler = noise_scheduler + self.loss_type = args_override["loss_type"] + print("num_train_timesteps", {args_override["num_train_timesteps"]}) + print("schedule_type", {args_override["schedule_type"]}) + print("beta_schedule", {args_override["beta_schedule"]}) + print("prediction_type", {args_override["prediction_type"]}) + print(f"Loss Type {self.loss_type}") + + def train_model(self, qpos, image,tactile, actions, is_pad=None): + """ + qpos: B his+1 14 + image: B his+1 N_view 3 H W + """ + env_state = None + noise = torch.randn_like(actions).to(actions.device) + bsz = actions.shape[0] + timesteps = torch.randint( + 0, + self.noise_scheduler.config.num_train_timesteps, + (bsz,), + device=actions.device, + ) + noisy_actions = self.noise_scheduler.add_noise(actions, noise, timesteps) + + pred, is_pad_hat = self.model( + qpos, image, tactile, env_state, noisy_actions, is_pad, denoise_steps=timesteps + ) + + pred_type = self.noise_scheduler.config.prediction_type + + if pred_type == "epsilon": + target = noise + elif pred_type == "sample": + target = actions + elif pred_type == "v_prediction": + # https://github.com/huggingface/diffusers/blob/main/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py + # https://github.com/huggingface/diffusers/blob/v0.11.1-patch/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py + # sigma = self.noise_scheduler.sigmas[timesteps] + # alpha_t, sigma_t = self.noise_scheduler._sigma_to_alpha_sigma_t(sigma) + self.noise_scheduler.alpha_t = self.noise_scheduler.alpha_t.to(self.device) + self.noise_scheduler.sigma_t = self.noise_scheduler.sigma_t.to(self.device) + alpha_t, sigma_t = ( + self.noise_scheduler.alpha_t[timesteps], + self.noise_scheduler.sigma_t[timesteps], + ) + alpha_t = alpha_t.unsqueeze(-1).unsqueeze(-1) + sigma_t = sigma_t.unsqueeze(-1).unsqueeze(-1) + v_t = alpha_t * noise - sigma_t * actions + target = v_t # flow matching? + else: + raise ValueError(f"Unsupported prediction type {pred_type}") + + loss_dict = {} + if self.loss_type == "l2": + loss = F.mse_loss(pred, target, reduction="none") + elif self.loss_type == "l1": + loss = F.l1_loss(pred, target, reduction="none") + diffusion_loss = (loss * ~is_pad.unsqueeze(-1)).mean() + diffusion_loss_name = pred_type + "_diffusion_loss_" + self.loss_type + loss_dict[diffusion_loss_name] = diffusion_loss + + loss_dict["loss"] = loss_dict[diffusion_loss_name] + return loss_dict + + # ===================inferece =============== + def conditional_sample(self, qpos, image,tactile, is_pad): + """ + diffusion process to generate actions + """ + if len(image.shape) == 5: # B N C H W + qpos = qpos.unsqueeze(1) + image = image.unsqueeze(1) + tactile = tactile.unsqueeze(1) + env_state = None + model = self.model + scheduler = self.noise_scheduler + batch = image.shape[0] + action_shape = (batch, self.num_queries, 14) + actions = torch.randn(action_shape, device=qpos.device, dtype=qpos.dtype) + scheduler.set_timesteps(self.num_inference_steps) + for t in scheduler.timesteps: + timesteps = torch.full((batch,), t, device=qpos.device, dtype=torch.long) + model_output, is_pad_hat, [mu, logvar] = model( + qpos, + image, + tactile, + env_state, + actions, + is_pad, + denoise_steps=timesteps, + is_training=False, + ) + actions = scheduler.step(model_output, t, actions).prev_sample + return actions + + def __call__(self, qpos, image, tactile, actions=None, is_pad=None, is_training=True): + # qpos: B D + # image: B Num_view C H W + # actions: B T K + normalize = transforms.Normalize( + mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] + ) + + if actions is not None: # training time + image = self.aug(image) if is_training else image + image = normalize(image) + actions = actions[:, : self.model.num_queries] + is_pad = is_pad[:, : self.model.num_queries] + loss_dict = self.train_model(qpos, image,tactile, actions, is_pad) + return loss_dict + else: # inference time + image = normalize(image) + a_hat = self.conditional_sample(qpos, image,tactile,is_pad) + return a_hat + + def configure_optimizers(self): + return self.optimizer + + +## use visual tokenization +""" +Input: + A unified dataset for all the datasets + image_data [0~1]: history_steps+1 Num_view C H W + qpos_data [normalized]: history_steps+1 D + action_data [raw]: chunk_size D + is_pad: chunk_size + future_imgs_data [0~1]: predict_frame Num_view C H W + is_pad_img : predict_frame +""" + + +class ACTPolicyDiffusion_with_Token_Prediction(nn.Module): + def __init__(self, args_override): + super().__init__() + model = build_diffusion_tp_model_and_optimizer(args_override) + self.model = model # decoder + self.camera_num = len(args_override["camera_names"]) + self.kl_weight = args_override["kl_weight"] + print(f"KL Weight {self.kl_weight}") + # memory buffer + self.history_steps = args_override["history_step"] + self.obs_image = deque(maxlen=self.history_steps + 1) + self.obs_qpos = deque(maxlen=self.history_steps + 1) + # visual tokenization + if "sim" in args_override["task_name"]: # for aloha env + self.aug = RandomShiftsAug(15, 20) + else: + self.aug = RandomShiftsAug(8, 10) # for robotwin env + # tokenizer model and shape + self.tokenizer_model_type = args_override["tokenizer_model_name"][17:19] # VI + self.tokenizer_enc, self.tokenizer_dec = get_tokenizer( + args_override["tokenizer_model_name"] + ) + self.token_dim = args_override["token_dim"] + self.num_temporal_token = self.model.num_temporal_token + self.token_h = ( + args_override["image_height"] + // args_override["image_downsample_rate"] + // args_override["tokenizer_model_spatial_rate"] + // args_override["resize_rate"] + ) + self.token_w = ( + args_override["image_width"] + // args_override["image_downsample_rate"] + // args_override["tokenizer_model_spatial_rate"] + // args_override["resize_rate"] + ) + print( + "token shape", + "token_h", + self.token_h, + "token_w", + self.token_w, + "token_dim", + self.token_dim, + ) + # video prediction hyperparameters + self.temporal_compression = args_override[ + "tokenizer_model_temporal_rate" + ] # temporal compression + self.predict_only_last = args_override["predict_only_last"] + self.prediction_weight = args_override["prediction_weight"] + self.imitate_weight = args_override["imitate_weight"] + self.predict_frame = args_override["predict_frame"] + self.temporal_downsample_rate = args_override[ + "temporal_downsample_rate" + ] # uniformly sample + self.resize_rate = args_override["resize_rate"] + print("tokenizer_model_type", self.tokenizer_model_type) + print("predict_frame", self.predict_frame) + print("prediction_weight", self.prediction_weight) + print("imitate_weight", self.imitate_weight) + + # diffusion hyperparameters + self.num_inference_steps = args_override["num_inference_steps"] + self.num_queries = args_override["num_queries"] + num_train_timesteps = args_override["num_train_timesteps"] + prediction_type = args_override["prediction_type"] + beta_schedule = args_override["beta_schedule"] + noise_scheduler = ( + DDIMScheduler if args_override["schedule_type"] == "DDIM" else DDPMScheduler + ) + noise_scheduler = noise_scheduler( + num_train_timesteps=num_train_timesteps, + beta_schedule=beta_schedule, + prediction_type=prediction_type, + ) + self.noise_scheduler = noise_scheduler + pred_type = self.noise_scheduler.config.prediction_type + self.loss_type = args_override["loss_type"] + self.diffusion_loss_name = pred_type + "_diffusion_loss_" + self.loss_type + + print("num_train_timesteps", {args_override["num_train_timesteps"]}) + print("schedule_type", {args_override["schedule_type"]}) + print("beta_schedule", {args_override["beta_schedule"]}) + print("prediction_type", {args_override["prediction_type"]}) + print(f"Loss Type {self.loss_type}") + + self.normalize = transforms.Normalize( + mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] + ) + self.vis_idx = 0 + + def train_model(self, qpos, image, env_state, actions, is_pad, is_image_pad=None): + # qpos: B T' D; T' = history_steps+1 + # image: B T+1 N C H W ,T = history_steps+1+predict_frame 0~1 + # actions: B H D, H = chunk_size + # is_pad: B H, is_valid + # is_image_pad: B predict_frame + env_state = None + bsz = actions.shape[0] + is_tokens_pad = torch.ones( + bsz, self.num_temporal_token, device=actions.device, dtype=torch.bool + ) # B T/t length after temporal compression + if self.predict_only_last: + is_tokens_pad = is_image_pad + else: + valid_is_tokens_pad = is_image_pad[ + :, :: self.temporal_compression + ] # avoid meaningless token + is_tokens_pad[:, : valid_is_tokens_pad.shape[-1]] = valid_is_tokens_pad + + timesteps = torch.randint( + 0, + self.noise_scheduler.config.num_train_timesteps, + (bsz,), + device=actions.device, + ) + + # image tokenization ; another image normalization for resnet + current_image_norm = self.normalize( + image[:, 0 : self.history_steps + 1] # image[:, 0:1] + ) # B 1 N C H W TODO image[:, self.history_steps,self.history_steps+1] + # TOOD only compress current and future frames image[:, self.history_steps:,:] + image_tokens = self.get_visual_token( + image[ + :, self.history_steps :, 0:1, :, :: self.resize_rate, :: self.resize_rate + ] # resize future frame + ) # B T+1 N C H W + # image_tokens = self.get_visual_token( + # image[..., :: self.resize_rate, :: self.resize_rate] # resize future frame + # ) # B T/t + 1 N D H' W' including history_steps+1+predict_frame + current_image_tokens = image_tokens[:, 0:1] # B 1 N D H' W' useless + future_image_tokens = image_tokens[:, 1:] # B T N D H' W' + self.image_tokens_shape = future_image_tokens.shape # B T N D H' W' + # TODO can use differnent noise_scheduler for image token & actions + actions_noise = torch.randn_like(actions).to(actions.device) + token_noise = torch.randn_like(future_image_tokens).to( + future_image_tokens.device + ) + noisy_actions = self.noise_scheduler.add_noise( + actions, actions_noise, timesteps + ) + noise_tokens = self.noise_scheduler.add_noise( + future_image_tokens, token_noise, timesteps + ) # future image token + # use detr-diffusion model to predict actions & image tokens + a_hat, is_pad_hat, pred_token, (mu, logvar) = self.model( + qpos, # B his+1 D + (current_image_norm, current_image_tokens), # B 1 N C H W + env_state, + actions, # B T D + is_pad, + noisy_actions, # B T D + noise_tokens, # B T' N_view D H' W' + is_tokens_pad, + denoise_steps=timesteps, + ) + + # prediction type + pred_type = self.noise_scheduler.config.prediction_type + + if pred_type == "epsilon": + target_action = actions_noise + target_token = token_noise + elif pred_type == "sample": + target_action = actions + target_token = future_image_tokens + else: + raise ValueError(f"Unsupported prediction type {pred_type}") + + # calculate diffusion loss + loss_dict = {} + if self.loss_type == "l2": + loss = F.mse_loss(a_hat, target_action, reduction="none") + elif self.loss_type == "l1": + loss = F.l1_loss(a_hat, target_action, reduction="none") + diffusion_loss = (loss * ~is_pad.unsqueeze(-1)).mean() + loss_dict[self.diffusion_loss_name] = diffusion_loss + # just vis diffusion l2 loss + diffusion_l2 = F.mse_loss(a_hat, target_action, reduction="none") + diffusion_l2 = (diffusion_l2 * ~is_pad.unsqueeze(-1)).mean().detach() + loss_dict["diffusion_l2"] = diffusion_l2 + + tokens_loss = F.mse_loss( + pred_token, target_token, reduction="none" + ) # B T N D H' W' + is_tokens_pad = ( + is_tokens_pad.unsqueeze(-1).unsqueeze(-1).unsqueeze(-1).unsqueeze(-1) + ) # B T N D H' W' + tokens_loss = (tokens_loss * ~is_tokens_pad).mean() + loss_dict["loss_prediction"] = tokens_loss + + if mu is not None and logvar is not None: + total_kld, dim_wise_kld, mean_kld = kl_divergence(mu, logvar) + loss_dict["kl"] = total_kld[0] + loss_dict["loss"] = ( + loss_dict[self.diffusion_loss_name] * self.imitate_weight + + loss_dict["kl"] * self.kl_weight + + loss_dict["loss_prediction"] * self.prediction_weight + ) + else: + loss_dict["loss"] = ( + loss_dict[self.diffusion_loss_name] * self.imitate_weight + + loss_dict["loss_prediction"] * self.prediction_weight + ) + + return loss_dict, (current_image_tokens, target_token, pred_token) + + def conditional_sample(self, qpos, image): + # qpos: B 1 D or B D + # image: B 1 N C H W or B N C H W + if len(image.shape) == 5: # B N C H W single frame + qpos = qpos.unsqueeze(1) + image = image.unsqueeze(1) + env_state = None + model = self.model + scheduler = self.noise_scheduler + scheduler.set_timesteps(self.num_inference_steps) + # process image observation + current_image_norm = self.normalize(image[:, 0 : self.history_steps + 1]) # B 1 N C H W + # initial noise action & token + batch = image.shape[0] + action_shape = (batch, self.num_queries, 14) + actions = torch.randn(action_shape, device=qpos.device, dtype=qpos.dtype) + tokens = None # TODO discard token prediction while evaluation + for t in scheduler.timesteps: + timesteps = torch.full((batch,), t, device=qpos.device, dtype=torch.long) + model_action_output, is_pad_hat, model_token_output, (mu, logvar) = model( + qpos, + (current_image_norm, None), + env_state, + None, + None, + actions, + tokens, + None, + denoise_steps=timesteps, + ) + actions = scheduler.step(model_action_output, t, actions).prev_sample + return actions, tokens, mu, logvar + + def __call__( + self, + qpos, + image, + actions=None, + is_pad=None, + future_imgs=None, + is_pad_img=None, + is_training=True, + save_rec=False, + ): + env_state = None + if actions is not None: # training time + # print(image.shape, future_imgs.shape) B 1 N C H W & B H N C H W + all_image = torch.cat( + [image, future_imgs], dim=1 + ) # B H+1+T' N C H W same resize maybe just use image_sample_size + all_image = self.aug(all_image) if is_training else all_image + loss_dict, (current_image_tokens, target_token, pred_token) = ( + self.train_model( + qpos, all_image, env_state, actions, is_pad, is_pad_img + ) + ) # B H + + if save_rec == True: # show the visulization result + # a_hat, pred_token, _, _ = self.conditional_sample(qpos, image) # generate bug? + # tokens_loss = F.mse_loss(pred_token, target_token) + # print('conditional tokens_loss', tokens_loss) # not predict from noise + print("is_image_pad rate ", is_pad_img[0].sum() / 20) + raw_videos = ( + all_image[:, :, 0].permute(0, 2, 1, 3, 4) * 2 - 1 + ) # B T C H W -> B C T H W -1,1 + rec_videos = self.generate_video_by_codes( + current_image_tokens, pred_token + ) + rec_gt_videos = self.generate_video_by_codes( + current_image_tokens, target_token + ) + error_videos = (rec_gt_videos - rec_videos).clip(-1, 1) + + raw_videos = tensor2numpy(raw_videos)[0] # C T H W + rec_videos = tensor2numpy(rec_videos)[0] # T H W C + rec_gt_videos = tensor2numpy(rec_gt_videos)[0] # T H W C + error_videos = tensor2numpy(error_videos)[0] # T H W C + vis_video = np.concatenate( + [raw_videos, rec_gt_videos, rec_videos, error_videos], axis=1 + ) # T H*N W C + return loss_dict, vis_video + else: + return loss_dict + + else: # inference time + qpos = qpos # B 1 D or B D + image = image # B 1 N C H W or B N C H W + # print(image.shape, image.max(), image.min()) + a_hat, pred_token, _, _ = self.conditional_sample(qpos, image) + # print('prediction action', a_hat.shape) + return a_hat # B H D + + # visual tokenization generate tokens & reconstruct video + def get_visual_token(self, input_tensor): + # input_tensor: B T N C H W range [0,1] -> [-1,1] + input_tensor = input_tensor * 2 - 1 # B T N C H W + input_tensor = input_tensor.permute(0, 2, 3, 1, 4, 5) # B N C T H W + horizon = input_tensor.shape[3] + C, H, W = input_tensor.shape[2], input_tensor.shape[4], input_tensor.shape[5] + self.num_view = input_tensor.shape[1] + codes_list = [] + # refer to Cosmos-Tokenizer/cosmos_tokenizer/video_lib.py Line 103 + for view_idx in range( + self.num_view + ): # deal with each view for video tokenization B C T H W + if self.tokenizer_model_type == "DV": # encoder for video tokenization + (indices, codes) = self.tokenizer_enc._enc_model( + input_tensor[:, view_idx] # B C T H W + )[ + :-1 + ] # B D T' H' W' + elif self.tokenizer_model_type == "DI": # encoder for image tokenization + input_tensor_image = ( + input_tensor[:, view_idx].permute(0, 2, 1, 3, 4).view(-1, C, H, W) + ) # B C T H W -> B*T C H W + (indices, codes) = self.tokenizer_enc._enc_model(input_tensor_image)[ + :-1 + ] # B*T D H' W' + codes = codes.view( + -1, horizon, codes.shape[1], codes.shape[2], codes.shape[3] + ).permute( + 0, 2, 1, 3, 4 + ) # B T+1 D H' W' -> B D T+1 H' W' + elif ( + self.tokenizer_model_type == "CV" + ): # encoder for video tokenization TODO check + (codes,) = self.tokenizer_enc._enc_model(input_tensor[:, view_idx])[ + :-1 + ] # B 16 T' H' W' + codes = codes / 15.0 # nomalize to [-1,1] HardCode + # TODO codes should normlize to [-1,1] + elif self.tokenizer_model_type == "CI": + input_tensor_image = ( + input_tensor[:, view_idx].permute(0, 2, 1, 3, 4).view(-1, C, H, W) + ) # B C T H W -> B*T C H W + (codes,) = self.tokenizer_enc._enc_model(input_tensor_image)[ + :-1 + ] # B*T D H' W' + codes = codes.view( + -1, horizon, codes.shape[1], codes.shape[2], codes.shape[3] + ).permute( + 0, 2, 1, 3, 4 + ) # B T+1 D H' W' -> B D T+1 H' W' + codes = codes / 15.0 # nomalize to [-1,1] HardCode + codes_list.append(codes.detach()) # Important + codes = torch.stack(codes_list, dim=1) # B N D T' H' W' -> # B T+1 N D H' W' + codes = codes.permute(0, 3, 1, 2, 4, 5).float() # B T+1 N D H' W' + return codes + + def generate_video_by_codes(self, current_codes, pred_codes): + # Input: B T'+1 D H' W' simgle view video tokens + # Output: B C T+1 H W range [-1,1] + codes = torch.cat([current_codes, pred_codes], dim=1)[ + :, :, 0 + ] # B T+1 N D H' W' -> B T'+1 D H' W'# HARD CODE + codes = codes.permute(0, 2, 1, 3, 4).to(dtype=torch.bfloat16) # B D T'+1 H' W' + # suitable for DV only + if self.tokenizer_model_type == "DV": + h = self.tokenizer_dec._dec_model.post_quant_conv( + codes + ) # problem shoudl fixed TODO + elif self.tokenizer_model_type == "CV": + h = codes * 15 # unnormalize to original + reconstructed_videos = self.tokenizer_dec._dec_model.decoder( + h + ).detach() # B C T+1 H W -1,1 + + return reconstructed_videos + + # For ROBOTWIN + def reset_obs(self, stats=None, norm_type="minmax"): + self.obs_image.clear() + self.obs_qpos.clear() + self.stats = stats + self.norm_type = norm_type + + def update_obs(self, obs): + image_data = ( + torch.from_numpy(obs["head_cam"]).unsqueeze(0).unsqueeze(0).float().cuda() + ) # B 1 C H W 0~1 + obs_qpos = torch.from_numpy(obs["agent_pos"]).unsqueeze(0).float().cuda() # B D + qpos_data = normalize_data( + obs_qpos, self.stats, "gaussian", data_type="qpos" + ) # qpos mean std + + if len(self.obs_image) == 0: + for _ in range(self.history_steps + 1): + self.obs_image.append(image_data) # B T N C H W + self.obs_qpos.append(qpos_data) + else: + self.obs_image.append(image_data) + self.obs_qpos.append(qpos_data) + + def get_action(self): + stacked_obs_image = torch.stack( + list(self.obs_image), dim=1 + ) # 1 n+1 1 3 H W raw + stacked_obs_qpos = torch.stack(list(self.obs_qpos), dim=1) # 1 n+1 14 + a_hat = ( + self(stacked_obs_qpos, stacked_obs_image).detach().cpu().numpy() + ) # 1 chunksize 14 + if self.norm_type == "minmax": + a_hat = (a_hat + 1) / 2 * ( + self.stats["action_max"] - self.stats["action_min"] + ) + self.stats["action_min"] + elif self.norm_type == "gaussian": + a_hat = a_hat * self.stats["action_std"] + self.stats["action_mean"] + return a_hat[0] # chunksize 14 + + +class ACTPolicyDiffusion_with_Pixel_Prediction(nn.Module): + def __init__(self, args_override): + super().__init__() + model = build_diffusion_pp_model_and_optimizer(args_override) # TODO + self.model = model # CVAE decoder + self.kl_weight = args_override["kl_weight"] + print(f"KL Weight {self.kl_weight}") + + # memory buffer + self.history_steps = args_override["history_step"] + self.obs_image = deque(maxlen=self.history_steps + 1) + self.obs_qpos = deque(maxlen=self.history_steps + 1) + # self.obs_depth = deque(maxlen=self.history_steps+1) + # visual tokenization + if "sim" in args_override["task_name"]: + self.aug = RandomShiftsAug(15, 20) # + self.aug = RandomShiftsAug(8, 10) + self.predict_only_last = args_override["predict_only_last"] + self.prediction_weight = args_override["prediction_weight"] + self.imitate_weight = args_override["imitate_weight"] + self.predict_frame = args_override["predict_frame"] + self.temporal_downsample_rate = args_override["temporal_downsample_rate"] + self.resize_rate = args_override["resize_rate"] + self.image_height = args_override["image_height"] + self.image_width = args_override["image_width"] + # T N C H W + self.future_images_shape = ( + args_override["predict_frame"] // args_override["temporal_downsample_rate"], + len(args_override["camera_names"]), + 3, + self.image_height // self.resize_rate, + self.image_width // self.resize_rate, + ) + print("predict_frame", self.predict_frame) + print("prediction_weight", self.prediction_weight) + print("imitate_weight", self.imitate_weight) + # diffusion step + self.num_inference_steps = args_override["num_inference_steps"] + self.num_queries = args_override["num_queries"] + num_train_timesteps = args_override["num_train_timesteps"] + prediction_type = args_override["prediction_type"] + beta_schedule = args_override["beta_schedule"] + noise_scheduler = ( + DDIMScheduler if args_override["schedule_type"] == "DDIM" else DDPMScheduler + ) + noise_scheduler = noise_scheduler( + num_train_timesteps=num_train_timesteps, + beta_schedule=beta_schedule, + prediction_type=prediction_type, + ) + self.noise_scheduler = noise_scheduler + pred_type = self.noise_scheduler.config.prediction_type + self.loss_type = args_override["loss_type"] + self.diffusion_loss_name = pred_type + "_diffusion_loss_" + self.loss_type + + print("num_train_timesteps", {args_override["num_train_timesteps"]}) + print("schedule_type", {args_override["schedule_type"]}) + print("beta_schedule", {args_override["beta_schedule"]}) + print("prediction_type", {args_override["prediction_type"]}) + print(f"Loss Type {self.loss_type}") + # discard resnet + self.normalize = transforms.Normalize( + mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] + ) + self.vis_idx = 0 + + def train_model(self, qpos, image, env_state, actions, is_pad, is_image_pad=None): + # qpos: B T' D; T' = history_steps+1 + # image: B T+1 N C H W ,T = history_steps+1+predict_frame 0~1s + # actions: B H D H = chunk_size + # is_pad: B H + # is_image_pad: B predict_frame + env_state = None + bsz = actions.shape[0] + + timesteps = torch.randint( + 0, + self.noise_scheduler.config.num_train_timesteps, + (bsz,), + device=actions.device, + ) + current_image_norm = self.normalize(image[:, 0:1]) + future_images = ( + image[:, 1:, :, :, :: self.resize_rate, :: self.resize_rate] * 2 - 1 + ) # B T N C H W scale to [-1,1] + # diffusion process + actions_noise = torch.randn_like(actions).to(actions.device) + pixel_noise = torch.randn_like(future_images).to(future_images.device) + noisy_actions = self.noise_scheduler.add_noise( + actions, actions_noise, timesteps + ) + noise_pixel = self.noise_scheduler.add_noise( + future_images, pixel_noise, timesteps + ) # future image token + # predict clean data + a_hat, is_pad_hat, pred_images, (mu, logvar) = self.model( + qpos, + (current_image_norm, None), + env_state, + actions, + is_pad, + noisy_actions, + noise_pixel, + is_image_pad, + denoise_steps=timesteps, + ) + + # prediction type + pred_type = self.noise_scheduler.config.prediction_type + + if pred_type == "epsilon": + target_action = actions_noise + target_images = pixel_noise + elif pred_type == "sample": + target_action = actions + target_images = future_images + else: + raise ValueError(f"Unsupported prediction type {pred_type}") + + # calculate diffusion loss + loss_dict = {} + if self.loss_type == "l2": + loss = F.mse_loss(a_hat, target_action, reduction="none") + elif self.loss_type == "l1": + loss = F.l1_loss(a_hat, target_action, reduction="none") + diffusion_loss = (loss * ~is_pad.unsqueeze(-1)).mean() + loss_dict[self.diffusion_loss_name] = diffusion_loss + diffusion_l2 = F.mse_loss(a_hat, target_action, reduction="none") + diffusion_l2 = (diffusion_l2 * ~is_pad.unsqueeze(-1)).mean().detach() + loss_dict["diffusion_l2"] = diffusion_l2 + + pixel_loss = F.mse_loss( + pred_images, target_images, reduction="none" + ) # B T N C H W + is_image_pad = ( + is_image_pad.unsqueeze(-1).unsqueeze(-1).unsqueeze(-1).unsqueeze(-1) + ) # B T N C H W + pixel_loss = (pixel_loss * ~is_image_pad).mean() + loss_dict["loss_prediction_pixel"] = pixel_loss + + if mu is not None and logvar is not None: + total_kld, dim_wise_kld, mean_kld = kl_divergence(mu, logvar) + loss_dict["kl"] = total_kld[0] + loss_dict["loss"] = ( + loss_dict[self.diffusion_loss_name] * self.imitate_weight + + loss_dict["kl"] * self.kl_weight + + loss_dict["loss_prediction_pixel"] * self.prediction_weight + ) + else: + loss_dict["loss"] = ( + loss_dict[self.diffusion_loss_name] * self.imitate_weight + + loss_dict["loss_prediction_pixel"] * self.prediction_weight + ) + + return loss_dict, (image[:, 0:1], target_images, pred_images) + + def conditional_sample(self, qpos, image): + # qpos: B 1 D + # image: B 1 N C H W + if len(image.shape) == 5: + qpos = qpos.unsqueeze(1) + image = image.unsqueeze(1) + env_state = None + model = self.model + scheduler = self.noise_scheduler + scheduler.set_timesteps(self.num_inference_steps) + # process image observation + current_image_norm = self.normalize(image[:, 0:1]) # B 1 N C H W + batch = image.shape[0] + action_shape = (batch, self.num_queries, 14) + future_images_shape = (batch, *self.future_images_shape) # B T N 3 H' W' + actions = torch.randn(action_shape, device=qpos.device, dtype=qpos.dtype) + pixels = torch.randn(future_images_shape, device=qpos.device, dtype=qpos.dtype) + # denoise + for t in scheduler.timesteps: + timesteps = torch.full((batch,), t, device=qpos.device, dtype=torch.long) + model_action_output, is_pad_hat, model_pixel_output, (mu, logvar) = model( + qpos, + (current_image_norm, None), + env_state, + None, + None, + actions, + pixels, + None, + denoise_steps=timesteps, + ) + actions = scheduler.step(model_action_output, t, actions).prev_sample + pixels = scheduler.step(model_pixel_output, t, pixels).prev_sample + return actions, pixels, mu, logvar + + def __call__( + self, + qpos, + image, + actions=None, + is_pad=None, + future_imgs=None, + is_pad_img=None, + is_training=True, + ): + env_state = None + if actions is not None: # training time + all_image = torch.cat( + [image, future_imgs], dim=1 + ) # B T N C H W same resize maybe just use image_sample_size + all_image = self.aug(all_image) if is_training else all_image + loss_dict, (current_image, target_images, pred_images) = self.train_model( + qpos, all_image, env_state, actions, is_pad, is_pad_img + ) # B H + return loss_dict + else: + qpos = qpos # B 1 D + image = image # B 1 N C H W 0~1 + a_hat, pred_images, _, _ = self.conditional_sample(qpos, image) + return a_hat # B H D + + +# discard +class ACTPolicy_NextFrame(nn.Module): + def __init__(self, args_override): + super().__init__() + model, optimizer = build_ACT_NF_model_and_optimizer(args_override) + self.model = model # CVAE decoder + self.optimizer = optimizer + self.kl_weight = args_override["kl_weight"] + self.nextframe_weight = 1 + print(f"KL Weight {self.kl_weight}") + + def __call__(self, qpos, image, actions=None, is_pad=None): + env_state = None + normalize = transforms.Normalize( + mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] + ) + if actions is not None: # training time + curr_image = image[:, 0:1] + next_image = image[:, 1:] + curr_image = normalize(curr_image) + image = torch.cat( + [curr_image, next_image], dim=1 + ) # B T C H W normalize currernt image not funture + + actions = actions[:, : self.model.num_queries] + is_pad = is_pad[:, : self.model.num_queries] + + a_hat, is_pad_hat, (mu, logvar), (obs_preds, obs_targets) = self.model( + qpos, image, env_state, actions, is_pad + ) + total_kld, dim_wise_kld, mean_kld = kl_divergence(mu, logvar) + loss_dict = dict() + all_l1 = F.l1_loss(actions, a_hat, reduction="none") + l1 = (all_l1 * ~is_pad.unsqueeze(-1)).mean() + obs_loss = ((obs_preds.sigmoid() - obs_targets) ** 2).mean() + loss_dict["l1"] = l1 + loss_dict["kl"] = total_kld[0] + loss_dict["next_frame"] = obs_loss + loss_dict["loss"] = ( + loss_dict["l1"] + + loss_dict["kl"] * self.kl_weight + + loss_dict["next_frame"] * self.nextframe_weight + ) + return loss_dict + else: # inference time + image = normalize(image) + a_hat, _, (_, _), (obs_preds, obs_targets) = self.model( + qpos, image, env_state + ) # no action, sample from prior + + # next frame prediction + bs = a_hat.shape[0] + patch_size = 16 + # image_size = 224 + img_h, img_w = self.model.img_h, self.model.img_w + ph, pw = img_h // patch_size, img_w // patch_size + nf_pred = obs_preds.sigmoid().reshape( + shape=(bs, ph, pw, patch_size, patch_size, 3) + ) + nf_pred = nf_pred.permute(0, 5, 1, 3, 2, 4) + nf_pred = nf_pred.reshape(shape=(bs, 3, img_h, img_w)) + + # import matplotlib.pyplot as plt + # plt.imshow(nf_pred[0].cpu().numpy().transpose(1,2,0)) + # plt.savefig('tmp.png') + # plt.clf() + # plt.close() + # import pdb; pdb.set_trace() + # self.next_frames.append({'next_frame':nf_pred[0].cpu().numpy().transpose(1,2,0)}) + self.next_frame = nf_pred[0].cpu().numpy().transpose(1, 2, 0) + return a_hat + + def configure_optimizers(self): + return self.optimizer + + +class CNNMLPPolicy(nn.Module): + def __init__(self, args_override): + super().__init__() + model, optimizer = build_CNNMLP_model_and_optimizer(args_override) + self.model = model # decoder + self.optimizer = optimizer + + def __call__(self, qpos, image, actions=None, is_pad=None): + env_state = None # TODO + normalize = transforms.Normalize( + mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] + ) + image = normalize(image) + if actions is not None: # training time + actions = actions[:, 0] + a_hat = self.model(qpos, image, env_state, actions) + mse = F.mse_loss(actions, a_hat) + loss_dict = dict() + loss_dict["mse"] = mse + loss_dict["loss"] = loss_dict["mse"] + return loss_dict + else: # inference time + a_hat = self.model(qpos, image, env_state) # no action, sample from prior + return a_hat + + def configure_optimizers(self): + return self.optimizer + + +if __name__ == "__main__": + + + import json + config_path = '/attached/remote-home2/xhl/8_kaust_pj/RoboTwin/policy/ACT-DP-TP-Policy/checkpoints/put_bottles_dustbin/single_20_2_300_600/act_dp/policy_config.json' + config = json.load(open(config_path, 'r')) + + # policy = ACT_Flow_Matching(config) + + # qpos = torch.randn(1,3, 14).to('cuda') + # image = torch.rand(1, 3,1, 3, 480, 640).to('cuda') + # actions = torch.randn(1, 20, 14).to('cuda') + # is_pad = torch.zeros(1, 20).bool().to('cuda') + # loss_dict = policy(qpos, image) + # print(loss_dict) + + config['condition_type'] = "adaLN" + policy = ACTDiffusionPolicy(config) + + qpos = torch.randn(1,3, 14).to('cuda') + image = torch.rand(1, 3,1, 3, 480, 640).to('cuda') + actions = torch.randn(1, 20, 14).to('cuda') + is_pad = torch.zeros(1, 20).bool().to('cuda') + loss_dict = policy(qpos, image) + print(loss_dict) + + + # from cosmos_tokenizer.networks import TokenizerConfigs + # from cosmos_tokenizer.utils import ( + # get_filepaths, + # get_output_filepath, + # read_video, + # resize_video, + # write_video, + # ) + # # from cosmos_tokenizer.video_lib import CausalVideoTokenizer + # from cosmos_tokenizer.image_lib import ImageTokenizer + + # tokenizer_type = 'DV' + # spatial_compression = 16 + # temporal_compression = 8 + # mode = 'torch' + # temporal_window = 17 + # dtype = 'bfloat16' + # device = 'cuda' + # model_name = "Cosmos-Tokenizer-DV8x16x16" + # checkpoint_enc = f'Cosmos-Tokenizer/pretrained_ckpts/{model_name}/encoder.jit' + # checkpoint_dec = f'Cosmos-Tokenizer/pretrained_ckpts/{model_name}/decoder.jit' + # # load model + # tokenizer_config = TokenizerConfigs[tokenizer_type].value + # tokenizer_config.update(dict(spatial_compression=spatial_compression)) + # tokenizer_config.update(dict(temporal_compression=temporal_compression)) + # autoencoder = CausalVideoTokenizer( + # checkpoint= None, + # checkpoint_enc=checkpoint_enc, + # checkpoint_dec=checkpoint_dec, + # tokenizer_config=tokenizer_config, + # device=device, + # dtype=dtype, + # ) + # # T C HW + # input_tensor = read_video('Cosmos-Tokenizer/test_robot_data/sim_transfer_cube_scripted/episode_0.mp4') + # print(input_tensor.dtype) # uint8 255 + + # batch_video = np.array(input_tensor[49:98:3,::2,::2])[np.newaxis, ...] # B T H W C + # output_video = autoencoder(batch_video, temporal_window=temporal_window)[0] # T H W C np.unit8 255 + + # autoencoder.get_latent_codes(batch_video)# B C T' H' W' # B 6 T' H' W' + + # encoder = CausalVideoTokenizer(checkpoint_enc=checkpoint_enc) + # decoder = CausalVideoTokenizer(checkpoint_dec=checkpoint_dec) + # batch_video_tensor = torch.from_numpy(batch_video).cuda().permute(0,4,1,2,3) / 127.5 - 1 # B C T H W -1,1 + # print(batch_video_tensor.shape) + # output_index, output_code = encoder._enc_model(batch_video_tensor)[:-1] # Input B C T H W -1,1 + # print(output_index.shape, output_code.shape) + # h = decoder._dec_model.post_quant_conv(output_code) # input shape B 6 T' H' W', output shape B 16 T' H' W' + # reconstructed_tensor = decoder._dec_model.decoder(h).detach() # B C T H W -1,1 + # rec_video = tensor2numpy(reconstructed_tensor)[0] # T H W C + + # print('diff', (rec_video - output_video).mean()) + # print('rec diff', (rec_video - batch_video).mean()) + # print('out diff', (output_video - batch_video).mean()) + + # vis_path = 'vis_0.mp4' + # vis_video = np.concatenate([batch_video[0], output_video, rec_video], axis=2) # T H W*2 C + # media.write_video(vis_path, vis_video, fps=3) + + # model_name = "Cosmos-Tokenizer-DI16x16" + # encoder = ImageTokenizer(checkpoint_enc=f'Cosmos-Tokenizer/pretrained_ckpts/{model_name}/encoder.jit') + # decoder = ImageTokenizer(checkpoint_dec=f'Cosmos-Tokenizer/pretrained_ckpts/{model_name}/decoder.jit') + # input_tensor = torch.randn(32, 3, 480, 640).to('cuda').to(torch.bfloat16) # [B, C, T, H, W] + # (indices, codes) = encoder.encode(input_tensor) + # print(codes.shape) + # print(codes.max(), codes.min()) + # reconstructed_tensor = decoder.decode(indices) # input index + # print(reconstructed_tensor.shape) + # model_name = "Cosmos-Tokenizer-CV4x8x8" + # encoder = CausalVideoTokenizer( + # checkpoint_enc=f"Cosmos-Tokenizer/pretrained_ckpts/{model_name}/encoder.jit" + # ) + # decoder = CausalVideoTokenizer( + # checkpoint_dec=f"Cosmos-Tokenizer/pretrained_ckpts/{model_name}/decoder.jit" + # ) + # # B N C T H W + # input_tensor = ( + # -torch.ones(16, 5, 3, 21, 480, 640).to("cuda").to(torch.bfloat16) + # ) # [B, C, T, H, W] + # for view_idx in range(5): + # (codes,) = encoder._enc_model(input_tensor[:, view_idx])[:-1] # B 16 T' H' W' + # codes = codes.detach() + # print(codes.shape) + # print(codes.max(), codes.min()) diff --git a/ACT_DP_multitask/requirements.txt b/ACT_DP_multitask/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..bd53ea17411770a66ab8affd00ece502cde36d3c --- /dev/null +++ b/ACT_DP_multitask/requirements.txt @@ -0,0 +1,160 @@ +absl-py==2.3.0 +ackermann-msgs==2.0.2 +action-msgs==1.2.1 +action-tutorials-interfaces==0.20.5 +actionlib-msgs==4.8.0 +aiohappyeyeballs==2.6.1 +ament-cmake-test==1.3.11 +ament-package==0.14.0 +angles==1.15.0 +annotated-types==0.7.0 +argcomplete==3.6.2 +asttokens==3.0.0 +async-timeout==5.0.1 +attrs==25.3.0 +boto==2.49.0 +builtin-interfaces==1.2.1 +cachetools==5.5.2 +catkin-pkg +certifi==2022.12.7 +charset-normalizer==2.1.1 +click==8.2.1 +composition-interfaces==1.2.1 +control-msgs==4.8.0 +controller-manager==2.50.0 +controller-manager-msgs==2.50.0 +crcmod==1.7 +cv-bridge==3.2.1 +cycler==0.12.1 +decorator==5.2.1 +diagnostic-msgs==4.8.0 +docstring_parser==0.16 +einops==0.8.1 +etils==1.12.2 +example-interfaces==0.9.3 +executing==2.2.0 +fasteners==0.19 +filelock==3.13.1 +flatbuffers==25.2.10 +fonttools==4.58.1 +frozenlist==1.6.0 +fsspec==2024.6.1 +gast==0.6.0 +geometry-msgs==4.8.0 +grpcio==1.72.1 +hf-xet==1.1.2 +hjson==3.1.0 +idna==3.4 +image-geometry==3.2.1 +immutabledict==4.2.1 +importlib_resources==6.5.2 +interactive-markers==2.3.2 +keras==2.15.0 +kiwisolver==1.4.8 +laser-geometry==2.4.0 +libclang==18.1.1 +lifecycle-msgs==1.2.1 +logging-demo==0.20.5 +map-msgs==2.1.0 +Markdown==3.8 +MarkupSafe==3.0.2 +message-filters==4.3.7 +monotonic==1.6 +mpmath==1.3.0 +msgpack==1.1.0 +nav-msgs==4.8.0 +networkx==3.3 +ninja==1.11.1.4 +numpy==1.26.4 +nvidia-ml-py==12.575.51 +oauthlib==3.2.2 +opt_einsum==3.4.0 +osrf-pycommon==2.1.6 +packaging==25.0 +parso==0.8.4 +pcl-msgs==1.0.0 +pendulum-msgs==0.20.5 +pillow==10.2.0 +piper_msgs==0.0.0 +platformdirs==4.3.8 +propcache==0.3.1 +protobuf==4.21.12 +psutil==7.0.0 +ptyprocess==0.7.0 +pure_eval==0.2.3 +py-cpuinfo==9.0.0 +pyarrow==20.0.0 +pyasn1==0.6.1 +pycparser==2.22 +Pygments==2.19.1 +pyparsing==3.2.3 +pyrealsense2==2.55.1.6486 +python-qt-binding==1.1.2 +pytz==2025.2 +PyYAML==6.0.1 +qt-dotgraph==2.2.4 +qt-gui==2.2.4 +qt-gui-cpp==2.2.4 +qt-gui-py-common==2.2.4 +rcl-interfaces==1.2.1 +rclpy==3.3.16 +rcutils==5.1.6 +regex==2024.11.6 +resource-retriever==3.1.3 +retry-decorator==1.1.1 +rmw-dds-common==1.6.0 +ros2cli==0.18.12 +rosbag2-interfaces==0.15.14 +rosbag2-py==0.15.14 +rosgraph-msgs==1.2.1 +rosidl-adapter==3.1.6 +rosidl-cli==3.1.6 +rosidl-cmake==3.1.6 +rosidl-generator-c==3.1.6 +rosidl-generator-cpp==3.1.6 +rosidl-generator-py==0.14.4 +rosidl-parser==3.1.6 +rosidl-runtime-py==0.9.3 +rosidl-typesupport-c==2.0.2 +rosidl-typesupport-cpp==2.0.2 +rosidl-typesupport-fastrtps-c==2.2.2 +rosidl-typesupport-fastrtps-cpp==2.2.2 +rosidl-typesupport-introspection-c==3.1.6 +rosidl-typesupport-introspection-cpp==3.1.6 +rpyutils==0.2.1 +rqt-py-common==1.1.7 +safetensors==0.5.3 +sensor-msgs==4.8.0 +sentencepiece==0.2.0 +setproctitle==1.3.6 +setuptools==78.1.1 +shape-msgs==4.8.0 +six==1.17.0 +smmap==5.0.2 +statistics-msgs==1.2.1 +std-msgs==4.8.0 +std-srvs==4.8.0 +stereo-msgs==4.8.0 +tensorboard-data-server==0.7.2 +tensorflow-estimator==2.15.0 +tensorflow-io-gcs-filesystem==0.37.1 +termcolor==3.1.0 +tf2-geometry-msgs==0.25.12 +tf2-kdl==0.25.12 +tf2-msgs==0.25.12 +tf2-py==0.25.12 +toml==0.10.2 +tqdm==4.67.1 +traitlets==5.14.3 +trajectory-msgs==4.8.0 +turtlesim==1.4.2 +typeguard==2.13.3 +typing_extensions==4.13.2 +unique-identifier-msgs==2.2.1 +urllib3==1.26.13 +visualization-msgs==4.8.0 +wcwidth==0.2.13 +wheel==0.45.1 +wrapt==1.14.1 +xacro==2.0.13 +zipp==3.22.0 diff --git a/ACT_DP_multitask/t5_encoder.py b/ACT_DP_multitask/t5_encoder.py new file mode 100644 index 0000000000000000000000000000000000000000..f931abeff5edd00b2a15ee0c1c29e13ebdcf2e7c --- /dev/null +++ b/ACT_DP_multitask/t5_encoder.py @@ -0,0 +1,110 @@ +import torch +from transformers import AutoTokenizer, T5EncoderModel + +class T5Embedder: + # available_models = ["google/t5-v1_1-xxl"] + + def __init__( + self, + device, + from_pretrained=None, + *, + cache_dir=None, + hf_token=None, + use_text_preprocessing=True, + t5_model_kwargs=None, + torch_dtype=None, + use_offload_folder=None, + model_max_length=120, + local_files_only=False, + ): + # from_pretrained="google/t5-v1_1-xxl" # zijian + self.device = torch.device(device) + self.torch_dtype = torch_dtype or torch.bfloat16 + self.cache_dir = cache_dir + + if t5_model_kwargs is None: + t5_model_kwargs = { + "low_cpu_mem_usage": True, + "torch_dtype": self.torch_dtype, + } + + if use_offload_folder is not None: + t5_model_kwargs["offload_folder"] = use_offload_folder + t5_model_kwargs["device_map"] = { + "shared": self.device, + "encoder.embed_tokens": self.device, + "encoder.block.0": self.device, + "encoder.block.1": self.device, + "encoder.block.2": self.device, + "encoder.block.3": self.device, + "encoder.block.4": self.device, + "encoder.block.5": self.device, + "encoder.block.6": self.device, + "encoder.block.7": self.device, + "encoder.block.8": self.device, + "encoder.block.9": self.device, + "encoder.block.10": self.device, + "encoder.block.11": self.device, + "encoder.block.12": "disk", + "encoder.block.13": "disk", + "encoder.block.14": "disk", + "encoder.block.15": "disk", + "encoder.block.16": "disk", + "encoder.block.17": "disk", + "encoder.block.18": "disk", + "encoder.block.19": "disk", + "encoder.block.20": "disk", + "encoder.block.21": "disk", + "encoder.block.22": "disk", + "encoder.block.23": "disk", + "encoder.final_layer_norm": "disk", + "encoder.dropout": "disk", + } + else: + t5_model_kwargs["device_map"] = { + "shared": self.device, + "encoder": self.device, + } + + self.use_text_preprocessing = use_text_preprocessing + self.hf_token = hf_token + + # assert from_pretrained in self.available_models + self.tokenizer = AutoTokenizer.from_pretrained( + from_pretrained, + model_max_length=model_max_length, + cache_dir=cache_dir, + local_files_only=local_files_only, + ) + self.model = T5EncoderModel.from_pretrained( + from_pretrained, + cache_dir=cache_dir, + local_files_only=local_files_only, + **t5_model_kwargs, + ).eval() + self.model_max_length = model_max_length + + def get_text_embeddings(self, texts): + text_tokens_and_mask = self.tokenizer( + texts, + max_length=self.model_max_length, + padding="longest", + truncation=True, + return_attention_mask=True, + add_special_tokens=True, + return_tensors="pt", + ) + + input_ids = text_tokens_and_mask["input_ids"].to(self.device) + attention_mask = text_tokens_and_mask["attention_mask"].to(self.device) + with torch.no_grad(): + text_encoder_embs = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + )["last_hidden_state"].detach() + return text_encoder_embs, attention_mask + + +if __name__ == "__main__": + T5Embedder(from_pretrained="google/t5-v1_1-xxl", device='cuda:7') \ No newline at end of file diff --git a/ACT_DP_multitask/utils.py b/ACT_DP_multitask/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..a3b40ca589c4da1ad9f0c00d5371957cec5ca9b7 --- /dev/null +++ b/ACT_DP_multitask/utils.py @@ -0,0 +1,410 @@ +import torch +from torch.utils.data import Dataset, DataLoader +import numpy as np +import os +import cv2 +import matplotlib.pyplot as plt +import math +from torch.nn.modules.batchnorm import _BatchNorm +from collections import OrderedDict +from torch.optim.lr_scheduler import LambdaLR +import torch.nn as nn +from torch.nn import functional as F +import h5py +import fnmatch +from torchvision import transforms +import pickle +from tqdm import tqdm +_UINT8_MAX_F = float(torch.iinfo(torch.uint8).max) + +def plot_history(train_history, validation_history, num_epochs, ckpt_dir, seed): + # save training curves + for key in train_history[0]: + plot_path = os.path.join(ckpt_dir, f"train_val_{key}_seed_{seed}.png") + plt.figure() + train_values = [summary[key].item() for summary in train_history] + val_values = [summary[key].item() for summary in validation_history] + plt.plot( + np.linspace(0, num_epochs - 1, len(train_history)), + train_values, + label="train", + ) + plt.plot( + np.linspace(0, num_epochs - 1, len(validation_history)), + val_values, + label="validation", + ) + plt.tight_layout() + plt.legend() + plt.title(key) + plt.savefig(plot_path) + print(f"Saved plots to {ckpt_dir}") + + +def tensor2numpy(input_tensor: torch.Tensor, range_min: int = -1) -> np.ndarray: + """Converts tensor in [-1,1] to image(dtype=np.uint8) in range [0..255]. + + Args: + input_tensor: Input image tensor of Bx3xHxW layout, range [-1..1]. + Returns: + A numpy image of layout BxHxWx3, range [0..255], uint8 dtype. + """ + if range_min == -1: + input_tensor = (input_tensor.float() + 1.0) / 2.0 + ndim = input_tensor.ndim + output_image = input_tensor.clamp(0, 1).cpu().numpy() + output_image = output_image.transpose((0,) + tuple(range(2, ndim)) + (1,)) + return (output_image * _UINT8_MAX_F + 0.5).astype(np.uint8) + + +def kl_divergence(mu, logvar): + batch_size = mu.size(0) + assert batch_size != 0 + if mu.data.ndimension() == 4: + mu = mu.view(mu.size(0), mu.size(1)) + if logvar.data.ndimension() == 4: + logvar = logvar.view(logvar.size(0), logvar.size(1)) + + klds = -0.5 * (1 + logvar - mu.pow(2) - logvar.exp()) + total_kld = klds.sum(1).mean(0, True) + dimension_wise_kld = klds.mean(0) + mean_kld = klds.mean(1).mean(0, True) + + return total_kld, dimension_wise_kld, mean_kld + + +class RandomShiftsAug(nn.Module): + def __init__(self, pad_h, pad_w): + super().__init__() + self.pad_h = pad_h + self.pad_w = pad_w + print(f"RandomShiftsAug: pad_h {pad_h}, pad_w {pad_w}") + + def forward(self, x): + orignal_shape = x.shape + n, h, w = x.shape[0], x.shape[-2], x.shape[-1] # n,T,M,C,H,W + x = x.view(n, -1, h, w) # n,T*M*C,H,W + padding = ( + self.pad_w, + self.pad_w, + self.pad_h, + self.pad_h, + ) # left, right, top, bottom padding + x = F.pad(x, padding, mode="replicate") + + h_pad, w_pad = h + 2 * self.pad_h, w + 2 * self.pad_w + eps_h = 1.0 / h_pad + eps_w = 1.0 / w_pad + + arange_h = torch.linspace( + -1.0 + eps_h, 1.0 - eps_h, h_pad, device=x.device, dtype=x.dtype + )[:h] + arange_w = torch.linspace( + -1.0 + eps_w, 1.0 - eps_w, w_pad, device=x.device, dtype=x.dtype + )[:w] + + arange_h = arange_h.unsqueeze(1).repeat(1, w).unsqueeze(2) # h w 1 + arange_w = arange_w.unsqueeze(1).repeat(1, h).unsqueeze(2) # w h 1 + + # print(arange_h.shape, arange_w.shape) + base_grid = torch.cat([arange_w.transpose(1, 0), arange_h], dim=2) # [H, W, 2] + base_grid = base_grid.unsqueeze(0).repeat( + n, 1, 1, 1 + ) # Repeat for batch [B, H, W, 2] + + shift_h = torch.randint( + 0, 2 * self.pad_h + 1, size=(n, 1, 1, 1), device=x.device, dtype=x.dtype + ).float() + shift_w = torch.randint( + 0, 2 * self.pad_w + 1, size=(n, 1, 1, 1), device=x.device, dtype=x.dtype + ).float() + shift_h *= 2.0 / h_pad + shift_w *= 2.0 / w_pad + + grid = base_grid + torch.cat([shift_w, shift_h], dim=3) + x = F.grid_sample(x, grid, padding_mode="zeros", align_corners=False) + return x.view(orignal_shape) + + +def get_norm_stats(state, action): + all_qpos_data = torch.from_numpy(np.array(state)) + all_action_data = torch.from_numpy(np.array(action)) + # normalize action data + action_mean = all_action_data.mean(dim=[0], keepdim=True) + action_std = all_action_data.std(dim=[0], keepdim=True) + action_std = torch.clip(action_std, 1e-2, np.inf) # clipping + action_max = torch.amax(all_action_data, dim=[0], keepdim=True) + action_min = torch.amin(all_action_data, dim=[0], keepdim=True) + + # normalize qpos data + qpos_mean = all_qpos_data.mean(dim=[0], keepdim=True) + qpos_std = all_qpos_data.std(dim=[0], keepdim=True) + qpos_std = torch.clip(qpos_std, 1e-2, np.inf) # clipping + + stats = { + "action_mean": action_mean.numpy().squeeze(), + "action_std": action_std.numpy().squeeze(), + "action_max": action_max.numpy().squeeze(), + "action_min": action_min.numpy().squeeze(), + "qpos_mean": qpos_mean.numpy().squeeze(), + "qpos_std": qpos_std.numpy().squeeze(), + } + + return stats + +class EpisodicDataset_Unified_Multiview(Dataset): + def __init__(self, data_path_list, camera_names, chunk_size,stats, img_aug=False): + super(EpisodicDataset_Unified_Multiview).__init__() + self.data_path_list = data_path_list + self.camera_names = camera_names + self.chunk_size = chunk_size + self.norm_stats = stats + self.img_aug = img_aug + self.ColorJitter = transforms.ColorJitter( + brightness=0.2,contrast=0.2,saturation=0.2,hue=0.01) + def __len__(self): + return len(self.data_path_list) * 16 + def __getitem__(self, path_index): + # qpos = np.concatenate((root['/arm/jointStatePosition/masterLeft'][()], root['/arm/jointStatePosition/masterRight'][()]), axis=-1) + # actions = np.concatenate((root['/arm/jointStatePosition/puppetLeft'] + path_index = path_index % len(self.data_path_list) # ensure index is within bounds + example_path = self.data_path_list[path_index] + with h5py.File(example_path, 'r') as f: + action = f['observations']['qpos'][()] # jointStatePosition/master + qpos = f['action'][()] # jointStatePosition/puppet + + parent_path = os.path.dirname(example_path) + Instruction_path = os.path.join(parent_path, 'instructions') + # randomly sample instruction file + instruction_files = [f for f in os.listdir(Instruction_path) if fnmatch.fnmatch(f, '*.pt')] + instruction_file = os.path.join(Instruction_path, np.random.choice(instruction_files)) + instruction = torch.load(instruction_file, weights_only=False) # num_token * 4096 tensor + # randomly sample an episode inex + episode_len = action.shape[0] + index = np.random.randint(0, episode_len) + obs_qpos = qpos[index:index + 1] + + # stack images + with h5py.File(example_path, 'r') as f: + camera_list = [] + for camera_name in self.camera_names: + cam_jpeg_code = f['observations']['images'][camera_name][index] + cam_image = cv2.imdecode(np.frombuffer(cam_jpeg_code, np.uint8), cv2.IMREAD_COLOR) # rgb + camera_list.append(cam_image) + obs_img = np.stack(camera_list, axis=0) # shape: (N_views, H, W, C) + original_action_shape = (self.chunk_size, *action.shape[1:]) + gt_action = np.zeros(original_action_shape) + action_len = min(self.chunk_size, episode_len - index) + gt_action[:action_len] = action[ + index : index + action_len + ] + is_pad = np.zeros(self.chunk_size) + is_pad[action_len:] = 1 + + # construct observations tensor type + image_data = torch.from_numpy(obs_img).unsqueeze(0).float() # (history_steps+1, 1, H, W, 3) add num_view + image_data = image_data.permute(0, 1, 4, 2, 3) # (1, N_views, 3, H, W) + qpos_data = torch.from_numpy(obs_qpos).float()# .unsqueeze(0) # (1, 14) + action_data = torch.from_numpy(gt_action).float() # (chunk_size, 14) + is_pad = torch.from_numpy(is_pad).bool() # (chunk_size, ) + instruction_data = instruction.mean(0).float() # (4096) + # normalize image and qpos + image_data = image_data / 255.0 # Normalize to [0, 1] T N C H W + qpos_data = (qpos_data - self.norm_stats["qpos_mean"]) / self.norm_stats[ + "qpos_std" + ] + if self.img_aug and random.random() < 0.25: + for t in range(image_data.shape[0]): + for i in range(image_data.shape[1]): + image_data[t, i] =self.ColorJitter(image_data[t, i]) + return image_data, qpos_data.float(), action_data, is_pad, instruction_data + +def load_data_unified( + data_dir='/home/algo/anyrobot/Anyrobot_RoboTwin_Challenge/policy/RDT/training_data/rdt_real_multitask', + camera_names=['cam_high', 'cam_left_wrist', 'cam_right_wrist'], + batch_size_train=32, + chunk_size=100, + img_aug=False, + fintune=False, +): + + HDF5_file_path = [] + for root, _, files in os.walk(data_dir, followlinks=True): + for filename in files: + if filename.endswith('.hdf5'): + HDF5_file_path.append(os.path.join(root, filename)) + print(f"Loading data from {data_dir} with {len(HDF5_file_path)} episodes and batch size {batch_size_train}") + + state_list = [] + action_list = [] + # qpos = np.concatenate((root['/arm/jointStatePosition/masterLeft'][()], root['/arm/jointStatePosition/masterRight'][()]), axis=-1) + # actions = np.concatenate((root['/arm/jointStatePosition/puppetLeft'] + for p in tqdm(HDF5_file_path, desc="Data statics collection"): + with h5py.File(p, 'r') as f: + action = f['observations']['qpos'][()] + qpos = f['action'][()] + state_list.append(qpos) + action_list.append(action) + states = np.concatenate(state_list, axis=0) + actions = np.concatenate(action_list, axis=0) + + if fintune: + # load stats from pretrain path 1590 episodes + pretrain_stats_path = '/home/algo/anyrobot/Anyrobot_RoboTwin_Challenge/policy/ACT_DP_multitask/checkpoints/real_pretrain_50_2000/act_dp/dataset_stats.pkl' + with open(pretrain_stats_path, 'rb') as f: + stats = pickle.load(f) + print(f"Loaded stats from {pretrain_stats_path}") + else: + stats = get_norm_stats(states, actions) + + for key, value in stats.items(): + print(f"{key}: {value}") + + train_dataset = EpisodicDataset_Unified_Multiview( + data_path_list=HDF5_file_path, + camera_names=camera_names, + chunk_size=chunk_size, + stats=stats, + img_aug=img_aug, + ) + + traind_data_loader = DataLoader( + train_dataset, + batch_size=batch_size_train, + shuffle=True, + num_workers=8, + pin_memory=True, + ) + + return traind_data_loader,None,None, stats + +def compute_dict_mean(epoch_dicts): + result = {k: None for k in epoch_dicts[0]} + num_items = len(epoch_dicts) + for k in result: + value_sum = 0 + for epoch_dict in epoch_dicts: + value_sum += epoch_dict[k] + result[k] = value_sum / num_items + return result + + +def detach_dict(d): + new_d = dict() + for k, v in d.items(): + new_d[k] = v.detach() + return new_d + + +# def set_seed(seed): +# torch.manual_seed(seed) +# np.random.seed(seed) +import random + + +def set_seed(seed): + random.seed(seed) # + np.random.seed(seed) # + torch.manual_seed(seed) # + torch.cuda.manual_seed(seed) # + torch.cuda.manual_seed_all(seed) # + +def get_cosine_schedule_with_warmup( + optimizer, num_warmup_steps, num_training_steps, num_cycles=0.5, last_epoch=-1 +): + """ + Create a schedule with a learning rate that decreases following the values of the cosine function between the + initial lr set in the optimizer to 0, after a warmup period during which it increases linearly between 0 and the + initial lr set in the optimizer. + + Args: + optimizer ([~torch.optim.Optimizer]): + The optimizer for which to schedule the learning rate. + num_warmup_steps (int): + The number of steps for the warmup phase. + num_training_steps (int): + The total number of training steps. + num_cycles (float, *optional*, defaults to 0.5): + The number of waves in the cosine schedule (the defaults is to just decrease from the max value to 0 + following a half-cosine). + last_epoch (int, *optional*, defaults to -1): + The index of the last epoch when resuming training. + + Return: + torch.optim.lr_scheduler.LambdaLR with the appropriate schedule. + """ + + def lr_lambda(current_step): + if current_step < num_warmup_steps: + return float(current_step) / float(max(1, num_warmup_steps)) + progress = float(current_step - num_warmup_steps) / float( + max(1, num_training_steps - num_warmup_steps) + ) + return max( + 0.0, 0.5 * (1.0 + math.cos(math.pi * float(num_cycles) * 2.0 * progress)) + ) + + return LambdaLR(optimizer, lr_lambda, last_epoch) + + +def get_constant_schedule(optimizer, last_epoch: int = -1) -> LambdaLR: + """ + Create a schedule with a constant learning rate, using the learning rate set in optimizer. + + Args: + optimizer ([`~torch.optim.Optimizer`]): + The optimizer for which to schedule the learning rate. + last_epoch (`int`, *optional*, defaults to -1): + The index of the last epoch when resuming training. + + Return: + `torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule. + """ + return LambdaLR(optimizer, lambda _: 1, last_epoch=last_epoch) + + +def normalize_data(action_data, stats, norm_type, data_type="action"): + + if norm_type == "minmax": + action_max = torch.from_numpy(stats[data_type + "_max"]).float().to(action_data.device) + action_min = torch.from_numpy(stats[data_type + "_min"]).float().to(action_data.device) + action_data = (action_data - action_min) / (action_max - action_min) * 2 - 1 + elif norm_type == "gaussian": + action_mean = torch.from_numpy(stats[data_type + "_mean"]).float().to(action_data.device) + action_std = torch.from_numpy(stats[data_type + "_std"]).float().to(action_data.device) + action_data = (action_data - action_mean) / action_std + return action_data + + +def convert_weight(obj): + newmodel = OrderedDict() + for k, v in obj.items(): + if k.startswith("module."): + newmodel[k[7:]] = v + else: + newmodel[k] = v + return newmodel + + +if __name__ == "__main__": + train_dataloader,_,_,stats = load_data_unified( + data_dir='/home/algo/anyrobot/Anyrobot_RoboTwin_Challenge/policy/RDT/training_data/rdt_real_multitask', + camera_names=['cam_high', 'cam_left_wrist', 'cam_right_wrist'], + batch_size_train=32, + chunk_size=100, + img_aug=True, + ) + + for i, (image_data, qpos_data, action_data, is_pad, instruction_data) in enumerate( + tqdm(train_dataloader, desc="Data loading") + ): + if i == 0: + print(f"Batch {i}:") + print(f"Image data shape: {image_data.shape} {image_data.max()}") + print(f"Qpos data shape: {qpos_data.shape} {qpos_data.max()}" ) + print(f"Action data shape: {action_data.shape} {action_data.max()}") + print(f"Is pad shape: {is_pad.shape}") + print(f"Instruction data shape: {instruction_data.shape}") + + continue + \ No newline at end of file