haiphamcse commited on
Commit
9855f47
·
verified ·
1 Parent(s): 3aee1e1

Upload folder using huggingface_hub

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +14 -0
  2. perception_models/.gitignore +10 -0
  3. perception_models/CODE_OF_CONDUCT.md +80 -0
  4. perception_models/CONTRIBUTING.md +31 -0
  5. perception_models/LEGRAD_PE_USAGE.md +72 -0
  6. perception_models/LICENSE.PE +201 -0
  7. perception_models/LICENSE.PLM +124 -0
  8. perception_models/README.md +408 -0
  9. perception_models/__pycache__/legrad_pe_audio.cpython-310.pyc +0 -0
  10. perception_models/__pycache__/legrad_pe_audio.cpython-313.pyc +0 -0
  11. perception_models/__pycache__/legrad_pe_image.cpython-312.pyc +0 -0
  12. perception_models/__pycache__/legrad_pe_image.cpython-313.pyc +0 -0
  13. perception_models/apps/detection/DETA_pe/README.md +53 -0
  14. perception_models/apps/detection/DETA_pe/datasets/__init__.py +37 -0
  15. perception_models/apps/detection/DETA_pe/datasets/coco.py +345 -0
  16. perception_models/apps/detection/DETA_pe/datasets/coco_eval.py +265 -0
  17. perception_models/apps/detection/DETA_pe/datasets/coco_panoptic.py +107 -0
  18. perception_models/apps/detection/DETA_pe/datasets/data_prefetcher.py +70 -0
  19. perception_models/apps/detection/DETA_pe/datasets/objects365.py +54 -0
  20. perception_models/apps/detection/DETA_pe/datasets/panoptic_eval.py +52 -0
  21. perception_models/apps/detection/DETA_pe/datasets/samplers.py +348 -0
  22. perception_models/apps/detection/DETA_pe/datasets/torchvision_datasets/__init__.py +7 -0
  23. perception_models/apps/detection/DETA_pe/datasets/torchvision_datasets/coco.py +84 -0
  24. perception_models/apps/detection/DETA_pe/datasets/transforms.py +327 -0
  25. perception_models/apps/detection/DETA_pe/engine.py +303 -0
  26. perception_models/apps/detection/DETA_pe/engine_tta.py +239 -0
  27. perception_models/apps/detection/DETA_pe/main.py +754 -0
  28. perception_models/apps/detection/DETA_pe/models/__init__.py +15 -0
  29. perception_models/apps/detection/DETA_pe/models/assigner.py +378 -0
  30. perception_models/apps/detection/DETA_pe/models/backbone.py +235 -0
  31. perception_models/apps/detection/DETA_pe/models/deformable_detr.py +776 -0
  32. perception_models/apps/detection/DETA_pe/models/deformable_transformer.py +451 -0
  33. perception_models/apps/detection/DETA_pe/models/matcher.py +102 -0
  34. perception_models/apps/detection/DETA_pe/models/ops/functions/__init__.py +9 -0
  35. perception_models/apps/detection/DETA_pe/models/ops/functions/ms_deform_attn_func.py +106 -0
  36. perception_models/apps/detection/DETA_pe/models/ops/make.sh +10 -0
  37. perception_models/apps/detection/DETA_pe/models/ops/modules/__init__.py +9 -0
  38. perception_models/apps/detection/DETA_pe/models/ops/modules/ms_deform_attn.py +161 -0
  39. perception_models/apps/detection/DETA_pe/models/ops/setup.py +71 -0
  40. perception_models/apps/detection/DETA_pe/models/ops/src/cpu/ms_deform_attn_cpu.cpp +41 -0
  41. perception_models/apps/detection/DETA_pe/models/ops/src/cpu/ms_deform_attn_cpu.h +33 -0
  42. perception_models/apps/detection/DETA_pe/models/ops/src/cuda/ms_deform_attn_cuda.cu +153 -0
  43. perception_models/apps/detection/DETA_pe/models/ops/src/cuda/ms_deform_attn_cuda.h +30 -0
  44. perception_models/apps/detection/DETA_pe/models/ops/src/cuda/ms_deform_im2col_cuda.cuh +1327 -0
  45. perception_models/apps/detection/DETA_pe/models/ops/src/ms_deform_attn.h +62 -0
  46. perception_models/apps/detection/DETA_pe/models/ops/src/vision.cpp +16 -0
  47. perception_models/apps/detection/DETA_pe/models/ops/test.py +89 -0
  48. perception_models/apps/detection/DETA_pe/models/pev1.py +686 -0
  49. perception_models/apps/detection/DETA_pe/models/position_encoding.py +97 -0
  50. perception_models/apps/detection/DETA_pe/models/segmentation.py +369 -0
.gitattributes CHANGED
@@ -33,3 +33,17 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ perception_models/apps/pe/docs/assets/dog.mp4 filter=lfs diff=lfs merge=lfs -text
37
+ perception_models/apps/pe/docs/assets/dog.png filter=lfs diff=lfs merge=lfs -text
38
+ perception_models/apps/pe/docs/assets/office.mp4 filter=lfs diff=lfs merge=lfs -text
39
+ perception_models/apps/pe/docs/assets/office.wav filter=lfs diff=lfs merge=lfs -text
40
+ perception_models/apps/pe/docs/assets/pikachu.webp filter=lfs diff=lfs merge=lfs -text
41
+ perception_models/apps/pe/docs/assets/shark.png filter=lfs diff=lfs merge=lfs -text
42
+ perception_models/apps/pe/docs/assets/spatial_correspondence.png filter=lfs diff=lfs merge=lfs -text
43
+ perception_models/apps/pe/docs/assets/spatial_features.png filter=lfs diff=lfs merge=lfs -text
44
+ perception_models/apps/pe/docs/assets/teaser.png filter=lfs diff=lfs merge=lfs -text
45
+ perception_models/apps/pe/docs/assets/train.mp4 filter=lfs diff=lfs merge=lfs -text
46
+ perception_models/apps/pe/docs/assets/train.wav filter=lfs diff=lfs merge=lfs -text
47
+ perception_models/apps/plm/docs/plm_main_fig.png filter=lfs diff=lfs merge=lfs -text
48
+ perception_models/core/tests/Rock-climbing-Canada-1920x1147.jpg filter=lfs diff=lfs merge=lfs -text
49
+ perception_models/core/tests/selfie_cathedral_peak.jpg filter=lfs diff=lfs merge=lfs -text
perception_models/.gitignore ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ *.pyc
2
+ .vscode
3
+ *.ipynb
4
+ slurm-*.out
5
+ wandb
6
+ data/*
7
+ data-gym-cache/*
8
+ torchinductor_*/*
9
+ tmp*/*
10
+ apps/plm/dummy_datasets
perception_models/CODE_OF_CONDUCT.md ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Code of Conduct
2
+
3
+ ## Our Pledge
4
+
5
+ In the interest of fostering an open and welcoming environment, we as
6
+ contributors and maintainers pledge to make participation in our project and
7
+ our community a harassment-free experience for everyone, regardless of age, body
8
+ size, disability, ethnicity, sex characteristics, gender identity and expression,
9
+ level of experience, education, socio-economic status, nationality, personal
10
+ appearance, race, religion, or sexual identity and orientation.
11
+
12
+ ## Our Standards
13
+
14
+ Examples of behavior that contributes to creating a positive environment
15
+ include:
16
+
17
+ * Using welcoming and inclusive language
18
+ * Being respectful of differing viewpoints and experiences
19
+ * Gracefully accepting constructive criticism
20
+ * Focusing on what is best for the community
21
+ * Showing empathy towards other community members
22
+
23
+ Examples of unacceptable behavior by participants include:
24
+
25
+ * The use of sexualized language or imagery and unwelcome sexual attention or
26
+ advances
27
+ * Trolling, insulting/derogatory comments, and personal or political attacks
28
+ * Public or private harassment
29
+ * Publishing others' private information, such as a physical or electronic
30
+ address, without explicit permission
31
+ * Other conduct which could reasonably be considered inappropriate in a
32
+ professional setting
33
+
34
+ ## Our Responsibilities
35
+
36
+ Project maintainers are responsible for clarifying the standards of acceptable
37
+ behavior and are expected to take appropriate and fair corrective action in
38
+ response to any instances of unacceptable behavior.
39
+
40
+ Project maintainers have the right and responsibility to remove, edit, or
41
+ reject comments, commits, code, wiki edits, issues, and other contributions
42
+ that are not aligned to this Code of Conduct, or to ban temporarily or
43
+ permanently any contributor for other behaviors that they deem inappropriate,
44
+ threatening, offensive, or harmful.
45
+
46
+ ## Scope
47
+
48
+ This Code of Conduct applies within all project spaces, and it also applies when
49
+ an individual is representing the project or its community in public spaces.
50
+ Examples of representing a project or community include using an official
51
+ project e-mail address, posting via an official social media account, or acting
52
+ as an appointed representative at an online or offline event. Representation of
53
+ a project may be further defined and clarified by project maintainers.
54
+
55
+ This Code of Conduct also applies outside the project spaces when there is a
56
+ reasonable belief that an individual's behavior may have a negative impact on
57
+ the project or its community.
58
+
59
+ ## Enforcement
60
+
61
+ Instances of abusive, harassing, or otherwise unacceptable behavior may be
62
+ reported by contacting the project team at <opensource-conduct@fb.com>. All
63
+ complaints will be reviewed and investigated and will result in a response that
64
+ is deemed necessary and appropriate to the circumstances. The project team is
65
+ obligated to maintain confidentiality with regard to the reporter of an incident.
66
+ Further details of specific enforcement policies may be posted separately.
67
+
68
+ Project maintainers who do not follow or enforce the Code of Conduct in good
69
+ faith may face temporary or permanent repercussions as determined by other
70
+ members of the project's leadership.
71
+
72
+ ## Attribution
73
+
74
+ This Code of Conduct is adapted from the [Contributor Covenant][homepage], version 1.4,
75
+ available at https://www.contributor-covenant.org/version/1/4/code-of-conduct.html
76
+
77
+ [homepage]: https://www.contributor-covenant.org
78
+
79
+ For answers to common questions about this code of conduct, see
80
+ https://www.contributor-covenant.org/faq
perception_models/CONTRIBUTING.md ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Contributing to Perception Models
2
+ We want to make contributing to this project as easy and transparent as
3
+ possible.
4
+
5
+ ## Pull Requests
6
+ We actively welcome your pull requests.
7
+
8
+ 1. Fork the repo and create your branch from `main`.
9
+ 2. If you've added code that should be tested, add tests.
10
+ 3. If you've changed APIs, update the documentation.
11
+ 4. Ensure the test suite passes.
12
+ 5. Make sure your code lints.
13
+ 6. If you haven't already, complete the Contributor License Agreement ("CLA").
14
+
15
+ ## Contributor License Agreement ("CLA")
16
+ In order to accept your pull request, we need you to submit a CLA. You only need
17
+ to do this once to work on any of Facebook's open source projects.
18
+
19
+ Complete your CLA here: <https://code.facebook.com/cla>
20
+
21
+ ## Issues
22
+ We use GitHub issues to track public bugs. Please ensure your description is
23
+ clear and has sufficient instructions to be able to reproduce the issue.
24
+
25
+ Facebook has a [bounty program](https://www.facebook.com/whitehat/) for the safe
26
+ disclosure of security bugs. In those cases, please go through the process
27
+ outlined on that page and do not file a public issue.
28
+
29
+ ## License
30
+ By contributing to mae, you agree that your contributions will be licensed
31
+ under the LICENSE file in the root directory of this source tree.
perception_models/LEGRAD_PE_USAGE.md ADDED
@@ -0,0 +1,72 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # LeGrad + PE Perception Encoder Notebook Usage
2
+
3
+ This repository includes a notebook `legrad_perception_encoder.ipynb` that demonstrates how to run **LeGrad** explanations on the PE CoCa-style vision encoder.
4
+
5
+ ## 1. Environment and installation
6
+
7
+ - **Install this repo** (from the repo root):
8
+
9
+ ```bash
10
+ pip install -e .
11
+ ```
12
+
13
+ - **Install LeGrad** (if not already installed):
14
+
15
+ ```bash
16
+ pip install legrad
17
+ ```
18
+
19
+ Make sure you have a working CUDA‑enabled PyTorch environment.
20
+
21
+ ## 2. Open the notebook
22
+
23
+ From the repo root:
24
+
25
+ ```bash
26
+ cd xai/perception_models
27
+ jupyter lab legrad_perception_encoder.ipynb
28
+ ```
29
+
30
+ ## 3. What the notebook does
31
+
32
+ The notebook shows how to:
33
+
34
+ 1. Load a PE CoCa‑style vision encoder:
35
+ - Uses `pe.CLIP.from_config("PE-Core-B16-224", pretrained=True)` and moves the model to CUDA.
36
+ 2. Wrap the model with LeGrad:
37
+ - `LeWrapper` lives in `core/legrad_pe.py`.
38
+ - It hooks PE residual blocks and attention pooling so gradients can be used to build visual explanations.
39
+ 3. Prepare inputs:
40
+ - Build an image transform with `transforms.get_image_transform(model.image_size)`.
41
+ - Tokenize text prompts with `transforms.get_text_tokenizer(model.context_length)`.
42
+ 4. Run LeGrad:
43
+ - **Multi‑layer explanation**:
44
+ - `heatmap = wrapped_model.compute_legrad_coca(text_emb, image=image_tensor)`
45
+ - **Single‑layer explanation**:
46
+ - `heatmap = wrapped_model.compute_legrad_coca_one_layer(text_emb, image=image_tensor, layer_idx=-1)`
47
+ 5. Visualize:
48
+ - Convert the `heatmap` to numpy and use `legrad.visualize` (or standard plotting) to overlay it on the image.
49
+
50
+ ## 4. Minimal code sketch (inside the notebook)
51
+
52
+ The core usage pattern is:
53
+
54
+ ```python
55
+ import core.vision_encoder.pe as pe
56
+ import core.vision_encoder.transforms as transforms
57
+ from core.legrad_pe import LeWrapper
58
+
59
+ model = pe.CLIP.from_config("PE-Core-B16-224", pretrained=True).cuda()
60
+ preprocess = transforms.get_image_transform(model.image_size)
61
+ tokenizer = transforms.get_text_tokenizer(model.context_length)
62
+
63
+ wrapped_model = LeWrapper(model, layer_index=-2)
64
+ ```
65
+
66
+ You can then:
67
+
68
+ - Preprocess an input image with `preprocess`,
69
+ - Tokenize prompts with `tokenizer`,
70
+ - Encode text/image, and
71
+ - Call one of the `compute_legrad_*` methods to obtain a heatmap for visualization.
72
+
perception_models/LICENSE.PE ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Apache License
2
+ Version 2.0, January 2004
3
+ http://www.apache.org/licenses/
4
+
5
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6
+
7
+ 1. Definitions.
8
+
9
+ "License" shall mean the terms and conditions for use, reproduction,
10
+ and distribution as defined by Sections 1 through 9 of this document.
11
+
12
+ "Licensor" shall mean the copyright owner or entity authorized by
13
+ the copyright owner that is granting the License.
14
+
15
+ "Legal Entity" shall mean the union of the acting entity and all
16
+ other entities that control, are controlled by, or are under common
17
+ control with that entity. For the purposes of this definition,
18
+ "control" means (i) the power, direct or indirect, to cause the
19
+ direction or management of such entity, whether by contract or
20
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
21
+ outstanding shares, or (iii) beneficial ownership of such entity.
22
+
23
+ "You" (or "Your") shall mean an individual or Legal Entity
24
+ exercising permissions granted by this License.
25
+
26
+ "Source" form shall mean the preferred form for making modifications,
27
+ including but not limited to software source code, documentation
28
+ source, and configuration files.
29
+
30
+ "Object" form shall mean any form resulting from mechanical
31
+ transformation or translation of a Source form, including but
32
+ not limited to compiled object code, generated documentation,
33
+ and conversions to other media types.
34
+
35
+ "Work" shall mean the work of authorship, whether in Source or
36
+ Object form, made available under the License, as indicated by a
37
+ copyright notice that is included in or attached to the work
38
+ (an example is provided in the Appendix below).
39
+
40
+ "Derivative Works" shall mean any work, whether in Source or Object
41
+ form, that is based on (or derived from) the Work and for which the
42
+ editorial revisions, annotations, elaborations, or other modifications
43
+ represent, as a whole, an original work of authorship. For the purposes
44
+ of this License, Derivative Works shall not include works that remain
45
+ separable from, or merely link (or bind by name) to the interfaces of,
46
+ the Work and Derivative Works thereof.
47
+
48
+ "Contribution" shall mean any work of authorship, including
49
+ the original version of the Work and any modifications or additions
50
+ to that Work or Derivative Works thereof, that is intentionally
51
+ submitted to Licensor for inclusion in the Work by the copyright owner
52
+ or by an individual or Legal Entity authorized to submit on behalf of
53
+ the copyright owner. For the purposes of this definition, "submitted"
54
+ means any form of electronic, verbal, or written communication sent
55
+ to the Licensor or its representatives, including but not limited to
56
+ communication on electronic mailing lists, source code control systems,
57
+ and issue tracking systems that are managed by, or on behalf of, the
58
+ Licensor for the purpose of discussing and improving the Work, but
59
+ excluding communication that is conspicuously marked or otherwise
60
+ designated in writing by the copyright owner as "Not a Contribution."
61
+
62
+ "Contributor" shall mean Licensor and any individual or Legal Entity
63
+ on behalf of whom a Contribution has been received by Licensor and
64
+ subsequently incorporated within the Work.
65
+
66
+ 2. Grant of Copyright License. Subject to the terms and conditions of
67
+ this License, each Contributor hereby grants to You a perpetual,
68
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69
+ copyright license to reproduce, prepare Derivative Works of,
70
+ publicly display, publicly perform, sublicense, and distribute the
71
+ Work and such Derivative Works in Source or Object form.
72
+
73
+ 3. Grant of Patent License. Subject to the terms and conditions of
74
+ this License, each Contributor hereby grants to You a perpetual,
75
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76
+ (except as stated in this section) patent license to make, have made,
77
+ use, offer to sell, sell, import, and otherwise transfer the Work,
78
+ where such license applies only to those patent claims licensable
79
+ by such Contributor that are necessarily infringed by their
80
+ Contribution(s) alone or by combination of their Contribution(s)
81
+ with the Work to which such Contribution(s) was submitted. If You
82
+ institute patent litigation against any entity (including a
83
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
84
+ or a Contribution incorporated within the Work constitutes direct
85
+ or contributory patent infringement, then any patent licenses
86
+ granted to You under this License for that Work shall terminate
87
+ as of the date such litigation is filed.
88
+
89
+ 4. Redistribution. You may reproduce and distribute copies of the
90
+ Work or Derivative Works thereof in any medium, with or without
91
+ modifications, and in Source or Object form, provided that You
92
+ meet the following conditions:
93
+
94
+ (a) You must give any other recipients of the Work or
95
+ Derivative Works a copy of this License; and
96
+
97
+ (b) You must cause any modified files to carry prominent notices
98
+ stating that You changed the files; and
99
+
100
+ (c) You must retain, in the Source form of any Derivative Works
101
+ that You distribute, all copyright, patent, trademark, and
102
+ attribution notices from the Source form of the Work,
103
+ excluding those notices that do not pertain to any part of
104
+ the Derivative Works; and
105
+
106
+ (d) If the Work includes a "NOTICE" text file as part of its
107
+ distribution, then any Derivative Works that You distribute must
108
+ include a readable copy of the attribution notices contained
109
+ within such NOTICE file, excluding those notices that do not
110
+ pertain to any part of the Derivative Works, in at least one
111
+ of the following places: within a NOTICE text file distributed
112
+ as part of the Derivative Works; within the Source form or
113
+ documentation, if provided along with the Derivative Works; or,
114
+ within a display generated by the Derivative Works, if and
115
+ wherever such third-party notices normally appear. The contents
116
+ of the NOTICE file are for informational purposes only and
117
+ do not modify the License. You may add Your own attribution
118
+ notices within Derivative Works that You distribute, alongside
119
+ or as an addendum to the NOTICE text from the Work, provided
120
+ that such additional attribution notices cannot be construed
121
+ as modifying the License.
122
+
123
+ You may add Your own copyright statement to Your modifications and
124
+ may provide additional or different license terms and conditions
125
+ for use, reproduction, or distribution of Your modifications, or
126
+ for any such Derivative Works as a whole, provided Your use,
127
+ reproduction, and distribution of the Work otherwise complies with
128
+ the conditions stated in this License.
129
+
130
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
131
+ any Contribution intentionally submitted for inclusion in the Work
132
+ by You to the Licensor shall be under the terms and conditions of
133
+ this License, without any additional terms or conditions.
134
+ Notwithstanding the above, nothing herein shall supersede or modify
135
+ the terms of any separate license agreement you may have executed
136
+ with Licensor regarding such Contributions.
137
+
138
+ 6. Trademarks. This License does not grant permission to use the trade
139
+ names, trademarks, service marks, or product names of the Licensor,
140
+ except as required for reasonable and customary use in describing the
141
+ origin of the Work and reproducing the content of the NOTICE file.
142
+
143
+ 7. Disclaimer of Warranty. Unless required by applicable law or
144
+ agreed to in writing, Licensor provides the Work (and each
145
+ Contributor provides its Contributions) on an "AS IS" BASIS,
146
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147
+ implied, including, without limitation, any warranties or conditions
148
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149
+ PARTICULAR PURPOSE. You are solely responsible for determining the
150
+ appropriateness of using or redistributing the Work and assume any
151
+ risks associated with Your exercise of permissions under this License.
152
+
153
+ 8. Limitation of Liability. In no event and under no legal theory,
154
+ whether in tort (including negligence), contract, or otherwise,
155
+ unless required by applicable law (such as deliberate and grossly
156
+ negligent acts) or agreed to in writing, shall any Contributor be
157
+ liable to You for damages, including any direct, indirect, special,
158
+ incidental, or consequential damages of any character arising as a
159
+ result of this License or out of the use or inability to use the
160
+ Work (including but not limited to damages for loss of goodwill,
161
+ work stoppage, computer failure or malfunction, or any and all
162
+ other commercial damages or losses), even if such Contributor
163
+ has been advised of the possibility of such damages.
164
+
165
+ 9. Accepting Warranty or Additional Liability. While redistributing
166
+ the Work or Derivative Works thereof, You may choose to offer,
167
+ and charge a fee for, acceptance of support, warranty, indemnity,
168
+ or other liability obligations and/or rights consistent with this
169
+ License. However, in accepting such obligations, You may act only
170
+ on Your own behalf and on Your sole responsibility, not on behalf
171
+ of any other Contributor, and only if You agree to indemnify,
172
+ defend, and hold each Contributor harmless for any liability
173
+ incurred by, or claims asserted against, such Contributor by reason
174
+ of your accepting any such warranty or additional liability.
175
+
176
+ END OF TERMS AND CONDITIONS
177
+
178
+ APPENDIX: How to apply the Apache License to your work.
179
+
180
+ To apply the Apache License to your work, attach the following
181
+ boilerplate notice, with the fields enclosed by brackets "[]"
182
+ replaced with your own identifying information. (Don't include
183
+ the brackets!) The text should be enclosed in the appropriate
184
+ comment syntax for the file format. We also recommend that a
185
+ file or class name and description of purpose be included on the
186
+ same "printed page" as the copyright notice for easier
187
+ identification within third-party archives.
188
+
189
+ Copyright [yyyy] [name of copyright owner]
190
+
191
+ Licensed under the Apache License, Version 2.0 (the "License");
192
+ you may not use this file except in compliance with the License.
193
+ You may obtain a copy of the License at
194
+
195
+ http://www.apache.org/licenses/LICENSE-2.0
196
+
197
+ Unless required by applicable law or agreed to in writing, software
198
+ distributed under the License is distributed on an "AS IS" BASIS,
199
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200
+ See the License for the specific language governing permissions and
201
+ limitations under the License.
perception_models/LICENSE.PLM ADDED
@@ -0,0 +1,124 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FAIR Noncommercial Research License
2
+ Last Updated: 17 April 2025
3
+
4
+ “Acceptable Use Policy” means the FAIR Acceptable Use Policy, applicable to Research Materials, that is incorporated into this Agreement.
5
+
6
+ “Agreement” means the terms and conditions for use, reproduction, distribution and modification of the Research Materials set forth herein.
7
+
8
+
9
+ “Documentation” means the specifications, manuals and documentation accompanying
10
+ Research Materials distributed by Meta.
11
+
12
+
13
+ “Licensee” or “you” means you, or your employer or any other person or entity (if you are entering into this Agreement on such person or entity’s behalf), of the age required under applicable laws, rules or regulations to provide legal consent and that has legal authority to bind your employer or such other person or entity if you are entering in this Agreement on their behalf.
14
+
15
+
16
+ “Meta” or “we” means Meta Platforms Ireland Limited (if you are located in or, if you are an entity, your principal place of business is in the EEA or Switzerland) and Meta Platforms, Inc. (if you are located outside of the EEA or Switzerland).
17
+
18
+ “Noncommercial Research Uses” means noncommercial research use cases related to research, development, education, processing, or analysis and in each case, is not primarily intended for commercial advantage or monetary compensation to you or others.
19
+
20
+ “Research Materials” means, collectively, Documentation and the models, software and algorithms, including machine-learning model code, trained model weights, inference-enabling code, training-enabling code, fine-tuning enabling code, demonstration materials and other elements of the foregoing distributed by Meta and made available under this Agreement.
21
+
22
+ By clicking “I Accept” below or by using or distributing any portion or element of the Research Materials, you agree to be bound by this Agreement.
23
+
24
+
25
+ 1. License Rights and Redistribution.
26
+
27
+
28
+ a. Grant of Rights. You are granted a non-exclusive, worldwide, non-transferable and royalty-free limited license under Meta’s intellectual property or other rights owned by Meta embodied in the Research Materials to use, reproduce, distribute, copy, create derivative works of, and make modifications to the Research Materials.
29
+
30
+ b. Redistribution and Use.
31
+ i. You will not use the Research Materials or any outputs or results of the Research Materials in connection with any commercial uses or for any uses other than Noncommercial Research Uses;
32
+
33
+
34
+ ii. Distribution of Research Materials, and any derivative works thereof, are subject to the terms of this Agreement. If you distribute or make the Research Materials, or any derivative works thereof, available to a third party, you may only do so under the terms of this Agreement. You shall also provide a copy of this Agreement to such third party.
35
+
36
+
37
+ iii. If you submit for publication the results of research you perform on, using, or otherwise in connection with Research Materials, you must acknowledge the use of Research Materials in your publication.
38
+
39
+
40
+ iv. Your use of the Research Materials must comply with applicable laws and regulations (including Trade Control Laws) and adhere to the FAIR Acceptable Use Policy, which is hereby incorporated by reference into this Agreement.
41
+ 2. User Support. Your Noncommercial Research Use of the Research Materials is done at your own discretion; Meta does not process any information nor provide any service in relation to such use. Meta is under no obligation to provide any support services for the Research Materials. Any support provided is “as is”, “with all faults”, and without warranty of any kind.
42
+
43
+
44
+ 3. Disclaimer of Warranty. UNLESS REQUIRED BY APPLICABLE LAW, THE RESEARCH MATERIALS AND ANY OUTPUT AND RESULTS THEREFROM ARE PROVIDED ON AN “AS IS” BASIS, WITHOUT WARRANTIES OF ANY KIND, AND META DISCLAIMS ALL WARRANTIES OF ANY KIND, BOTH EXPRESS AND IMPLIED, INCLUDING, WITHOUT LIMITATION, ANY WARRANTIES OF TITLE, NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. YOU ARE SOLELY RESPONSIBLE FOR DETERMINING THE APPROPRIATENESS OF USING OR REDISTRIBUTING THE RESEARCH MATERIALS AND ASSUME ANY RISKS ASSOCIATED WITH YOUR USE OF THE RESEARCH MATERIALS AND ANY OUTPUT AND RESULTS.
45
+
46
+ 4. Limitation of Liability. IN NO EVENT WILL META OR ITS AFFILIATES BE LIABLE UNDER ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, TORT, NEGLIGENCE, PRODUCTS LIABILITY, OR OTHERWISE, ARISING OUT OF THIS AGREEMENT, FOR ANY LOST PROFITS OR ANY DIRECT OR INDIRECT, SPECIAL, CONSEQUENTIAL, INCIDENTAL, EXEMPLARY OR PUNITIVE DAMAGES, EVEN IF META OR ITS AFFILIATES HAVE BEEN ADVISED OF THE POSSIBILITY OF ANY OF THE FOREGOING.
47
+
48
+ 5. Intellectual Property.
49
+
50
+
51
+ a. Subject to Meta’s ownership of Research Materials and derivatives made by or for Meta, with respect to any derivative works and modifications of the Research Materials that are made by you, as between you and Meta, you are and will be the owner of such derivative works and modifications.
52
+
53
+ b. If you institute litigation or other proceedings against Meta or any entity (including a cross-claim or counterclaim in a lawsuit) alleging that the Research Materials, outputs or results, or any portion of any of the foregoing, constitutes infringement of intellectual property or other rights owned or licensable by you, then any licenses granted to you under this Agreement shall terminate as of the date such litigation or claim is filed or instituted. You will indemnify and hold harmless Meta from and against any claim by any third party arising out of or related to your use or distribution of the Research Materials.
54
+
55
+ 6. Term and Termination. The term of this Agreement will commence upon your acceptance of this Agreement or access to the Research Materials and will continue in full force and effect until terminated in accordance with the terms and conditions herein. Meta may terminate this Agreement if you are in breach of any term or condition of this Agreement. Upon termination of this Agreement, you shall delete and cease use of the Research Materials. Sections 5, 6 and 9 shall survive the termination of this Agreement.
56
+
57
+ 7. Governing Law and Jurisdiction. This Agreement will be governed and construed under the laws of the State of California without regard to choice of law principles, and the UN Convention on Contracts for the International Sale of Goods does not apply to this Agreement. The courts of California shall have exclusive jurisdiction of any dispute arising out of this Agreement.
58
+
59
+
60
+ 8. Modifications and Amendments. Meta may modify this Agreement from time to time by posting a revised version at [INSERT URL]; provided that they are similar in spirit to the current version of the Agreement, but may differ in detail to address new problems or concerns. All such changes will be effective immediately. Your continued use of the Research Materials after any modification to this Agreement constitutes your agreement to such modification. Except as provided in this Agreement, no modification or addition to any provision of this Agreement will be binding unless it is in writing and signed by an authorized representative of both you and Meta.
61
+
62
+
63
+ FAIR Acceptable Use Policy
64
+
65
+ The Fundamental AI Research (FAIR) team at Meta seeks to further understanding of new and existing research domains with the mission of advancing the state-of-the-art in artificial intelligence through open research for the benefit of all.
66
+
67
+ As part of this mission, Meta makes certain research materials available for noncommercial research use. Meta is committed to promoting the safe and responsible use of such research materials.
68
+
69
+ Prohibited Uses
70
+
71
+ You agree you will not use, or allow others to use, Research Materials to:
72
+
73
+ Violate the law or others’ rights, including to:
74
+ Engage in, promote, generate, contribute to, encourage, plan, incite, or further illegal or unlawful activity or content, such as:
75
+ Violence or terrorism
76
+ Exploitation or harm to children, including the solicitation, creation, acquisition, or dissemination of child exploitative content or failure to report Child Sexual Abuse Material
77
+ Human trafficking, exploitation, and sexual violence
78
+ The illegal distribution of information or materials to minors, including obscene materials, or failure to employ legally required age-gating in connection with such information or materials.
79
+ Sexual solicitation
80
+ Any other criminal activity
81
+
82
+ Engage in, promote, incite, or facilitate the harassment, abuse, threatening, or bullying of individuals or groups of individuals
83
+
84
+ Engage in, promote, incite, or facilitate discrimination or other unlawful or harmful conduct in the provision of employment, employment benefits, credit, housing, other economic benefits, or other essential goods and services
85
+
86
+ Engage in the unauthorized or unlicensed practice of any profession including, but not limited to, financial, legal, medical/health, or related professional practices
87
+
88
+ Collect, process, disclose, generate, or infer health, demographic, or other sensitive personal or private information about individuals without rights and consents required by applicable laws
89
+
90
+ Engage in or facilitate any action or generate any content that infringes, misappropriates, or otherwise violates any third-party rights, including the outputs or results of any technology using FAIR research materials
91
+
92
+ Create, generate, or facilitate the creation of malicious code, malware, computer viruses or do anything else that could disable, overburden, interfere with or impair the proper working, integrity, operation or appearance of a website or computer system
93
+
94
+ 2. Engage in, promote, incite, facilitate, or assist in the planning or development of activities that present a risk of death or bodily harm to individuals, including use of research artifacts related to the following:
95
+
96
+ Military, warfare, nuclear industries or applications, espionage, use for materials or activities that are subject to the International Traffic Arms Regulations (ITAR) maintained by the United States Department of State
97
+
98
+ Guns and illegal weapons (including weapon development)
99
+
100
+ Illegal drugs and regulated/controlled substances
101
+
102
+ Operation of critical infrastructure, transportation technologies, or heavy machinery
103
+
104
+ Self-harm or harm to others, including suicide, cutting, and eating disorders
105
+
106
+ Any content intended to incite or promote violence, abuse, or any infliction of bodily harm to an individual
107
+
108
+ 3. Intentionally deceive or mislead others, including use of FAIR Research Materials related to the following:
109
+
110
+ Generating, promoting, or furthering fraud or the creation or promotion of disinformation
111
+
112
+ Generating, promoting, or furthering defamatory content, including the creation of defamatory statements, images, or other content
113
+
114
+ Generating, promoting, or further distributing spam
115
+
116
+ Impersonating another individual without consent, authorization, or legal right
117
+
118
+ Representing that outputs of FAIR research materials or outputs from technology using FAIR research materials are human-generated
119
+
120
+ Generating or facilitating false online engagement, including fake reviews and other means of fake online engagement
121
+
122
+ 4. Fail to appropriately disclose to end users any known dangers of your Research Materials.
123
+
124
+ Please report any violation of this Policy or other problems that could lead to a violation of this Policy by submitting a report here [https://docs.google.com/forms/d/e/1FAIpQLSeb11cryAopJ7LNrC4nxEUXrHY26hfkXQMf_uH-oFgA3WlYZQ/viewform].
perception_models/README.md ADDED
@@ -0,0 +1,408 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Perception Models: Powerful Models for Image, Video, and Audio Perception
2
+ [![Code License](https://img.shields.io/badge/Code_License-Apache_2.0-olive)](https://opensource.org/licenses/Apache-2.0)
3
+
4
+ This repo is the home to the state-of-the-art for image and video _perception_: [**Perception Encoder (PE)**](https://arxiv.org/abs/2504.13181) for image, video, [audio](https://ai.meta.com/research/publications/pushing-the-frontier-of-audiovisual-perception-with-large-scale-multimodal-correspondence-learning/) encoding, and [**Perception Language Model (PLM)**](https://arxiv.org/abs/2504.13180) for decoding.
5
+
6
+ > [!TIP]
7
+ > Click to Navigate!
8
+ >
9
+ > [Perception Encoder and Perception Encoder Audio-Visual](#perception-encoder-pe)
10
+ >
11
+ > [Perception Language Model](#perception-language-model-plm)
12
+ >
13
+ > [Dataset Releases](#dataset-releases)
14
+
15
+ ## Updates
16
+ * **[Dec-16-25]:** We have released the Perception Encoder Audio-Visual (PE-AV) and Perception Encoder Audio-Frame (PE-A-Frame) models: [[`Blog`](https://ai.meta.com/blog/sam-audio/)][[`paper`](https://ai.meta.com/research/publications/pushing-the-frontier-of-audiovisual-perception-with-large-scale-multimodal-correspondence-learning/)] :fire::fire:
17
+ * **[Jul-14-25]:** PerceptionLM is now available in [Hugging Face transformers](https://huggingface.co/docs/transformers/main/en/model_doc/perception_lm). :fire::fire:
18
+ * **[Jul-11-25]:** We have release 8 new checkpoints for [Perception Encoder](apps/pe/README.md): 2x small core models (T and S), 2x tiling-tuned lang models (G and L), and 4x smaller spatial models (L, B, S, T). Give them a try! :fire::fire::fire:
19
+ * **[May-28-25]:** Perception Encoder has been integrated into [timm](https://github.com/huggingface/pytorch-image-models)! :fire::fire:
20
+ * **[Apr-18-25]:** Perception Language Model (PLM) and PLM-VideoBench are added to lmms-eval. This makes it easy to reproduce PLM results and allows you to evaluate on the PLM-VideoBench. [[`lmms-eval`](https://github.com/EvolvingLMMs-Lab/lmms-eval/pull/638)] :fire::fire:
21
+ * **[Apr-17-25]:** Perception Encoder (PE) and Perception Language Model (PLM) are released. [[`Blog`](https://ai.meta.com/blog/meta-fair-updates-perception-localization-reasoning)] :fire::fire:
22
+
23
+
24
+ ## Perception Encoder (PE)
25
+ [![Data](https://img.shields.io/badge/Download-PE%20Data-ffcc00.svg)](https://huggingface.co/datasets/facebook/PE-Video)
26
+ [![Hugging Face Collection](https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Collection-blue)](https://huggingface.co/collections/facebook/perception-encoder-67f977c9a65ca5895a7f6ba1)
27
+ [![Paper](https://img.shields.io/badge/Technical%20Report-Perception%20Encoder-b31b1b.svg)](https://ai.meta.com/research/publications/perception-encoder-the-best-visual-embeddings-are-not-at-the-output-of-the-network)
28
+ [![Paper](https://img.shields.io/badge/Technical%20Report-Perception%20Encoder%20AV-b31b1b.svg)](https://ai.meta.com/research/publications/pushing-the-frontier-of-audiovisual-perception-with-large-scale-multimodal-correspondence-learning/)
29
+ [![Paper](https://img.shields.io/badge/arXiv-2504.13181-brightgreen.svg?style=flat-square)](https://arxiv.org/abs/2504.13181)
30
+ [![Colab Demo](https://img.shields.io/static/v1?label=Demo&message=Google%20Colab&logo=google&color=orange)](https://colab.research.google.com/github/facebookresearch/perception_models/blob/main/apps/pe/docs/pe_demo.ipynb)
31
+ [![Model License](https://img.shields.io/badge/Model_License-Apache_2.0-olive)](https://opensource.org/licenses/Apache-2.0)
32
+
33
+ [Perception Encoder (PE)](https://arxiv.org/abs/2504.13181) is a family of the state-of-the-art vision and audio encoders for encoding images, video, and audio: PE core outperforms SigLIP2 on image and InternVideo2 on video benchmarks; PE lang can be used to outperform QwenVL2.5 and InternVL3 on vision language modeling; and PE spatial outperforms DINOv2 on dense prediction tasks. And all of this follows the same, easily scalable contrastive pretraining. Please see [README](apps/pe/README.md) for more details.
34
+
35
+ <img src="apps/pe/docs/assets/teaser.png" style="width: 100%; margin: 0 auto; display: block;" />
36
+
37
+ ### Models
38
+ PE has 4 types of checkpoints, each excelling in a different area of computer vision and audio understanding:
39
+ - [PE core](#vision-language-benchmarks): a CLIP model excels in vision-language tasks such as zero-shot image and video classification and video retrieval.
40
+ - [PE lang](#multimodal-llm-benchmarks): a LLM-aligned PE that powers [PLM](https://arxiv.org/abs/2504.13180) to compete at the forefront of multimodal LLM benchmarks.
41
+ - [PE spatial](#vision-centric-benchmarks): a spatially tuned PE that outperforms best spatial models for vision-centric tasks such as detection, depth estimation, and tracking.
42
+ - [PE audio-visual](#audio-visual-benchmarks): a CLIP Model that embeds audio, video, audio-video, and text into a joint embedding space.
43
+
44
+ #### Vision-Language Benchmarks
45
+ | | Model | Checkpoint | IN-1k | IN-v2 | IN-A | ObjectNet | COCO-T2I | Kinetics-400 | VTT-T2V
46
+ |:--:|:---:|:---:|:---:|:---:|:---:|:---:|:---:|:---:|:---:|
47
+ | | **T/16** 384px | [PE-Core-T16-384](https://huggingface.co/facebook/PE-Core-T16-384) | 62.1 | 54.7 | 21.1 | 43.9 | 33.0 | 41.5 | 28.8 |
48
+ | | **S/16** 384px | [PE-Core-S16-384](https://huggingface.co/facebook/PE-Core-S16-384) | 72.7 | 65.0 | 49.5 | 60.0 | 42.6 | 55.0 | 39.3 |
49
+ | | **B/16** 224px | [PE-Core-B16-224](https://huggingface.co/facebook/PE-Core-B16-224) | 78.4 | 71.7 | 62.4 | 71.9 | 50.9 | 65.6 | 47.6 |
50
+ | | **L/14** 336px | [PE-Core-L14-336](https://huggingface.co/facebook/PE-Core-L14-336) | 83.5 | 77.9 | 89.0 | 84.7 | 57.1 | 73.4 | 50.3 |
51
+ | | **G/14** 448px | [PE-Core-G14-448](https://huggingface.co/facebook/PE-Core-G14-448) | 85.4 | 80.2 | 92.6 | 88.2 | 58.1 | 76.9 | 51.2 |
52
+
53
+ #### Multimodal LLM Benchmarks
54
+
55
+ 🔬 Controlled Setting:
56
+ | | Encoder | Checkpoint | Doc VQA (val) | InfoQA (val) | TextVQA | MVBench | PerceptionTest (val) | EgoSchema (val) |
57
+ |:--:|:---:|:---:|:---:|:---:|:---:|:---:|:---:|:---:|
58
+ | | **L/14** 448px | [PE-Lang-L14-448](https://huggingface.co/facebook/PE-Lang-L14-448) | 81.9 | 46.4 | 73.0 | 52.3 | 54.7 | 59.8 |
59
+ | | **G/14** 448px | [PE-Lang-G14-448](https://huggingface.co/facebook/PE-Lang-G14-448) | 84.4 | 48.3 | 75.2 | 52.4 | 56.0 | 62.0 |
60
+
61
+
62
+ 🔥 SotA Setting:
63
+ | | Model | Encoder | Doc VQA (test) | InfoQA (test) | TextVQA | MVBench | PerceptionTest (test) | EgoSchema (test) |
64
+ |:--:|:---:|:---:|:---:|:---:|:---:|:---:|:---:|:---:|
65
+ | | PLM-3B | [PE-Lang-L14-448-Tiling](https://huggingface.co/facebook/PE-Lang-L14-448-Tiling)* | 93.8 | 74.6 | 84.3 | 74.7 | 79.3 | 66.9 |
66
+ | | PLM-8B | [PE-Lang-G14-448-Tiling](https://huggingface.co/facebook/PE-Lang-G14-448-Tiling)* | 94.6 | 80.9 | 86.5 | 77.1 | 82.7 | 68.8 |
67
+
68
+ \* These checkpoints were aligned with tiling. Use them if you use higher than 448 resolution with tiling in the LLM decoder.
69
+
70
+ #### Vision-centric Benchmarks
71
+ 🦾 Main model:
72
+ | | Encoder | Checkpoint | ADE20k <br/> [Segmentation](https://github.com/open-mmlab/mmsegmentation)<br />Linear Probe mIoU | DAVIS<br /> [Tracking](https://github.com/facebookresearch/dino/blob/main/eval_video_segmentation.py) <br />Zero-Shot J&F | LVIS <br /> [Mask R-CNN](../detection/detectron2_pe/) 1024px <br /> Box / Mask mAP | COCO <br/> [DETA](../detection/DETA_pe/) 1824px <br /> Box mAP |
73
+ |:--:|:---:|:---:|:---:|:---:|:---:|:---:|
74
+ | | **G/14** 448px | [PE-Spatial-G14-448](https://huggingface.co/facebook/PE-Spatial-G14-448) | 49.3 | 61.5 | 54.2 / 49.3 | 66.0 |
75
+
76
+
77
+ <div align="center">
78
+ <img src="apps/pe/docs/assets/spatial_correspondence.png" style="width: 80%; margin: 0 auto; padding-top: 20px; padding-bottom: 20px; display: block;" />
79
+
80
+ Visualization of PCA of non-maked visual tokens, mapped to RGB values.
81
+ </div>
82
+
83
+ ⚗️ Distilled Models:
84
+ | | Encoder<br />(Distilled from G) | Checkpoint | ADE20k <br/> [Segmentation](https://github.com/open-mmlab/mmsegmentation)<br />Linear Probe mIoU | DAVIS<br /> [Tracking](https://github.com/facebookresearch/dino/blob/main/eval_video_segmentation.py) <br />Zero-Shot J&F |
85
+ |:--:|:---:|:---:|:---:|:---:|
86
+ | | **T/16** 512px | [PE-Spatial-T16-512](https://huggingface.co/facebook/PE-Spatial-T16-512) | 27.6 | 55.0 |
87
+ | | **S/16** 512px | [PE-Spatial-S16-512](https://huggingface.co/facebook/PE-Spatial-S16-512) | 37.5 | 57.5 |
88
+ | | **B/16** 512px | [PE-Spatial-B16-512](https://huggingface.co/facebook/PE-Spatial-B16-512) | 44.4 | 58.9 |
89
+ | | **L/14** 448px | [PE-Spatial-L14-448](https://huggingface.co/facebook/PE-Spatial-L14-448) | 48.1 | 60.6 |
90
+
91
+ See paper for comparison to other models.
92
+
93
+ #### Audio-Visual Benchmarks
94
+
95
+ | | Model | Checkpoint | Avg Retrieval | AudioCaps T→A | AudioCaps T→V | AudioCaps V→A | Clotho T→A | Valor T→A | Valor T→V | VCTK A→T | VGGSound V→A | Internal V→A |
96
+ |:--:|:-----:|--------------|---------------|---------------|---------------|---------------|------------|-----------|-----------|----------|---------------|---------------|
97
+ | 🆕 | **AV S** 16 frames | [`pe-av-small-16-frame`](https://huggingface.co/facebook/pe-av-small-16-frame) | 45.2 | 41.2 | 18.6 | 75.4 | 24.0 | 29.8 | 70.1 | 96.1 | 34.1 | 17.9 |
98
+ | 🆕 | **AV B** 16 frames | [`pe-av-base-16-frame`](https://huggingface.co/facebook/pe-av-base-16-frame) | 47.0 | 43.1 | 19.8 | 80.6 | 23.4 | 31.9 | 70.0 | 94.8 | 39.0 | 20.4 |
99
+ | 🆕 | **AV L** 16 frames | [`pe-av-large-16-frame`](https://huggingface.co/facebook/pe-av-large-16-frame) | 48.2 | 44.7 | 19.5 | 86.1 | 22.8 | 35.0 | 70.9 | 85.6 | 45.2 | 23.9 |
100
+ | 🆕 | **AV S** all frames | [`pe-av-small`](https://huggingface.co/facebook/pe-av-small) | 48.1 | 41.8 | 18.8 | 77.4 | 23.9 | 29.3 | 70.9 | 94.9 | 35.4 | 40.5 |
101
+ | 🆕 | **AV B** all frames | [`pe-av-base`](https://huggingface.co/facebook/pe-av-base) | 50.2 | 42.7 | 19.6 | 83.7 | 23.8 | 30.8 | 71.2 | 94.9 | 40.7 | 44.6 |
102
+ | 🆕 | **AV L** all frames | [`pe-av-large`](https://huggingface.co/facebook/pe-av-large) | 51.6 | 45.8 | 20.8 | 88.3 | 23.0 | 35.1 | 70.9 | 85.6 | 48.3 | 46.5 |
103
+
104
+ #### Audio Event Localization Benchmarks
105
+
106
+ | | Model | Checkpoint | Internal Bench (AUROC) | ASFX-SED (AUROC) | AudioSet-Strong (AUROC) | DESED (AUROC) | UrbanSED (AUROC) |
107
+ |:--:|:-----:|------------------|---------------------|------------------|-----------------------|-------------|-------------|
108
+ | 🆕 | **A-Frame S** | [`pe-a-frame-small`](https://huggingface.co/facebook/pe-a-frame-small)| 0.91 | 0.83 | 0.96 | 0.96 | 0.88 |
109
+ | 🆕 | **A-Frame B** | [`pe-a-frame-base`](https://huggingface.co/facebook/pe-a-frame-base)| 0.92 | 0.83 | 0.96 | 0.98 | 0.89 |
110
+ | 🆕 | **A-Frame L** | [`pe-a-frame-large`](https://huggingface.co/facebook/pe-a-frame-large)| 0.91 | 0.83 | 0.96 | 0.97 | 0.89 |
111
+
112
+ ### Getting Started with PE
113
+ You can get started with the following example for image and text feature extraction or use our [Colab Demo](https://colab.research.google.com/github/facebookresearch/perception_models/blob/main/apps/pe/docs/pe_demo.ipynb)
114
+
115
+ ```python
116
+ import torch
117
+ from PIL import Image
118
+ import core.vision_encoder.pe as pe
119
+ import core.vision_encoder.transforms as transforms
120
+
121
+ print("CLIP configs:", pe.CLIP.available_configs())
122
+ # CLIP configs: ['PE-Core-G14-448', 'PE-Core-L14-336', 'PE-Core-B16-224', 'PE-Core-S16-384', 'PE-Core-T16-384']
123
+
124
+ model = pe.CLIP.from_config("PE-Core-L14-336", pretrained=True) # Downloads from HF
125
+ model = model.cuda()
126
+
127
+ preprocess = transforms.get_image_transform(model.image_size)
128
+ tokenizer = transforms.get_text_tokenizer(model.context_length)
129
+
130
+ image = preprocess(Image.open("docs/assets/cat.png")).unsqueeze(0).cuda()
131
+ text = tokenizer(["a diagram", "a dog", "a cat"]).cuda()
132
+
133
+ with torch.no_grad(), torch.autocast("cuda"):
134
+ image_features, text_features, logit_scale = model(image, text)
135
+ text_probs = (logit_scale * image_features @ text_features.T).softmax(dim=-1)
136
+
137
+ print("Label probs:", text_probs) # prints: [[0.0, 0.0, 1.0]]
138
+ ```
139
+
140
+ > [!TIP]
141
+ > See [`apps/pe/README.md`](apps/pe/README.md) for details and how to get started!
142
+
143
+ ### Getting Started with PE-AV
144
+
145
+ ```python
146
+ import os
147
+ from core.audio_visual_encoder import PEAudioVisual, PEAudioVisualTransform
148
+ import torch
149
+
150
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
151
+ model = PEAudioVisual.from_config("pe-av-large", pretrained=True).to(device)
152
+ transform = PEAudioVisualTransform.from_config("pe-av-large")
153
+
154
+ video_files = ["assets/train.mp4", "assets/office.mp4"]
155
+ descriptions = [
156
+ "A person talking with sirens and a train in the background",
157
+ "Two people talking in an office, with sounds of workers typing on a keyboard"
158
+ ]
159
+
160
+ def embed(videos=None, audio=None, text=None):
161
+ inputs = transform(videos=videos, audio=audio, text=text)
162
+ inputs = inputs.to(device)
163
+ with torch.inference_mode(), torch.autocast(device.type, dtype=torch.bfloat16):
164
+ return model(**inputs)
165
+
166
+ vt_outputs = embed(videos=video_files, text=descriptions)
167
+ avt_outputs = embed(videos=video_files, audio=video_files, text=descriptions)
168
+ at_outputs = embed(audio=video_files, text=descriptions)
169
+
170
+ # Compute dot product between visual and text
171
+ vt_dot_products = torch.einsum("ij,ij->i", vt_outputs.visual_embeds, vt_outputs.visual_text_embeds)
172
+ # Compute dot product between audio_visual and text
173
+ avt_dot_products = torch.einsum("ij,ij->i", avt_outputs.audio_visual_embeds, avt_outputs.audio_visual_text_embeds)
174
+ # Compute dot product between audio and text
175
+ at_dot_products = torch.einsum("ij,ij->i", at_outputs.audio_embeds, at_outputs.audio_text_embeds)
176
+ # Compute dot product between audio and video
177
+ av_dot_products = torch.einsum("ij,ij->i", avt_outputs.audio_embeds, avt_outputs.video_embeds)
178
+ ```
179
+
180
+ ### Getting Started with PE-A-Frame
181
+
182
+ ```python
183
+ from core.audio_visual_encoder import (
184
+ PEAudioFrame,
185
+ PEAudioFrameTransform,
186
+ )
187
+ import torch
188
+
189
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
190
+ model = PEAudioFrame.from_config("pe-a-frame-large", pretrained=True).to(device)
191
+ transform = PEAudioFrameTransform.from_config("pe-a-frame-large")
192
+
193
+ descriptions = ["a person talking"]
194
+ inputs = transform(
195
+ audio=["assets/office.mp4"],
196
+ text=descriptions,
197
+ ).to(device)
198
+
199
+ with torch.inference_mode():
200
+ outputs = model(**inputs)
201
+
202
+ # Print the spans for each description (start and end timestamps for when they occur in the audio)
203
+ for description, spans in zip(descriptions, outputs.spans):
204
+ span_str = ", ".join([f"({start:.2f}, {end:.2f})" for start, end in spans])
205
+ print(f'"{description}": [{span_str}]')
206
+
207
+ ```
208
+
209
+ > [!TIP]
210
+ > See [`apps/pe/README.md`](apps/pe/README.md) for additional details!
211
+
212
+ ## Perception Language Model (PLM)
213
+ [![Data](https://img.shields.io/badge/Download-PLM%20Data-ffcc00.svg)](https://huggingface.co/datasets/facebook/PLM-Video-Human)
214
+ [![Hugging Face Collection](https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Collection-blue)](https://huggingface.co/collections/facebook/perception-lm-67f9783f171948c383ee7498)
215
+ [![Paper](https://img.shields.io/badge/Technical%20Report-PerceptionLM-b31b1b.svg)](https://ai.meta.com/research/publications/perceptionlm-open-access-data-and-models-for-detailed-visual-understanding)
216
+ [![Paper](https://img.shields.io/badge/arXiv-2504.13180-brightgreen.svg?style=flat-square)](https://arxiv.org/abs/2504.13180)
217
+ [![Colab](https://img.shields.io/badge/Google%20Colab-Tutorials-red)](apps/plm/notebook_demos)
218
+ [![ModelLicense](https://img.shields.io/badge/Model_License-FAIR_Research_License-lightgrey)](LICENSE.PLM)
219
+
220
+ PerceptionLM (PLM) is a family of open and fully reproducible models to facilitate research in vision-language modeling (VLM). In conjunction with PE, it is powerful enough to compete with the latest state-of-the-art VLMs such as InternVL3 and QwenVL2.5, while using _fully open data_. We also release the largest spatiotemporally annotated video dense captioning and fine-grained human activity recognition datasets to ever exist.
221
+
222
+ ![Description of the image](apps/plm/docs/plm_main_fig.png)
223
+
224
+ ### Models
225
+ PLM releases models in three different sizes (1B, 3B and 8B).
226
+ * [Perception-LM-1B](https://huggingface.co/facebook/Perception-LM-1B): A PLM model trained using Llama-3.2-1B-Instruct base LLM.
227
+ * [Perception-LM-3B](https://huggingface.co/facebook/Perception-LM-3B): A PLM model trained using Llama-3.2-3B-Instruct base LLM.
228
+ * [Perception-LM-8B](https://huggingface.co/facebook/Perception-LM-8B): A PLM model trained using Llama-3.1-8B-Instruct base LLM.
229
+
230
+ #### PLM Image Benchmark Results
231
+
232
+ | Model | DocVQA | ChartQA | TextVQA | InfoQA | AI2D | OCRBench | COCO | Nocap | Flickr | MMMU | VQAv2 | OKVQA | VizWiz | MME | SEED | BLINK | CVBench | RealWorldQA | VSR | POPE |
233
+ |:---------:|:--------:|:---------:|:---------:|:--------:|:------:|:----------:|:------------:|:-------------:|:--------------:|:------:|:-------:|:--------:|:--------:|:-----:|:------:|:-------:|:----------:|:-------------:|:-----:|:------:|
234
+ | PLM1B | 90.7 | 78.6 | 82.1 | 63.0 | 84.9 | 807 | 138.6 | 124.2 | 100.5 | 34.8 | 81.7 | 61.0 | 59.7 | 1603| 76.3 | 46.8 | 73.8 | 67.1 | 68.8| 88.4 |
235
+ | PLM3B | 93.8 | 84.3 | 84.3 | 74.6 | 90.9 | 830 | 144.9 | 126.5 | 98.0 | 41.2 | 84.3 | 66.8 | 64.0 | 1879| 78.5 | 55.4 | 81.4 | 72.4 | 80.4| 88.7 |
236
+ | PLM8B | 94.6 | 85.5 | 86.5 | 80.9 | 92.7 | 870 | 146.7 | 129.9 | 105.6 | 46.1 | 85.6 | 69.6 | 67.0 | 1989| 79.3 | 56.0 | 81.3 | 75.0 | 82.8| 89.9 |
237
+
238
+ #### PLM Video Benchmark Results
239
+
240
+ | Model | VATEX | DREAM&nbsp;1K | How2QA | MVBench | NExTQA | PerceptionTest&nbsp;(test) | STAR | TVQA | VideoMME | TVBench | ActivityNetQA | EgoSchema&nbsp;(test) | TemporalBench | TOMATO | MotionBench&nbsp;(dev) | TempCompass&nbsp;(MCQ) | CGBench&nbsp;(clue) | Charades&nbsp;STA | VideoHallucer | Halluc.&nbsp;EventHallusion |
241
+ |:-------------:|:---------------------------:|:-----------------------:|:---------------------:|:-------------:|:-------------:|:--------------------------:|:----------:|:----------:|:----------------:|:-------------:|:--------------------:|:----------------------:|:---------------------:|:------------:|:------------------------:|:-----------------------:|:---------------------:|:-------------------:|:-------------------------------:|:--------------------------------:|
242
+ | PLM1B | 92.5 | 34.3 | 86.4 | 70.1 | 80.3 | 72.7 | 83.7 | 50.3 | 49.2 | 50.4 | 62.5 | 60.4 | 18.2 | 25.5 | 52.2 | 64.6 | 43.6 | 55.2 | 49.2 | 79.5 |
243
+ | PLM3B | 96.1 | 37.4 | 89.4 | 74.7 | 83.4 | 79.3 | 84.8 | 55.3 | 54.9 | 58.9 | 66.2 | 66.9 | 23.4 | 30.9 | 60.4 | 69.3 | 47.2 | 57.7 | 55.5 | 76.5 |
244
+ | PLM8B | 99.7 | 35.9 | 90.7 | 77.1 | 84.1 | 82.7 | 84.9 | 59.3 | 58.3 | 63.5 | 67.3 | 68.8 | 28.3 | 33.2 | 61.4 | 72.7 | 46.4 | 58.6 | 57.7 | 77.3 |
245
+
246
+ ### PLM Resources
247
+
248
+ | Resource | Description | Documentation |
249
+ | --- | --- |--------------------------------------------------------|
250
+ | **Evaluation** | Evaluation of PLM using lmms-eval | [`docs/evaluation.md`](apps/plm/docs/evaluation.md) |
251
+ | **Training / Finetuning** | Training and finetuning instructions for PLM | [`docs/training.md`](apps/plm/docs/training.md) |
252
+ | **PLM-VideoBench** | Evaluation on PLM-VideoBench using lmms-eval | [`docs/plm_videobench.md`](apps/plm/docs/plm_videobench.md) |
253
+ | **End-to-End Finetuning Example** | End-to-end finetuning example on radiology images | [`docs/finetune_example.md`](apps/plm/docs/finetune_example.md) |
254
+ | **Generating Response** | Generate responses using a trained model with `generate.py` | [`generate.py`](apps/plm/generate.py) |
255
+
256
+
257
+ > [!TIP]
258
+ > See [`apps/plm/README.md`](apps/plm/README.md) for details and how to get started!
259
+
260
+ ## Dataset Releases
261
+
262
+
263
+ ### 🎥 [PE-Video-Dataset (PVD)](https://huggingface.co/datasets/facebook/PE-Video)
264
+
265
+
266
+ PVD comprises 1M high quality and diverse videos. Among them, 120K videos are accompanied by automated and human-verified annotations. and all videos are accompanied with video description and keywords. The videos are motion-centered, covering both first-person and third-person views with a wide coverage of scenes.
267
+
268
+ 🔹 [**PVD**](https://huggingface.co/datasets/facebook/PE-Video) - 1M High-Quality Human Annotated Video Dataset
269
+
270
+ <table>
271
+ <tr>
272
+ <td colspan="2" align="center"><strong>PVD</strong></td>
273
+ </tr>
274
+ <tr>
275
+ <td align="center">
276
+ <img src="https://github.com/user-attachments/assets/ead8a7ed-4d5b-465a-a396-68948683dfcf" alt="output_2" width="300"/><br>
277
+ A person's hands pruning a plant with green leaves.
278
+ </td>
279
+ <td align="center">
280
+ <img src="https://github.com/user-attachments/assets/9e509e49-f550-4c5c-9571-ed57c5118227" alt="output" width="300"/><br>
281
+ A detailed diorama of a rural landscape featuring a horse-drawn carriage moving along a dirt path
282
+ </td>
283
+ </tr>
284
+ </table>
285
+
286
+ ---
287
+
288
+
289
+ ### 🎥 [PLM-Video-Human](https://huggingface.co/datasets/facebook/PLM-Video-Human)
290
+
291
+ PLM-Video-Human is a collection of human-annotated resources for training Vision Language Models, focused on detailed video understanding. Training tasks include:
292
+
293
+ 🔹 [**FGQA**](https://huggingface.co/datasets/facebook/PLM-Video-Human#fine-grained-question-answering-fgqa) — Fine-Grained Question Answering
294
+ 🔹 [**RTLoc**](https://huggingface.co/datasets/facebook/PLM-Video-Human#region-temporal-localization-rtloc) — Region-Temporal Localization
295
+ 🔹 [**RCap**](https://huggingface.co/datasets/facebook/PLM-Video-Human#region-video-captioning-rcap) — Region Video Captioning
296
+ 🔹 [**RDCap**](https://huggingface.co/datasets/facebook/PLM-Video-Human#region-dense-temporal-captioning-rdcap) — Region Dense Temporal Captioning
297
+
298
+ <table>
299
+ <tr>
300
+ <td colspan="2" align="center"><strong>FGQA</strong></td>
301
+ </tr>
302
+ <tr>
303
+ <td colspan="2" align="center">
304
+ <img src="https://github.com/user-attachments/assets/4f5c6c5e-687d-49df-9bf8-db9ec7f1f281" alt="fgqa" width="500"/>
305
+ </td>
306
+ </tr>
307
+ <tr>
308
+ <th>Question</th>
309
+ <th>Answer</th>
310
+ </tr>
311
+ <tr>
312
+ <td>In what direction do you move the tool while removing the shell?</td>
313
+ <td>Both clockwise and anticlockwise.</td>
314
+ </tr>
315
+ </table>
316
+
317
+ <table>
318
+ <tr>
319
+ <td colspan="2" align="center"><strong>STC</strong></td>
320
+ </tr>
321
+ <tr>
322
+ <td colspan="2" align="center">
323
+ <img src="https://github.com/user-attachments/assets/a2a129c7-c1e9-47b5-a3b4-fc96a237a9fb" alt="stc" width="500"/>
324
+ </td>
325
+ </tr>
326
+ <tr>
327
+ <th>Time (s) &nbsp;&nbsp;&nbsp;&nbsp;&nbsp;</th>
328
+ <th>Description</th>
329
+ </tr>
330
+ <tr>
331
+ <td>[0, 4]</td>
332
+ <td>The masked subject is a young boy wearing a red jacket and gray pants. He is grasping a monkey bar–like activity in a playground.</td>
333
+ </tr>
334
+ <tr>
335
+ <td>[5, 14]</td>
336
+ <td>He lets go of his hands and runs to the right side of the frame.</td>
337
+ </tr>
338
+ <tr>
339
+ <td>[15, 30]</td>
340
+ <td>The subject is out of frame.</td>
341
+ </tr>
342
+ <tr>
343
+ <td>[31, 45]</td>
344
+ <td>The subject runs back into the frame toward the higher monkey bar in the playground.</td>
345
+ </tr>
346
+ <tr>
347
+ <td>[46, 74]</td>
348
+ <td>He jumps underneath the metal bar and looks up at it. A man wearing a white polo runs toward the subject.</td>
349
+ </tr>
350
+ <tr>
351
+ <td>[75, 116]</td>
352
+ <td>The man in the white polo lifts the subject upward so he can grasp the higher metal bar. The subject holds onto the bar and hangs from it.</td>
353
+ </tr>
354
+ </table>
355
+
356
+ ---
357
+
358
+ ### 🤖 Auto-Generated Datasets
359
+
360
+ Sythetic image/video captions and QAs used in PLM, please refer to the paper, Section 3 (PLM), for more details. The sythetic annotations covers: SA1B, Openimages, Obejct365, ArxivQA, UCSF, PDFAcc, YT-1B, Ego4d with captions, YT-1B with MCQAs and Ego4d with QAs.
361
+
362
+ 🖼️ [**PLM-Image-Auto**](https://huggingface.co/datasets/facebook/PLM-Image-Auto) — Automatically generated image datasets
363
+
364
+ 📹 [**PLM-Video-Auto**](https://huggingface.co/datasets/facebook/PLM-Video-Auto) — Automatically generated video datasets
365
+
366
+
367
+ ---
368
+
369
+ ## Installation :wrench:
370
+ ```shell
371
+ git clone https://github.com/facebookresearch/perception_models.git
372
+ cd perception_models
373
+
374
+ conda create --name perception_models python=3.12
375
+ conda activate perception_models
376
+
377
+ # Install PyTorch
378
+ pip install torch==2.5.1 torchvision==0.20.1 torchaudio==2.5.1 xformers --index-url https://download.pytorch.org/whl/cu124
379
+
380
+ # We use torchcodec for decoding videos into PyTorch tensors
381
+ conda install ffmpeg -c conda-forge
382
+ pip install torchcodec==0.1 --index-url=https://download.pytorch.org/whl/cu124
383
+
384
+ pip install -e .
385
+ ```
386
+ This will install an editable version of repo, allowing you to make changes to the code without needing to reinstall the package every time.
387
+
388
+
389
+ ## 🙏 Acknowledgement
390
+ We are thankful to [Meta Lingua](https://github.com/facebookresearch/lingua) for releasing their code as open-source contributions. The code structure and code implementation of the LLM is directly forked from [Meta Lingua](https://github.com/facebookresearch/lingua). We are also thankful to [Open_CLIP](https://github.com/mlfoundations/open_clip) for open-source contributions in CLIP training, and [CLIP_benchmark](https://github.com/LAION-AI/CLIP_benchmark) for CLIP model evaluation.
391
+
392
+
393
+ ## 📜 Citation
394
+ ```BibTeX
395
+ @article{bolya2025PerceptionEncoder,
396
+ title={Perception Encoder: The best visual embeddings are not at the output of the network},
397
+ author={Daniel Bolya and Po-Yao Huang and Peize Sun and Jang Hyun Cho and Andrea Madotto and Chen Wei and Tengyu Ma and Jiale Zhi and Jathushan Rajasegaran and Hanoona Rasheed and Junke Wang and Marco Monteiro and Hu Xu and Shiyu Dong and Nikhila Ravi and Daniel Li and Piotr Doll{\'a}r and Christoph Feichtenhofer},
398
+ journal={arXiv:2504.13181},
399
+ year={2025}
400
+ }
401
+
402
+ @article{cho2025PerceptionLM,
403
+ title={PerceptionLM: Open-Access Data and Models for Detailed Visual Understanding},
404
+ author={Jang Hyun Cho and Andrea Madotto and Effrosyni Mavroudi and Triantafyllos Afouras and Tushar Nagarajan and Muhammad Maaz and Yale Song and Tengyu Ma and Shuming Hu and Hanoona Rasheed and Peize Sun and Po-Yao Huang and Daniel Bolya and Suyog Jain and Miguel Martin and Huiyu Wang and Nikhila Ravi and Shashank Jain and Temmy Stark and Shane Moon and Babak Damavandi and Vivian Lee and Andrew Westbury and Salman Khan and Philipp Kr\"{a}henb\"{u}hl and Piotr Doll{\'a}r and Lorenzo Torresani and Kristen Grauman and Christoph Feichtenhofer},
405
+ journal={arXiv:2504.13180},
406
+ year={2025}
407
+ }
408
+ ```
perception_models/__pycache__/legrad_pe_audio.cpython-310.pyc ADDED
Binary file (6.49 kB). View file
 
perception_models/__pycache__/legrad_pe_audio.cpython-313.pyc ADDED
Binary file (10.2 kB). View file
 
perception_models/__pycache__/legrad_pe_image.cpython-312.pyc ADDED
Binary file (12.8 kB). View file
 
perception_models/__pycache__/legrad_pe_image.cpython-313.pyc ADDED
Binary file (11.5 kB). View file
 
perception_models/apps/detection/DETA_pe/README.md ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SOTA COCO Object Detection with PE
2
+
3
+ ## Getting started
4
+
5
+ Please refer to [INSTALL.md](../INSTALL.md) for installation and dataset preparation instructions.
6
+
7
+ Also install [Deformable Attention](models/ops/make.sh) ops.
8
+
9
+ ## Results and Fine-tuned Models
10
+
11
+ <table><tbody>
12
+ <!-- START TABLE -->
13
+ <!-- TABLE HEADER -->
14
+ <th valign="bottom">detector</th>
15
+ <th valign="bottom">vision encoder</th>
16
+ <th valign="bottom">box<br/>AP</th>
17
+ <th valign="bottom">box(TTA)<br/>AP</th>
18
+ <th valign="bottom">download</th>
19
+ <!-- TABLE BODY -->
20
+ <!-- ROW: DETA -->
21
+ <tr><td align="left">DETA</td>
22
+ <td align="center">PE spatial G</td>
23
+ <td align="center"> 65.2 </td>
24
+ <td align="center"> 66.0 </td>
25
+ <td align="center"><a href="https://huggingface.co/facebook/PE-Detection/resolve/main/deta_coco_1824pix.pth">model</a></td>
26
+ </tr>
27
+ </tbody></table>
28
+
29
+
30
+ ## Training
31
+ We apply a four-stage training, Objects365(12ep, 1024pix), Objects365(6ep, 1536pix), COCO(12ep, 1728pix), COCO(3ep, 1824pix)
32
+
33
+ ```
34
+ sbatch scripts/pretrain_spatial_Gwin384_o365ep12_1024pix_16node.sh
35
+
36
+ sbatch scripts/pretrain_continue_spatial_Gwin384_o365ep6_1536pix_16node.sh
37
+
38
+ sbatch scripts/finetune_spatial_Gwin384_cocoep12_1728pix_8node.sh
39
+
40
+ sbatch scripts/finetune_further_spatial_Gwin384_cocoep3_1824pix_8node.sh
41
+
42
+ ```
43
+
44
+ ## Evaluation
45
+ ```
46
+ bash scripts/eval_1824pix.sh --resume deta_coco_1824pix.pth
47
+ ```
48
+
49
+ ## Evaluation with TTA (Test-Time Augmentation)
50
+ ```
51
+ sbatch scripts/eval_tta_slurm_1824pix.sh --resume deta_coco_1824pix.pth
52
+ ```
53
+ Note: If you get 65.9 AP, it is probably caused by different package versions, trying different hyperparameters like `--quad_scale 0.4` will give 66.0 AP.
perception_models/apps/detection/DETA_pe/datasets/__init__.py ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ------------------------------------------------------------------------
2
+ # Deformable DETR
3
+ # Copyright (c) 2020 SenseTime. All Rights Reserved.
4
+ # Licensed under the Apache License, Version 2.0 [see LICENSE for details]
5
+ # ------------------------------------------------------------------------
6
+ # Modified from DETR (https://github.com/facebookresearch/detr)
7
+ # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
8
+ # ------------------------------------------------------------------------
9
+
10
+ import torch.utils.data
11
+
12
+ from .coco import build as build_coco
13
+ from .objects365 import build as build_objects365
14
+ from .torchvision_datasets import CocoDetection
15
+
16
+
17
+ def get_coco_api_from_dataset(dataset):
18
+ for _ in range(10):
19
+ # if isinstance(dataset, torchvision.datasets.CocoDetection):
20
+ # break
21
+ if isinstance(dataset, torch.utils.data.Subset):
22
+ dataset = dataset.dataset
23
+ if isinstance(dataset, CocoDetection):
24
+ return dataset.coco
25
+
26
+
27
+ def build_dataset(image_set, args):
28
+ if args.dataset_file == "objects365":
29
+ return build_objects365(image_set, args)
30
+ if args.dataset_file == "coco":
31
+ return build_coco(image_set, args)
32
+ if args.dataset_file == "coco_panoptic":
33
+ # to avoid making panopticapi required for coco
34
+ from .coco_panoptic import build as build_coco_panoptic
35
+
36
+ return build_coco_panoptic(image_set, args)
37
+ raise ValueError(f"dataset {args.dataset_file} not supported")
perception_models/apps/detection/DETA_pe/datasets/coco.py ADDED
@@ -0,0 +1,345 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ------------------------------------------------------------------------
2
+ # Deformable DETR
3
+ # Copyright (c) 2020 SenseTime. All Rights Reserved.
4
+ # Licensed under the Apache License, Version 2.0 [see LICENSE for details]
5
+ # ------------------------------------------------------------------------
6
+ # Modified from DETR (https://github.com/facebookresearch/detr)
7
+ # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
8
+ # ------------------------------------------------------------------------
9
+
10
+ """
11
+ COCO dataset which returns image_id for evaluation.
12
+
13
+ Mostly copy-paste from https://github.com/pytorch/vision/blob/13b35ff/references/detection/coco_utils.py
14
+ """
15
+ import random
16
+ from pathlib import Path
17
+
18
+ import datasets.transforms as T
19
+ import torch
20
+ import torch.utils.data
21
+ import torchvision.transforms.functional as F
22
+ from pycocotools import mask as coco_mask
23
+ from util.misc import get_local_rank, get_local_size
24
+
25
+ from .torchvision_datasets import CocoDetection as TvCocoDetection
26
+
27
+
28
+
29
+ class CocoDetection(TvCocoDetection):
30
+ def __init__(
31
+ self,
32
+ img_folder,
33
+ ann_file,
34
+ transforms,
35
+ return_masks,
36
+ cache_mode=False,
37
+ local_rank=0,
38
+ local_size=1,
39
+ test_hflip_aug=False,
40
+ tta=False,
41
+ is_train=False,
42
+ lsj_img_size=1824,
43
+ ):
44
+ super(CocoDetection, self).__init__(
45
+ img_folder,
46
+ ann_file,
47
+ cache_mode=cache_mode,
48
+ local_rank=local_rank,
49
+ local_size=local_size,
50
+ )
51
+ self._transforms = transforms
52
+ self.prepare = ConvertCocoPolysToMask(return_masks)
53
+ self.test_hflip_aug = test_hflip_aug
54
+ self.tta = tta
55
+ if lsj_img_size == 1728: # for back-compatibility
56
+ self.tta_image_size = [1536, 1152,]
57
+ else:
58
+ self.tta_image_size = [1728, 1536, 1344,]
59
+
60
+ self.is_train = is_train
61
+
62
+ def __getitem__(self, idx):
63
+ img, target = super(CocoDetection, self).__getitem__(idx)
64
+ image_id = self.ids[idx]
65
+ target = {"image_id": image_id, "annotations": target}
66
+ img, target = self.prepare(img, target)
67
+ if self._transforms is not None:
68
+ img, target = self._transforms(img, target)
69
+
70
+ if self.test_hflip_aug:
71
+ flipped_img = torch.flip(img, dims=[-1])
72
+ new_img = torch.cat([img, flipped_img], dim=0)
73
+ return new_img, target
74
+
75
+ elif self.tta:
76
+ tta_images = [img]
77
+ flipped_img = torch.flip(img, dims=[-1])
78
+ tta_images.append(flipped_img)
79
+ _, height, width = img.shape
80
+ max_size_len = height if height >= width else width
81
+ for new_max_size in self.tta_image_size:
82
+ scale = new_max_size / max_size_len
83
+ new_height, new_width = int(scale * height), int(scale * width)
84
+ new_img = F.resize(img, size=(new_height, new_width))
85
+ tta_images.append(new_img)
86
+ flipped_img = torch.flip(new_img, dims=[-1])
87
+ tta_images.append(flipped_img)
88
+ return tta_images, target
89
+ else:
90
+ return img, target
91
+
92
+
93
+ def convert_coco_poly_to_mask(segmentations, height, width):
94
+ masks = []
95
+ for polygons in segmentations:
96
+ rles = coco_mask.frPyObjects(polygons, height, width)
97
+ mask = coco_mask.decode(rles)
98
+ if len(mask.shape) < 3:
99
+ mask = mask[..., None]
100
+ mask = torch.as_tensor(mask, dtype=torch.uint8)
101
+ mask = mask.any(dim=2)
102
+ masks.append(mask)
103
+ if masks:
104
+ masks = torch.stack(masks, dim=0)
105
+ else:
106
+ masks = torch.zeros((0, height, width), dtype=torch.uint8)
107
+ return masks
108
+
109
+
110
+ class ConvertCocoPolysToMask(object):
111
+ def __init__(self, return_masks=False):
112
+ self.return_masks = return_masks
113
+
114
+ def __call__(self, image, target):
115
+ w, h = image.size
116
+
117
+ image_id = target["image_id"]
118
+ image_id = torch.tensor([image_id])
119
+
120
+ anno = target["annotations"]
121
+
122
+ anno = [obj for obj in anno if "iscrowd" not in obj or obj["iscrowd"] == 0]
123
+
124
+ boxes = [obj["bbox"] for obj in anno]
125
+ # guard against no boxes via resizing
126
+ boxes = torch.as_tensor(boxes, dtype=torch.float32).reshape(-1, 4)
127
+ boxes[:, 2:] += boxes[:, :2]
128
+ boxes[:, 0::2].clamp_(min=0, max=w)
129
+ boxes[:, 1::2].clamp_(min=0, max=h)
130
+
131
+ classes = [obj["category_id"] for obj in anno]
132
+ classes = torch.tensor(classes, dtype=torch.int64)
133
+
134
+ if self.return_masks:
135
+ segmentations = [obj["segmentation"] for obj in anno]
136
+ masks = convert_coco_poly_to_mask(segmentations, h, w)
137
+
138
+ keypoints = None
139
+ if anno and "keypoints" in anno[0]:
140
+ keypoints = [obj["keypoints"] for obj in anno]
141
+ keypoints = torch.as_tensor(keypoints, dtype=torch.float32)
142
+ num_keypoints = keypoints.shape[0]
143
+ if num_keypoints:
144
+ keypoints = keypoints.view(num_keypoints, -1, 3)
145
+
146
+ keep = (boxes[:, 3] > boxes[:, 1]) & (boxes[:, 2] > boxes[:, 0])
147
+ boxes = boxes[keep]
148
+ classes = classes[keep]
149
+ if self.return_masks:
150
+ masks = masks[keep]
151
+ if keypoints is not None:
152
+ keypoints = keypoints[keep]
153
+
154
+ target = {}
155
+ target["boxes"] = boxes
156
+ target["labels"] = classes
157
+ if self.return_masks:
158
+ target["masks"] = masks
159
+ target["image_id"] = image_id
160
+ if keypoints is not None:
161
+ target["keypoints"] = keypoints
162
+
163
+ # for conversion to coco api
164
+ area = torch.tensor([obj["area"] for obj in anno])
165
+ iscrowd = torch.tensor(
166
+ [obj["iscrowd"] if "iscrowd" in obj else 0 for obj in anno]
167
+ )
168
+ target["area"] = area[keep]
169
+ target["iscrowd"] = iscrowd[keep]
170
+
171
+ target["orig_size"] = torch.as_tensor([int(h), int(w)])
172
+ target["size"] = torch.as_tensor([int(h), int(w)])
173
+
174
+ return image, target
175
+
176
+
177
+ def make_coco_transforms(image_set, bigger):
178
+
179
+ normalize = T.Compose(
180
+ [T.ToTensor(), T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]
181
+ )
182
+
183
+ if "train" in image_set:
184
+ scales = [480, 512, 544, 576, 608, 640, 672, 704, 736, 768, 800]
185
+ if "val" in image_set or "test" in image_set:
186
+ scales = [800]
187
+
188
+ max_size = 1333
189
+ if bigger:
190
+ scales = [int(1.5 * s) for s in scales]
191
+ max_size = 2000
192
+
193
+ if image_set == "train":
194
+ augmentation_list = [
195
+ T.RandomHorizontalFlip(),
196
+ T.RandomSelect(
197
+ T.RandomResize(scales, max_size=max_size),
198
+ T.Compose(
199
+ [
200
+ T.RandomResize([400, 500, 600]),
201
+ T.RandomSizeCrop(384, 600),
202
+ T.RandomResize(scales, max_size=max_size),
203
+ ]
204
+ ),
205
+ ),
206
+ normalize,
207
+ ]
208
+
209
+ return T.Compose(augmentation_list)
210
+
211
+ if image_set == "val":
212
+ return T.Compose(
213
+ [
214
+ T.RandomResize(scales, max_size=max_size),
215
+ normalize,
216
+ ]
217
+ )
218
+
219
+ raise ValueError(f"unknown {image_set}")
220
+
221
+
222
+ def make_coco_transforms_lsj(
223
+ image_set, image_size, lsj_img_train_min=480, lsj_strong_aug=False
224
+ ):
225
+ """
226
+ Reference: https://github.com/facebookresearch/detectron2/blob/main/projects/ViTDet/configs/common/coco_loader_lsj.py
227
+
228
+ import detectron2.data.transforms as T
229
+ from detectron2 import model_zoo
230
+ from detectron2.config import LazyCall as L
231
+
232
+ # Data using LSJ
233
+ image_size = 1024
234
+ dataloader = model_zoo.get_config("common/data/coco.py").dataloader
235
+ dataloader.train.mapper.augmentations = [
236
+ L(T.RandomFlip)(horizontal=True), # flip first
237
+ L(T.ResizeScale)(
238
+ min_scale=0.1, max_scale=2.0, target_height=image_size, target_width=image_size
239
+ ),
240
+ L(T.FixedSizeCrop)(crop_size=(image_size, image_size), pad=False),
241
+ ]
242
+ dataloader.train.mapper.image_format = "RGB"
243
+ dataloader.train.total_batch_size = 64
244
+ # recompute boxes due to cropping
245
+ dataloader.train.mapper.recompute_boxes = True
246
+
247
+ dataloader.test.mapper.augmentations = [
248
+ L(T.ResizeShortestEdge)(short_edge_length=image_size, max_size=image_size),
249
+ ]
250
+ """
251
+
252
+ """
253
+ In our implementation, we simulate lsj data augmentation by:
254
+ (1) first the following augmentations
255
+ (2) then padding to (image_size, image_size) in collator, see util/misc/collate_fn_lsj.py
256
+ """
257
+ normalize = T.Compose(
258
+ [T.ToTensor(), T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]
259
+ )
260
+
261
+ if "train" in image_set:
262
+ scales = [scale for scale in range(lsj_img_train_min, image_size, 32)]
263
+ if "val" in image_set or "test" in image_set or "unlabel" in image_set:
264
+ scales = [image_size - 32]
265
+
266
+ # max_size = 1333
267
+ # if bigger:
268
+ # scales = [int(1.5 * s) for s in scales]
269
+ # max_size = 2000
270
+ max_size = image_size - 32 # for some wired bugs
271
+
272
+ augmentation_list = []
273
+ if "train" in image_set:
274
+ if lsj_strong_aug:
275
+ augmentation_list.extend(
276
+ [
277
+ T.ColorJitter((0.4, 0.4, 0.4, 0.1), p=0.5),
278
+ T.RandomGrayscale(p=0.2),
279
+ # T.RandomErasingP05(),
280
+ ]
281
+ )
282
+ augmentation_list.extend(
283
+ [
284
+ T.RandomHorizontalFlip(),
285
+ T.RandomSelect(
286
+ # similar to (T.ResizeScale)(min_scale=0.1, max_scale=1.0, target_height=image_size, target_width=image_size) and pad
287
+ T.RandomResize(scales, max_size=max_size),
288
+ # similar to (T.ResizeScale)(min_scale=1.0, max_scale=2.0, target_height=image_size, target_width=image_size) and crop
289
+ T.Compose(
290
+ [
291
+ T.RandomResize([400, 500, 600]),
292
+ T.RandomSizeCrop(384, 600),
293
+ T.RandomResize([max_size], max_size=max_size),
294
+ ]
295
+ ),
296
+ ),
297
+ normalize,
298
+ ]
299
+ )
300
+ return T.Compose(augmentation_list)
301
+
302
+ if image_set == "val":
303
+ return T.Compose(
304
+ [
305
+ T.RandomResize(scales, max_size=max_size),
306
+ normalize,
307
+ ]
308
+ )
309
+
310
+ raise ValueError(f"unknown {image_set}")
311
+
312
+
313
+ def build(image_set, args):
314
+ root = Path(args.coco_path)
315
+ assert root.exists(), f"provided COCO path {root} does not exist"
316
+ mode = "instances"
317
+ PATHS = {
318
+ "train": (root / "train2017", root / "annotations" / f"{mode}_train2017.json"),
319
+ "val": (root / "val2017", root / "annotations" / f"{mode}_val2017.json"),
320
+ }
321
+
322
+ img_folder, ann_file = PATHS[image_set]
323
+ if args.lsj:
324
+ coco_transform = make_coco_transforms_lsj(
325
+ image_set,
326
+ args.lsj_img_size,
327
+ args.lsj_img_train_min,
328
+ args.lsj_strong_aug,
329
+ )
330
+ else:
331
+ coco_transform = make_coco_transforms(image_set, args.bigger)
332
+ dataset = CocoDetection(
333
+ img_folder,
334
+ ann_file,
335
+ transforms=coco_transform,
336
+ return_masks=args.masks,
337
+ cache_mode=args.cache_mode,
338
+ local_rank=get_local_rank(),
339
+ local_size=get_local_size(),
340
+ test_hflip_aug=args.test_hflip_aug,
341
+ tta=args.tta,
342
+ is_train=("train" in image_set),
343
+ lsj_img_size=args.lsj_img_size,
344
+ )
345
+ return dataset
perception_models/apps/detection/DETA_pe/datasets/coco_eval.py ADDED
@@ -0,0 +1,265 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ------------------------------------------------------------------------
2
+ # Deformable DETR
3
+ # Copyright (c) 2020 SenseTime. All Rights Reserved.
4
+ # Licensed under the Apache License, Version 2.0 [see LICENSE for details]
5
+ # ------------------------------------------------------------------------
6
+ # Modified from DETR (https://github.com/facebookresearch/detr)
7
+ # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
8
+ # ------------------------------------------------------------------------
9
+
10
+ """
11
+ COCO evaluator that works in distributed mode.
12
+
13
+ Mostly copy-paste from https://github.com/pytorch/vision/blob/edfd5a7/references/detection/coco_eval.py
14
+ The difference is that there is less copy-pasting from pycocotools
15
+ in the end of the file, as python3 can suppress prints with contextlib
16
+ """
17
+ import os
18
+ import contextlib
19
+ import copy
20
+ import numpy as np
21
+ import torch
22
+
23
+ from pycocotools.cocoeval import COCOeval
24
+ from pycocotools.coco import COCO
25
+ import pycocotools.mask as mask_util
26
+
27
+ from util.misc import all_gather
28
+
29
+
30
+ class CocoEvaluator(object):
31
+ def __init__(self, coco_gt, iou_types):
32
+ assert isinstance(iou_types, (list, tuple))
33
+ coco_gt = copy.deepcopy(coco_gt)
34
+ self.coco_gt = coco_gt
35
+
36
+ self.iou_types = iou_types
37
+ self.coco_eval = {}
38
+ for iou_type in iou_types:
39
+ self.coco_eval[iou_type] = COCOeval(coco_gt, iouType=iou_type)
40
+
41
+ self.img_ids = []
42
+ self.eval_imgs = {k: [] for k in iou_types}
43
+
44
+ def update(self, predictions):
45
+ img_ids = list(np.unique(list(predictions.keys())))
46
+ self.img_ids.extend(img_ids)
47
+
48
+ for iou_type in self.iou_types:
49
+ results = self.prepare(predictions, iou_type)
50
+
51
+ # suppress pycocotools prints
52
+ with open(os.devnull, 'w') as devnull:
53
+ with contextlib.redirect_stdout(devnull):
54
+ coco_dt = COCO.loadRes(self.coco_gt, results) if results else COCO()
55
+ coco_eval = self.coco_eval[iou_type]
56
+
57
+ coco_eval.cocoDt = coco_dt
58
+ coco_eval.params.imgIds = list(img_ids)
59
+ img_ids, eval_imgs = evaluate(coco_eval)
60
+
61
+ self.eval_imgs[iou_type].append(eval_imgs)
62
+
63
+ def synchronize_between_processes(self):
64
+ for iou_type in self.iou_types:
65
+ self.eval_imgs[iou_type] = np.concatenate(self.eval_imgs[iou_type], 2)
66
+ create_common_coco_eval(self.coco_eval[iou_type], self.img_ids, self.eval_imgs[iou_type])
67
+
68
+ def accumulate(self):
69
+ for coco_eval in self.coco_eval.values():
70
+ coco_eval.accumulate()
71
+
72
+ def summarize(self):
73
+ for iou_type, coco_eval in self.coco_eval.items():
74
+ print("IoU metric: {}".format(iou_type))
75
+ coco_eval.summarize()
76
+
77
+ def prepare(self, predictions, iou_type):
78
+ if iou_type == "bbox":
79
+ return self.prepare_for_coco_detection(predictions)
80
+ elif iou_type == "segm":
81
+ return self.prepare_for_coco_segmentation(predictions)
82
+ elif iou_type == "keypoints":
83
+ return self.prepare_for_coco_keypoint(predictions)
84
+ else:
85
+ raise ValueError("Unknown iou type {}".format(iou_type))
86
+
87
+ def prepare_for_coco_detection(self, predictions):
88
+ coco_results = []
89
+ for original_id, prediction in predictions.items():
90
+ if len(prediction) == 0:
91
+ continue
92
+
93
+ boxes = prediction["boxes"]
94
+ boxes = convert_to_xywh(boxes).tolist()
95
+ scores = prediction["scores"].tolist()
96
+ labels = prediction["labels"].tolist()
97
+
98
+ coco_results.extend(
99
+ [
100
+ {
101
+ "image_id": original_id,
102
+ "category_id": labels[k],
103
+ "bbox": box,
104
+ "score": scores[k],
105
+ }
106
+ for k, box in enumerate(boxes)
107
+ ]
108
+ )
109
+ return coco_results
110
+
111
+ def prepare_for_coco_segmentation(self, predictions):
112
+ coco_results = []
113
+ for original_id, prediction in predictions.items():
114
+ if len(prediction) == 0:
115
+ continue
116
+
117
+ scores = prediction["scores"]
118
+ labels = prediction["labels"]
119
+ masks = prediction["masks"]
120
+
121
+ masks = masks > 0.5
122
+
123
+ scores = prediction["scores"].tolist()
124
+ labels = prediction["labels"].tolist()
125
+
126
+ rles = [
127
+ mask_util.encode(np.array(mask[0, :, :, np.newaxis], dtype=np.uint8, order="F"))[0]
128
+ for mask in masks
129
+ ]
130
+ for rle in rles:
131
+ rle["counts"] = rle["counts"].decode("utf-8")
132
+
133
+ coco_results.extend(
134
+ [
135
+ {
136
+ "image_id": original_id,
137
+ "category_id": labels[k],
138
+ "segmentation": rle,
139
+ "score": scores[k],
140
+ }
141
+ for k, rle in enumerate(rles)
142
+ ]
143
+ )
144
+ return coco_results
145
+
146
+ def prepare_for_coco_keypoint(self, predictions):
147
+ coco_results = []
148
+ for original_id, prediction in predictions.items():
149
+ if len(prediction) == 0:
150
+ continue
151
+
152
+ boxes = prediction["boxes"]
153
+ boxes = convert_to_xywh(boxes).tolist()
154
+ scores = prediction["scores"].tolist()
155
+ labels = prediction["labels"].tolist()
156
+ keypoints = prediction["keypoints"]
157
+ keypoints = keypoints.flatten(start_dim=1).tolist()
158
+
159
+ coco_results.extend(
160
+ [
161
+ {
162
+ "image_id": original_id,
163
+ "category_id": labels[k],
164
+ 'keypoints': keypoint,
165
+ "score": scores[k],
166
+ }
167
+ for k, keypoint in enumerate(keypoints)
168
+ ]
169
+ )
170
+ return coco_results
171
+
172
+
173
+ def convert_to_xywh(boxes):
174
+ xmin, ymin, xmax, ymax = boxes.unbind(1)
175
+ return torch.stack((xmin, ymin, xmax - xmin, ymax - ymin), dim=1)
176
+
177
+
178
+ def merge(img_ids, eval_imgs):
179
+ all_img_ids = all_gather(img_ids)
180
+ all_eval_imgs = all_gather(eval_imgs)
181
+
182
+ merged_img_ids = []
183
+ for p in all_img_ids:
184
+ merged_img_ids.extend(p)
185
+
186
+ merged_eval_imgs = []
187
+ for p in all_eval_imgs:
188
+ merged_eval_imgs.append(p)
189
+
190
+ merged_img_ids = np.array(merged_img_ids)
191
+ merged_eval_imgs = np.concatenate(merged_eval_imgs, 2)
192
+
193
+ # keep only unique (and in sorted order) images
194
+ merged_img_ids, idx = np.unique(merged_img_ids, return_index=True)
195
+ merged_eval_imgs = merged_eval_imgs[..., idx]
196
+
197
+ return merged_img_ids, merged_eval_imgs
198
+
199
+
200
+ def create_common_coco_eval(coco_eval, img_ids, eval_imgs):
201
+ img_ids, eval_imgs = merge(img_ids, eval_imgs)
202
+ img_ids = list(img_ids)
203
+ eval_imgs = list(eval_imgs.flatten())
204
+
205
+ coco_eval.evalImgs = eval_imgs
206
+ coco_eval.params.imgIds = img_ids
207
+ coco_eval._paramsEval = copy.deepcopy(coco_eval.params)
208
+
209
+
210
+ #################################################################
211
+ # From pycocotools, just removed the prints and fixed
212
+ # a Python3 bug about unicode not defined
213
+ #################################################################
214
+
215
+
216
+ def evaluate(self):
217
+ '''
218
+ Run per image evaluation on given images and store results (a list of dict) in self.evalImgs
219
+ :return: None
220
+ '''
221
+ # tic = time.time()
222
+ # print('Running per image evaluation...')
223
+ p = self.params
224
+ # add backward compatibility if useSegm is specified in params
225
+ if p.useSegm is not None:
226
+ p.iouType = 'segm' if p.useSegm == 1 else 'bbox'
227
+ print('useSegm (deprecated) is not None. Running {} evaluation'.format(p.iouType))
228
+ # print('Evaluate annotation type *{}*'.format(p.iouType))
229
+ p.imgIds = list(np.unique(p.imgIds))
230
+ if p.useCats:
231
+ p.catIds = list(np.unique(p.catIds))
232
+ p.maxDets = sorted(p.maxDets)
233
+ self.params = p
234
+
235
+ self._prepare()
236
+ # loop through images, area range, max detection number
237
+ catIds = p.catIds if p.useCats else [-1]
238
+
239
+ if p.iouType == 'segm' or p.iouType == 'bbox':
240
+ computeIoU = self.computeIoU
241
+ elif p.iouType == 'keypoints':
242
+ computeIoU = self.computeOks
243
+ self.ious = {
244
+ (imgId, catId): computeIoU(imgId, catId)
245
+ for imgId in p.imgIds
246
+ for catId in catIds}
247
+
248
+ evaluateImg = self.evaluateImg
249
+ maxDet = p.maxDets[-1]
250
+ evalImgs = [
251
+ evaluateImg(imgId, catId, areaRng, maxDet)
252
+ for catId in catIds
253
+ for areaRng in p.areaRng
254
+ for imgId in p.imgIds
255
+ ]
256
+ # this is NOT in the pycocotools code, but could be done outside
257
+ evalImgs = np.asarray(evalImgs).reshape(len(catIds), len(p.areaRng), len(p.imgIds))
258
+ self._paramsEval = copy.deepcopy(self.params)
259
+ # toc = time.time()
260
+ # print('DONE (t={:0.2f}s).'.format(toc-tic))
261
+ return p.imgIds, evalImgs
262
+
263
+ #################################################################
264
+ # end of straight copy from pycocotools, just removing the prints
265
+ #################################################################
perception_models/apps/detection/DETA_pe/datasets/coco_panoptic.py ADDED
@@ -0,0 +1,107 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ------------------------------------------------------------------------
2
+ # Deformable DETR
3
+ # Copyright (c) 2020 SenseTime. All Rights Reserved.
4
+ # Licensed under the Apache License, Version 2.0 [see LICENSE for details]
5
+ # ------------------------------------------------------------------------
6
+ # Modified from DETR (https://github.com/facebookresearch/detr)
7
+ # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
8
+ # ------------------------------------------------------------------------
9
+
10
+ import json
11
+ from pathlib import Path
12
+
13
+ import numpy as np
14
+ import torch
15
+ from PIL import Image
16
+
17
+ from panopticapi.utils import rgb2id
18
+ from util.box_ops import masks_to_boxes
19
+
20
+ from .coco import make_coco_transforms
21
+
22
+
23
+ class CocoPanoptic:
24
+ def __init__(self, img_folder, ann_folder, ann_file, transforms=None, return_masks=True):
25
+ with open(ann_file, 'r') as f:
26
+ self.coco = json.load(f)
27
+
28
+ # sort 'images' field so that they are aligned with 'annotations'
29
+ # i.e., in alphabetical order
30
+ self.coco['images'] = sorted(self.coco['images'], key=lambda x: x['id'])
31
+ # sanity check
32
+ if "annotations" in self.coco:
33
+ for img, ann in zip(self.coco['images'], self.coco['annotations']):
34
+ assert img['file_name'][:-4] == ann['file_name'][:-4]
35
+
36
+ self.img_folder = img_folder
37
+ self.ann_folder = ann_folder
38
+ self.ann_file = ann_file
39
+ self.transforms = transforms
40
+ self.return_masks = return_masks
41
+
42
+ def __getitem__(self, idx):
43
+ ann_info = self.coco['annotations'][idx] if "annotations" in self.coco else self.coco['images'][idx]
44
+ img_path = Path(self.img_folder) / ann_info['file_name'].replace('.png', '.jpg')
45
+ ann_path = Path(self.ann_folder) / ann_info['file_name']
46
+
47
+ img = Image.open(img_path).convert('RGB')
48
+ w, h = img.size
49
+ if "segments_info" in ann_info:
50
+ masks = np.asarray(Image.open(ann_path), dtype=np.uint32)
51
+ masks = rgb2id(masks)
52
+
53
+ ids = np.array([ann['id'] for ann in ann_info['segments_info']])
54
+ masks = masks == ids[:, None, None]
55
+
56
+ masks = torch.as_tensor(masks, dtype=torch.uint8)
57
+ labels = torch.tensor([ann['category_id'] for ann in ann_info['segments_info']], dtype=torch.int64)
58
+
59
+ target = {}
60
+ target['image_id'] = torch.tensor([ann_info['image_id'] if "image_id" in ann_info else ann_info["id"]])
61
+ if self.return_masks:
62
+ target['masks'] = masks
63
+ target['labels'] = labels
64
+
65
+ target["boxes"] = masks_to_boxes(masks)
66
+
67
+ target['size'] = torch.as_tensor([int(h), int(w)])
68
+ target['orig_size'] = torch.as_tensor([int(h), int(w)])
69
+ if "segments_info" in ann_info:
70
+ for name in ['iscrowd', 'area']:
71
+ target[name] = torch.tensor([ann[name] for ann in ann_info['segments_info']])
72
+
73
+ if self.transforms is not None:
74
+ img, target = self.transforms(img, target)
75
+
76
+ return img, target
77
+
78
+ def __len__(self):
79
+ return len(self.coco['images'])
80
+
81
+ def get_height_and_width(self, idx):
82
+ img_info = self.coco['images'][idx]
83
+ height = img_info['height']
84
+ width = img_info['width']
85
+ return height, width
86
+
87
+
88
+ def build(image_set, args):
89
+ img_folder_root = Path(args.coco_path)
90
+ ann_folder_root = Path(args.coco_panoptic_path)
91
+ assert img_folder_root.exists(), f'provided COCO path {img_folder_root} does not exist'
92
+ assert ann_folder_root.exists(), f'provided COCO path {ann_folder_root} does not exist'
93
+ mode = 'panoptic'
94
+ PATHS = {
95
+ "train": ("train2017", Path("annotations") / f'{mode}_train2017.json'),
96
+ "val": ("val2017", Path("annotations") / f'{mode}_val2017.json'),
97
+ }
98
+
99
+ img_folder, ann_file = PATHS[image_set]
100
+ img_folder_path = img_folder_root / img_folder
101
+ ann_folder = ann_folder_root / f'{mode}_{img_folder}'
102
+ ann_file = ann_folder_root / ann_file
103
+
104
+ dataset = CocoPanoptic(img_folder_path, ann_folder, ann_file,
105
+ transforms=make_coco_transforms(image_set), return_masks=args.masks)
106
+
107
+ return dataset
perception_models/apps/detection/DETA_pe/datasets/data_prefetcher.py ADDED
@@ -0,0 +1,70 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ------------------------------------------------------------------------
2
+ # Deformable DETR
3
+ # Copyright (c) 2020 SenseTime. All Rights Reserved.
4
+ # Licensed under the Apache License, Version 2.0 [see LICENSE for details]
5
+ # ------------------------------------------------------------------------
6
+
7
+ import torch
8
+
9
+ def to_cuda(samples, targets, device):
10
+ samples = samples.to(device, non_blocking=True)
11
+ targets = [{k: v.to(device, non_blocking=True) for k, v in t.items()} for t in targets]
12
+ return samples, targets
13
+
14
+ class data_prefetcher():
15
+ def __init__(self, loader, device, prefetch=True):
16
+ self.loader = iter(loader)
17
+ self.prefetch = prefetch
18
+ self.device = device
19
+ if prefetch:
20
+ self.stream = torch.cuda.Stream()
21
+ self.preload()
22
+
23
+ def preload(self):
24
+ try:
25
+ self.next_samples, self.next_targets = next(self.loader)
26
+ except StopIteration:
27
+ self.next_samples = None
28
+ self.next_targets = None
29
+ return
30
+ # if record_stream() doesn't work, another option is to make sure device inputs are created
31
+ # on the main stream.
32
+ # self.next_input_gpu = torch.empty_like(self.next_input, device='cuda')
33
+ # self.next_target_gpu = torch.empty_like(self.next_target, device='cuda')
34
+ # Need to make sure the memory allocated for next_* is not still in use by the main stream
35
+ # at the time we start copying to next_*:
36
+ # self.stream.wait_stream(torch.cuda.current_stream())
37
+ with torch.cuda.stream(self.stream):
38
+ self.next_samples, self.next_targets = to_cuda(self.next_samples, self.next_targets, self.device)
39
+ # more code for the alternative if record_stream() doesn't work:
40
+ # copy_ will record the use of the pinned source tensor in this side stream.
41
+ # self.next_input_gpu.copy_(self.next_input, non_blocking=True)
42
+ # self.next_target_gpu.copy_(self.next_target, non_blocking=True)
43
+ # self.next_input = self.next_input_gpu
44
+ # self.next_target = self.next_target_gpu
45
+
46
+ # With Amp, it isn't necessary to manually convert data to half.
47
+ # if args.fp16:
48
+ # self.next_input = self.next_input.half()
49
+ # else:
50
+
51
+ def next(self):
52
+ if self.prefetch:
53
+ torch.cuda.current_stream().wait_stream(self.stream)
54
+ samples = self.next_samples
55
+ targets = self.next_targets
56
+ if samples is not None:
57
+ samples.record_stream(torch.cuda.current_stream())
58
+ if targets is not None:
59
+ for t in targets:
60
+ for k, v in t.items():
61
+ v.record_stream(torch.cuda.current_stream())
62
+ self.preload()
63
+ else:
64
+ try:
65
+ samples, targets = next(self.loader)
66
+ samples, targets = to_cuda(samples, targets, self.device)
67
+ except StopIteration:
68
+ samples = None
69
+ targets = None
70
+ return samples, targets
perception_models/apps/detection/DETA_pe/datasets/objects365.py ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ------------------------------------------------------------------------
2
+ # Deformable DETR
3
+ # Copyright (c) 2020 SenseTime. All Rights Reserved.
4
+ # Licensed under the Apache License, Version 2.0 [see LICENSE for details]
5
+ # ------------------------------------------------------------------------
6
+ # Modified from DETR (https://github.com/facebookresearch/detr)
7
+ # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
8
+ # ------------------------------------------------------------------------
9
+
10
+ """
11
+ COCO dataset which returns image_id for evaluation.
12
+
13
+ Mostly copy-paste from https://github.com/pytorch/vision/blob/13b35ff/references/detection/coco_utils.py
14
+ """
15
+ from pathlib import Path
16
+
17
+ import datasets.transforms as T
18
+
19
+ import torch
20
+ import torch.utils.data
21
+ from pycocotools import mask as coco_mask
22
+ from util.misc import get_local_rank, get_local_size
23
+
24
+ from .coco import CocoDetection, make_coco_transforms, make_coco_transforms_lsj
25
+ from .torchvision_datasets import CocoDetection as TvCocoDetection
26
+
27
+
28
+ def build(image_set, args):
29
+ root = Path(args.coco_path)
30
+ assert root.exists(), f"provided Objects365 path {root} does not exist"
31
+ mode = "instances"
32
+ PATHS = {
33
+ "train": (
34
+ root / "train",
35
+ root / "annotations" / "zhiyuan_objv2_train_fixmiss.json",
36
+ ),
37
+ "val": (root / "val", root / "annotations" / "zhiyuan_objv2_val.json"),
38
+ }
39
+
40
+ img_folder, ann_file = PATHS[image_set]
41
+ if args.lsj:
42
+ coco_transform = make_coco_transforms_lsj(image_set, args.lsj_img_size)
43
+ else:
44
+ coco_transform = make_coco_transforms(image_set, args.bigger)
45
+ dataset = CocoDetection(
46
+ img_folder,
47
+ ann_file,
48
+ transforms=coco_transform,
49
+ return_masks=args.masks,
50
+ cache_mode=args.cache_mode,
51
+ local_rank=get_local_rank(),
52
+ local_size=get_local_size(),
53
+ )
54
+ return dataset
perception_models/apps/detection/DETA_pe/datasets/panoptic_eval.py ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ------------------------------------------------------------------------
2
+ # Deformable DETR
3
+ # Copyright (c) 2020 SenseTime. All Rights Reserved.
4
+ # Licensed under the Apache License, Version 2.0 [see LICENSE for details]
5
+ # ------------------------------------------------------------------------
6
+ # Modified from DETR (https://github.com/facebookresearch/detr)
7
+ # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
8
+ # ------------------------------------------------------------------------
9
+
10
+ import json
11
+ import os
12
+
13
+ import util.misc as utils
14
+
15
+ try:
16
+ from panopticapi.evaluation import pq_compute
17
+ except ImportError:
18
+ pass
19
+
20
+
21
+ class PanopticEvaluator(object):
22
+ def __init__(self, ann_file, ann_folder, output_dir="panoptic_eval"):
23
+ self.gt_json = ann_file
24
+ self.gt_folder = ann_folder
25
+ if utils.is_main_process():
26
+ if not os.path.exists(output_dir):
27
+ os.mkdir(output_dir)
28
+ self.output_dir = output_dir
29
+ self.predictions = []
30
+
31
+ def update(self, predictions):
32
+ for p in predictions:
33
+ with open(os.path.join(self.output_dir, p["file_name"]), "wb") as f:
34
+ f.write(p.pop("png_string"))
35
+
36
+ self.predictions += predictions
37
+
38
+ def synchronize_between_processes(self):
39
+ all_predictions = utils.all_gather(self.predictions)
40
+ merged_predictions = []
41
+ for p in all_predictions:
42
+ merged_predictions += p
43
+ self.predictions = merged_predictions
44
+
45
+ def summarize(self):
46
+ if utils.is_main_process():
47
+ json_data = {"annotations": self.predictions}
48
+ predictions_json = os.path.join(self.output_dir, "predictions.json")
49
+ with open(predictions_json, "w") as f:
50
+ f.write(json.dumps(json_data))
51
+ return pq_compute(self.gt_json, predictions_json, gt_folder=self.gt_folder, pred_folder=self.output_dir)
52
+ return None
perception_models/apps/detection/DETA_pe/datasets/samplers.py ADDED
@@ -0,0 +1,348 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ------------------------------------------------------------------------
2
+ # Deformable DETR
3
+ # Copyright (c) 2020 SenseTime. All Rights Reserved.
4
+ # Licensed under the Apache License, Version 2.0 [see LICENSE for details]
5
+ # ------------------------------------------------------------------------
6
+ # Modified from codes in torch.utils.data.distributed
7
+ # ------------------------------------------------------------------------
8
+
9
+ import json
10
+ import math
11
+ import os
12
+ from collections import defaultdict
13
+
14
+ import torch
15
+ import torch.distributed as dist
16
+
17
+ from fvcore.common.timer import Timer
18
+ from lvis import LVIS
19
+ from torch.utils.data.sampler import Sampler
20
+
21
+
22
+ def load_dataset_dicts(json_file):
23
+ timer = Timer()
24
+ lvis_api = LVIS(json_file)
25
+ if timer.seconds() > 1:
26
+ print("Loading {} takes {:.2f} seconds.".format(json_file, timer.seconds()))
27
+
28
+ img_ids = sorted(lvis_api.imgs.keys())
29
+ imgs = lvis_api.load_imgs(img_ids)
30
+ anns = [lvis_api.img_ann_map[img_id] for img_id in img_ids]
31
+
32
+ imgs_anns = list(zip(imgs, anns))
33
+ print(
34
+ "Loaded {} images in the LVIS format from {}".format(len(imgs_anns), json_file)
35
+ )
36
+ dataset_dicts = []
37
+
38
+ for img_dict, anno_dict_list in imgs_anns:
39
+ record = {}
40
+ image_id = record["image_id"] = img_dict["id"]
41
+ objs = []
42
+ for anno in anno_dict_list:
43
+ # Check that the image_id in this annotation is the same as
44
+ # the image_id we're looking at.
45
+ # This fails only when the data parsing logic or the annotation file is buggy.
46
+ assert anno["image_id"] == image_id
47
+ obj = {}
48
+ # Convert 1-indexed to 0-indexed
49
+ obj["category_id"] = anno["category_id"] - 1
50
+
51
+ objs.append(obj)
52
+ record["annotations"] = objs
53
+ dataset_dicts.append(record)
54
+
55
+ return dataset_dicts
56
+
57
+
58
+ def repeat_factors_from_category_frequency(dataset_dicts, repeat_thresh, sqrt=True):
59
+ # 1. For each category c, compute the fraction of images that contain it: f(c)
60
+ category_freq = defaultdict(int)
61
+ for dataset_dict in dataset_dicts: # For each image (without repeats)
62
+ cat_ids = {ann["category_id"] for ann in dataset_dict["annotations"]}
63
+ for cat_id in cat_ids:
64
+ category_freq[cat_id] += 1
65
+ num_images = len(dataset_dicts)
66
+ for k, v in category_freq.items():
67
+ category_freq[k] = v / num_images
68
+
69
+ # 2. For each category c, compute the category-level repeat factor:
70
+ # r(c) = max(1, sqrt(t / f(c)))
71
+ category_rep = {
72
+ cat_id: max(
73
+ 1.0,
74
+ (
75
+ math.sqrt(repeat_thresh / cat_freq)
76
+ if sqrt
77
+ else (repeat_thresh / cat_freq)
78
+ ),
79
+ )
80
+ for cat_id, cat_freq in category_freq.items()
81
+ }
82
+ for cat_id in sorted(category_rep.keys()):
83
+ print(
84
+ f"Cat ID {cat_id}: freq={category_freq[cat_id]:.2f}, rep={category_rep[cat_id]:.2f}"
85
+ )
86
+
87
+ # 3. For each image I, compute the image-level repeat factor:
88
+ # r(I) = max_{c in I} r(c)
89
+ rep_factors = []
90
+ for dataset_dict in dataset_dicts:
91
+ cat_ids = {ann["category_id"] for ann in dataset_dict["annotations"]}
92
+ rep_factor = max({category_rep[cat_id] for cat_id in cat_ids}, default=1.0)
93
+ rep_factors.append(rep_factor)
94
+
95
+ return torch.tensor(rep_factors, dtype=torch.float32)
96
+
97
+
98
+ class RepeatFactorTrainingSampler(Sampler):
99
+ def __init__(
100
+ self,
101
+ dataset,
102
+ num_replicas=None,
103
+ rank=None,
104
+ local_rank=None,
105
+ local_size=None,
106
+ shuffle=True,
107
+ ):
108
+ if num_replicas is None:
109
+ if not dist.is_available():
110
+ raise RuntimeError("Requires distributed package to be available")
111
+ num_replicas = dist.get_world_size()
112
+ if rank is None:
113
+ if not dist.is_available():
114
+ raise RuntimeError("Requires distributed package to be available")
115
+ rank = dist.get_rank()
116
+ self.dataset = dataset
117
+ self.num_replicas = num_replicas
118
+ self.rank = rank
119
+ self.epoch = 0
120
+ self.num_samples = int(math.ceil(len(self.dataset) * 1.0 / self.num_replicas))
121
+ self.total_size = self.num_samples * self.num_replicas
122
+ self.shuffle = shuffle
123
+
124
+ json_file = (
125
+ "/checkpoint/onevision/peizesun/public_data/d2_data/lvis/lvis_v1_train.json"
126
+ )
127
+ dataset_dicts = load_dataset_dicts(json_file)
128
+ repeat_factors = repeat_factors_from_category_frequency(
129
+ dataset_dicts, repeat_thresh=0.001
130
+ )
131
+ # Split into whole number (_int_part) and fractional (_frac_part) parts.
132
+ self._int_part = torch.trunc(repeat_factors)
133
+ self._frac_part = repeat_factors - self._int_part
134
+
135
+ def _get_epoch_indices(self, generator):
136
+ """
137
+ Create a list of dataset indices (with repeats) to use for one epoch.
138
+
139
+ Args:
140
+ generator (torch.Generator): pseudo random number generator used for
141
+ stochastic rounding.
142
+
143
+ Returns:
144
+ torch.Tensor: list of dataset indices to use in one epoch. Each index
145
+ is repeated based on its calculated repeat factor.
146
+ """
147
+ # Since repeat factors are fractional, we use stochastic rounding so
148
+ # that the target repeat factor is achieved in expectation over the
149
+ # course of training
150
+ rands = torch.rand(len(self._frac_part), generator=generator)
151
+ rep_factors = self._int_part + (rands < self._frac_part).float()
152
+ # Construct a list of indices in which we repeat images as specified
153
+ indices = []
154
+ for dataset_index, rep_factor in enumerate(rep_factors):
155
+ indices.extend([dataset_index] * int(rep_factor.item()))
156
+ return torch.tensor(indices, dtype=torch.int64)
157
+
158
+ def __iter__(self):
159
+ if self.shuffle:
160
+ g = torch.Generator()
161
+ g.manual_seed(self.epoch)
162
+ # Sample indices with repeats determined by stochastic rounding; each
163
+ # "epoch" may have a slightly different size due to the rounding.
164
+ rfs_indices = self._get_epoch_indices(g)
165
+ # deterministically shuffle based on epoch
166
+ randperm = torch.randperm(len(rfs_indices), generator=g)
167
+ indices = rfs_indices[randperm].tolist()
168
+ else:
169
+ g = torch.Generator()
170
+ g.manual_seed(0)
171
+ # Sample indices with repeats determined by stochastic rounding; each
172
+ # "epoch" may have a slightly different size due to the rounding.
173
+ rfs_indices = self._get_epoch_indices(g)
174
+ indices = rfs_indices.tolist()
175
+
176
+ # add extra samples to make it evenly divisible
177
+ if self.total_size > len(indices):
178
+ indices += indices[: (self.total_size - len(indices))]
179
+ assert len(indices) == self.total_size
180
+ # subsample
181
+ offset = self.num_samples * self.rank
182
+ indices = indices[offset : offset + self.num_samples]
183
+ assert len(indices) == self.num_samples
184
+
185
+ return iter(indices)
186
+ else:
187
+ self.num_samples = int(math.ceil(len(indices) * 1.0 / self.num_replicas))
188
+ self.total_size = self.num_samples * self.num_replicas
189
+ indices += indices[: (self.total_size - len(indices))]
190
+ assert len(indices) == self.total_size
191
+ # subsample
192
+ offset = self.num_samples * self.rank
193
+ indices = indices[offset : offset + self.num_samples]
194
+ assert len(indices) == self.num_samples
195
+
196
+ return iter(indices)
197
+
198
+ def __len__(self):
199
+ return self.num_samples
200
+
201
+ def set_epoch(self, epoch):
202
+ self.epoch = epoch
203
+
204
+
205
+ class DistributedSampler(Sampler):
206
+ """Sampler that restricts data loading to a subset of the dataset.
207
+ It is especially useful in conjunction with
208
+ :class:`torch.nn.parallel.DistributedDataParallel`. In such case, each
209
+ process can pass a DistributedSampler instance as a DataLoader sampler,
210
+ and load a subset of the original dataset that is exclusive to it.
211
+ .. note::
212
+ Dataset is assumed to be of constant size.
213
+ Arguments:
214
+ dataset: Dataset used for sampling.
215
+ num_replicas (optional): Number of processes participating in
216
+ distributed training.
217
+ rank (optional): Rank of the current process within num_replicas.
218
+ """
219
+
220
+ def __init__(
221
+ self,
222
+ dataset,
223
+ num_replicas=None,
224
+ rank=None,
225
+ local_rank=None,
226
+ local_size=None,
227
+ shuffle=True,
228
+ ):
229
+ if num_replicas is None:
230
+ if not dist.is_available():
231
+ raise RuntimeError("Requires distributed package to be available")
232
+ num_replicas = dist.get_world_size()
233
+ if rank is None:
234
+ if not dist.is_available():
235
+ raise RuntimeError("Requires distributed package to be available")
236
+ rank = dist.get_rank()
237
+ self.dataset = dataset
238
+ self.num_replicas = num_replicas
239
+ self.rank = rank
240
+ self.epoch = 0
241
+ self.num_samples = int(math.ceil(len(self.dataset) * 1.0 / self.num_replicas))
242
+ self.total_size = self.num_samples * self.num_replicas
243
+ self.shuffle = shuffle
244
+
245
+ def __iter__(self):
246
+ if self.shuffle:
247
+ # deterministically shuffle based on epoch
248
+ g = torch.Generator()
249
+ g.manual_seed(self.epoch)
250
+ indices = torch.randperm(len(self.dataset), generator=g).tolist()
251
+ else:
252
+ indices = torch.arange(len(self.dataset)).tolist()
253
+
254
+ # add extra samples to make it evenly divisible
255
+ indices += indices[: (self.total_size - len(indices))]
256
+ assert len(indices) == self.total_size
257
+
258
+ # subsample
259
+ offset = self.num_samples * self.rank
260
+ indices = indices[offset : offset + self.num_samples]
261
+ assert len(indices) == self.num_samples
262
+
263
+ return iter(indices)
264
+
265
+ def __len__(self):
266
+ return self.num_samples
267
+
268
+ def set_epoch(self, epoch):
269
+ self.epoch = epoch
270
+
271
+
272
+ class NodeDistributedSampler(Sampler):
273
+ """Sampler that restricts data loading to a subset of the dataset.
274
+ It is especially useful in conjunction with
275
+ :class:`torch.nn.parallel.DistributedDataParallel`. In such case, each
276
+ process can pass a DistributedSampler instance as a DataLoader sampler,
277
+ and load a subset of the original dataset that is exclusive to it.
278
+ .. note::
279
+ Dataset is assumed to be of constant size.
280
+ Arguments:
281
+ dataset: Dataset used for sampling.
282
+ num_replicas (optional): Number of processes participating in
283
+ distributed training.
284
+ rank (optional): Rank of the current process within num_replicas.
285
+ """
286
+
287
+ def __init__(
288
+ self,
289
+ dataset,
290
+ num_replicas=None,
291
+ rank=None,
292
+ local_rank=None,
293
+ local_size=None,
294
+ shuffle=True,
295
+ ):
296
+ if num_replicas is None:
297
+ if not dist.is_available():
298
+ raise RuntimeError("Requires distributed package to be available")
299
+ num_replicas = dist.get_world_size()
300
+ if rank is None:
301
+ if not dist.is_available():
302
+ raise RuntimeError("Requires distributed package to be available")
303
+ rank = dist.get_rank()
304
+ if local_rank is None:
305
+ local_rank = int(os.environ.get("LOCAL_RANK", 0))
306
+ if local_size is None:
307
+ local_size = int(os.environ.get("LOCAL_SIZE", 1))
308
+ self.dataset = dataset
309
+ self.shuffle = shuffle
310
+ self.num_replicas = num_replicas
311
+ self.num_parts = local_size
312
+ self.rank = rank
313
+ self.local_rank = local_rank
314
+ self.epoch = 0
315
+ self.num_samples = int(math.ceil(len(self.dataset) * 1.0 / self.num_replicas))
316
+ self.total_size = self.num_samples * self.num_replicas
317
+
318
+ self.total_size_parts = self.num_samples * self.num_replicas // self.num_parts
319
+
320
+ def __iter__(self):
321
+ if self.shuffle:
322
+ # deterministically shuffle based on epoch
323
+ g = torch.Generator()
324
+ g.manual_seed(self.epoch)
325
+ indices = torch.randperm(len(self.dataset), generator=g).tolist()
326
+ else:
327
+ indices = torch.arange(len(self.dataset)).tolist()
328
+ indices = [i for i in indices if i % self.num_parts == self.local_rank]
329
+
330
+ # add extra samples to make it evenly divisible
331
+ indices += indices[: (self.total_size_parts - len(indices))]
332
+ assert len(indices) == self.total_size_parts
333
+
334
+ # subsample
335
+ indices = indices[
336
+ self.rank
337
+ // self.num_parts : self.total_size_parts : self.num_replicas
338
+ // self.num_parts
339
+ ]
340
+ assert len(indices) == self.num_samples
341
+
342
+ return iter(indices)
343
+
344
+ def __len__(self):
345
+ return self.num_samples
346
+
347
+ def set_epoch(self, epoch):
348
+ self.epoch = epoch
perception_models/apps/detection/DETA_pe/datasets/torchvision_datasets/__init__.py ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ # ------------------------------------------------------------------------
2
+ # Deformable DETR
3
+ # Copyright (c) 2020 SenseTime. All Rights Reserved.
4
+ # Licensed under the Apache License, Version 2.0 [see LICENSE for details]
5
+ # ------------------------------------------------------------------------
6
+
7
+ from .coco import CocoDetection
perception_models/apps/detection/DETA_pe/datasets/torchvision_datasets/coco.py ADDED
@@ -0,0 +1,84 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ------------------------------------------------------------------------
2
+ # Deformable DETR
3
+ # Copyright (c) 2020 SenseTime. All Rights Reserved.
4
+ # Licensed under the Apache License, Version 2.0 [see LICENSE for details]
5
+ # ------------------------------------------------------------------------
6
+ # Modified from torchvision
7
+ # ------------------------------------------------------------------------
8
+
9
+ """
10
+ Copy-Paste from torchvision, but add utility of caching images on memory
11
+ """
12
+ from torchvision.datasets.vision import VisionDataset
13
+ from PIL import Image
14
+ import os
15
+ import os.path
16
+ import tqdm
17
+ from io import BytesIO
18
+
19
+
20
+ class CocoDetection(VisionDataset):
21
+ """`MS Coco Detection <http://mscoco.org/dataset/#detections-challenge2016>`_ Dataset.
22
+ Args:
23
+ root (string): Root directory where images are downloaded to.
24
+ annFile (string): Path to json annotation file.
25
+ transform (callable, optional): A function/transform that takes in an PIL image
26
+ and returns a transformed version. E.g, ``transforms.ToTensor``
27
+ target_transform (callable, optional): A function/transform that takes in the
28
+ target and transforms it.
29
+ transforms (callable, optional): A function/transform that takes input sample and its target as entry
30
+ and returns a transformed version.
31
+ """
32
+
33
+ def __init__(self, root, annFile, transform=None, target_transform=None, transforms=None,
34
+ cache_mode=False, local_rank=0, local_size=1):
35
+ super(CocoDetection, self).__init__(root, transforms, transform, target_transform)
36
+ from pycocotools.coco import COCO
37
+ self.coco = COCO(annFile)
38
+ self.ids = list(sorted(self.coco.imgs.keys()))
39
+ self.cache_mode = cache_mode
40
+ self.local_rank = local_rank
41
+ self.local_size = local_size
42
+ if cache_mode:
43
+ self.cache = {}
44
+ self.cache_images()
45
+
46
+ def cache_images(self):
47
+ self.cache = {}
48
+ for index, img_id in zip(tqdm.trange(len(self.ids)), self.ids):
49
+ if index % self.local_size != self.local_rank:
50
+ continue
51
+ path = self.coco.loadImgs(img_id)[0]['file_name']
52
+ with open(os.path.join(self.root, path), 'rb') as f:
53
+ self.cache[path] = f.read()
54
+
55
+ def get_image(self, path):
56
+ if self.cache_mode:
57
+ if path not in self.cache.keys():
58
+ with open(os.path.join(self.root, path), 'rb') as f:
59
+ self.cache[path] = f.read()
60
+ return Image.open(BytesIO(self.cache[path])).convert('RGB')
61
+ return Image.open(os.path.join(self.root, path)).convert('RGB')
62
+
63
+ def __getitem__(self, index):
64
+ """
65
+ Args:
66
+ index (int): Index
67
+ Returns:
68
+ tuple: Tuple (image, target). target is the object returned by ``coco.loadAnns``.
69
+ """
70
+ coco = self.coco
71
+ img_id = self.ids[index]
72
+ ann_ids = coco.getAnnIds(imgIds=img_id)
73
+ target = coco.loadAnns(ann_ids)
74
+
75
+ path = coco.loadImgs(img_id)[0]['file_name']
76
+
77
+ img = self.get_image(path)
78
+ if self.transforms is not None:
79
+ img, target = self.transforms(img, target)
80
+
81
+ return img, target
82
+
83
+ def __len__(self):
84
+ return len(self.ids)
perception_models/apps/detection/DETA_pe/datasets/transforms.py ADDED
@@ -0,0 +1,327 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ------------------------------------------------------------------------
2
+ # Deformable DETR
3
+ # Copyright (c) 2020 SenseTime. All Rights Reserved.
4
+ # Licensed under the Apache License, Version 2.0 [see LICENSE for details]
5
+ # ------------------------------------------------------------------------
6
+ # Modified from DETR (https://github.com/facebookresearch/detr)
7
+ # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
8
+ # ------------------------------------------------------------------------
9
+
10
+ """
11
+ Transforms and data augmentation for both image + bbox.
12
+ """
13
+ import random
14
+
15
+ import PIL
16
+ import torch
17
+ import torchvision.transforms as T
18
+ import torchvision.transforms.functional as F
19
+
20
+ from util.box_ops import box_xyxy_to_cxcywh
21
+ from util.misc import interpolate
22
+
23
+
24
+ def crop(image, target, region):
25
+ cropped_image = F.crop(image, *region)
26
+
27
+ target = target.copy()
28
+ i, j, h, w = region
29
+
30
+ # should we do something wrt the original size?
31
+ target["size"] = torch.tensor([h, w])
32
+
33
+ fields = ["labels", "area", "iscrowd"]
34
+
35
+ if "boxes" in target:
36
+ boxes = target["boxes"]
37
+ max_size = torch.as_tensor([w, h], dtype=torch.float32)
38
+ cropped_boxes = boxes - torch.as_tensor([j, i, j, i])
39
+ cropped_boxes = torch.min(cropped_boxes.reshape(-1, 2, 2), max_size)
40
+ cropped_boxes = cropped_boxes.clamp(min=0)
41
+ area = (cropped_boxes[:, 1, :] - cropped_boxes[:, 0, :]).prod(dim=1)
42
+ target["boxes"] = cropped_boxes.reshape(-1, 4)
43
+ target["area"] = area
44
+ fields.append("boxes")
45
+
46
+ if "masks" in target:
47
+ # FIXME should we update the area here if there are no boxes?
48
+ target["masks"] = target["masks"][:, i : i + h, j : j + w]
49
+ fields.append("masks")
50
+
51
+ # remove elements for which the boxes or masks that have zero area
52
+ if "boxes" in target or "masks" in target:
53
+ # favor boxes selection when defining which elements to keep
54
+ # this is compatible with previous implementation
55
+ if "boxes" in target:
56
+ cropped_boxes = target["boxes"].reshape(-1, 2, 2)
57
+ keep = torch.all(cropped_boxes[:, 1, :] > cropped_boxes[:, 0, :], dim=1)
58
+ else:
59
+ keep = target["masks"].flatten(1).any(1)
60
+
61
+ for field in fields:
62
+ target[field] = target[field][keep]
63
+
64
+ return cropped_image, target
65
+
66
+
67
+ def hflip(image, target):
68
+ flipped_image = F.hflip(image)
69
+
70
+ w, h = image.size
71
+
72
+ target = target.copy()
73
+ if "boxes" in target:
74
+ boxes = target["boxes"]
75
+ boxes = boxes[:, [2, 1, 0, 3]] * torch.as_tensor(
76
+ [-1, 1, -1, 1]
77
+ ) + torch.as_tensor([w, 0, w, 0])
78
+ target["boxes"] = boxes
79
+
80
+ if "masks" in target:
81
+ target["masks"] = target["masks"].flip(-1)
82
+
83
+ return flipped_image, target
84
+
85
+
86
+ def resize(image, target, size, max_size=None):
87
+ # size can be min_size (scalar) or (w, h) tuple
88
+
89
+ def get_size_with_aspect_ratio(image_size, size, max_size=None):
90
+ w, h = image_size
91
+ if max_size is not None:
92
+ min_original_size = float(min((w, h)))
93
+ max_original_size = float(max((w, h)))
94
+ if max_original_size / min_original_size * size > max_size:
95
+ size = int(round(max_size * min_original_size / max_original_size))
96
+
97
+ if (w <= h and w == size) or (h <= w and h == size):
98
+ return (h, w)
99
+ if w < h:
100
+ ow = size
101
+ oh = int(size * h / w)
102
+ else:
103
+ oh = size
104
+ ow = int(size * w / h)
105
+ return (oh, ow)
106
+
107
+ def get_size(image_size, size, max_size=None):
108
+ if isinstance(size, (list, tuple)):
109
+ return size[::-1]
110
+ else:
111
+ return get_size_with_aspect_ratio(image_size, size, max_size)
112
+
113
+ size = get_size(image.size, size, max_size)
114
+ rescaled_image = F.resize(image, size)
115
+
116
+ if target is None:
117
+ return rescaled_image, None
118
+
119
+ ratios = tuple(
120
+ float(s) / float(s_orig) for s, s_orig in zip(rescaled_image.size, image.size)
121
+ )
122
+ ratio_width, ratio_height = ratios
123
+
124
+ target = target.copy()
125
+ if "boxes" in target:
126
+ boxes = target["boxes"]
127
+ scaled_boxes = boxes * torch.as_tensor(
128
+ [ratio_width, ratio_height, ratio_width, ratio_height]
129
+ )
130
+ target["boxes"] = scaled_boxes
131
+
132
+ if "area" in target:
133
+ area = target["area"]
134
+ scaled_area = area * (ratio_width * ratio_height)
135
+ target["area"] = scaled_area
136
+
137
+ h, w = size
138
+ target["size"] = torch.tensor([h, w])
139
+
140
+ if "masks" in target:
141
+ target["masks"] = (
142
+ interpolate(target["masks"][:, None].float(), size, mode="nearest")[:, 0]
143
+ > 0.5
144
+ )
145
+
146
+ return rescaled_image, target
147
+
148
+
149
+ def pad(image, target, padding):
150
+ # assumes that we only pad on the bottom right corners
151
+ padded_image = F.pad(image, (0, 0, padding[0], padding[1]))
152
+ if target is None:
153
+ return padded_image, None
154
+ target = target.copy()
155
+ # should we do something wrt the original size?
156
+ target["size"] = torch.tensor(padded_image[::-1])
157
+ if "masks" in target:
158
+ target["masks"] = torch.nn.functional.pad(
159
+ target["masks"], (0, padding[0], 0, padding[1])
160
+ )
161
+ return padded_image, target
162
+
163
+
164
+ class RandomCrop(object):
165
+ def __init__(self, size):
166
+ self.size = size
167
+
168
+ def __call__(self, img, target):
169
+ region = T.RandomCrop.get_params(img, self.size)
170
+ return crop(img, target, region)
171
+
172
+
173
+ class RandomSizeCrop(object):
174
+ def __init__(self, min_size: int, max_size: int):
175
+ self.min_size = min_size
176
+ self.max_size = max_size
177
+
178
+ def __call__(self, img: PIL.Image.Image, target: dict):
179
+ w = random.randint(self.min_size, min(img.width, self.max_size))
180
+ h = random.randint(self.min_size, min(img.height, self.max_size))
181
+ region = T.RandomCrop.get_params(img, [h, w])
182
+ return crop(img, target, region)
183
+
184
+
185
+ class CenterCrop(object):
186
+ def __init__(self, size):
187
+ self.size = size
188
+
189
+ def __call__(self, img, target):
190
+ image_width, image_height = img.size
191
+ crop_height, crop_width = self.size
192
+ crop_top = int(round((image_height - crop_height) / 2.0))
193
+ crop_left = int(round((image_width - crop_width) / 2.0))
194
+ return crop(img, target, (crop_top, crop_left, crop_height, crop_width))
195
+
196
+
197
+ class RandomHorizontalFlip(object):
198
+ def __init__(self, p=0.5):
199
+ self.p = p
200
+
201
+ def __call__(self, img, target):
202
+ if random.random() < self.p:
203
+ return hflip(img, target)
204
+ return img, target
205
+
206
+
207
+ class RandomResize(object):
208
+ def __init__(self, sizes, max_size=None):
209
+ assert isinstance(sizes, (list, tuple))
210
+ self.sizes = sizes
211
+ self.max_size = max_size
212
+
213
+ def __call__(self, img, target=None):
214
+ size = random.choice(self.sizes)
215
+ return resize(img, target, size, self.max_size)
216
+
217
+
218
+ class RandomPad(object):
219
+ def __init__(self, max_pad):
220
+ self.max_pad = max_pad
221
+
222
+ def __call__(self, img, target):
223
+ pad_x = random.randint(0, self.max_pad)
224
+ pad_y = random.randint(0, self.max_pad)
225
+ return pad(img, target, (pad_x, pad_y))
226
+
227
+
228
+ class RandomSelect(object):
229
+ """
230
+ Randomly selects between transforms1 and transforms2,
231
+ with probability p for transforms1 and (1 - p) for transforms2
232
+ """
233
+
234
+ def __init__(self, transforms1, transforms2, p=0.5):
235
+ self.transforms1 = transforms1
236
+ self.transforms2 = transforms2
237
+ self.p = p
238
+
239
+ def __call__(self, img, target):
240
+ if random.random() < self.p:
241
+ return self.transforms1(img, target)
242
+ return self.transforms2(img, target)
243
+
244
+
245
+ class ToTensor(object):
246
+ def __call__(self, img, target):
247
+ return F.to_tensor(img), target
248
+
249
+
250
+ class RandomErasingP05(object):
251
+ def __init__(self):
252
+ self.eraser = T.Compose(
253
+ [
254
+ T.ToTensor(),
255
+ T.RandomErasing(
256
+ p=0.5, scale=(0.02, 0.2), ratio=(0.1, 6), value="random"
257
+ ),
258
+ T.ToPILImage(),
259
+ ]
260
+ )
261
+
262
+ def __call__(self, img, target):
263
+ return self.eraser(img), target
264
+
265
+
266
+ class RandomErasing(object):
267
+ def __init__(self, *args, **kwargs):
268
+ self.eraser = T.RandomErasing(*args, **kwargs)
269
+
270
+ def __call__(self, img, target):
271
+ return self.eraser(img), target
272
+
273
+
274
+ class ColorJitter(object):
275
+ def __init__(self, jitter=(0.2, 0.2, 0.2, 0.1), p=0.5):
276
+ self.color_jitter = T.ColorJitter(*jitter)
277
+ self.p = p
278
+
279
+ def __call__(self, img, target):
280
+ if random.random() < self.p:
281
+ return self.color_jitter(img), target
282
+ return img, target
283
+
284
+
285
+ class RandomGrayscale(object):
286
+ def __init__(self, p=0.5):
287
+ self.random_gray = T.RandomGrayscale(p=p)
288
+
289
+ def __call__(self, img, target):
290
+ return self.random_gray(img), target
291
+
292
+
293
+ class Normalize(object):
294
+ def __init__(self, mean, std):
295
+ self.mean = mean
296
+ self.std = std
297
+
298
+ def __call__(self, image, target=None):
299
+ image = F.normalize(image, mean=self.mean, std=self.std)
300
+ if target is None:
301
+ return image, None
302
+ target = target.copy()
303
+ h, w = image.shape[-2:]
304
+ if "boxes" in target:
305
+ boxes = target["boxes"]
306
+ boxes = box_xyxy_to_cxcywh(boxes)
307
+ boxes = boxes / torch.tensor([w, h, w, h], dtype=torch.float32)
308
+ target["boxes"] = boxes
309
+ return image, target
310
+
311
+
312
+ class Compose(object):
313
+ def __init__(self, transforms):
314
+ self.transforms = transforms
315
+
316
+ def __call__(self, image, target):
317
+ for t in self.transforms:
318
+ image, target = t(image, target)
319
+ return image, target
320
+
321
+ def __repr__(self):
322
+ format_string = self.__class__.__name__ + "("
323
+ for t in self.transforms:
324
+ format_string += "\n"
325
+ format_string += " {0}".format(t)
326
+ format_string += "\n)"
327
+ return format_string
perception_models/apps/detection/DETA_pe/engine.py ADDED
@@ -0,0 +1,303 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ------------------------------------------------------------------------
2
+ # Deformable DETR
3
+ # Copyright (c) 2020 SenseTime. All Rights Reserved.
4
+ # Licensed under the Apache License, Version 2.0 [see LICENSE for details]
5
+ # ------------------------------------------------------------------------
6
+ # Modified from DETR (https://github.com/facebookresearch/detr)
7
+ # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
8
+ # ------------------------------------------------------------------------
9
+
10
+ """
11
+ Train and eval functions used in main.py
12
+ """
13
+ import math
14
+ import os
15
+ import sys
16
+ from typing import Iterable
17
+
18
+ import torch
19
+ import util.misc as utils
20
+ from datasets.coco_eval import CocoEvaluator, convert_to_xywh
21
+ from datasets.data_prefetcher import data_prefetcher
22
+ from datasets.panoptic_eval import PanopticEvaluator
23
+ from util.ema import requires_grad, update_ema
24
+ from util.misc import NestedTensor
25
+
26
+
27
+ def train_one_epoch(
28
+ model: torch.nn.Module,
29
+ criterion: torch.nn.Module,
30
+ data_loader: Iterable,
31
+ optimizer: torch.optim.Optimizer,
32
+ device: torch.device,
33
+ epoch: int,
34
+ max_norm: float = 0,
35
+ ema: torch.nn.Module = None,
36
+ ema_decay: float = 0.999,
37
+ ):
38
+ model.train()
39
+ criterion.train()
40
+ metric_logger = utils.MetricLogger(delimiter=" ")
41
+ metric_logger.add_meter("lr", utils.SmoothedValue(window_size=1, fmt="{value:.6f}"))
42
+ metric_logger.add_meter(
43
+ "class_error", utils.SmoothedValue(window_size=1, fmt="{value:.2f}")
44
+ )
45
+ metric_logger.add_meter(
46
+ "grad_norm", utils.SmoothedValue(window_size=1, fmt="{value:.2f}")
47
+ )
48
+ header = "Epoch: [{}]".format(epoch)
49
+ print_freq = 10
50
+
51
+ prefetcher = data_prefetcher(data_loader, device, prefetch=True)
52
+ samples, targets = prefetcher.next()
53
+
54
+ # for samples, targets in metric_logger.log_every(data_loader, print_freq, header):
55
+ for _ in metric_logger.log_every(range(len(data_loader)), print_freq, header):
56
+ outputs = model(samples)
57
+ loss_dict = criterion(outputs, targets)
58
+ weight_dict = criterion.weight_dict
59
+ losses = sum(
60
+ loss_dict[k] * weight_dict[k] for k in loss_dict.keys() if k in weight_dict
61
+ )
62
+
63
+ # reduce losses over all GPUs for logging purposes
64
+ loss_dict_reduced = utils.reduce_dict(loss_dict)
65
+ loss_dict_reduced_unscaled = {
66
+ f"{k}_unscaled": v for k, v in loss_dict_reduced.items()
67
+ }
68
+ loss_dict_reduced_scaled = {
69
+ k: v * weight_dict[k]
70
+ for k, v in loss_dict_reduced.items()
71
+ if k in weight_dict
72
+ }
73
+ losses_reduced_scaled = sum(loss_dict_reduced_scaled.values())
74
+
75
+ loss_value = losses_reduced_scaled.item()
76
+
77
+ if not math.isfinite(loss_value):
78
+ print("Loss is {}, stopping training".format(loss_value))
79
+ print(loss_dict_reduced)
80
+ sys.exit(1)
81
+
82
+ optimizer.zero_grad()
83
+ losses.backward()
84
+ if max_norm > 0:
85
+ grad_total_norm = torch.nn.utils.clip_grad_norm_(
86
+ model.parameters(), max_norm
87
+ )
88
+ else:
89
+ grad_total_norm = utils.get_total_grad_norm(model.parameters(), max_norm)
90
+ optimizer.step()
91
+
92
+ if ema is not None:
93
+ update_ema(ema, model.module, ema_decay)
94
+ # torch.cuda.empty_cache()
95
+
96
+ metric_logger.update(
97
+ loss=loss_value, **loss_dict_reduced_scaled, **loss_dict_reduced_unscaled
98
+ )
99
+ metric_logger.update(class_error=loss_dict_reduced["class_error"])
100
+ metric_logger.update(lr=optimizer.param_groups[0]["lr"])
101
+ metric_logger.update(grad_norm=grad_total_norm)
102
+
103
+ samples, targets = prefetcher.next()
104
+ # gather the stats from all processes
105
+ metric_logger.synchronize_between_processes()
106
+ print("Averaged stats:", metric_logger)
107
+ return {k: meter.global_avg for k, meter in metric_logger.meters.items()}
108
+
109
+
110
+ @torch.no_grad()
111
+ def evaluate(
112
+ model_no_ema,
113
+ criterion,
114
+ postprocessors,
115
+ data_loader,
116
+ base_ds,
117
+ device,
118
+ output_dir,
119
+ test_hflip_aug,
120
+ tta,
121
+ soft_nms,
122
+ ema=None,
123
+ save_result=False,
124
+ save_result_dir="",
125
+ soft_nms_method="quad",
126
+ nms_thresh=0.7,
127
+ quad_scale=0.5,
128
+ lsj_img_size=1824,
129
+ ):
130
+ model = model_no_ema if ema is None else ema
131
+ model.eval()
132
+ criterion.eval()
133
+
134
+ metric_logger = utils.MetricLogger(delimiter=" ")
135
+ metric_logger.add_meter(
136
+ "class_error", utils.SmoothedValue(window_size=1, fmt="{value:.2f}")
137
+ )
138
+ header = "Test:"
139
+
140
+ iou_types = tuple(k for k in ("segm", "bbox") if k in postprocessors.keys())
141
+ coco_evaluator = CocoEvaluator(base_ds, iou_types)
142
+ # coco_evaluator.coco_eval[iou_types[0]].params.iouThrs = [0, 0.1, 0.5, 0.75]
143
+
144
+ panoptic_evaluator = None
145
+ if "panoptic" in postprocessors.keys():
146
+ panoptic_evaluator = PanopticEvaluator(
147
+ data_loader.dataset.ann_file,
148
+ data_loader.dataset.ann_folder,
149
+ output_dir=os.path.join(output_dir, "panoptic_eval"),
150
+ )
151
+
152
+ prediction_list = []
153
+ for samples, targets in metric_logger.log_every(data_loader, 10, header):
154
+ samples = samples.to(device)
155
+ targets = [{k: v.to(device) for k, v in t.items()} for t in targets]
156
+
157
+ if test_hflip_aug:
158
+ assert (
159
+ samples.tensors.shape[0] == 1
160
+ ), "test_hflip_aug only supports batch size 1"
161
+ assert (
162
+ samples.tensors.shape[1] == 6
163
+ ), "test_hflip_aug requires two images in a batch"
164
+ first_samples = NestedTensor(samples.tensors[:, :3], samples.mask)
165
+ outputs = model(first_samples)
166
+ flipped_samples = NestedTensor(samples.tensors[:, 3:], samples.mask)
167
+ flipped_outputs = model(flipped_samples)
168
+ else:
169
+ outputs = model(samples)
170
+ loss_dict = criterion(outputs, targets)
171
+ weight_dict = criterion.weight_dict
172
+
173
+ # reduce losses over all GPUs for logging purposes
174
+ loss_dict_reduced = utils.reduce_dict(loss_dict)
175
+ loss_dict_reduced_scaled = {
176
+ k: v * weight_dict[k]
177
+ for k, v in loss_dict_reduced.items()
178
+ if k in weight_dict
179
+ }
180
+ loss_dict_reduced_unscaled = {
181
+ f"{k}_unscaled": v for k, v in loss_dict_reduced.items()
182
+ }
183
+ metric_logger.update(
184
+ loss=sum(loss_dict_reduced_scaled.values()),
185
+ **loss_dict_reduced_scaled,
186
+ **loss_dict_reduced_unscaled,
187
+ )
188
+ metric_logger.update(class_error=loss_dict_reduced["class_error"])
189
+
190
+ orig_target_sizes = torch.stack([t["orig_size"] for t in targets], dim=0)
191
+ if test_hflip_aug:
192
+ new_outputs = {}
193
+ pred_logits = outputs["pred_logits"]
194
+ pred_boxes = outputs["pred_boxes"]
195
+
196
+ flipped_pred_logits = flipped_outputs["pred_logits"]
197
+ flipped_pred_boxes = flipped_outputs["pred_boxes"]
198
+
199
+ reflipped_pred_boxes = flipped_pred_boxes[
200
+ :, :, [0, 1, 2, 3]
201
+ ] * torch.as_tensor([-1, 1, 1, 1]).to(
202
+ flipped_pred_boxes.device
203
+ ) + torch.as_tensor(
204
+ [1, 0, 0, 0]
205
+ ).to(
206
+ flipped_pred_boxes.device
207
+ )
208
+
209
+ new_pred_logits = torch.cat([pred_logits, flipped_pred_logits], dim=1)
210
+ new_pred_boxes = torch.cat([pred_boxes, reflipped_pred_boxes], dim=1)
211
+
212
+ new_outputs["pred_logits"] = new_pred_logits
213
+ new_outputs["pred_boxes"] = new_pred_boxes
214
+ results = postprocessors["bbox"](
215
+ new_outputs,
216
+ orig_target_sizes,
217
+ soft_nms=soft_nms,
218
+ method=soft_nms_method,
219
+ nms_thresh=nms_thresh,
220
+ quad_scale=quad_scale,
221
+ )
222
+ else:
223
+ results = postprocessors["bbox"](
224
+ outputs,
225
+ orig_target_sizes,
226
+ soft_nms=soft_nms,
227
+ method=soft_nms_method,
228
+ nms_thresh=nms_thresh,
229
+ quad_scale=quad_scale,
230
+ )
231
+ if "segm" in postprocessors.keys():
232
+ target_sizes = torch.stack([t["size"] for t in targets], dim=0)
233
+ results = postprocessors["segm"](
234
+ results, outputs, orig_target_sizes, target_sizes
235
+ )
236
+ res = {
237
+ target["image_id"].item(): output
238
+ for target, output in zip(targets, results)
239
+ }
240
+ if coco_evaluator is not None:
241
+ coco_evaluator.update(res)
242
+
243
+ if panoptic_evaluator is not None:
244
+ res_pano = postprocessors["panoptic"](
245
+ outputs, target_sizes, orig_target_sizes
246
+ )
247
+ for i, target in enumerate(targets):
248
+ image_id = target["image_id"].item()
249
+ file_name = f"{image_id:012d}.png"
250
+ res_pano[i]["image_id"] = image_id
251
+ res_pano[i]["file_name"] = file_name
252
+
253
+ panoptic_evaluator.update(res_pano)
254
+
255
+ for target, output in zip(targets, results):
256
+ res_cpu = {
257
+ target["image_id"].item(): {
258
+ "boxes": output["boxes"].cpu(),
259
+ "labels": output["labels"].cpu(),
260
+ "scores": output["scores"].cpu(),
261
+ }
262
+ }
263
+ prediction_list.append(res_cpu)
264
+
265
+ # gather the stats from all processes
266
+ metric_logger.synchronize_between_processes()
267
+ print("Averaged stats:", metric_logger)
268
+
269
+ if save_result:
270
+
271
+ from torch import distributed as dist
272
+
273
+ os.makedirs(save_result_dir, exist_ok=True)
274
+ rank = dist.get_rank()
275
+ torch.save(
276
+ prediction_list,
277
+ os.path.join(save_result_dir, f"val2017_prediction_{rank}.pth"),
278
+ )
279
+
280
+ if coco_evaluator is not None:
281
+ coco_evaluator.synchronize_between_processes()
282
+ if panoptic_evaluator is not None:
283
+ panoptic_evaluator.synchronize_between_processes()
284
+
285
+ # accumulate predictions from all images
286
+ if coco_evaluator is not None:
287
+ coco_evaluator.accumulate()
288
+ coco_evaluator.summarize()
289
+ panoptic_res = None
290
+ if panoptic_evaluator is not None:
291
+ panoptic_res = panoptic_evaluator.summarize()
292
+ stats = {k: meter.global_avg for k, meter in metric_logger.meters.items()}
293
+ if coco_evaluator is not None:
294
+ if "bbox" in postprocessors.keys():
295
+ stats["coco_eval_bbox"] = coco_evaluator.coco_eval["bbox"].stats.tolist()
296
+ if "segm" in postprocessors.keys():
297
+ stats["coco_eval_masks"] = coco_evaluator.coco_eval["segm"].stats.tolist()
298
+ if panoptic_res is not None:
299
+ stats["PQ_all"] = panoptic_res["All"]
300
+ stats["PQ_th"] = panoptic_res["Things"]
301
+ stats["PQ_st"] = panoptic_res["Stuff"]
302
+ return stats, coco_evaluator
303
+
perception_models/apps/detection/DETA_pe/engine_tta.py ADDED
@@ -0,0 +1,239 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ------------------------------------------------------------------------
2
+ # Deformable DETR
3
+ # Copyright (c) 2020 SenseTime. All Rights Reserved.
4
+ # Licensed under the Apache License, Version 2.0 [see LICENSE for details]
5
+ # ------------------------------------------------------------------------
6
+ # Modified from DETR (https://github.com/facebookresearch/detr)
7
+ # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
8
+ # ------------------------------------------------------------------------
9
+
10
+ """
11
+ Train and eval functions used in main.py
12
+ """
13
+ import math
14
+ import os
15
+ import sys
16
+ from typing import Iterable
17
+
18
+ import torch
19
+ import util.misc as utils
20
+ from datasets.coco_eval import CocoEvaluator, convert_to_xywh
21
+ from datasets.data_prefetcher import data_prefetcher
22
+ from datasets.panoptic_eval import PanopticEvaluator
23
+ from models.utils_softnms import batched_soft_nms
24
+ from util.misc import NestedTensor
25
+
26
+
27
+ # Make sure this is consistent with datasets/coco.py
28
+ # TODO: make it configurable
29
+ SCALE_RANGES_DICT = {
30
+ 1728: [[0, 10000], [32, 10000], [32, 10000],],
31
+ 1824: [[0, 10000], [0, 10000], [64, 10000], [64, 10000],],
32
+ }
33
+
34
+
35
+ def filter_boxes(boxes, min_scale, max_scale):
36
+ """
37
+ boxes: (N, 4) shape
38
+ """
39
+ w = boxes[:, 2] - boxes[:, 0]
40
+ h = boxes[:, 3] - boxes[:, 1]
41
+ keep = (w * h > min_scale * min_scale) & (w * h < max_scale * max_scale)
42
+ return keep
43
+
44
+
45
+ @torch.no_grad()
46
+ def evaluate_tta(
47
+ model_no_ema,
48
+ criterion,
49
+ postprocessors,
50
+ data_loader,
51
+ base_ds,
52
+ device,
53
+ output_dir,
54
+ test_hflip_aug,
55
+ tta,
56
+ soft_nms,
57
+ ema=None,
58
+ save_result=False,
59
+ save_result_dir="",
60
+ soft_nms_method="quad",
61
+ nms_thresh=0.7,
62
+ quad_scale=0.5,
63
+ lsj_img_size=1824,
64
+ ):
65
+ model = model_no_ema if ema is None else ema
66
+ model.eval()
67
+ criterion.eval()
68
+
69
+ metric_logger = utils.MetricLogger(delimiter=" ")
70
+ metric_logger.add_meter(
71
+ "class_error", utils.SmoothedValue(window_size=1, fmt="{value:.2f}")
72
+ )
73
+ header = "Test:"
74
+
75
+ iou_types = tuple(k for k in ("segm", "bbox") if k in postprocessors.keys())
76
+ coco_evaluator = CocoEvaluator(base_ds, iou_types)
77
+ # coco_evaluator.coco_eval[iou_types[0]].params.iouThrs = [0, 0.1, 0.5, 0.75]
78
+
79
+ SCALE_RANGES = SCALE_RANGES_DICT[lsj_img_size]
80
+ IMAGE_SIZE = [lsj_img_size for _ in range(len(SCALE_RANGES))]
81
+
82
+ prediction_list = []
83
+ for samples, targets in metric_logger.log_every(data_loader, 10, header):
84
+ samples = samples.to(device)
85
+ targets = [{k: v.to(device) for k, v in t.items()} for t in targets]
86
+
87
+ orig_target_sizes = torch.stack([t["orig_size"] for t in targets], dim=0)
88
+ metric_logger.update(loss=0, class_error=0, loss_bbox=0, loss_ce=0)
89
+ ########################### Begin of inference_one_image ###########################
90
+ if tta:
91
+ assert samples.tensors.shape[0] == 1, "tta only supports batch size 1"
92
+ assert (
93
+ samples.tensors.shape[1] % 3 == 0
94
+ ), "tta requires dimensions of samples.tensors to be divisible by 3"
95
+
96
+ all_boxes = []
97
+ all_scores = []
98
+ all_classes = []
99
+
100
+ num_scales = samples.tensors.shape[1] // 3
101
+ for scale_ind in range(num_scales):
102
+ first_samples = NestedTensor(
103
+ samples.tensors[
104
+ :,
105
+ scale_ind * 3 : (scale_ind + 1) * 3,
106
+ : IMAGE_SIZE[scale_ind // 2],
107
+ : IMAGE_SIZE[scale_ind // 2],
108
+ ],
109
+ samples.mask[
110
+ :,
111
+ scale_ind,
112
+ : IMAGE_SIZE[scale_ind // 2],
113
+ : IMAGE_SIZE[scale_ind // 2],
114
+ ],
115
+ )
116
+
117
+ if scale_ind % 2 == 0:
118
+ ######## no flip #######
119
+ outputs = model(first_samples)
120
+ noaug_results = postprocessors["bbox"](
121
+ outputs,
122
+ orig_target_sizes,
123
+ soft_nms=soft_nms,
124
+ method=soft_nms_method,
125
+ nms_thresh=nms_thresh,
126
+ quad_scale=quad_scale,
127
+ )
128
+ keep = filter_boxes(
129
+ noaug_results[0]["boxes"], *SCALE_RANGES[scale_ind // 2]
130
+ )
131
+ all_boxes.append(noaug_results[0]["boxes"][keep])
132
+ all_scores.append(noaug_results[0]["scores"][keep])
133
+ all_classes.append(noaug_results[0]["labels"][keep])
134
+ else:
135
+ ######## flipped #######
136
+ flipped_outputs = model(first_samples)
137
+ flipped_pred_logits = flipped_outputs["pred_logits"]
138
+ flipped_pred_boxes = flipped_outputs["pred_boxes"]
139
+ reflipped_pred_boxes = flipped_pred_boxes[
140
+ :, :, [0, 1, 2, 3]
141
+ ] * torch.as_tensor([-1, 1, 1, 1]).to(
142
+ flipped_pred_boxes.device
143
+ ) + torch.as_tensor(
144
+ [1, 0, 0, 0]
145
+ ).to(
146
+ flipped_pred_boxes.device
147
+ )
148
+ new_outputs = {}
149
+ new_outputs["pred_logits"] = flipped_pred_logits
150
+ new_outputs["pred_boxes"] = reflipped_pred_boxes
151
+ new_results = postprocessors["bbox"](
152
+ new_outputs,
153
+ orig_target_sizes,
154
+ soft_nms=soft_nms,
155
+ method=soft_nms_method,
156
+ nms_thresh=nms_thresh,
157
+ quad_scale=quad_scale,
158
+ )
159
+ keep = filter_boxes(
160
+ new_results[0]["boxes"], *SCALE_RANGES[scale_ind // 2]
161
+ )
162
+ all_boxes.append(new_results[0]["boxes"][keep])
163
+ all_scores.append(new_results[0]["scores"][keep])
164
+ all_classes.append(new_results[0]["labels"][keep])
165
+
166
+ ######## merge #######
167
+ all_boxes = torch.cat(all_boxes, dim=0)
168
+ all_scores = torch.cat(all_scores, dim=0)
169
+ all_classes = torch.cat(all_classes, dim=0)
170
+
171
+ keep_inds, updated_scores = batched_soft_nms(
172
+ all_boxes,
173
+ all_scores,
174
+ all_classes,
175
+ method=soft_nms_method,
176
+ threshold=nms_thresh,
177
+ quad_scale=quad_scale,
178
+ )
179
+ merged_scores = updated_scores
180
+ merged_classes = all_classes[keep_inds]
181
+ merged_boxes = all_boxes[keep_inds]
182
+
183
+ results = [
184
+ {
185
+ "boxes": merged_boxes,
186
+ "scores": merged_scores,
187
+ "labels": merged_classes,
188
+ }
189
+ ]
190
+ else:
191
+ outputs = model(samples)
192
+ results = postprocessors["bbox"](outputs, orig_target_sizes)
193
+
194
+ ########################### End of inference_one_image ###########################
195
+ res = {
196
+ target["image_id"].item(): output
197
+ for target, output in zip(targets, results)
198
+ }
199
+ if coco_evaluator is not None:
200
+ coco_evaluator.update(res)
201
+
202
+ for target, output in zip(targets, results):
203
+ res_cpu = {
204
+ target["image_id"].item(): {
205
+ "boxes": output["boxes"].cpu(),
206
+ "labels": output["labels"].cpu(),
207
+ "scores": output["scores"].cpu(),
208
+ }
209
+ }
210
+ prediction_list.append(res_cpu)
211
+
212
+ # gather the stats from all processes
213
+ metric_logger.synchronize_between_processes()
214
+ print("Averaged stats:", metric_logger)
215
+
216
+ if save_result:
217
+ from torch import distributed as dist
218
+
219
+ os.makedirs(save_result_dir, exist_ok=True)
220
+
221
+ rank = dist.get_rank()
222
+ torch.save(
223
+ prediction_list,
224
+ os.path.join(save_result_dir, f"val2017_prediction_{rank}.pth"),
225
+ )
226
+
227
+ if coco_evaluator is not None:
228
+ coco_evaluator.synchronize_between_processes()
229
+
230
+ # accumulate predictions from all images
231
+ if coco_evaluator is not None:
232
+ coco_evaluator.accumulate()
233
+ coco_evaluator.summarize()
234
+
235
+ stats = {k: meter.global_avg for k, meter in metric_logger.meters.items()}
236
+ if coco_evaluator is not None:
237
+ if "bbox" in postprocessors.keys():
238
+ stats["coco_eval_bbox"] = coco_evaluator.coco_eval["bbox"].stats.tolist()
239
+ return stats, coco_evaluator
perception_models/apps/detection/DETA_pe/main.py ADDED
@@ -0,0 +1,754 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Modified from
2
+ # ------------------------------------------------------------------------
3
+ # Deformable DETR
4
+ # Copyright (c) 2020 SenseTime. All Rights Reserved.
5
+ # Licensed under the Apache License, Version 2.0 [see LICENSE for details]
6
+ # ------------------------------------------------------------------------
7
+ # Modified from DETR (https://github.com/facebookresearch/detr)
8
+ # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
9
+ # ------------------------------------------------------------------------
10
+
11
+
12
+ import argparse
13
+ import datetime
14
+ import json
15
+ import os
16
+ import random
17
+ import time
18
+ from copy import deepcopy
19
+ from pathlib import Path
20
+
21
+ import datasets
22
+ import datasets.samplers as samplers
23
+
24
+ import numpy as np
25
+ import torch
26
+ import util.misc as utils
27
+ from datasets import build_dataset, get_coco_api_from_dataset
28
+ from engine import evaluate, train_one_epoch
29
+ from engine_tta import evaluate_tta
30
+ from models import build_model
31
+ from torch.utils.data import DataLoader
32
+ from util.ema import requires_grad, update_ema
33
+
34
+
35
+ def get_args_parser():
36
+ parser = argparse.ArgumentParser("Deformable DETR Detector", add_help=False)
37
+ parser.add_argument("--lr", default=2e-4, type=float)
38
+ parser.add_argument(
39
+ "--lr_backbone_names", default=["backbone.0"], type=str, nargs="+"
40
+ )
41
+ parser.add_argument("--lr_backbone", default=2e-5, type=float)
42
+ parser.add_argument(
43
+ "--lr_linear_proj_names",
44
+ default=["reference_points", "sampling_offsets"],
45
+ type=str,
46
+ nargs="+",
47
+ )
48
+ parser.add_argument("--lr_linear_proj_mult", default=0.1, type=float)
49
+ parser.add_argument("--batch_size", default=2, type=int)
50
+ parser.add_argument("--weight_decay", default=1e-4, type=float)
51
+ parser.add_argument("--epochs", default=50, type=int)
52
+ parser.add_argument("--eval_per_epochs", default=1, type=int)
53
+ parser.add_argument("--save_per_epochs", default=1, type=int)
54
+ parser.add_argument("--lr_drop", default=40, type=int)
55
+ parser.add_argument("--lr_drop_epochs", default=None, type=int, nargs="+")
56
+ parser.add_argument(
57
+ "--clip_max_norm", default=0.1, type=float, help="gradient clipping max norm"
58
+ )
59
+
60
+ parser.add_argument("--sgd", action="store_true")
61
+ parser.add_argument("--ema", action="store_true")
62
+ parser.add_argument("--ema_decay", default=0.999, type=float)
63
+
64
+ # Variants of Deformable DETR
65
+ parser.add_argument("--with_box_refine", default=False, action="store_true")
66
+ parser.add_argument("--two_stage", default=False, action="store_true")
67
+
68
+ # Model parameters
69
+ parser.add_argument(
70
+ "--frozen_weights",
71
+ type=str,
72
+ default=None,
73
+ help="Path to the pretrained model. If set, only the mask head will be trained",
74
+ )
75
+
76
+ # * Backbone
77
+ parser.add_argument(
78
+ "--backbone",
79
+ default="resnet50",
80
+ type=str,
81
+ help="Name of the convolutional backbone to use",
82
+ )
83
+ parser.add_argument(
84
+ "--backbone_size",
85
+ default="Gwin384",
86
+ type=str,
87
+ help="backbone size",
88
+ )
89
+ parser.add_argument(
90
+ "--backbone_path",
91
+ default="",
92
+ type=str,
93
+ )
94
+ parser.add_argument(
95
+ "--backbone_lrd",
96
+ default=1.0,
97
+ type=float,
98
+ )
99
+ parser.add_argument(
100
+ "--backbone_layers",
101
+ default=12,
102
+ type=int,
103
+ )
104
+ parser.add_argument(
105
+ "--backbone_init_values",
106
+ default=0.0,
107
+ type=float,
108
+ )
109
+ parser.add_argument(
110
+ "--backbone_tile_posemb",
111
+ default=False,
112
+ type=bool,
113
+ )
114
+ parser.add_argument(
115
+ "--backbone_use_act_checkpoint",
116
+ action="store_true",
117
+ help="If true, we use act_checkpoint in backbone",
118
+ )
119
+ parser.add_argument(
120
+ "--backbone_act_checkpoint_ratio",
121
+ default=1.0,
122
+ type=float,
123
+ )
124
+ parser.add_argument(
125
+ "--backbone_tta_rope",
126
+ action="store_true",
127
+ )
128
+ parser.add_argument(
129
+ "--backbone_multi_layer",
130
+ action="store_true",
131
+ )
132
+
133
+ parser.add_argument(
134
+ "--backbone_win_aug",
135
+ action="store_true",
136
+ )
137
+
138
+ parser.add_argument(
139
+ "--backbone_dp",
140
+ default=-1.0,
141
+ type=float,
142
+ )
143
+
144
+ parser.add_argument(
145
+ "--bf16",
146
+ action="store_true",
147
+ )
148
+ parser.add_argument(
149
+ "--fp16",
150
+ action="store_true",
151
+ )
152
+ parser.add_argument(
153
+ "--dilation",
154
+ action="store_true",
155
+ help="If true, we replace stride with dilation in the last convolutional block (DC5)",
156
+ )
157
+ parser.add_argument(
158
+ "--position_embedding",
159
+ default="sine",
160
+ type=str,
161
+ choices=("sine", "learned"),
162
+ help="Type of positional embedding to use on top of the image features",
163
+ )
164
+ parser.add_argument(
165
+ "--position_embedding_scale",
166
+ default=2 * np.pi,
167
+ type=float,
168
+ help="position / size * scale",
169
+ )
170
+ parser.add_argument(
171
+ "--num_feature_levels", default=4, type=int, help="number of feature levels"
172
+ )
173
+
174
+ # * Transformer
175
+ parser.add_argument(
176
+ "--enc_layers",
177
+ default=6,
178
+ type=int,
179
+ help="Number of encoding layers in the transformer",
180
+ )
181
+ parser.add_argument(
182
+ "--dec_layers",
183
+ default=6,
184
+ type=int,
185
+ help="Number of decoding layers in the transformer",
186
+ )
187
+ parser.add_argument(
188
+ "--dim_feedforward",
189
+ default=1024,
190
+ type=int,
191
+ help="Intermediate size of the feedforward layers in the transformer blocks",
192
+ )
193
+ parser.add_argument(
194
+ "--hidden_dim",
195
+ default=256,
196
+ type=int,
197
+ help="Size of the embeddings (dimension of the transformer)",
198
+ )
199
+ parser.add_argument(
200
+ "--dropout", default=0.1, type=float, help="Dropout applied in the transformer"
201
+ )
202
+ parser.add_argument(
203
+ "--nheads",
204
+ default=8,
205
+ type=int,
206
+ help="Number of attention heads inside the transformer's attentions",
207
+ )
208
+ parser.add_argument(
209
+ "--num_queries", default=300, type=int, help="Number of query slots"
210
+ )
211
+ parser.add_argument("--dec_n_points", default=4, type=int)
212
+ parser.add_argument("--enc_n_points", default=4, type=int)
213
+
214
+ # * Segmentation
215
+ parser.add_argument(
216
+ "--masks",
217
+ action="store_true",
218
+ help="Train segmentation head if the flag is provided",
219
+ )
220
+
221
+ # Loss
222
+ parser.add_argument(
223
+ "--no_aux_loss",
224
+ dest="aux_loss",
225
+ action="store_false",
226
+ help="Disables auxiliary decoding losses (loss at each layer)",
227
+ )
228
+ parser.add_argument("--use_fed_loss", action="store_true")
229
+
230
+ # * Matcher
231
+ parser.add_argument("--assign_first_stage", action="store_true")
232
+ parser.add_argument("--assign_second_stage", action="store_true")
233
+ parser.add_argument(
234
+ "--set_cost_class",
235
+ default=2,
236
+ type=float,
237
+ help="Class coefficient in the matching cost",
238
+ )
239
+ parser.add_argument(
240
+ "--set_cost_bbox",
241
+ default=5,
242
+ type=float,
243
+ help="L1 box coefficient in the matching cost",
244
+ )
245
+ parser.add_argument(
246
+ "--set_cost_giou",
247
+ default=2,
248
+ type=float,
249
+ help="giou box coefficient in the matching cost",
250
+ )
251
+
252
+ # * Loss coefficients
253
+ parser.add_argument("--mask_loss_coef", default=1, type=float)
254
+ parser.add_argument("--dice_loss_coef", default=1, type=float)
255
+ parser.add_argument("--cls_loss_coef", default=2, type=float)
256
+ parser.add_argument("--bbox_loss_coef", default=5, type=float)
257
+ parser.add_argument("--giou_loss_coef", default=2, type=float)
258
+ parser.add_argument("--focal_alpha", default=0.25, type=float)
259
+
260
+ # dataset parameters
261
+ parser.add_argument("--new_mean_std", action="store_true")
262
+ parser.add_argument("--dataset_file", default="coco")
263
+ parser.add_argument("--coco_path", default="./data/coco", type=str)
264
+ parser.add_argument("--coco_panoptic_path", type=str)
265
+ parser.add_argument("--remove_difficult", action="store_true")
266
+ parser.add_argument("--bigger", action="store_true")
267
+ parser.add_argument("--lsj", action="store_true")
268
+ parser.add_argument("--lsj_ms", action="store_true")
269
+
270
+ parser.add_argument("--lsj_img_size", default=1024, type=int)
271
+ parser.add_argument("--lsj_img_train_min", default=480, type=int)
272
+ parser.add_argument("--lsj_img_size_max", default=-1, type=int)
273
+ parser.add_argument("--lsj_strong_aug", action="store_true")
274
+
275
+ parser.add_argument("--save_result", action="store_true")
276
+ parser.add_argument("--save_result_dir", default="", type=str)
277
+ parser.add_argument("--test_hflip_aug", action="store_true")
278
+ parser.add_argument("--tta", action="store_true")
279
+ parser.add_argument("--soft_nms", action="store_true")
280
+ parser.add_argument("--soft_nms_method", default="quad", type=str)
281
+ parser.add_argument("--nms_thresh", default=0.7, type=float)
282
+ parser.add_argument("--quad_scale", default=0.5, type=float)
283
+ parser.add_argument(
284
+ "--output_dir", default="", help="path where to save, empty for no saving"
285
+ )
286
+ parser.add_argument(
287
+ "--device", default="cuda", help="device to use for training / testing"
288
+ )
289
+ parser.add_argument("--seed", default=42, type=int)
290
+ parser.add_argument("--resume", default="", help="resume from checkpoint")
291
+ parser.add_argument("--auto_resume", action="store_true")
292
+
293
+ parser.add_argument(
294
+ "--resume_norope",
295
+ action="store_true",
296
+ help="resume from checkpoint without rope params",
297
+ )
298
+ parser.add_argument("--finetune", default="", help="finetune from checkpoint")
299
+ parser.add_argument("--keep_class_embed", action="store_true")
300
+ parser.add_argument(
301
+ "--start_epoch", default=0, type=int, metavar="N", help="start epoch"
302
+ )
303
+ parser.add_argument("--eval", action="store_true")
304
+ parser.add_argument("--num_workers", default=8, type=int)
305
+ parser.add_argument(
306
+ "--cache_mode",
307
+ default=False,
308
+ action="store_true",
309
+ help="whether to cache images on memory",
310
+ )
311
+
312
+ return parser
313
+
314
+
315
+ # lr_backbone_names = ["backbone.0", "backbone.neck", "input_proj", "transformer.encoder"]
316
+ def match_name_keywords(n, name_keywords):
317
+ out = False
318
+ for b in name_keywords:
319
+ if b in n:
320
+ out = True
321
+ break
322
+ return out
323
+
324
+
325
+ def get_vit_lr_decay_rate_vev01(name, lr_decay_rate=1.0, num_layers=12):
326
+ layer_id = num_layers + 1
327
+ if ".positional_embedding" in name or ".conv1" in name or ".ln_pre" in name:
328
+ layer_id = 0
329
+ elif ".resblocks." in name:
330
+ layer_id = int(name[name.find(".resblocks.") :].split(".")[2]) + 1
331
+ return lr_decay_rate ** (num_layers + 1 - layer_id)
332
+
333
+
334
+ def custom_lr(model_without_ddp, args):
335
+ param_dicts = [
336
+ {
337
+ "params": [
338
+ p
339
+ for n, p in model_without_ddp.named_parameters()
340
+ if not match_name_keywords(n, args.lr_backbone_names)
341
+ and not match_name_keywords(n, args.lr_linear_proj_names)
342
+ and p.requires_grad
343
+ ],
344
+ "lr": args.lr,
345
+ },
346
+ {
347
+ "params": [
348
+ p
349
+ for n, p in model_without_ddp.named_parameters()
350
+ if match_name_keywords(n, args.lr_linear_proj_names) and p.requires_grad
351
+ ],
352
+ "lr": args.lr * args.lr_linear_proj_mult,
353
+ },
354
+ ]
355
+ if "vev01" in args.backbone:
356
+ for p_key, p_value in model_without_ddp.named_parameters():
357
+ if (
358
+ match_name_keywords(p_key, args.lr_backbone_names)
359
+ and p_value.requires_grad
360
+ ):
361
+ p_lr = args.lr_backbone * get_vit_lr_decay_rate_vev01(
362
+ p_key, args.backbone_lrd, args.backbone_layers
363
+ )
364
+ param_dicts.append(
365
+ {
366
+ "params": [p_value],
367
+ "lr": p_lr,
368
+ }
369
+ )
370
+ print(f"param_name: {p_key}, lr: {p_lr}")
371
+ else:
372
+ param_groups_backbone = {
373
+ "params": [
374
+ p
375
+ for n, p in model_without_ddp.named_parameters()
376
+ if match_name_keywords(n, args.lr_backbone_names) and p.requires_grad
377
+ ],
378
+ "lr": args.lr_backbone,
379
+ }
380
+ param_dicts.append(param_groups_backbone)
381
+
382
+ return param_dicts
383
+
384
+
385
+ def main(args):
386
+ utils.init_distributed_mode(args)
387
+ print("git:\n {}\n".format(utils.get_sha()))
388
+
389
+ if args.frozen_weights is not None:
390
+ assert args.masks, "Frozen training is meant for segmentation only"
391
+ print(args)
392
+
393
+ device = torch.device(args.device)
394
+
395
+ # fix the seed for reproducibility
396
+ seed = args.seed + utils.get_rank()
397
+ torch.manual_seed(seed)
398
+ np.random.seed(seed)
399
+ random.seed(seed)
400
+
401
+ model, criterion, postprocessors = build_model(args)
402
+ model.to(device)
403
+
404
+ model_without_ddp = model
405
+ n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad)
406
+ print("model:", model_without_ddp)
407
+ for n, p in model_without_ddp.named_parameters():
408
+ print(n)
409
+ print("number of params:", n_parameters)
410
+
411
+ if args.ema:
412
+ ema = deepcopy(model).to(device)
413
+ requires_grad(ema, False)
414
+ print(f"EMA Parameters: {sum(p.numel() for p in ema.parameters()):,}")
415
+
416
+ dataset_train = build_dataset(image_set="train", args=args)
417
+ dataset_val = build_dataset(image_set="val", args=args)
418
+
419
+ if args.distributed:
420
+ if args.cache_mode:
421
+ sampler_train = samplers.NodeDistributedSampler(dataset_train)
422
+ sampler_val = samplers.NodeDistributedSampler(dataset_val, shuffle=False)
423
+ else:
424
+ if args.dataset_file == "lvis":
425
+ sampler_train = samplers.RepeatFactorTrainingSampler(dataset_train)
426
+ else:
427
+ sampler_train = samplers.DistributedSampler(dataset_train)
428
+ sampler_val = samplers.DistributedSampler(dataset_val, shuffle=False)
429
+ else:
430
+ sampler_train = torch.utils.data.RandomSampler(dataset_train)
431
+ sampler_val = torch.utils.data.SequentialSampler(dataset_val)
432
+
433
+ batch_sampler_train = torch.utils.data.BatchSampler(
434
+ sampler_train, args.batch_size, drop_last=True
435
+ )
436
+ if args.lsj_ms:
437
+ collator = utils.CollatorLSJMultiscale(args.lsj_img_size, args.tta)
438
+ elif args.lsj:
439
+ lsj_img_size_colla = (
440
+ args.lsj_img_size_max if args.lsj_img_size_max > 0 else args.lsj_img_size
441
+ )
442
+ collator = utils.CollatorLSJ(lsj_img_size_colla, args.tta)
443
+ else:
444
+ collator = utils.collate_fn
445
+
446
+ data_loader_train = DataLoader(
447
+ dataset_train,
448
+ batch_sampler=batch_sampler_train,
449
+ collate_fn=collator,
450
+ num_workers=args.num_workers,
451
+ pin_memory=True,
452
+ )
453
+ data_loader_val = DataLoader(
454
+ dataset_val,
455
+ args.batch_size,
456
+ sampler=sampler_val,
457
+ drop_last=False,
458
+ collate_fn=collator,
459
+ num_workers=args.num_workers,
460
+ pin_memory=True,
461
+ )
462
+
463
+ param_dicts = custom_lr(model_without_ddp, args)
464
+
465
+ if args.sgd:
466
+ optimizer = torch.optim.SGD(
467
+ param_dicts, lr=args.lr, momentum=0.9, weight_decay=args.weight_decay
468
+ )
469
+ else:
470
+ optimizer = torch.optim.AdamW(
471
+ param_dicts, lr=args.lr, weight_decay=args.weight_decay
472
+ )
473
+ lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, args.lr_drop)
474
+
475
+ if args.distributed:
476
+ model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu])
477
+ model_without_ddp = model.module
478
+
479
+ if args.dataset_file == "coco_panoptic":
480
+ # We also evaluate AP during panoptic training, on original coco DS
481
+ coco_val = datasets.coco.build("val", args)
482
+ base_ds = get_coco_api_from_dataset(coco_val)
483
+ else:
484
+ base_ds = get_coco_api_from_dataset(dataset_val)
485
+
486
+ if args.frozen_weights is not None:
487
+ checkpoint = torch.load(args.frozen_weights, map_location="cpu")
488
+ model_without_ddp.detr.load_state_dict(checkpoint["model"])
489
+
490
+ if args.tta:
491
+ evaluate_fn = evaluate_tta
492
+ else:
493
+ evaluate_fn = evaluate
494
+
495
+ output_dir = Path(args.output_dir)
496
+ if args.auto_resume:
497
+ resumed_ckpt = os.path.join(args.output_dir, "checkpoint.pth")
498
+ if os.path.exists(resumed_ckpt):
499
+ args.resume = resumed_ckpt
500
+ args.finetune = None
501
+
502
+ if args.finetune:
503
+ checkpoint = torch.load(args.finetune, map_location="cpu")
504
+ state_dict = checkpoint["model"]
505
+ for k in list(state_dict.keys()):
506
+ if "class_embed" in k and not args.keep_class_embed:
507
+ print("removing", k)
508
+ del state_dict[k]
509
+ if "freqs" in k:
510
+ print("removing", k)
511
+ del state_dict[k]
512
+
513
+ missing_keys, unexpected_keys = model_without_ddp.load_state_dict(
514
+ state_dict, strict=False
515
+ )
516
+ unexpected_keys = [
517
+ k
518
+ for k in unexpected_keys
519
+ if not (k.endswith("total_params") or k.endswith("total_ops"))
520
+ ]
521
+ if len(missing_keys) > 0:
522
+ print("Missing Keys: {}".format(missing_keys))
523
+ if len(unexpected_keys) > 0:
524
+ print("Unexpected Keys: {}".format(unexpected_keys))
525
+
526
+ if "epoch" in checkpoint:
527
+ print("finetuning from epoch", checkpoint["epoch"])
528
+
529
+ if args.ema:
530
+ ema.load_state_dict(
531
+ checkpoint["ema"] if "ema" in checkpoint else state_dict, strict=False
532
+ )
533
+
534
+ if args.resume:
535
+ print("Resuming training from {}".format(args.resume))
536
+ if args.resume.startswith("https"):
537
+ checkpoint = torch.hub.load_state_dict_from_url(
538
+ args.resume, map_location="cpu", check_hash=True
539
+ )
540
+ else:
541
+ checkpoint = torch.load(args.resume, map_location="cpu")
542
+
543
+ if args.resume_norope:
544
+ state_dict = checkpoint["model"]
545
+ for k in list(state_dict.keys()):
546
+ if "freqs" in k:
547
+ print("removing", k)
548
+ del state_dict[k]
549
+
550
+ missing_keys, unexpected_keys = model_without_ddp.load_state_dict(
551
+ state_dict, strict=False
552
+ )
553
+ if args.ema:
554
+ ema.load_state_dict(
555
+ checkpoint["ema"] if "ema" in checkpoint else state_dict,
556
+ strict=False,
557
+ )
558
+ else:
559
+ missing_keys, unexpected_keys = model_without_ddp.load_state_dict(
560
+ checkpoint["model"], strict=False
561
+ )
562
+ if args.ema:
563
+ ema.load_state_dict(
564
+ checkpoint["ema"] if "ema" in checkpoint else state_dict,
565
+ strict=False,
566
+ )
567
+ unexpected_keys = [
568
+ k
569
+ for k in unexpected_keys
570
+ if not (k.endswith("total_params") or k.endswith("total_ops"))
571
+ ]
572
+ if len(missing_keys) > 0:
573
+ print("Missing Keys: {}".format(missing_keys))
574
+ if len(unexpected_keys) > 0:
575
+ print("Unexpected Keys: {}".format(unexpected_keys))
576
+ if (
577
+ not args.eval
578
+ and "optimizer" in checkpoint
579
+ and "lr_scheduler" in checkpoint
580
+ and "epoch" in checkpoint
581
+ ):
582
+ import copy
583
+
584
+ p_groups = copy.deepcopy(optimizer.param_groups)
585
+ optimizer.load_state_dict(checkpoint["optimizer"])
586
+ for pg, pg_old in zip(optimizer.param_groups, p_groups):
587
+ pg["lr"] = pg_old["lr"]
588
+ pg["initial_lr"] = pg_old["initial_lr"]
589
+ print(optimizer.param_groups)
590
+ lr_scheduler.load_state_dict(checkpoint["lr_scheduler"])
591
+ # todo: this is a hack for doing experiment that resume from checkpoint and also modify lr scheduler (e.g., decrease lr in advance).
592
+ args.override_resumed_lr_drop = True
593
+ if args.override_resumed_lr_drop:
594
+ print(
595
+ "Warning: (hack) args.override_resumed_lr_drop is set to True, so args.lr_drop would override lr_drop in resumed lr_scheduler."
596
+ )
597
+ lr_scheduler.step_size = args.lr_drop
598
+ lr_scheduler.base_lrs = list(
599
+ map(lambda group: group["initial_lr"], optimizer.param_groups)
600
+ )
601
+ lr_scheduler.step(lr_scheduler.last_epoch)
602
+ args.start_epoch = checkpoint["epoch"] + 1
603
+ # check the resumed model
604
+ if not args.eval:
605
+ test_stats, coco_evaluator = evaluate_fn(
606
+ model,
607
+ criterion,
608
+ postprocessors,
609
+ data_loader_val,
610
+ base_ds,
611
+ device,
612
+ args.output_dir,
613
+ args.test_hflip_aug,
614
+ args.tta,
615
+ args.soft_nms,
616
+ ema if args.ema else None,
617
+ args.save_result,
618
+ args.save_result_dir,
619
+ soft_nms_method=args.soft_nms_method,
620
+ nms_thresh=args.nms_thresh,
621
+ quad_scale=args.quad_scale,
622
+ lsj_img_size=args.lsj_img_size,
623
+ )
624
+ torch.cuda.empty_cache()
625
+
626
+ if args.eval:
627
+ test_stats, coco_evaluator = evaluate_fn(
628
+ model,
629
+ criterion,
630
+ postprocessors,
631
+ data_loader_val,
632
+ base_ds,
633
+ device,
634
+ args.output_dir,
635
+ args.test_hflip_aug,
636
+ args.tta,
637
+ args.soft_nms,
638
+ ema if args.ema else None,
639
+ args.save_result,
640
+ args.save_result_dir,
641
+ soft_nms_method=args.soft_nms_method,
642
+ nms_thresh=args.nms_thresh,
643
+ quad_scale=args.quad_scale,
644
+ lsj_img_size=args.lsj_img_size,
645
+ )
646
+
647
+ if args.output_dir:
648
+ utils.save_on_master(
649
+ coco_evaluator.coco_eval["bbox"].eval, output_dir / "eval.pth"
650
+ )
651
+ return
652
+
653
+ print("Start training")
654
+ start_time = time.time()
655
+ if args.ema:
656
+ ema.eval() # EMA model should always be in eval mode
657
+ for epoch in range(args.start_epoch, args.epochs):
658
+ if args.distributed:
659
+ sampler_train.set_epoch(epoch)
660
+ train_stats = train_one_epoch(
661
+ model,
662
+ criterion,
663
+ data_loader_train,
664
+ optimizer,
665
+ device,
666
+ epoch,
667
+ args.clip_max_norm,
668
+ ema if args.ema else None,
669
+ ema_decay=args.ema_decay,
670
+ )
671
+ lr_scheduler.step()
672
+ if args.output_dir:
673
+ checkpoint_paths = [output_dir / "checkpoint.pth"]
674
+ # extra checkpoint before LR drop and every 5 epochs
675
+ if (
676
+ (epoch + 1) % args.lr_drop == 0
677
+ or (epoch + 1) % args.save_per_epochs == 0
678
+ or epoch + 1 == args.epochs
679
+ ):
680
+ checkpoint_paths.append(output_dir / f"checkpoint{epoch:04}.pth")
681
+ for checkpoint_path in checkpoint_paths:
682
+ ckpt_dict = {
683
+ "model": model_without_ddp.state_dict(),
684
+ "optimizer": optimizer.state_dict(),
685
+ "lr_scheduler": lr_scheduler.state_dict(),
686
+ "epoch": epoch,
687
+ "args": args,
688
+ }
689
+ if args.ema:
690
+ ckpt_dict["ema"] = ema.state_dict()
691
+ utils.save_on_master(
692
+ ckpt_dict,
693
+ checkpoint_path,
694
+ )
695
+
696
+ torch.cuda.empty_cache()
697
+ if epoch % args.eval_per_epochs == 0 or epoch + 1 == args.epochs:
698
+ test_stats, coco_evaluator = evaluate_fn(
699
+ model,
700
+ criterion,
701
+ postprocessors,
702
+ data_loader_val,
703
+ base_ds,
704
+ device,
705
+ args.output_dir,
706
+ args.test_hflip_aug,
707
+ args.tta,
708
+ args.soft_nms,
709
+ ema if args.ema else None,
710
+ args.save_result,
711
+ args.save_result_dir,
712
+ soft_nms_method=args.soft_nms_method,
713
+ nms_thresh=args.nms_thresh,
714
+ quad_scale=args.quad_scale,
715
+ lsj_img_size=args.lsj_img_size,
716
+ )
717
+ log_stats = {
718
+ **{f"train_{k}": v for k, v in train_stats.items()},
719
+ **{f"test_{k}": v for k, v in test_stats.items()},
720
+ "epoch": epoch,
721
+ "n_parameters": n_parameters,
722
+ }
723
+
724
+ if args.output_dir and utils.is_main_process():
725
+ with (output_dir / "log.txt").open("a") as f:
726
+ f.write(json.dumps(log_stats) + "\n")
727
+
728
+ # for evaluation logs
729
+ if coco_evaluator is not None:
730
+ (output_dir / "eval").mkdir(exist_ok=True)
731
+ if "bbox" in coco_evaluator.coco_eval:
732
+ filenames = ["latest.pth"]
733
+ if epoch % 50 == 0:
734
+ filenames.append(f"{epoch:03}.pth")
735
+ for name in filenames:
736
+ torch.save(
737
+ coco_evaluator.coco_eval["bbox"].eval,
738
+ output_dir / "eval" / name,
739
+ )
740
+ torch.cuda.empty_cache()
741
+
742
+ total_time = time.time() - start_time
743
+ total_time_str = str(datetime.timedelta(seconds=int(total_time)))
744
+ print("Training time {}".format(total_time_str))
745
+
746
+
747
+ if __name__ == "__main__":
748
+ parser = argparse.ArgumentParser(
749
+ "Deformable DETR training and evaluation script", parents=[get_args_parser()]
750
+ )
751
+ args = parser.parse_args()
752
+ if args.output_dir:
753
+ Path(args.output_dir).mkdir(parents=True, exist_ok=True)
754
+ main(args)
perception_models/apps/detection/DETA_pe/models/__init__.py ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ------------------------------------------------------------------------
2
+ # Deformable DETR
3
+ # Copyright (c) 2020 SenseTime. All Rights Reserved.
4
+ # Licensed under the Apache License, Version 2.0 [see LICENSE for details]
5
+ # ------------------------------------------------------------------------
6
+ # Modified from DETR (https://github.com/facebookresearch/detr)
7
+ # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
8
+ # ------------------------------------------------------------------------
9
+
10
+ from .deformable_detr import build
11
+
12
+
13
+ def build_model(args):
14
+ return build(args)
15
+
perception_models/apps/detection/DETA_pe/models/assigner.py ADDED
@@ -0,0 +1,378 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ # Modified by Jeffrey Ouyang-Zhang
3
+
4
+ from typing import List
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+
9
+ from util.box_ops import (
10
+ box_cxcywh_to_xyxy,
11
+ box_iou,
12
+ box_xyxy_to_cxcywh,
13
+ generalized_box_iou,
14
+ )
15
+
16
+
17
+ # from https://github.com/facebookresearch/detectron2/blob/cbbc1ce26473cb2a5cc8f58e8ada9ae14cb41052/detectron2/layers/wrappers.py#L100
18
+ def nonzero_tuple(x):
19
+ """
20
+ A 'as_tuple=True' version of torch.nonzero to support torchscript.
21
+ because of https://github.com/pytorch/pytorch/issues/38718
22
+ """
23
+ if torch.jit.is_scripting():
24
+ if x.dim() == 0:
25
+ return x.unsqueeze(0).nonzero().unbind(1)
26
+ return x.nonzero().unbind(1)
27
+ else:
28
+ return x.nonzero(as_tuple=True)
29
+
30
+
31
+ # from https://github.com/facebookresearch/detectron2/blob/9921a2caa585d4fa66c4b534b6fab6e74d89b582/detectron2/modeling/matcher.py#L9
32
+ class Matcher(object):
33
+ """
34
+ This class assigns to each predicted "element" (e.g., a box) a ground-truth
35
+ element. Each predicted element will have exactly zero or one matches; each
36
+ ground-truth element may be matched to zero or more predicted elements.
37
+
38
+ The matching is determined by the MxN match_quality_matrix, that characterizes
39
+ how well each (ground-truth, prediction)-pair match each other. For example,
40
+ if the elements are boxes, this matrix may contain box intersection-over-union
41
+ overlap values.
42
+
43
+ The matcher returns (a) a vector of length N containing the index of the
44
+ ground-truth element m in [0, M) that matches to prediction n in [0, N).
45
+ (b) a vector of length N containing the labels for each prediction.
46
+ """
47
+
48
+ def __init__(
49
+ self,
50
+ thresholds: List[float],
51
+ labels: List[int],
52
+ allow_low_quality_matches: bool = False,
53
+ ):
54
+ """
55
+ Args:
56
+ thresholds (list): a list of thresholds used to stratify predictions
57
+ into levels.
58
+ labels (list): a list of values to label predictions belonging at
59
+ each level. A label can be one of {-1, 0, 1} signifying
60
+ {ignore, negative class, positive class}, respectively.
61
+ allow_low_quality_matches (bool): if True, produce additional matches
62
+ for predictions with maximum match quality lower than high_threshold.
63
+ See set_low_quality_matches_ for more details.
64
+
65
+ For example,
66
+ thresholds = [0.3, 0.5]
67
+ labels = [0, -1, 1]
68
+ All predictions with iou < 0.3 will be marked with 0 and
69
+ thus will be considered as false positives while training.
70
+ All predictions with 0.3 <= iou < 0.5 will be marked with -1 and
71
+ thus will be ignored.
72
+ All predictions with 0.5 <= iou will be marked with 1 and
73
+ thus will be considered as true positives.
74
+ """
75
+ # Add -inf and +inf to first and last position in thresholds
76
+ thresholds = thresholds[:]
77
+ assert thresholds[0] > 0
78
+ thresholds.insert(0, -float("inf"))
79
+ thresholds.append(float("inf"))
80
+ # Currently torchscript does not support all + generator
81
+ assert all(
82
+ [low <= high for (low, high) in zip(thresholds[:-1], thresholds[1:])]
83
+ ), thresholds
84
+ assert all([l in [-1, 0, 1] for l in labels])
85
+ assert len(labels) == len(thresholds) - 1
86
+ self.thresholds = thresholds
87
+ self.labels = labels
88
+ self.allow_low_quality_matches = allow_low_quality_matches
89
+
90
+ def __call__(self, match_quality_matrix):
91
+ """
92
+ Args:
93
+ match_quality_matrix (Tensor[float]): an MxN tensor, containing the
94
+ pairwise quality between M ground-truth elements and N predicted
95
+ elements. All elements must be >= 0 (due to the us of `torch.nonzero`
96
+ for selecting indices in :meth:`set_low_quality_matches_`).
97
+
98
+ Returns:
99
+ matches (Tensor[int64]): a vector of length N, where matches[i] is a matched
100
+ ground-truth index in [0, M)
101
+ match_labels (Tensor[int8]): a vector of length N, where pred_labels[i] indicates
102
+ whether a prediction is a true or false positive or ignored
103
+ """
104
+ assert match_quality_matrix.dim() == 2
105
+ if match_quality_matrix.numel() == 0:
106
+ default_matches = match_quality_matrix.new_full(
107
+ (match_quality_matrix.size(1),), 0, dtype=torch.int64
108
+ )
109
+ # When no gt boxes exist, we define IOU = 0 and therefore set labels
110
+ # to `self.labels[0]`, which usually defaults to background class 0
111
+ # To choose to ignore instead, can make labels=[-1,0,-1,1] + set appropriate thresholds
112
+ default_match_labels = match_quality_matrix.new_full(
113
+ (match_quality_matrix.size(1),), self.labels[0], dtype=torch.int8
114
+ )
115
+ return default_matches, default_match_labels
116
+
117
+ assert torch.all(match_quality_matrix >= 0)
118
+
119
+ # match_quality_matrix is M (gt) x N (predicted)
120
+ # Max over gt elements (dim 0) to find best gt candidate for each prediction
121
+ matched_vals, matches = match_quality_matrix.max(dim=0)
122
+
123
+ match_labels = matches.new_full(matches.size(), 1, dtype=torch.int8)
124
+
125
+ for l, low, high in zip(self.labels, self.thresholds[:-1], self.thresholds[1:]):
126
+ low_high = (matched_vals >= low) & (matched_vals < high)
127
+ match_labels[low_high] = l
128
+
129
+ if self.allow_low_quality_matches:
130
+ self.set_low_quality_matches_(match_labels, match_quality_matrix)
131
+
132
+ return matches, match_labels
133
+
134
+ def set_low_quality_matches_(self, match_labels, match_quality_matrix):
135
+ """
136
+ Produce additional matches for predictions that have only low-quality matches.
137
+ Specifically, for each ground-truth G find the set of predictions that have
138
+ maximum overlap with it (including ties); for each prediction in that set, if
139
+ it is unmatched, then match it to the ground-truth G.
140
+
141
+ This function implements the RPN assignment case (i) in Sec. 3.1.2 of
142
+ :paper:`Faster R-CNN`.
143
+ """
144
+ # For each gt, find the prediction with which it has highest quality
145
+ highest_quality_foreach_gt, _ = match_quality_matrix.max(dim=1)
146
+ # Find the highest quality match available, even if it is low, including ties.
147
+ # Note that the matches qualities must be positive due to the use of
148
+ # `torch.nonzero`.
149
+ _, pred_inds_with_highest_quality = nonzero_tuple(
150
+ match_quality_matrix == highest_quality_foreach_gt[:, None]
151
+ )
152
+ # If an anchor was labeled positive only due to a low-quality match
153
+ # with gt_A, but it has larger overlap with gt_B, it's matched index will still be gt_B.
154
+ # This follows the implementation in Detectron, and is found to have no significant impact.
155
+ match_labels[pred_inds_with_highest_quality] = 1
156
+
157
+
158
+ # from https://github.com/facebookresearch/detectron2/blob/cbbc1ce26473cb2a5cc8f58e8ada9ae14cb41052/detectron2/modeling/sampling.py#L9
159
+ def subsample_labels(
160
+ labels: torch.Tensor, num_samples: int, positive_fraction: float, bg_label: int
161
+ ):
162
+ """
163
+ Return `num_samples` (or fewer, if not enough found)
164
+ random samples from `labels` which is a mixture of positives & negatives.
165
+ It will try to return as many positives as possible without
166
+ exceeding `positive_fraction * num_samples`, and then try to
167
+ fill the remaining slots with negatives.
168
+
169
+ Args:
170
+ labels (Tensor): (N, ) label vector with values:
171
+ * -1: ignore
172
+ * bg_label: background ("negative") class
173
+ * otherwise: one or more foreground ("positive") classes
174
+ num_samples (int): The total number of labels with value >= 0 to return.
175
+ Values that are not sampled will be filled with -1 (ignore).
176
+ positive_fraction (float): The number of subsampled labels with values > 0
177
+ is `min(num_positives, int(positive_fraction * num_samples))`. The number
178
+ of negatives sampled is `min(num_negatives, num_samples - num_positives_sampled)`.
179
+ In order words, if there are not enough positives, the sample is filled with
180
+ negatives. If there are also not enough negatives, then as many elements are
181
+ sampled as is possible.
182
+ bg_label (int): label index of background ("negative") class.
183
+
184
+ Returns:
185
+ pos_idx, neg_idx (Tensor):
186
+ 1D vector of indices. The total length of both is `num_samples` or fewer.
187
+ """
188
+ positive = nonzero_tuple((labels != -1) & (labels != bg_label))[0]
189
+ negative = nonzero_tuple(labels == bg_label)[0]
190
+
191
+ num_pos = int(num_samples * positive_fraction)
192
+ # protect against not enough positive examples
193
+ num_pos = min(positive.numel(), num_pos)
194
+ num_neg = num_samples - num_pos
195
+ # protect against not enough negative examples
196
+ num_neg = min(negative.numel(), num_neg)
197
+
198
+ # randomly select positive and negative examples
199
+ perm1 = torch.randperm(positive.numel(), device=positive.device)[:num_pos]
200
+ perm2 = torch.randperm(negative.numel(), device=negative.device)[:num_neg]
201
+
202
+ pos_idx = positive[perm1]
203
+ neg_idx = negative[perm2]
204
+ return pos_idx, neg_idx
205
+
206
+
207
+ def sample_topk_per_gt(pr_inds, gt_inds, iou, k):
208
+ if len(gt_inds) == 0:
209
+ return pr_inds, gt_inds
210
+ # find topk matches for each gt
211
+ gt_inds2, counts = gt_inds.unique(return_counts=True)
212
+ scores, pr_inds2 = iou[gt_inds2].topk(k, dim=1)
213
+ gt_inds2 = gt_inds2[:, None].repeat(1, k)
214
+
215
+ # filter to as many matches that gt has
216
+ pr_inds3 = torch.cat([pr[:c] for c, pr in zip(counts, pr_inds2)])
217
+ gt_inds3 = torch.cat([gt[:c] for c, gt in zip(counts, gt_inds2)])
218
+ return pr_inds3, gt_inds3
219
+
220
+
221
+ # modified from https://github.com/facebookresearch/detectron2/blob/cbbc1ce26473cb2a5cc8f58e8ada9ae14cb41052/detectron2/modeling/roi_heads/roi_heads.py#L123
222
+ class Stage2Assigner(nn.Module):
223
+ def __init__(self, num_queries, max_k=4):
224
+ super().__init__()
225
+ self.positive_fraction = 0.25
226
+ self.bg_label = 400 # number > 91 to filter out later
227
+ self.batch_size_per_image = num_queries
228
+ self.proposal_matcher = Matcher(
229
+ thresholds=[0.6], labels=[0, 1], allow_low_quality_matches=True
230
+ )
231
+ self.k = max_k
232
+
233
+ def _sample_proposals(
234
+ self,
235
+ matched_idxs: torch.Tensor,
236
+ matched_labels: torch.Tensor,
237
+ gt_classes: torch.Tensor,
238
+ ):
239
+ """
240
+ Based on the matching between N proposals and M groundtruth,
241
+ sample the proposals and set their classification labels.
242
+
243
+ Args:
244
+ matched_idxs (Tensor): a vector of length N, each is the best-matched
245
+ gt index in [0, M) for each proposal.
246
+ matched_labels (Tensor): a vector of length N, the matcher's label
247
+ (one of cfg.MODEL.ROI_HEADS.IOU_LABELS) for each proposal.
248
+ gt_classes (Tensor): a vector of length M.
249
+
250
+ Returns:
251
+ Tensor: a vector of indices of sampled proposals. Each is in [0, N).
252
+ Tensor: a vector of the same length, the classification label for
253
+ each sampled proposal. Each sample is labeled as either a category in
254
+ [0, num_classes) or the background (num_classes).
255
+ """
256
+ has_gt = gt_classes.numel() > 0
257
+ # Get the corresponding GT for each proposal
258
+ if has_gt:
259
+ gt_classes = gt_classes[matched_idxs]
260
+ # Label unmatched proposals (0 label from matcher) as background (label=num_classes)
261
+ gt_classes[matched_labels == 0] = self.bg_label
262
+ # Label ignore proposals (-1 label)
263
+ gt_classes[matched_labels == -1] = -1
264
+ else:
265
+ gt_classes = torch.zeros_like(matched_idxs) + self.bg_label
266
+
267
+ sampled_fg_idxs, sampled_bg_idxs = subsample_labels(
268
+ gt_classes, self.batch_size_per_image, self.positive_fraction, self.bg_label
269
+ )
270
+
271
+ sampled_idxs = torch.cat([sampled_fg_idxs, sampled_bg_idxs], dim=0)
272
+ return sampled_idxs, gt_classes[sampled_idxs]
273
+
274
+ def forward(self, outputs, targets, return_cost_matrix=False):
275
+ # COCO categories are from 1 to 90. They set num_classes=91 and apply sigmoid.
276
+
277
+ bs = len(targets)
278
+ indices = []
279
+ ious = []
280
+ for b in range(bs):
281
+ iou, _ = box_iou(
282
+ box_cxcywh_to_xyxy(targets[b]["boxes"]),
283
+ box_cxcywh_to_xyxy(outputs["init_reference"][b].detach()),
284
+ )
285
+ matched_idxs, matched_labels = self.proposal_matcher(
286
+ iou
287
+ ) # proposal_id -> highest_iou_gt_id, proposal_id -> [1 if iou > 0.6, 0 ow]
288
+ sampled_idxs, sampled_gt_classes = (
289
+ self._sample_proposals( # list of sampled proposal_ids, sampled_id -> [0, num_classes)+[bg_label]
290
+ matched_idxs, matched_labels, targets[b]["labels"]
291
+ )
292
+ )
293
+ pos_pr_inds = sampled_idxs[sampled_gt_classes != self.bg_label]
294
+ pos_gt_inds = matched_idxs[pos_pr_inds]
295
+ pos_pr_inds, pos_gt_inds = self.postprocess_indices(
296
+ pos_pr_inds, pos_gt_inds, iou
297
+ )
298
+ indices.append((pos_pr_inds, pos_gt_inds))
299
+ ious.append(iou)
300
+ if return_cost_matrix:
301
+ return indices, ious
302
+ return indices
303
+
304
+ def postprocess_indices(self, pr_inds, gt_inds, iou):
305
+ return sample_topk_per_gt(pr_inds, gt_inds, iou, self.k)
306
+
307
+
308
+ # modified from https://github.com/facebookresearch/detectron2/blob/cbbc1ce26473cb2a5cc8f58e8ada9ae14cb41052/detectron2/modeling/proposal_generator/rpn.py#L181
309
+ class Stage1Assigner(nn.Module):
310
+ def __init__(self, t_low=0.3, t_high=0.7, max_k=4):
311
+ super().__init__()
312
+ self.positive_fraction = 0.5
313
+ self.batch_size_per_image = 256
314
+ self.k = max_k
315
+ self.t_low = t_low
316
+ self.t_high = t_high
317
+ self.anchor_matcher = Matcher(
318
+ thresholds=[t_low, t_high],
319
+ labels=[0, -1, 1],
320
+ allow_low_quality_matches=True,
321
+ )
322
+
323
+ def _subsample_labels(self, label):
324
+ """
325
+ Randomly sample a subset of positive and negative examples, and overwrite
326
+ the label vector to the ignore value (-1) for all elements that are not
327
+ included in the sample.
328
+
329
+ Args:
330
+ labels (Tensor): a vector of -1, 0, 1. Will be modified in-place and returned.
331
+ """
332
+ pos_idx, neg_idx = subsample_labels(
333
+ label, self.batch_size_per_image, self.positive_fraction, 0
334
+ )
335
+ # Fill with the ignore label (-1), then set positive and negative labels
336
+ label.fill_(-1)
337
+ label.scatter_(0, pos_idx, 1)
338
+ label.scatter_(0, neg_idx, 0)
339
+ return label
340
+
341
+ def forward(self, outputs, targets):
342
+ bs = len(targets)
343
+ indices = []
344
+ for b in range(bs):
345
+ anchors = outputs["anchors"][b]
346
+ if len(targets[b]["boxes"]) == 0:
347
+ indices.append(
348
+ (
349
+ torch.tensor([], dtype=torch.long, device=anchors.device),
350
+ torch.tensor([], dtype=torch.long, device=anchors.device),
351
+ )
352
+ )
353
+ continue
354
+ iou, _ = box_iou(
355
+ box_cxcywh_to_xyxy(targets[b]["boxes"]),
356
+ box_cxcywh_to_xyxy(anchors),
357
+ )
358
+ matched_idxs, matched_labels = self.anchor_matcher(
359
+ iou
360
+ ) # proposal_id -> highest_iou_gt_id, proposal_id -> [1 if iou > 0.7, 0 if iou < 0.3, -1 ow]
361
+ matched_labels = self._subsample_labels(matched_labels)
362
+
363
+ all_pr_inds = torch.arange(len(anchors)).to(anchors.device)
364
+
365
+ pos_pr_inds = all_pr_inds[matched_labels == 1]
366
+ pos_gt_inds = matched_idxs[pos_pr_inds]
367
+ pos_ious = iou[pos_gt_inds, pos_pr_inds]
368
+ pos_pr_inds, pos_gt_inds = self.postprocess_indices(
369
+ pos_pr_inds, pos_gt_inds, iou
370
+ )
371
+ pos_pr_inds, pos_gt_inds = pos_pr_inds.to(anchors.device), pos_gt_inds.to(
372
+ anchors.device
373
+ )
374
+ indices.append((pos_pr_inds, pos_gt_inds))
375
+ return indices
376
+
377
+ def postprocess_indices(self, pr_inds, gt_inds, iou):
378
+ return sample_topk_per_gt(pr_inds, gt_inds, iou, self.k)
perception_models/apps/detection/DETA_pe/models/backbone.py ADDED
@@ -0,0 +1,235 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ------------------------------------------------------------------------
2
+ # Deformable DETR
3
+ # Copyright (c) 2020 SenseTime. All Rights Reserved.
4
+ # Licensed under the Apache License, Version 2.0 [see LICENSE for details]
5
+ # ------------------------------------------------------------------------
6
+ # Modified from DETR (https://github.com/facebookresearch/detr)
7
+ # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
8
+ # ------------------------------------------------------------------------
9
+
10
+ """
11
+ Backbone modules.
12
+ """
13
+ from collections import OrderedDict
14
+ from functools import partial
15
+ from typing import Dict, List
16
+
17
+ import torch
18
+ import torch.nn.functional as F
19
+ import torchvision
20
+ from torch import nn
21
+ from torch.cuda.amp import autocast
22
+ from torchvision.models._utils import IntermediateLayerGetter
23
+ from util.misc import is_main_process, NestedTensor
24
+
25
+ from .position_encoding import build_position_encoding
26
+ from .swin import get_swinl
27
+ from .pev1 import get_pev1_and_fpn_backbone
28
+
29
+
30
+ class FrozenBatchNorm2d(torch.nn.Module):
31
+ """
32
+ BatchNorm2d where the batch statistics and the affine parameters are fixed.
33
+
34
+ Copy-paste from torchvision.misc.ops with added eps before rqsrt,
35
+ without which any other models than torchvision.models.resnet[18,34,50,101]
36
+ produce nans.
37
+ """
38
+
39
+ def __init__(self, n, eps=1e-5):
40
+ super(FrozenBatchNorm2d, self).__init__()
41
+ self.register_buffer("weight", torch.ones(n))
42
+ self.register_buffer("bias", torch.zeros(n))
43
+ self.register_buffer("running_mean", torch.zeros(n))
44
+ self.register_buffer("running_var", torch.ones(n))
45
+ self.eps = eps
46
+
47
+ def _load_from_state_dict(
48
+ self,
49
+ state_dict,
50
+ prefix,
51
+ local_metadata,
52
+ strict,
53
+ missing_keys,
54
+ unexpected_keys,
55
+ error_msgs,
56
+ ):
57
+ num_batches_tracked_key = prefix + "num_batches_tracked"
58
+ if num_batches_tracked_key in state_dict:
59
+ del state_dict[num_batches_tracked_key]
60
+
61
+ super(FrozenBatchNorm2d, self)._load_from_state_dict(
62
+ state_dict,
63
+ prefix,
64
+ local_metadata,
65
+ strict,
66
+ missing_keys,
67
+ unexpected_keys,
68
+ error_msgs,
69
+ )
70
+
71
+ def forward(self, x):
72
+ # move reshapes to the beginning
73
+ # to make it fuser-friendly
74
+ w = self.weight.reshape(1, -1, 1, 1)
75
+ b = self.bias.reshape(1, -1, 1, 1)
76
+ rv = self.running_var.reshape(1, -1, 1, 1)
77
+ rm = self.running_mean.reshape(1, -1, 1, 1)
78
+ eps = self.eps
79
+ scale = w * (rv + eps).rsqrt()
80
+ bias = b - rm * scale
81
+ return x * scale + bias
82
+
83
+
84
+ class BackboneBase(nn.Module):
85
+
86
+ def __init__(
87
+ self, backbone: nn.Module, train_backbone: bool, return_interm_layers: bool
88
+ ):
89
+ super().__init__()
90
+ for name, parameter in backbone.named_parameters():
91
+ if (
92
+ not train_backbone
93
+ or "layer2" not in name
94
+ and "layer3" not in name
95
+ and "layer4" not in name
96
+ ):
97
+ parameter.requires_grad_(False)
98
+ if return_interm_layers:
99
+ # return_layers = {"layer1": "0", "layer2": "1", "layer3": "2", "layer4": "3"}
100
+ return_layers = {"layer2": "0", "layer3": "1", "layer4": "2"}
101
+ self.strides = [8, 16, 32]
102
+ self.num_channels = [512, 1024, 2048]
103
+ else:
104
+ return_layers = {"layer4": "0"}
105
+ self.strides = [32]
106
+ self.num_channels = [2048]
107
+ self.body = IntermediateLayerGetter(backbone, return_layers=return_layers)
108
+
109
+ def forward(self, tensor_list: NestedTensor):
110
+ xs = self.body(tensor_list.tensors)
111
+ out: Dict[str, NestedTensor] = {}
112
+ for name, x in xs.items():
113
+ m = tensor_list.mask
114
+ assert m is not None
115
+ mask = F.interpolate(m[None].float(), size=x.shape[-2:]).to(torch.bool)[0]
116
+ out[name] = NestedTensor(x, mask)
117
+ return out
118
+
119
+
120
+ class Backbone(BackboneBase):
121
+ """ResNet backbone with frozen BatchNorm."""
122
+
123
+ def __init__(
124
+ self,
125
+ name: str,
126
+ train_backbone: bool,
127
+ return_interm_layers: bool,
128
+ dilation: bool,
129
+ ):
130
+ norm_layer = FrozenBatchNorm2d
131
+ backbone = getattr(torchvision.models, name)(
132
+ replace_stride_with_dilation=[False, False, dilation],
133
+ pretrained=is_main_process(),
134
+ norm_layer=norm_layer,
135
+ )
136
+ assert name not in ("resnet18", "resnet34"), "number of channels are hard coded"
137
+ super().__init__(backbone, train_backbone, return_interm_layers)
138
+ if dilation:
139
+ self.strides[-1] = self.strides[-1] // 2
140
+
141
+
142
+ class SwinBackbone(nn.Module):
143
+ def __init__(self):
144
+ # we skip R50 FrozenBatchNorm2d, dilation, train l{2,3,4} only
145
+ super().__init__()
146
+ self.body = get_swinl()
147
+ self.features = ["res3", "res4", "res5"]
148
+ self.strides = [8, 16, 32]
149
+ self.num_channels = [384, 768, 1536]
150
+
151
+ def forward(self, tensor_list: NestedTensor):
152
+ xs = self.body(tensor_list.tensors)
153
+ m = tensor_list.mask[None]
154
+ assert m is not None
155
+ out: Dict[str, NestedTensor] = {}
156
+ for name in self.features:
157
+ mask = F.interpolate(m.float(), size=xs[name].shape[-2:]).to(torch.bool)[0]
158
+ out[name] = NestedTensor(xs[name], mask)
159
+ return out
160
+
161
+
162
+ class PEv1Backbone(nn.Module):
163
+ def __init__(self, args):
164
+ super().__init__()
165
+ self.body = get_pev1_and_fpn_backbone(args)
166
+ self.features = self.body._out_features
167
+
168
+ self.bf16 = args.bf16
169
+ self.fp16 = args.fp16
170
+
171
+ _out_feature_strides = self.body._out_feature_strides
172
+ _out_feature_channels = self.body._out_feature_channels
173
+ self.strides = [_out_feature_strides[f] for f in _out_feature_strides.keys()]
174
+ self.num_channels = [
175
+ _out_feature_channels[f] for f in _out_feature_channels.keys()
176
+ ]
177
+
178
+ def forward(self, tensor_list: NestedTensor):
179
+ # xs = self.body(tensor_list.tensors)
180
+ # backbone
181
+ if self.bf16:
182
+ with autocast(dtype=torch.bfloat16):
183
+ xs = self.body(tensor_list.tensors.to(torch.bfloat16))
184
+ xs = {k: v.float() for k, v in xs.items()}
185
+ elif self.fp16:
186
+ with autocast(dtype=torch.float16):
187
+ xs = self.body(tensor_list.tensors.half())
188
+ xs = {k: v.float() for k, v in xs.items()}
189
+ else:
190
+ xs = self.body(tensor_list.tensors)
191
+
192
+ m = tensor_list.mask[None]
193
+ assert m is not None
194
+ out: Dict[str, NestedTensor] = {}
195
+
196
+ for name in self.features:
197
+ mask = F.interpolate(m.float(), size=xs[name].shape[-2:]).to(torch.bool)[0]
198
+ out[name] = NestedTensor(xs[name], mask)
199
+ return out
200
+
201
+
202
+ class Joiner(nn.Sequential):
203
+ def __init__(self, backbone, position_embedding):
204
+ super().__init__(backbone, position_embedding)
205
+ self.strides = backbone.strides
206
+ self.num_channels = backbone.num_channels
207
+
208
+ def forward(self, tensor_list: NestedTensor):
209
+ xs = self[0](tensor_list)
210
+ out: List[NestedTensor] = []
211
+ pos = []
212
+ for name, x in sorted(xs.items()):
213
+ out.append(x)
214
+
215
+ # position encoding
216
+ for x in out:
217
+ pos.append(self[1](x).to(x.tensors.dtype))
218
+
219
+ return out, pos
220
+
221
+
222
+ def build_backbone(args):
223
+ position_embedding = build_position_encoding(args)
224
+ train_backbone = args.lr_backbone > 0
225
+ return_interm_layers = args.masks or (args.num_feature_levels > 1)
226
+ if "swin" in args.backbone:
227
+ backbone = SwinBackbone()
228
+ elif "pev1" in args.backbone:
229
+ backbone = PEv1Backbone(args)
230
+ else:
231
+ backbone = Backbone(
232
+ args.backbone, train_backbone, return_interm_layers, args.dilation
233
+ )
234
+ model = Joiner(backbone, position_embedding)
235
+ return model
perception_models/apps/detection/DETA_pe/models/deformable_detr.py ADDED
@@ -0,0 +1,776 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ------------------------------------------------------------------------
2
+ # Deformable DETR
3
+ # Copyright (c) 2020 SenseTime. All Rights Reserved.
4
+ # Licensed under the Apache License, Version 2.0 [see LICENSE for details]
5
+ # ------------------------------------------------------------------------
6
+ # Modified from DETR (https://github.com/facebookresearch/detr)
7
+ # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
8
+ # ------------------------------------------------------------------------
9
+
10
+ """
11
+ Deformable DETR model and criterion classes.
12
+ """
13
+ import copy
14
+ import math
15
+
16
+ import torch
17
+ import torch.nn.functional as F
18
+ from torch import nn
19
+ from torchvision.ops.boxes import batched_nms
20
+
21
+ from util import box_ops
22
+ from util.misc import (
23
+ accuracy,
24
+ get_world_size,
25
+ interpolate,
26
+ inverse_sigmoid,
27
+ is_dist_avail_and_initialized,
28
+ nested_tensor_from_tensor_list,
29
+ NestedTensor,
30
+ )
31
+
32
+ from .assigner import Stage1Assigner, Stage2Assigner
33
+
34
+ from .backbone import build_backbone
35
+ from .deformable_transformer import build_deforamble_transformer
36
+ from .matcher import build_matcher
37
+ from .segmentation import (
38
+ DETRsegm,
39
+ dice_loss,
40
+ PostProcessPanoptic,
41
+ PostProcessSegm,
42
+ sigmoid_focal_loss,
43
+ )
44
+ from .utils_fed_loss import get_fed_loss_inds, load_class_freq
45
+ from .utils_softnms import batched_soft_nms
46
+
47
+
48
+ def _get_clones(module, N):
49
+ return nn.ModuleList([copy.deepcopy(module) for i in range(N)])
50
+
51
+
52
+ class DeformableDETR(nn.Module):
53
+ """This is the Deformable DETR module that performs object detection"""
54
+
55
+ def __init__(
56
+ self,
57
+ backbone,
58
+ transformer,
59
+ num_classes,
60
+ num_queries,
61
+ num_feature_levels,
62
+ aux_loss=True,
63
+ with_box_refine=False,
64
+ two_stage=False,
65
+ ):
66
+ """Initializes the model.
67
+ Parameters:
68
+ backbone: torch module of the backbone to be used. See backbone.py
69
+ transformer: torch module of the transformer architecture. See transformer.py
70
+ num_classes: number of object classes
71
+ num_queries: number of object queries, ie detection slot. This is the maximal number of objects
72
+ DETR can detect in a single image. For COCO, we recommend 100 queries.
73
+ aux_loss: True if auxiliary decoding losses (loss at each decoder layer) are to be used.
74
+ with_box_refine: iterative bounding box refinement
75
+ two_stage: two-stage Deformable DETR
76
+ """
77
+ super().__init__()
78
+ self.num_queries = num_queries
79
+ self.transformer = transformer
80
+ hidden_dim = transformer.d_model
81
+ self.class_embed = nn.Linear(hidden_dim, num_classes)
82
+ self.bbox_embed = MLP(hidden_dim, hidden_dim, 4, 3)
83
+ self.num_feature_levels = num_feature_levels
84
+ if not two_stage:
85
+ self.query_embed = nn.Embedding(num_queries, hidden_dim * 2)
86
+ if num_feature_levels > 1:
87
+ num_backbone_outs = len(backbone.strides)
88
+ input_proj_list = []
89
+ for _ in range(num_backbone_outs):
90
+ in_channels = backbone.num_channels[_]
91
+ input_proj_list.append(
92
+ nn.Sequential(
93
+ nn.Conv2d(in_channels, hidden_dim, kernel_size=1),
94
+ nn.GroupNorm(32, hidden_dim),
95
+ )
96
+ )
97
+ for _ in range(num_feature_levels - num_backbone_outs):
98
+ input_proj_list.append(
99
+ nn.Sequential(
100
+ nn.Conv2d(
101
+ in_channels, hidden_dim, kernel_size=3, stride=2, padding=1
102
+ ),
103
+ nn.GroupNorm(32, hidden_dim),
104
+ )
105
+ )
106
+ in_channels = hidden_dim
107
+ self.input_proj = nn.ModuleList(input_proj_list)
108
+ else:
109
+ self.input_proj = nn.ModuleList(
110
+ [
111
+ nn.Sequential(
112
+ nn.Conv2d(backbone.num_channels[0], hidden_dim, kernel_size=1),
113
+ nn.GroupNorm(32, hidden_dim),
114
+ )
115
+ ]
116
+ )
117
+ self.backbone = backbone
118
+ self.aux_loss = aux_loss
119
+ self.with_box_refine = with_box_refine
120
+ self.two_stage = two_stage
121
+
122
+ prior_prob = 0.01
123
+ bias_value = -math.log((1 - prior_prob) / prior_prob)
124
+ self.class_embed.bias.data = torch.ones(num_classes) * bias_value
125
+ nn.init.constant_(self.bbox_embed.layers[-1].weight.data, 0)
126
+ nn.init.constant_(self.bbox_embed.layers[-1].bias.data, 0)
127
+ for proj in self.input_proj:
128
+ nn.init.xavier_uniform_(proj[0].weight, gain=1)
129
+ nn.init.constant_(proj[0].bias, 0)
130
+
131
+ # if two-stage, the last class_embed and bbox_embed is for region proposal generation
132
+ num_pred = (
133
+ (transformer.decoder.num_layers + 1)
134
+ if two_stage
135
+ else transformer.decoder.num_layers
136
+ )
137
+ if with_box_refine:
138
+ self.class_embed = _get_clones(self.class_embed, num_pred)
139
+ self.bbox_embed = _get_clones(self.bbox_embed, num_pred)
140
+ nn.init.constant_(self.bbox_embed[0].layers[-1].bias.data[2:], -2.0)
141
+ # hack implementation for iterative bounding box refinement
142
+ self.transformer.decoder.bbox_embed = self.bbox_embed
143
+ else:
144
+ nn.init.constant_(self.bbox_embed.layers[-1].bias.data[2:], -2.0)
145
+ self.class_embed = nn.ModuleList(
146
+ [self.class_embed for _ in range(num_pred)]
147
+ )
148
+ self.bbox_embed = nn.ModuleList([self.bbox_embed for _ in range(num_pred)])
149
+ self.transformer.decoder.bbox_embed = None
150
+ if two_stage:
151
+ # hack implementation for two-stage
152
+ self.transformer.decoder.class_embed = self.class_embed
153
+ for box_embed in self.bbox_embed:
154
+ nn.init.constant_(box_embed.layers[-1].bias.data[2:], 0.0)
155
+
156
+ def forward(self, samples: NestedTensor):
157
+ """The forward expects a NestedTensor, which consists of:
158
+ - samples.tensor: batched images, of shape [batch_size x 3 x H x W]
159
+ - samples.mask: a binary mask of shape [batch_size x H x W], containing 1 on padded pixels
160
+
161
+ It returns a dict with the following elements:
162
+ - "pred_logits": the classification logits (including no-object) for all queries.
163
+ Shape= [batch_size x num_queries x (num_classes + 1)]
164
+ - "pred_boxes": The normalized boxes coordinates for all queries, represented as
165
+ (center_x, center_y, height, width). These values are normalized in [0, 1],
166
+ relative to the size of each individual image (disregarding possible padding).
167
+ See PostProcess for information on how to retrieve the unnormalized bounding box.
168
+ - "aux_outputs": Optional, only returned when auxilary losses are activated. It is a list of
169
+ dictionnaries containing the two above keys for each decoder layer.
170
+ """
171
+ if not isinstance(samples, NestedTensor):
172
+ samples = nested_tensor_from_tensor_list(samples)
173
+ features, pos = self.backbone(samples)
174
+
175
+ srcs = []
176
+ masks = []
177
+ for l, feat in enumerate(features):
178
+ src, mask = feat.decompose()
179
+ srcs.append(self.input_proj[l](src))
180
+ masks.append(mask)
181
+ assert mask is not None
182
+ if self.num_feature_levels > len(srcs):
183
+ _len_srcs = len(srcs)
184
+ for l in range(_len_srcs, self.num_feature_levels):
185
+ if l == _len_srcs:
186
+ src = self.input_proj[l](features[-1].tensors)
187
+ else:
188
+ src = self.input_proj[l](srcs[-1])
189
+ m = samples.mask
190
+ mask = F.interpolate(m[None].float(), size=src.shape[-2:]).to(
191
+ torch.bool
192
+ )[0]
193
+ pos_l = self.backbone[1](NestedTensor(src, mask)).to(src.dtype)
194
+ srcs.append(src)
195
+ masks.append(mask)
196
+ pos.append(pos_l)
197
+
198
+ query_embeds = None
199
+ if not self.two_stage:
200
+ query_embeds = self.query_embed.weight
201
+ (
202
+ hs,
203
+ init_reference,
204
+ inter_references,
205
+ enc_outputs_class,
206
+ enc_outputs_coord_unact,
207
+ anchors,
208
+ ) = self.transformer(srcs, masks, pos, query_embeds)
209
+
210
+ outputs_classes = []
211
+ outputs_coords = []
212
+ for lvl in range(hs.shape[0]):
213
+ if lvl == 0:
214
+ reference = init_reference
215
+ else:
216
+ reference = inter_references[lvl - 1]
217
+ reference = inverse_sigmoid(reference)
218
+ outputs_class = self.class_embed[lvl](hs[lvl])
219
+ tmp = self.bbox_embed[lvl](hs[lvl])
220
+ if reference.shape[-1] == 4:
221
+ tmp += reference
222
+ else:
223
+ assert reference.shape[-1] == 2
224
+ tmp[..., :2] += reference
225
+ outputs_coord = tmp.sigmoid()
226
+ outputs_classes.append(outputs_class)
227
+ outputs_coords.append(outputs_coord)
228
+ outputs_class = torch.stack(outputs_classes)
229
+ outputs_coord = torch.stack(outputs_coords)
230
+
231
+ out = {
232
+ "pred_logits": outputs_class[-1],
233
+ "pred_boxes": outputs_coord[-1],
234
+ "init_reference": init_reference,
235
+ }
236
+ if self.aux_loss:
237
+ out["aux_outputs"] = self._set_aux_loss(outputs_class, outputs_coord)
238
+
239
+ if self.two_stage:
240
+ enc_outputs_coord = enc_outputs_coord_unact.sigmoid()
241
+ out["enc_outputs"] = {
242
+ "pred_logits": enc_outputs_class,
243
+ "pred_boxes": enc_outputs_coord,
244
+ "anchors": anchors,
245
+ }
246
+ return out
247
+
248
+ @torch.jit.unused
249
+ def _set_aux_loss(self, outputs_class, outputs_coord):
250
+ # this is a workaround to make torchscript happy, as torchscript
251
+ # doesn't support dictionary with non-homogeneous values, such
252
+ # as a dict having both a Tensor and a list.
253
+ return [
254
+ {"pred_logits": a, "pred_boxes": b}
255
+ for a, b in zip(outputs_class[:-1], outputs_coord[:-1])
256
+ ]
257
+
258
+
259
+ class SetCriterion(nn.Module):
260
+ """This class computes the loss for DETR.
261
+ The process happens in two steps:
262
+ 1) we compute hungarian assignment between ground truth boxes and the outputs of the model
263
+ 2) we supervise each pair of matched ground-truth / prediction (supervise class and box)
264
+ """
265
+
266
+ def __init__(
267
+ self,
268
+ num_classes,
269
+ matcher,
270
+ weight_dict,
271
+ losses,
272
+ focal_alpha=0.25,
273
+ num_queries=300,
274
+ assign_first_stage=False,
275
+ assign_second_stage=False,
276
+ use_fed_loss=False,
277
+ ):
278
+ """Create the criterion.
279
+ Parameters:
280
+ num_classes: number of object categories, omitting the special no-object category
281
+ matcher: module able to compute a matching between targets and proposals
282
+ weight_dict: dict containing as key the names of the losses and as values their relative weight.
283
+ losses: list of all the losses to be applied. See get_loss for list of available losses.
284
+ focal_alpha: alpha in Focal Loss
285
+ """
286
+ super().__init__()
287
+ self.num_classes = num_classes
288
+ self.matcher = matcher
289
+ self.weight_dict = weight_dict
290
+ self.losses = losses
291
+ self.focal_alpha = focal_alpha
292
+ self.assign_first_stage = assign_first_stage
293
+ self.assign_second_stage = assign_second_stage
294
+
295
+ if self.assign_first_stage:
296
+ self.stg1_assigner = Stage1Assigner()
297
+ if self.assign_second_stage:
298
+ self.stg2_assigner = Stage2Assigner(num_queries)
299
+
300
+ self.use_fed_loss = use_fed_loss
301
+ if self.use_fed_loss:
302
+ print("Using federated loss")
303
+ print("Using federated loss")
304
+ print("Using federated loss")
305
+ self.register_buffer("fed_loss_weight", load_class_freq(freq_weight=0.5))
306
+
307
+ def loss_labels(self, outputs, targets, indices, num_boxes, log=True):
308
+ """Classification loss (NLL)
309
+ targets dicts must contain the key "labels" containing a tensor of dim [nb_target_boxes]
310
+ """
311
+ assert "pred_logits" in outputs
312
+ src_logits = outputs["pred_logits"]
313
+
314
+ idx = self._get_src_permutation_idx(indices)
315
+ target_classes_o = torch.cat(
316
+ [t["labels"][J] for t, (_, J) in zip(targets, indices)]
317
+ )
318
+ target_classes = torch.full(
319
+ src_logits.shape[:2],
320
+ self.num_classes,
321
+ dtype=torch.int64,
322
+ device=src_logits.device,
323
+ )
324
+ target_classes[idx] = target_classes_o
325
+
326
+ target_classes_onehot = torch.zeros(
327
+ [src_logits.shape[0], src_logits.shape[1], src_logits.shape[2] + 1],
328
+ dtype=src_logits.dtype,
329
+ layout=src_logits.layout,
330
+ device=src_logits.device,
331
+ )
332
+ target_classes_onehot.scatter_(2, target_classes.unsqueeze(-1), 1)
333
+
334
+ target_classes_onehot = target_classes_onehot[:, :, :-1]
335
+ if self.use_fed_loss:
336
+ inds = (
337
+ get_fed_loss_inds(
338
+ gt_classes=target_classes_o - 1,
339
+ num_sample_cats=50,
340
+ weight=self.fed_loss_weight,
341
+ C=target_classes_onehot.shape[2] - 1,
342
+ )
343
+ + 1
344
+ ) # pay attention to the -1 and +1
345
+ loss_ce = (
346
+ sigmoid_focal_loss(
347
+ src_logits[:, :, inds],
348
+ target_classes_onehot[:, :, inds],
349
+ num_boxes,
350
+ alpha=self.focal_alpha,
351
+ gamma=2,
352
+ )
353
+ * src_logits.shape[1]
354
+ )
355
+ else:
356
+ loss_ce = (
357
+ sigmoid_focal_loss(
358
+ src_logits,
359
+ target_classes_onehot,
360
+ num_boxes,
361
+ alpha=self.focal_alpha,
362
+ gamma=2,
363
+ )
364
+ * src_logits.shape[1]
365
+ )
366
+ losses = {"loss_ce": loss_ce}
367
+
368
+ if log:
369
+ # TODO this should probably be a separate loss, not hacked in this one here
370
+ losses["class_error"] = 100 - accuracy(src_logits[idx], target_classes_o)[0]
371
+ return losses
372
+
373
+ @torch.no_grad()
374
+ def loss_cardinality(self, outputs, targets, indices, num_boxes):
375
+ """Compute the cardinality error, ie the absolute error in the number of predicted non-empty boxes
376
+ This is not really a loss, it is intended for logging purposes only. It doesn't propagate gradients
377
+ """
378
+ pred_logits = outputs["pred_logits"]
379
+ device = pred_logits.device
380
+ tgt_lengths = torch.as_tensor(
381
+ [len(v["labels"]) for v in targets], device=device
382
+ )
383
+ # Count the number of predictions that are NOT "no-object" (which is the last class)
384
+ card_pred = (pred_logits.argmax(-1) != pred_logits.shape[-1] - 1).sum(1)
385
+ card_err = F.l1_loss(card_pred.float(), tgt_lengths.float())
386
+ losses = {"cardinality_error": card_err}
387
+ return losses
388
+
389
+ def loss_boxes(self, outputs, targets, indices, num_boxes):
390
+ """Compute the losses related to the bounding boxes, the L1 regression loss and the GIoU loss
391
+ targets dicts must contain the key "boxes" containing a tensor of dim [nb_target_boxes, 4]
392
+ The target boxes are expected in format (center_x, center_y, h, w), normalized by the image size.
393
+ """
394
+ assert "pred_boxes" in outputs
395
+ idx = self._get_src_permutation_idx(indices)
396
+ src_boxes = outputs["pred_boxes"][idx]
397
+ target_boxes = torch.cat(
398
+ [t["boxes"][i] for t, (_, i) in zip(targets, indices)], dim=0
399
+ )
400
+
401
+ loss_bbox = F.l1_loss(src_boxes, target_boxes, reduction="none")
402
+
403
+ losses = {}
404
+ losses["loss_bbox"] = loss_bbox.sum() / num_boxes
405
+
406
+ loss_giou = 1 - torch.diag(
407
+ box_ops.generalized_box_iou(
408
+ box_ops.box_cxcywh_to_xyxy(src_boxes),
409
+ box_ops.box_cxcywh_to_xyxy(target_boxes),
410
+ )
411
+ )
412
+ losses["loss_giou"] = loss_giou.sum() / num_boxes
413
+ return losses
414
+
415
+ def loss_masks(self, outputs, targets, indices, num_boxes):
416
+ """Compute the losses related to the masks: the focal loss and the dice loss.
417
+ targets dicts must contain the key "masks" containing a tensor of dim [nb_target_boxes, h, w]
418
+ """
419
+ assert "pred_masks" in outputs
420
+
421
+ src_idx = self._get_src_permutation_idx(indices)
422
+ tgt_idx = self._get_tgt_permutation_idx(indices)
423
+
424
+ src_masks = outputs["pred_masks"]
425
+
426
+ # TODO use valid to mask invalid areas due to padding in loss
427
+ target_masks, valid = nested_tensor_from_tensor_list(
428
+ [t["masks"] for t in targets]
429
+ ).decompose()
430
+ target_masks = target_masks.to(src_masks)
431
+
432
+ src_masks = src_masks[src_idx]
433
+ # upsample predictions to the target size
434
+ src_masks = interpolate(
435
+ src_masks[:, None],
436
+ size=target_masks.shape[-2:],
437
+ mode="bilinear",
438
+ align_corners=False,
439
+ )
440
+ src_masks = src_masks[:, 0].flatten(1)
441
+
442
+ target_masks = target_masks[tgt_idx].flatten(1)
443
+
444
+ losses = {
445
+ "loss_mask": sigmoid_focal_loss(src_masks, target_masks, num_boxes),
446
+ "loss_dice": dice_loss(src_masks, target_masks, num_boxes),
447
+ }
448
+ return losses
449
+
450
+ def _get_src_permutation_idx(self, indices):
451
+ # permute predictions following indices
452
+ batch_idx = torch.cat(
453
+ [torch.full_like(src, i) for i, (src, _) in enumerate(indices)]
454
+ )
455
+ src_idx = torch.cat([src for (src, _) in indices])
456
+ return batch_idx, src_idx
457
+
458
+ def _get_tgt_permutation_idx(self, indices):
459
+ # permute targets following indices
460
+ batch_idx = torch.cat(
461
+ [torch.full_like(tgt, i) for i, (_, tgt) in enumerate(indices)]
462
+ )
463
+ tgt_idx = torch.cat([tgt for (_, tgt) in indices])
464
+ return batch_idx, tgt_idx
465
+
466
+ def get_loss(self, loss, outputs, targets, indices, num_boxes, **kwargs):
467
+ loss_map = {
468
+ "labels": self.loss_labels,
469
+ "cardinality": self.loss_cardinality,
470
+ "boxes": self.loss_boxes,
471
+ "masks": self.loss_masks,
472
+ }
473
+ assert loss in loss_map, f"do you really want to compute {loss} loss?"
474
+ return loss_map[loss](outputs, targets, indices, num_boxes, **kwargs)
475
+
476
+ def forward(self, outputs, targets):
477
+ """This performs the loss computation.
478
+ Parameters:
479
+ outputs: dict of tensors, see the output specification of the model for the format
480
+ targets: list of dicts, such that len(targets) == batch_size.
481
+ The expected keys in each dict depends on the losses applied, see each loss' doc
482
+ """
483
+ outputs_without_aux = {
484
+ k: v
485
+ for k, v in outputs.items()
486
+ if k != "aux_outputs" and k != "enc_outputs"
487
+ }
488
+
489
+ # Retrieve the matching between the outputs of the last layer and the targets
490
+ if self.assign_second_stage:
491
+ indices = self.stg2_assigner(outputs_without_aux, targets)
492
+ else:
493
+ indices = self.matcher(outputs_without_aux, targets)
494
+
495
+ # Compute the average number of target boxes accross all nodes, for normalization purposes
496
+ num_boxes = sum(len(t["labels"]) for t in targets)
497
+ num_boxes = torch.as_tensor(
498
+ [num_boxes], dtype=torch.float, device=next(iter(outputs.values())).device
499
+ )
500
+ if is_dist_avail_and_initialized():
501
+ torch.distributed.all_reduce(num_boxes)
502
+ num_boxes = torch.clamp(num_boxes / get_world_size(), min=1).item()
503
+
504
+ # Compute all the requested losses
505
+ losses = {}
506
+ for loss in self.losses:
507
+ kwargs = {}
508
+ losses.update(
509
+ self.get_loss(loss, outputs, targets, indices, num_boxes, **kwargs)
510
+ )
511
+
512
+ # In case of auxiliary losses, we repeat this process with the output of each intermediate layer.
513
+ if "aux_outputs" in outputs:
514
+ for i, aux_outputs in enumerate(outputs["aux_outputs"]):
515
+ if not self.assign_second_stage:
516
+ indices = self.matcher(aux_outputs, targets)
517
+ for loss in self.losses:
518
+ if loss == "masks":
519
+ # Intermediate masks losses are too costly to compute, we ignore them.
520
+ continue
521
+ kwargs = {}
522
+ if loss == "labels":
523
+ # Logging is enabled only for the last layer
524
+ kwargs["log"] = False
525
+ l_dict = self.get_loss(
526
+ loss, aux_outputs, targets, indices, num_boxes, **kwargs
527
+ )
528
+ l_dict = {k + f"_{i}": v for k, v in l_dict.items()}
529
+ losses.update(l_dict)
530
+
531
+ if "enc_outputs" in outputs:
532
+ enc_outputs = outputs["enc_outputs"]
533
+ bin_targets = copy.deepcopy(targets)
534
+ for bt in bin_targets:
535
+ bt["labels"] = torch.zeros_like(bt["labels"])
536
+ if self.assign_first_stage:
537
+ indices = self.stg1_assigner(enc_outputs, bin_targets)
538
+ else:
539
+ indices = self.matcher(enc_outputs, bin_targets)
540
+ for loss in self.losses:
541
+ if loss == "masks":
542
+ # Intermediate masks losses are too costly to compute, we ignore them.
543
+ continue
544
+ kwargs = {}
545
+ if loss == "labels":
546
+ # Logging is enabled only for the last layer
547
+ kwargs["log"] = False
548
+ l_dict = self.get_loss(
549
+ loss, enc_outputs, bin_targets, indices, num_boxes, **kwargs
550
+ )
551
+ l_dict = {k + f"_enc": v for k, v in l_dict.items()}
552
+ losses.update(l_dict)
553
+
554
+ return losses
555
+
556
+
557
+ class PostProcess(nn.Module):
558
+ """This module converts the model's output into the format expected by the coco api"""
559
+
560
+ @torch.no_grad()
561
+ def forward(self, outputs, target_sizes, num_topk=100):
562
+ """Perform the computation
563
+ Parameters:
564
+ outputs: raw outputs of the model
565
+ target_sizes: tensor of dimension [batch_size x 2] containing the size of each images of the batch
566
+ For evaluation, this must be the original image size (before any data augmentation)
567
+ For visualization, this should be the image size after data augment, but before padding
568
+ """
569
+ out_logits, out_bbox = outputs["pred_logits"], outputs["pred_boxes"]
570
+
571
+ assert len(out_logits) == len(target_sizes)
572
+ assert target_sizes.shape[1] == 2
573
+
574
+ prob = out_logits.sigmoid()
575
+ topk_values, topk_indexes = torch.topk(
576
+ prob.view(out_logits.shape[0], -1), num_topk, dim=1
577
+ )
578
+ scores = topk_values
579
+ topk_boxes = topk_indexes // out_logits.shape[2]
580
+ labels = topk_indexes % out_logits.shape[2]
581
+ boxes = box_ops.box_cxcywh_to_xyxy(out_bbox)
582
+ boxes = torch.gather(boxes, 1, topk_boxes.unsqueeze(-1).repeat(1, 1, 4))
583
+
584
+ # and from relative [0, 1] to absolute [0, height] coordinates
585
+ img_h, img_w = target_sizes.unbind(1)
586
+ scale_fct = torch.stack([img_w, img_h, img_w, img_h], dim=1)
587
+ boxes = boxes * scale_fct[:, None, :]
588
+
589
+ results = [
590
+ {"scores": s, "labels": l, "boxes": b}
591
+ for s, l, b in zip(scores, labels, boxes)
592
+ ]
593
+
594
+ return results
595
+
596
+
597
+ class NMSPostProcess(nn.Module):
598
+ """This module converts the model's output into the format expected by the coco api"""
599
+
600
+ @torch.no_grad()
601
+ def forward(
602
+ self,
603
+ outputs,
604
+ target_sizes,
605
+ num_topk=100,
606
+ soft_nms=False,
607
+ nms_thresh=0.7,
608
+ method="quad",
609
+ quad_scale=1.0,
610
+ ):
611
+ """Perform the computation
612
+ Parameters:
613
+ outputs: raw outputs of the model
614
+ target_sizes: tensor of dimension [batch_size x 2] containing the size of each images of the batch
615
+ For evaluation, this must be the original image size (before any data augmentation)
616
+ For visualization, this should be the image size after data augment, but before padding
617
+ """
618
+ out_logits, out_bbox = outputs["pred_logits"], outputs["pred_boxes"]
619
+ bs, n_queries, n_cls = out_logits.shape
620
+
621
+ assert len(out_logits) == len(target_sizes)
622
+ assert target_sizes.shape[1] == 2
623
+
624
+ prob = out_logits.sigmoid()
625
+
626
+ all_scores = prob.view(bs, n_queries * n_cls).to(out_logits.device)
627
+ all_indexes = (
628
+ torch.arange(n_queries * n_cls)[None].repeat(bs, 1).to(out_logits.device)
629
+ )
630
+ all_boxes = all_indexes // out_logits.shape[2]
631
+ all_labels = all_indexes % out_logits.shape[2]
632
+
633
+ boxes = box_ops.box_cxcywh_to_xyxy(out_bbox)
634
+ boxes = torch.gather(boxes, 1, all_boxes.unsqueeze(-1).repeat(1, 1, 4))
635
+
636
+ # and from relative [0, 1] to absolute [0, height] coordinates
637
+ img_h, img_w = target_sizes.unbind(1)
638
+ scale_fct = torch.stack([img_w, img_h, img_w, img_h], dim=1)
639
+ boxes = boxes * scale_fct[:, None, :]
640
+
641
+ results = []
642
+ for b in range(bs):
643
+ box = boxes[b]
644
+ score = all_scores[b]
645
+ lbls = all_labels[b]
646
+
647
+ if soft_nms:
648
+ if n_queries * n_cls > 2000:
649
+ pre_topk = score.topk(2000).indices
650
+ box = box[pre_topk]
651
+ score = score[pre_topk]
652
+ lbls = lbls[pre_topk]
653
+ # Apply soft-NMS to get indices and updated scores
654
+ keep_inds, updated_scores = batched_soft_nms(
655
+ box,
656
+ score,
657
+ lbls,
658
+ nms_thresh,
659
+ method=method,
660
+ quad_scale=quad_scale,
661
+ )[:num_topk]
662
+
663
+ results.append(
664
+ {
665
+ "scores": updated_scores,
666
+ "labels": lbls[keep_inds],
667
+ "boxes": box[keep_inds],
668
+ }
669
+ )
670
+ else:
671
+ if n_queries * n_cls > 10000:
672
+ pre_topk = score.topk(10000).indices
673
+ box = box[pre_topk]
674
+ score = score[pre_topk]
675
+ lbls = lbls[pre_topk]
676
+ keep_inds = batched_nms(box, score, lbls, nms_thresh)[:num_topk]
677
+ results.append(
678
+ {
679
+ "scores": score[keep_inds],
680
+ "labels": lbls[keep_inds],
681
+ "boxes": box[keep_inds],
682
+ }
683
+ )
684
+
685
+ return results
686
+
687
+
688
+ class MLP(nn.Module):
689
+ """Very simple multi-layer perceptron (also called FFN)"""
690
+
691
+ def __init__(self, input_dim, hidden_dim, output_dim, num_layers):
692
+ super().__init__()
693
+ self.num_layers = num_layers
694
+ h = [hidden_dim] * (num_layers - 1)
695
+ self.layers = nn.ModuleList(
696
+ nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim])
697
+ )
698
+
699
+ def forward(self, x):
700
+ for i, layer in enumerate(self.layers):
701
+ x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x)
702
+ return x
703
+
704
+
705
+ def build(args):
706
+ # num_classes = 20 if args.dataset_file != 'coco' else 91
707
+ if args.dataset_file == "coco_panoptic":
708
+ num_classes = 250
709
+ elif args.dataset_file == "voc":
710
+ num_classes = 20
711
+ elif args.dataset_file == "objects365":
712
+ num_classes = 366
713
+ elif args.dataset_file == "lvis":
714
+ num_classes = 1204
715
+ else: # coco
716
+ num_classes = 91
717
+ device = torch.device(args.device)
718
+
719
+ backbone = build_backbone(args)
720
+
721
+ transformer = build_deforamble_transformer(args)
722
+ model = DeformableDETR(
723
+ backbone,
724
+ transformer,
725
+ num_classes=num_classes,
726
+ num_queries=args.num_queries,
727
+ num_feature_levels=args.num_feature_levels,
728
+ aux_loss=args.aux_loss,
729
+ with_box_refine=args.with_box_refine,
730
+ two_stage=args.two_stage,
731
+ )
732
+ if args.masks:
733
+ model = DETRsegm(model, freeze_detr=(args.frozen_weights is not None))
734
+ matcher = build_matcher(args)
735
+ weight_dict = {"loss_ce": args.cls_loss_coef, "loss_bbox": args.bbox_loss_coef}
736
+ weight_dict["loss_giou"] = args.giou_loss_coef
737
+ if args.masks:
738
+ weight_dict["loss_mask"] = args.mask_loss_coef
739
+ weight_dict["loss_dice"] = args.dice_loss_coef
740
+ # TODO this is a hack
741
+ if args.aux_loss:
742
+ aux_weight_dict = {}
743
+ for i in range(args.dec_layers - 1):
744
+ aux_weight_dict.update({k + f"_{i}": v for k, v in weight_dict.items()})
745
+ aux_weight_dict.update({k + f"_enc": v for k, v in weight_dict.items()})
746
+ weight_dict.update(aux_weight_dict)
747
+
748
+ losses = ["labels", "boxes", "cardinality"]
749
+ if args.masks:
750
+ losses += ["masks"]
751
+ # num_classes, matcher, weight_dict, losses, focal_alpha=0.25
752
+ criterion = SetCriterion(
753
+ num_classes,
754
+ matcher,
755
+ weight_dict,
756
+ losses,
757
+ focal_alpha=args.focal_alpha,
758
+ num_queries=args.num_queries,
759
+ assign_first_stage=args.assign_first_stage,
760
+ assign_second_stage=args.assign_second_stage,
761
+ use_fed_loss=args.use_fed_loss,
762
+ )
763
+ criterion.to(device)
764
+ if args.assign_second_stage:
765
+ postprocessors = {"bbox": NMSPostProcess()}
766
+ else:
767
+ postprocessors = {"bbox": PostProcess()}
768
+ if args.masks:
769
+ postprocessors["segm"] = PostProcessSegm()
770
+ if args.dataset_file == "coco_panoptic":
771
+ is_thing_map = {i: i <= 90 for i in range(201)}
772
+ postprocessors["panoptic"] = PostProcessPanoptic(
773
+ is_thing_map, threshold=0.85
774
+ )
775
+
776
+ return model, criterion, postprocessors
perception_models/apps/detection/DETA_pe/models/deformable_transformer.py ADDED
@@ -0,0 +1,451 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ------------------------------------------------------------------------
2
+ # Deformable DETR
3
+ # Copyright (c) 2020 SenseTime. All Rights Reserved.
4
+ # Licensed under the Apache License, Version 2.0 [see LICENSE for details]
5
+ # ------------------------------------------------------------------------
6
+ # Modified from DETR (https://github.com/facebookresearch/detr)
7
+ # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
8
+ # ------------------------------------------------------------------------
9
+
10
+ import copy
11
+ from typing import Optional, List
12
+ import math
13
+
14
+ import torch
15
+ import torch.nn.functional as F
16
+ from torch import nn, Tensor
17
+ from torch.nn.init import xavier_uniform_, constant_, uniform_, normal_
18
+
19
+ from util.misc import inverse_sigmoid
20
+ from models.ops.modules import MSDeformAttn
21
+
22
+ from torchvision.ops.boxes import batched_nms
23
+ from util.box_ops import box_cxcywh_to_xyxy
24
+
25
+ class DeformableTransformer(nn.Module):
26
+ def __init__(self, d_model=256, nhead=8,
27
+ num_encoder_layers=6, num_decoder_layers=6, dim_feedforward=1024, dropout=0.1,
28
+ activation="relu", return_intermediate_dec=False,
29
+ num_feature_levels=4, dec_n_points=4, enc_n_points=4,
30
+ two_stage=False, two_stage_num_proposals=300,
31
+ assign_first_stage=False):
32
+ super().__init__()
33
+
34
+ self.d_model = d_model
35
+ self.nhead = nhead
36
+ self.two_stage = two_stage
37
+ self.two_stage_num_proposals = two_stage_num_proposals
38
+ self.assign_first_stage = assign_first_stage
39
+
40
+ encoder_layer = DeformableTransformerEncoderLayer(d_model, dim_feedforward,
41
+ dropout, activation,
42
+ num_feature_levels, nhead, enc_n_points)
43
+ self.encoder = DeformableTransformerEncoder(encoder_layer, num_encoder_layers)
44
+
45
+ decoder_layer = DeformableTransformerDecoderLayer(d_model, dim_feedforward,
46
+ dropout, activation,
47
+ num_feature_levels, nhead, dec_n_points)
48
+ self.decoder = DeformableTransformerDecoder(decoder_layer, num_decoder_layers, return_intermediate_dec)
49
+
50
+ self.level_embed = nn.Parameter(torch.Tensor(num_feature_levels, d_model))
51
+
52
+ if two_stage:
53
+ self.enc_output = nn.Linear(d_model, d_model)
54
+ self.enc_output_norm = nn.LayerNorm(d_model)
55
+ self.pos_trans = nn.Linear(d_model * 2, d_model * 2)
56
+ self.pos_trans_norm = nn.LayerNorm(d_model * 2)
57
+ self.pix_trans = nn.Linear(d_model, d_model)
58
+ self.pix_trans_norm = nn.LayerNorm(d_model)
59
+ else:
60
+ self.reference_points = nn.Linear(d_model, 2)
61
+
62
+ self._reset_parameters()
63
+
64
+ def _reset_parameters(self):
65
+ for p in self.parameters():
66
+ if p.dim() > 1:
67
+ nn.init.xavier_uniform_(p)
68
+ for m in self.modules():
69
+ if isinstance(m, MSDeformAttn):
70
+ m._reset_parameters()
71
+ if not self.two_stage:
72
+ xavier_uniform_(self.reference_points.weight.data, gain=1.0)
73
+ constant_(self.reference_points.bias.data, 0.)
74
+ normal_(self.level_embed)
75
+
76
+ def get_proposal_pos_embed(self, proposals):
77
+ num_pos_feats = 128
78
+ temperature = 10000
79
+ scale = 2 * math.pi
80
+
81
+ dim_t = torch.arange(num_pos_feats, dtype=torch.float32, device=proposals.device)
82
+ dim_t = temperature ** (2 * (dim_t // 2) / num_pos_feats)
83
+ # N, L, 4
84
+ proposals = proposals.sigmoid() * scale
85
+ # N, L, 4, 128
86
+ pos = proposals[:, :, :, None] / dim_t
87
+ # N, L, 4, 64, 2
88
+ pos = torch.stack((pos[:, :, :, 0::2].sin(), pos[:, :, :, 1::2].cos()), dim=4).flatten(2)
89
+ return pos
90
+
91
+ def gen_encoder_output_proposals(self, memory, memory_padding_mask, spatial_shapes):
92
+ N_, S_, C_ = memory.shape
93
+ base_scale = 4.0
94
+ proposals = []
95
+ _cur = 0
96
+ level_ids = []
97
+ for lvl, (H_, W_) in enumerate(spatial_shapes):
98
+ mask_flatten_ = memory_padding_mask[:, _cur:(_cur + H_ * W_)].view(N_, H_, W_, 1)
99
+ valid_H = torch.sum(~mask_flatten_[:, :, 0, 0], 1)
100
+ valid_W = torch.sum(~mask_flatten_[:, 0, :, 0], 1)
101
+
102
+ grid_y, grid_x = torch.meshgrid(torch.linspace(0, H_ - 1, H_, dtype=torch.float32, device=memory.device),
103
+ torch.linspace(0, W_ - 1, W_, dtype=torch.float32, device=memory.device))
104
+ grid = torch.cat([grid_x.unsqueeze(-1), grid_y.unsqueeze(-1)], -1)
105
+
106
+ scale = torch.cat([valid_W.unsqueeze(-1), valid_H.unsqueeze(-1)], 1).view(N_, 1, 1, 2)
107
+ grid = (grid.unsqueeze(0).expand(N_, -1, -1, -1) + 0.5) / scale
108
+ wh = torch.ones_like(grid) * 0.05 * (2.0 ** lvl)
109
+ proposal = torch.cat((grid, wh), -1).view(N_, -1, 4)
110
+ proposals.append(proposal)
111
+ _cur += (H_ * W_)
112
+ level_ids.append(grid.new_ones(H_ * W_, dtype=torch.long) * lvl)
113
+ output_proposals = torch.cat(proposals, 1)
114
+ output_proposals_valid = ((output_proposals > 0.01) & (output_proposals < 0.99)).all(-1, keepdim=True)
115
+ output_proposals = torch.log(output_proposals / (1 - output_proposals))
116
+ output_proposals = output_proposals.masked_fill(memory_padding_mask.unsqueeze(-1), float('inf'))
117
+ output_proposals = output_proposals.masked_fill(~output_proposals_valid, float('inf'))
118
+
119
+ output_memory = memory
120
+ output_memory = output_memory.masked_fill(memory_padding_mask.unsqueeze(-1), float(0))
121
+ output_memory = output_memory.masked_fill(~output_proposals_valid, float(0))
122
+ output_memory = self.enc_output_norm(self.enc_output(output_memory))
123
+ level_ids = torch.cat(level_ids)
124
+ return output_memory, output_proposals, level_ids
125
+
126
+ def get_valid_ratio(self, mask):
127
+ _, H, W = mask.shape
128
+ valid_H = torch.sum(~mask[:, :, 0], 1)
129
+ valid_W = torch.sum(~mask[:, 0, :], 1)
130
+ valid_ratio_h = valid_H.float() / H
131
+ valid_ratio_w = valid_W.float() / W
132
+ valid_ratio = torch.stack([valid_ratio_w, valid_ratio_h], -1)
133
+ return valid_ratio
134
+
135
+ def forward(self, srcs, masks, pos_embeds, query_embed=None):
136
+ assert self.two_stage or query_embed is not None
137
+
138
+ # prepare input for encoder
139
+ src_flatten = []
140
+ mask_flatten = []
141
+ lvl_pos_embed_flatten = []
142
+ spatial_shapes = []
143
+ for lvl, (src, mask, pos_embed) in enumerate(zip(srcs, masks, pos_embeds)):
144
+ bs, c, h, w = src.shape
145
+ spatial_shape = (h, w)
146
+ spatial_shapes.append(spatial_shape)
147
+ src = src.flatten(2).transpose(1, 2)
148
+ mask = mask.flatten(1)
149
+ pos_embed = pos_embed.flatten(2).transpose(1, 2)
150
+ lvl_pos_embed = pos_embed + self.level_embed[lvl].view(1, 1, -1)
151
+ lvl_pos_embed_flatten.append(lvl_pos_embed)
152
+ src_flatten.append(src)
153
+ mask_flatten.append(mask)
154
+ src_flatten = torch.cat(src_flatten, 1)
155
+ mask_flatten = torch.cat(mask_flatten, 1)
156
+ lvl_pos_embed_flatten = torch.cat(lvl_pos_embed_flatten, 1)
157
+ spatial_shapes = torch.as_tensor(spatial_shapes, dtype=torch.long, device=src_flatten.device)
158
+ level_start_index = torch.cat((spatial_shapes.new_zeros((1, )), spatial_shapes.prod(1).cumsum(0)[:-1]))
159
+ valid_ratios = torch.stack([self.get_valid_ratio(m) for m in masks], 1)
160
+
161
+ # encoder
162
+ memory = self.encoder(src_flatten, spatial_shapes, level_start_index, valid_ratios, lvl_pos_embed_flatten, mask_flatten)
163
+
164
+ # prepare input for decoder
165
+ bs, _, c = memory.shape
166
+ if self.two_stage:
167
+ output_memory, output_proposals, level_ids = self.gen_encoder_output_proposals(memory, mask_flatten, spatial_shapes)
168
+
169
+ # hack implementation for two-stage Deformable DETR
170
+ enc_outputs_class = self.decoder.class_embed[self.decoder.num_layers](output_memory)
171
+ enc_outputs_coord_unact = self.decoder.bbox_embed[self.decoder.num_layers](output_memory) + output_proposals
172
+
173
+ topk = self.two_stage_num_proposals
174
+ proposal_logit = enc_outputs_class[..., 0]
175
+
176
+ if self.assign_first_stage:
177
+ proposal_boxes = box_cxcywh_to_xyxy(enc_outputs_coord_unact.sigmoid().float()).clamp(0, 1)
178
+ topk_proposals = []
179
+ for b in range(bs):
180
+ prop_boxes_b = proposal_boxes[b]
181
+ prop_logits_b = proposal_logit[b]
182
+
183
+ # pre-nms per-level topk
184
+ pre_nms_topk = 1000
185
+ pre_nms_inds = []
186
+ for lvl in range(len(spatial_shapes)):
187
+ lvl_mask = level_ids == lvl
188
+ pre_nms_inds.append(torch.topk(prop_logits_b.sigmoid() * lvl_mask, pre_nms_topk)[1])
189
+ pre_nms_inds = torch.cat(pre_nms_inds)
190
+
191
+ # nms on topk indices
192
+ post_nms_inds = batched_nms(prop_boxes_b[pre_nms_inds], prop_logits_b[pre_nms_inds], level_ids[pre_nms_inds], 0.9)
193
+ keep_inds = pre_nms_inds[post_nms_inds]
194
+
195
+ if len(keep_inds) < self.two_stage_num_proposals:
196
+ print(f'[WARNING] nms proposals ({len(keep_inds)}) < {self.two_stage_num_proposals}, running naive topk')
197
+ keep_inds = torch.topk(proposal_logit[b], topk)[1]
198
+
199
+ # keep top Q/L indices for L levels
200
+ q_per_l = topk // len(spatial_shapes)
201
+ is_level_ordered = level_ids[keep_inds][None] == torch.arange(len(spatial_shapes), device=level_ids.device)[:,None] # LS
202
+ keep_inds_mask = is_level_ordered & (is_level_ordered.cumsum(1) <= q_per_l) # LS
203
+ keep_inds_mask = keep_inds_mask.any(0) # S
204
+
205
+ # pad to Q indices (might let ones filtered from pre-nms sneak by... unlikely because we pick high conf anyways)
206
+ if keep_inds_mask.sum() < topk:
207
+ num_to_add = topk - keep_inds_mask.sum()
208
+ pad_inds = (~keep_inds_mask).nonzero()[:num_to_add]
209
+ keep_inds_mask[pad_inds] = True
210
+
211
+ # index
212
+ keep_inds_topk = keep_inds[keep_inds_mask]
213
+ topk_proposals.append(keep_inds_topk)
214
+ topk_proposals = torch.stack(topk_proposals)
215
+ else:
216
+ topk_proposals = torch.topk(proposal_logit, topk, dim=1)[1]
217
+
218
+ topk_coords_unact = torch.gather(enc_outputs_coord_unact, 1, topk_proposals.unsqueeze(-1).repeat(1, 1, 4))
219
+ topk_coords_unact = topk_coords_unact.detach()
220
+ reference_points = topk_coords_unact.sigmoid()
221
+ init_reference_out = reference_points
222
+ pos_trans_out = self.pos_trans_norm(self.pos_trans(self.get_proposal_pos_embed(topk_coords_unact)))
223
+ query_embed, tgt = torch.split(pos_trans_out, c, dim=2)
224
+
225
+ topk_feats = torch.stack([output_memory[b][topk_proposals[b]] for b in range(bs)]).detach()
226
+ tgt = tgt + self.pix_trans_norm(self.pix_trans(topk_feats))
227
+ else:
228
+ query_embed, tgt = torch.split(query_embed, c, dim=1)
229
+ query_embed = query_embed.unsqueeze(0).expand(bs, -1, -1)
230
+ tgt = tgt.unsqueeze(0).expand(bs, -1, -1)
231
+ reference_points = self.reference_points(query_embed).sigmoid()
232
+ init_reference_out = reference_points
233
+
234
+ # decoder
235
+ hs, inter_references = self.decoder(tgt, reference_points, memory,
236
+ spatial_shapes, level_start_index, valid_ratios, query_embed, mask_flatten)
237
+
238
+ inter_references_out = inter_references
239
+ if self.two_stage:
240
+ return hs, init_reference_out, inter_references_out, enc_outputs_class, enc_outputs_coord_unact, output_proposals.sigmoid()
241
+ return hs, init_reference_out, inter_references_out, None, None, None
242
+
243
+
244
+ class DeformableTransformerEncoderLayer(nn.Module):
245
+ def __init__(self,
246
+ d_model=256, d_ffn=1024,
247
+ dropout=0.1, activation="relu",
248
+ n_levels=4, n_heads=8, n_points=4):
249
+ super().__init__()
250
+
251
+ # self attention
252
+ self.self_attn = MSDeformAttn(d_model, n_levels, n_heads, n_points)
253
+ self.dropout1 = nn.Dropout(dropout)
254
+ self.norm1 = nn.LayerNorm(d_model)
255
+
256
+ # ffn
257
+ self.linear1 = nn.Linear(d_model, d_ffn)
258
+ self.activation = _get_activation_fn(activation)
259
+ self.dropout2 = nn.Dropout(dropout)
260
+ self.linear2 = nn.Linear(d_ffn, d_model)
261
+ self.dropout3 = nn.Dropout(dropout)
262
+ self.norm2 = nn.LayerNorm(d_model)
263
+
264
+ @staticmethod
265
+ def with_pos_embed(tensor, pos):
266
+ return tensor if pos is None else tensor + pos
267
+
268
+ def forward_ffn(self, src):
269
+ src2 = self.linear2(self.dropout2(self.activation(self.linear1(src))))
270
+ src = src + self.dropout3(src2)
271
+ src = self.norm2(src)
272
+ return src
273
+
274
+ def forward(self, src, pos, reference_points, spatial_shapes, level_start_index, padding_mask=None):
275
+ # self attention
276
+ src2 = self.self_attn(self.with_pos_embed(src, pos), reference_points, src, spatial_shapes, level_start_index, padding_mask)
277
+ src = src + self.dropout1(src2)
278
+ src = self.norm1(src)
279
+
280
+ # ffn
281
+ src = self.forward_ffn(src)
282
+
283
+ return src
284
+
285
+
286
+ class DeformableTransformerEncoder(nn.Module):
287
+ def __init__(self, encoder_layer, num_layers):
288
+ super().__init__()
289
+ self.layers = _get_clones(encoder_layer, num_layers)
290
+ self.num_layers = num_layers
291
+
292
+ @staticmethod
293
+ def get_reference_points(spatial_shapes, valid_ratios, device):
294
+ reference_points_list = []
295
+ for lvl, (H_, W_) in enumerate(spatial_shapes):
296
+
297
+ ref_y, ref_x = torch.meshgrid(torch.linspace(0.5, H_ - 0.5, H_, dtype=torch.float32, device=device),
298
+ torch.linspace(0.5, W_ - 0.5, W_, dtype=torch.float32, device=device))
299
+ ref_y = ref_y.reshape(-1)[None] / (valid_ratios[:, None, lvl, 1] * H_)
300
+ ref_x = ref_x.reshape(-1)[None] / (valid_ratios[:, None, lvl, 0] * W_)
301
+ ref = torch.stack((ref_x, ref_y), -1)
302
+ reference_points_list.append(ref)
303
+ reference_points = torch.cat(reference_points_list, 1)
304
+ reference_points = reference_points[:, :, None] * valid_ratios[:, None]
305
+ return reference_points
306
+
307
+ def forward(self, src, spatial_shapes, level_start_index, valid_ratios, pos=None, padding_mask=None):
308
+ output = src
309
+ reference_points = self.get_reference_points(spatial_shapes, valid_ratios, device=src.device)
310
+ for _, layer in enumerate(self.layers):
311
+ output = layer(output, pos, reference_points, spatial_shapes, level_start_index, padding_mask)
312
+
313
+ return output
314
+
315
+
316
+ class DeformableTransformerDecoderLayer(nn.Module):
317
+ def __init__(self, d_model=256, d_ffn=1024,
318
+ dropout=0.1, activation="relu",
319
+ n_levels=4, n_heads=8, n_points=4):
320
+ super().__init__()
321
+
322
+ # cross attention
323
+ self.cross_attn = MSDeformAttn(d_model, n_levels, n_heads, n_points)
324
+ self.dropout1 = nn.Dropout(dropout)
325
+ self.norm1 = nn.LayerNorm(d_model)
326
+
327
+ # self attention
328
+ self.self_attn = nn.MultiheadAttention(d_model, n_heads, dropout=dropout)
329
+ self.dropout2 = nn.Dropout(dropout)
330
+ self.norm2 = nn.LayerNorm(d_model)
331
+
332
+ # ffn
333
+ self.linear1 = nn.Linear(d_model, d_ffn)
334
+ self.activation = _get_activation_fn(activation)
335
+ self.dropout3 = nn.Dropout(dropout)
336
+ self.linear2 = nn.Linear(d_ffn, d_model)
337
+ self.dropout4 = nn.Dropout(dropout)
338
+ self.norm3 = nn.LayerNorm(d_model)
339
+
340
+ @staticmethod
341
+ def with_pos_embed(tensor, pos):
342
+ return tensor if pos is None else tensor + pos
343
+
344
+ def forward_ffn(self, tgt):
345
+ tgt2 = self.linear2(self.dropout3(self.activation(self.linear1(tgt))))
346
+ tgt = tgt + self.dropout4(tgt2)
347
+ tgt = self.norm3(tgt)
348
+ return tgt
349
+
350
+ def forward(self, tgt, query_pos, reference_points, src, src_spatial_shapes, level_start_index, src_padding_mask=None):
351
+ # self attention
352
+ q = k = self.with_pos_embed(tgt, query_pos)
353
+ tgt2 = self.self_attn(q.transpose(0, 1), k.transpose(0, 1), tgt.transpose(0, 1))[0].transpose(0, 1)
354
+ tgt = tgt + self.dropout2(tgt2)
355
+ tgt = self.norm2(tgt)
356
+
357
+ # cross attention
358
+ tgt2 = self.cross_attn(self.with_pos_embed(tgt, query_pos),
359
+ reference_points,
360
+ src, src_spatial_shapes, level_start_index, src_padding_mask)
361
+ tgt = tgt + self.dropout1(tgt2)
362
+ tgt = self.norm1(tgt)
363
+
364
+ # ffn
365
+ tgt = self.forward_ffn(tgt)
366
+
367
+ return tgt
368
+
369
+
370
+ class DeformableTransformerDecoder(nn.Module):
371
+ def __init__(self, decoder_layer, num_layers, return_intermediate=False):
372
+ super().__init__()
373
+ self.layers = _get_clones(decoder_layer, num_layers)
374
+ self.num_layers = num_layers
375
+ self.return_intermediate = return_intermediate
376
+ # hack implementation for iterative bounding box refinement and two-stage Deformable DETR
377
+ self.bbox_embed = None
378
+ self.class_embed = None
379
+
380
+ def forward(self, tgt, reference_points, src, src_spatial_shapes, src_level_start_index, src_valid_ratios,
381
+ query_pos=None, src_padding_mask=None):
382
+ output = tgt
383
+
384
+ intermediate = []
385
+ intermediate_reference_points = []
386
+ for lid, layer in enumerate(self.layers):
387
+ if reference_points.shape[-1] == 4:
388
+ reference_points_input = reference_points[:, :, None] \
389
+ * torch.cat([src_valid_ratios, src_valid_ratios], -1)[:, None]
390
+ else:
391
+ assert reference_points.shape[-1] == 2
392
+ reference_points_input = reference_points[:, :, None] * src_valid_ratios[:, None]
393
+ output = layer(output, query_pos, reference_points_input, src, src_spatial_shapes, src_level_start_index, src_padding_mask)
394
+
395
+ # hack implementation for iterative bounding box refinement
396
+ if self.bbox_embed is not None:
397
+ tmp = self.bbox_embed[lid](output)
398
+ if reference_points.shape[-1] == 4:
399
+ new_reference_points = tmp + inverse_sigmoid(reference_points)
400
+ new_reference_points = new_reference_points.sigmoid()
401
+ else:
402
+ assert reference_points.shape[-1] == 2
403
+ new_reference_points = tmp
404
+ new_reference_points[..., :2] = tmp[..., :2] + inverse_sigmoid(reference_points)
405
+ new_reference_points = new_reference_points.sigmoid()
406
+ reference_points = new_reference_points.detach()
407
+
408
+ if self.return_intermediate:
409
+ intermediate.append(output)
410
+ intermediate_reference_points.append(reference_points)
411
+
412
+ if self.return_intermediate:
413
+ return torch.stack(intermediate), torch.stack(intermediate_reference_points)
414
+
415
+ return output, reference_points
416
+
417
+
418
+ def _get_clones(module, N):
419
+ return nn.ModuleList([copy.deepcopy(module) for i in range(N)])
420
+
421
+
422
+ def _get_activation_fn(activation):
423
+ """Return an activation function given a string"""
424
+ if activation == "relu":
425
+ return F.relu
426
+ if activation == "gelu":
427
+ return F.gelu
428
+ if activation == "glu":
429
+ return F.glu
430
+ raise RuntimeError(F"activation should be relu/gelu, not {activation}.")
431
+
432
+
433
+ def build_deforamble_transformer(args):
434
+ return DeformableTransformer(
435
+ d_model=args.hidden_dim,
436
+ nhead=args.nheads,
437
+ num_encoder_layers=args.enc_layers,
438
+ num_decoder_layers=args.dec_layers,
439
+ dim_feedforward=args.dim_feedforward,
440
+ dropout=args.dropout,
441
+ activation="relu",
442
+ return_intermediate_dec=True,
443
+ num_feature_levels=args.num_feature_levels,
444
+ dec_n_points=args.dec_n_points,
445
+ enc_n_points=args.enc_n_points,
446
+ two_stage=args.two_stage,
447
+ two_stage_num_proposals=args.num_queries,
448
+ assign_first_stage=args.assign_first_stage,
449
+ )
450
+
451
+
perception_models/apps/detection/DETA_pe/models/matcher.py ADDED
@@ -0,0 +1,102 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ------------------------------------------------------------------------
2
+ # Deformable DETR
3
+ # Copyright (c) 2020 SenseTime. All Rights Reserved.
4
+ # Licensed under the Apache License, Version 2.0 [see LICENSE for details]
5
+ # ------------------------------------------------------------------------
6
+ # Modified from DETR (https://github.com/facebookresearch/detr)
7
+ # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
8
+ # ------------------------------------------------------------------------
9
+
10
+ """
11
+ Modules to compute the matching cost and solve the corresponding LSAP.
12
+ """
13
+ import torch
14
+ from scipy.optimize import linear_sum_assignment
15
+ from torch import nn
16
+
17
+ from util.box_ops import box_cxcywh_to_xyxy, generalized_box_iou
18
+
19
+
20
+ class HungarianMatcher(nn.Module):
21
+ """This class computes an assignment between the targets and the predictions of the network
22
+
23
+ For efficiency reasons, the targets don't include the no_object. Because of this, in general,
24
+ there are more predictions than targets. In this case, we do a 1-to-1 matching of the best predictions,
25
+ while the others are un-matched (and thus treated as non-objects).
26
+ """
27
+
28
+ def __init__(self,
29
+ cost_class: float = 1,
30
+ cost_bbox: float = 1,
31
+ cost_giou: float = 1):
32
+ """Creates the matcher
33
+
34
+ Params:
35
+ cost_class: This is the relative weight of the classification error in the matching cost
36
+ cost_bbox: This is the relative weight of the L1 error of the bounding box coordinates in the matching cost
37
+ cost_giou: This is the relative weight of the giou loss of the bounding box in the matching cost
38
+ """
39
+ super().__init__()
40
+ self.cost_class = cost_class
41
+ self.cost_bbox = cost_bbox
42
+ self.cost_giou = cost_giou
43
+ assert cost_class != 0 or cost_bbox != 0 or cost_giou != 0, "all costs cant be 0"
44
+
45
+ def forward(self, outputs, targets):
46
+ """ Performs the matching
47
+
48
+ Params:
49
+ outputs: This is a dict that contains at least these entries:
50
+ "pred_logits": Tensor of dim [batch_size, num_queries, num_classes] with the classification logits
51
+ "pred_boxes": Tensor of dim [batch_size, num_queries, 4] with the predicted box coordinates
52
+
53
+ targets: This is a list of targets (len(targets) = batch_size), where each target is a dict containing:
54
+ "labels": Tensor of dim [num_target_boxes] (where num_target_boxes is the number of ground-truth
55
+ objects in the target) containing the class labels
56
+ "boxes": Tensor of dim [num_target_boxes, 4] containing the target box coordinates
57
+
58
+ Returns:
59
+ A list of size batch_size, containing tuples of (index_i, index_j) where:
60
+ - index_i is the indices of the selected predictions (in order)
61
+ - index_j is the indices of the corresponding selected targets (in order)
62
+ For each batch element, it holds:
63
+ len(index_i) = len(index_j) = min(num_queries, num_target_boxes)
64
+ """
65
+ with torch.no_grad():
66
+ bs, num_queries = outputs["pred_logits"].shape[:2]
67
+
68
+ # We flatten to compute the cost matrices in a batch
69
+ out_prob = outputs["pred_logits"].flatten(0, 1).sigmoid()
70
+ out_bbox = outputs["pred_boxes"].flatten(0, 1) # [batch_size * num_queries, 4]
71
+
72
+ # Also concat the target labels and boxes
73
+ tgt_ids = torch.cat([v["labels"] for v in targets])
74
+ tgt_bbox = torch.cat([v["boxes"] for v in targets])
75
+
76
+ # Compute the classification cost.
77
+ alpha = 0.25
78
+ gamma = 2.0
79
+ neg_cost_class = (1 - alpha) * (out_prob ** gamma) * (-(1 - out_prob + 1e-8).log())
80
+ pos_cost_class = alpha * ((1 - out_prob) ** gamma) * (-(out_prob + 1e-8).log())
81
+ cost_class = pos_cost_class[:, tgt_ids] - neg_cost_class[:, tgt_ids]
82
+
83
+ # Compute the L1 cost between boxes
84
+ cost_bbox = torch.cdist(out_bbox, tgt_bbox, p=1)
85
+
86
+ # Compute the giou cost betwen boxes
87
+ cost_giou = -generalized_box_iou(box_cxcywh_to_xyxy(out_bbox),
88
+ box_cxcywh_to_xyxy(tgt_bbox))
89
+
90
+ # Final cost matrix
91
+ C = self.cost_bbox * cost_bbox + self.cost_class * cost_class + self.cost_giou * cost_giou
92
+ C = C.view(bs, num_queries, -1).cpu()
93
+
94
+ sizes = [len(v["boxes"]) for v in targets]
95
+ indices = [linear_sum_assignment(c[i]) for i, c in enumerate(C.split(sizes, -1))]
96
+ return [(torch.as_tensor(i, dtype=torch.int64), torch.as_tensor(j, dtype=torch.int64)) for i, j in indices]
97
+
98
+
99
+ def build_matcher(args):
100
+ return HungarianMatcher(cost_class=args.set_cost_class,
101
+ cost_bbox=args.set_cost_bbox,
102
+ cost_giou=args.set_cost_giou)
perception_models/apps/detection/DETA_pe/models/ops/functions/__init__.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ # ------------------------------------------------------------------------------------------------
2
+ # Deformable DETR
3
+ # Copyright (c) 2020 SenseTime. All Rights Reserved.
4
+ # Licensed under the Apache License, Version 2.0 [see LICENSE for details]
5
+ # ------------------------------------------------------------------------------------------------
6
+ # Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
7
+ # ------------------------------------------------------------------------------------------------
8
+
9
+ from .ms_deform_attn_func import ms_deform_attn_core_pytorch, MSDeformAttnFunction
perception_models/apps/detection/DETA_pe/models/ops/functions/ms_deform_attn_func.py ADDED
@@ -0,0 +1,106 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ------------------------------------------------------------------------------------------------
2
+ # Deformable DETR
3
+ # Copyright (c) 2020 SenseTime. All Rights Reserved.
4
+ # Licensed under the Apache License, Version 2.0 [see LICENSE for details]
5
+ # ------------------------------------------------------------------------------------------------
6
+ # Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
7
+ # ------------------------------------------------------------------------------------------------
8
+
9
+ from __future__ import absolute_import, division, print_function
10
+
11
+ import MultiScaleDeformableAttention as MSDA
12
+
13
+ import torch
14
+ import torch.nn.functional as F
15
+ from torch.autograd import Function
16
+ from torch.autograd.function import once_differentiable
17
+
18
+
19
+ class MSDeformAttnFunction(Function):
20
+ @staticmethod
21
+ def forward(
22
+ ctx,
23
+ value,
24
+ value_spatial_shapes,
25
+ value_level_start_index,
26
+ sampling_locations,
27
+ attention_weights,
28
+ im2col_step,
29
+ ):
30
+ ctx.im2col_step = im2col_step
31
+ output = MSDA.ms_deform_attn_forward(
32
+ value,
33
+ value_spatial_shapes,
34
+ value_level_start_index,
35
+ sampling_locations,
36
+ attention_weights,
37
+ ctx.im2col_step,
38
+ )
39
+ ctx.save_for_backward(
40
+ value,
41
+ value_spatial_shapes,
42
+ value_level_start_index,
43
+ sampling_locations,
44
+ attention_weights,
45
+ )
46
+ return output
47
+
48
+ @staticmethod
49
+ @once_differentiable
50
+ def backward(ctx, grad_output):
51
+ (
52
+ value,
53
+ value_spatial_shapes,
54
+ value_level_start_index,
55
+ sampling_locations,
56
+ attention_weights,
57
+ ) = ctx.saved_tensors
58
+ grad_value, grad_sampling_loc, grad_attn_weight = MSDA.ms_deform_attn_backward(
59
+ value,
60
+ value_spatial_shapes,
61
+ value_level_start_index,
62
+ sampling_locations,
63
+ attention_weights,
64
+ grad_output,
65
+ ctx.im2col_step,
66
+ )
67
+
68
+ return grad_value, None, None, grad_sampling_loc, grad_attn_weight, None
69
+
70
+
71
+ def ms_deform_attn_core_pytorch(
72
+ value, value_spatial_shapes, sampling_locations, attention_weights
73
+ ):
74
+ # for debug and test only,
75
+ # need to use cuda version instead
76
+ N_, S_, M_, D_ = value.shape
77
+ _, Lq_, M_, L_, P_, _ = sampling_locations.shape
78
+ value_list = value.split([H_ * W_ for H_, W_ in value_spatial_shapes], dim=1)
79
+ sampling_grids = 2 * sampling_locations - 1
80
+ sampling_value_list = []
81
+ for lid_, (H_, W_) in enumerate(value_spatial_shapes):
82
+ # N_, H_*W_, M_, D_ -> N_, H_*W_, M_*D_ -> N_, M_*D_, H_*W_ -> N_*M_, D_, H_, W_
83
+ value_l_ = (
84
+ value_list[lid_].flatten(2).transpose(1, 2).reshape(N_ * M_, D_, H_, W_)
85
+ )
86
+ # N_, Lq_, M_, P_, 2 -> N_, M_, Lq_, P_, 2 -> N_*M_, Lq_, P_, 2
87
+ sampling_grid_l_ = sampling_grids[:, :, :, lid_].transpose(1, 2).flatten(0, 1)
88
+ # N_*M_, D_, Lq_, P_
89
+ sampling_value_l_ = F.grid_sample(
90
+ value_l_,
91
+ sampling_grid_l_,
92
+ mode="bilinear",
93
+ padding_mode="zeros",
94
+ align_corners=False,
95
+ )
96
+ sampling_value_list.append(sampling_value_l_)
97
+ # (N_, Lq_, M_, L_, P_) -> (N_, M_, Lq_, L_, P_) -> (N_, M_, 1, Lq_, L_*P_)
98
+ attention_weights = attention_weights.transpose(1, 2).reshape(
99
+ N_ * M_, 1, Lq_, L_ * P_
100
+ )
101
+ output = (
102
+ (torch.stack(sampling_value_list, dim=-2).flatten(-2) * attention_weights)
103
+ .sum(-1)
104
+ .view(N_, M_ * D_, Lq_)
105
+ )
106
+ return output.transpose(1, 2).contiguous()
perception_models/apps/detection/DETA_pe/models/ops/make.sh ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env bash
2
+ # ------------------------------------------------------------------------------------------------
3
+ # Deformable DETR
4
+ # Copyright (c) 2020 SenseTime. All Rights Reserved.
5
+ # Licensed under the Apache License, Version 2.0 [see LICENSE for details]
6
+ # ------------------------------------------------------------------------------------------------
7
+ # Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
8
+ # ------------------------------------------------------------------------------------------------
9
+
10
+ python setup.py build install
perception_models/apps/detection/DETA_pe/models/ops/modules/__init__.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ # ------------------------------------------------------------------------------------------------
2
+ # Deformable DETR
3
+ # Copyright (c) 2020 SenseTime. All Rights Reserved.
4
+ # Licensed under the Apache License, Version 2.0 [see LICENSE for details]
5
+ # ------------------------------------------------------------------------------------------------
6
+ # Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
7
+ # ------------------------------------------------------------------------------------------------
8
+
9
+ from .ms_deform_attn import MSDeformAttn
perception_models/apps/detection/DETA_pe/models/ops/modules/ms_deform_attn.py ADDED
@@ -0,0 +1,161 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ------------------------------------------------------------------------------------------------
2
+ # Deformable DETR
3
+ # Copyright (c) 2020 SenseTime. All Rights Reserved.
4
+ # Licensed under the Apache License, Version 2.0 [see LICENSE for details]
5
+ # ------------------------------------------------------------------------------------------------
6
+ # Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
7
+ # ------------------------------------------------------------------------------------------------
8
+
9
+ from __future__ import absolute_import, division, print_function
10
+
11
+ import math
12
+
13
+ import warnings
14
+
15
+ import torch
16
+ import torch.nn.functional as F
17
+ from torch import nn
18
+ from torch.nn.init import constant_, xavier_uniform_
19
+
20
+ from ..functions import ms_deform_attn_core_pytorch, MSDeformAttnFunction
21
+
22
+
23
+ def _is_power_of_2(n):
24
+ if (not isinstance(n, int)) or (n < 0):
25
+ raise ValueError(
26
+ "invalid input for _is_power_of_2: {} (type: {})".format(n, type(n))
27
+ )
28
+ return (n & (n - 1) == 0) and n != 0
29
+
30
+
31
+ class MSDeformAttn(nn.Module):
32
+ def __init__(self, d_model=256, n_levels=4, n_heads=8, n_points=4):
33
+ """
34
+ Multi-Scale Deformable Attention Module
35
+ :param d_model hidden dimension
36
+ :param n_levels number of feature levels
37
+ :param n_heads number of attention heads
38
+ :param n_points number of sampling points per attention head per feature level
39
+ """
40
+ super().__init__()
41
+ if d_model % n_heads != 0:
42
+ raise ValueError(
43
+ "d_model must be divisible by n_heads, but got {} and {}".format(
44
+ d_model, n_heads
45
+ )
46
+ )
47
+ _d_per_head = d_model // n_heads
48
+ # you'd better set _d_per_head to a power of 2 which is more efficient in our CUDA implementation
49
+ if not _is_power_of_2(_d_per_head):
50
+ warnings.warn(
51
+ "You'd better set d_model in MSDeformAttn to make the dimension of each attention head a power of 2 "
52
+ "which is more efficient in our CUDA implementation."
53
+ )
54
+
55
+ self.im2col_step = 64
56
+
57
+ self.d_model = d_model
58
+ self.n_levels = n_levels
59
+ self.n_heads = n_heads
60
+ self.n_points = n_points
61
+
62
+ self.sampling_offsets = nn.Linear(d_model, n_heads * n_levels * n_points * 2)
63
+ self.attention_weights = nn.Linear(d_model, n_heads * n_levels * n_points)
64
+ self.value_proj = nn.Linear(d_model, d_model)
65
+ self.output_proj = nn.Linear(d_model, d_model)
66
+
67
+ self._reset_parameters()
68
+
69
+ def _reset_parameters(self):
70
+ constant_(self.sampling_offsets.weight.data, 0.0)
71
+ thetas = torch.arange(self.n_heads, dtype=torch.float32) * (
72
+ 2.0 * math.pi / self.n_heads
73
+ )
74
+ grid_init = torch.stack([thetas.cos(), thetas.sin()], -1)
75
+ grid_init = (
76
+ (grid_init / grid_init.abs().max(-1, keepdim=True)[0])
77
+ .view(self.n_heads, 1, 1, 2)
78
+ .repeat(1, self.n_levels, self.n_points, 1)
79
+ )
80
+ for i in range(self.n_points):
81
+ grid_init[:, :, i, :] *= i + 1
82
+ with torch.no_grad():
83
+ self.sampling_offsets.bias = nn.Parameter(grid_init.view(-1))
84
+ constant_(self.attention_weights.weight.data, 0.0)
85
+ constant_(self.attention_weights.bias.data, 0.0)
86
+ xavier_uniform_(self.value_proj.weight.data)
87
+ constant_(self.value_proj.bias.data, 0.0)
88
+ xavier_uniform_(self.output_proj.weight.data)
89
+ constant_(self.output_proj.bias.data, 0.0)
90
+
91
+ def forward(
92
+ self,
93
+ query,
94
+ reference_points,
95
+ input_flatten,
96
+ input_spatial_shapes,
97
+ input_level_start_index,
98
+ input_padding_mask=None,
99
+ ):
100
+ """
101
+ :param query (N, Length_{query}, C)
102
+ :param reference_points (N, Length_{query}, n_levels, 2), range in [0, 1], top-left (0,0), bottom-right (1, 1), including padding area
103
+ or (N, Length_{query}, n_levels, 4), add additional (w, h) to form reference boxes
104
+ :param input_flatten (N, \sum_{l=0}^{L-1} H_l \cdot W_l, C)
105
+ :param input_spatial_shapes (n_levels, 2), [(H_0, W_0), (H_1, W_1), ..., (H_{L-1}, W_{L-1})]
106
+ :param input_level_start_index (n_levels, ), [0, H_0*W_0, H_0*W_0+H_1*W_1, H_0*W_0+H_1*W_1+H_2*W_2, ..., H_0*W_0+H_1*W_1+...+H_{L-1}*W_{L-1}]
107
+ :param input_padding_mask (N, \sum_{l=0}^{L-1} H_l \cdot W_l), True for padding elements, False for non-padding elements
108
+
109
+ :return output (N, Length_{query}, C)
110
+ """
111
+ N, Len_q, _ = query.shape
112
+ N, Len_in, _ = input_flatten.shape
113
+ assert (input_spatial_shapes[:, 0] * input_spatial_shapes[:, 1]).sum() == Len_in
114
+
115
+ value = self.value_proj(input_flatten)
116
+ if input_padding_mask is not None:
117
+ value = value.masked_fill(input_padding_mask[..., None], float(0))
118
+ value = value.view(N, Len_in, self.n_heads, self.d_model // self.n_heads)
119
+ sampling_offsets = self.sampling_offsets(query).view(
120
+ N, Len_q, self.n_heads, self.n_levels, self.n_points, 2
121
+ )
122
+ attention_weights = self.attention_weights(query).view(
123
+ N, Len_q, self.n_heads, self.n_levels * self.n_points
124
+ )
125
+ attention_weights = F.softmax(attention_weights, -1).view(
126
+ N, Len_q, self.n_heads, self.n_levels, self.n_points
127
+ )
128
+ # N, Len_q, n_heads, n_levels, n_points, 2
129
+ if reference_points.shape[-1] == 2:
130
+ offset_normalizer = torch.stack(
131
+ [input_spatial_shapes[..., 1], input_spatial_shapes[..., 0]], -1
132
+ )
133
+ sampling_locations = (
134
+ reference_points[:, :, None, :, None, :]
135
+ + sampling_offsets / offset_normalizer[None, None, None, :, None, :]
136
+ )
137
+ elif reference_points.shape[-1] == 4:
138
+ sampling_locations = (
139
+ reference_points[:, :, None, :, None, :2]
140
+ + sampling_offsets
141
+ / self.n_points
142
+ * reference_points[:, :, None, :, None, 2:]
143
+ * 0.5
144
+ )
145
+ else:
146
+ raise ValueError(
147
+ "Last dim of reference_points must be 2 or 4, but get {} instead.".format(
148
+ reference_points.shape[-1]
149
+ )
150
+ )
151
+ output = MSDeformAttnFunction.apply(
152
+ value,
153
+ input_spatial_shapes,
154
+ input_level_start_index,
155
+ sampling_locations,
156
+ attention_weights,
157
+ self.im2col_step,
158
+ )
159
+
160
+ output = self.output_proj(output)
161
+ return output
perception_models/apps/detection/DETA_pe/models/ops/setup.py ADDED
@@ -0,0 +1,71 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ------------------------------------------------------------------------------------------------
2
+ # Deformable DETR
3
+ # Copyright (c) 2020 SenseTime. All Rights Reserved.
4
+ # Licensed under the Apache License, Version 2.0 [see LICENSE for details]
5
+ # ------------------------------------------------------------------------------------------------
6
+ # Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
7
+ # ------------------------------------------------------------------------------------------------
8
+
9
+ import os
10
+ import glob
11
+
12
+ import torch
13
+
14
+ from torch.utils.cpp_extension import CUDA_HOME
15
+ from torch.utils.cpp_extension import CppExtension
16
+ from torch.utils.cpp_extension import CUDAExtension
17
+
18
+ from setuptools import find_packages
19
+ from setuptools import setup
20
+
21
+ requirements = ["torch", "torchvision"]
22
+
23
+ def get_extensions():
24
+ this_dir = os.path.dirname(os.path.abspath(__file__))
25
+ extensions_dir = os.path.join(this_dir, "src")
26
+
27
+ main_file = glob.glob(os.path.join(extensions_dir, "*.cpp"))
28
+ source_cpu = glob.glob(os.path.join(extensions_dir, "cpu", "*.cpp"))
29
+ source_cuda = glob.glob(os.path.join(extensions_dir, "cuda", "*.cu"))
30
+
31
+ sources = main_file + source_cpu
32
+ extension = CppExtension
33
+ extra_compile_args = {"cxx": []}
34
+ define_macros = []
35
+
36
+ if torch.cuda.is_available() and CUDA_HOME is not None:
37
+ extension = CUDAExtension
38
+ sources += source_cuda
39
+ define_macros += [("WITH_CUDA", None)]
40
+ extra_compile_args["nvcc"] = [
41
+ "-DCUDA_HAS_FP16=1",
42
+ "-D__CUDA_NO_HALF_OPERATORS__",
43
+ "-D__CUDA_NO_HALF_CONVERSIONS__",
44
+ "-D__CUDA_NO_HALF2_OPERATORS__",
45
+ ]
46
+ else:
47
+ raise NotImplementedError('Cuda is not availabel')
48
+
49
+ sources = [os.path.join(extensions_dir, s) for s in sources]
50
+ include_dirs = [extensions_dir]
51
+ ext_modules = [
52
+ extension(
53
+ "MultiScaleDeformableAttention",
54
+ sources,
55
+ include_dirs=include_dirs,
56
+ define_macros=define_macros,
57
+ extra_compile_args=extra_compile_args,
58
+ )
59
+ ]
60
+ return ext_modules
61
+
62
+ setup(
63
+ name="MultiScaleDeformableAttention",
64
+ version="1.0",
65
+ author="Weijie Su",
66
+ url="https://github.com/fundamentalvision/Deformable-DETR",
67
+ description="PyTorch Wrapper for CUDA Functions of Multi-Scale Deformable Attention",
68
+ packages=find_packages(exclude=("configs", "tests",)),
69
+ ext_modules=get_extensions(),
70
+ cmdclass={"build_ext": torch.utils.cpp_extension.BuildExtension},
71
+ )
perception_models/apps/detection/DETA_pe/models/ops/src/cpu/ms_deform_attn_cpu.cpp ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /*!
2
+ **************************************************************************************************
3
+ * Deformable DETR
4
+ * Copyright (c) 2020 SenseTime. All Rights Reserved.
5
+ * Licensed under the Apache License, Version 2.0 [see LICENSE for details]
6
+ **************************************************************************************************
7
+ * Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
8
+ **************************************************************************************************
9
+ */
10
+
11
+ #include <vector>
12
+
13
+ #include <ATen/ATen.h>
14
+ #include <ATen/cuda/CUDAContext.h>
15
+
16
+
17
+ at::Tensor
18
+ ms_deform_attn_cpu_forward(
19
+ const at::Tensor &value,
20
+ const at::Tensor &spatial_shapes,
21
+ const at::Tensor &level_start_index,
22
+ const at::Tensor &sampling_loc,
23
+ const at::Tensor &attn_weight,
24
+ const int im2col_step)
25
+ {
26
+ AT_ERROR("Not implement on cpu");
27
+ }
28
+
29
+ std::vector<at::Tensor>
30
+ ms_deform_attn_cpu_backward(
31
+ const at::Tensor &value,
32
+ const at::Tensor &spatial_shapes,
33
+ const at::Tensor &level_start_index,
34
+ const at::Tensor &sampling_loc,
35
+ const at::Tensor &attn_weight,
36
+ const at::Tensor &grad_output,
37
+ const int im2col_step)
38
+ {
39
+ AT_ERROR("Not implement on cpu");
40
+ }
41
+
perception_models/apps/detection/DETA_pe/models/ops/src/cpu/ms_deform_attn_cpu.h ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /*!
2
+ **************************************************************************************************
3
+ * Deformable DETR
4
+ * Copyright (c) 2020 SenseTime. All Rights Reserved.
5
+ * Licensed under the Apache License, Version 2.0 [see LICENSE for details]
6
+ **************************************************************************************************
7
+ * Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
8
+ **************************************************************************************************
9
+ */
10
+
11
+ #pragma once
12
+ #include <torch/extension.h>
13
+
14
+ at::Tensor
15
+ ms_deform_attn_cpu_forward(
16
+ const at::Tensor &value,
17
+ const at::Tensor &spatial_shapes,
18
+ const at::Tensor &level_start_index,
19
+ const at::Tensor &sampling_loc,
20
+ const at::Tensor &attn_weight,
21
+ const int im2col_step);
22
+
23
+ std::vector<at::Tensor>
24
+ ms_deform_attn_cpu_backward(
25
+ const at::Tensor &value,
26
+ const at::Tensor &spatial_shapes,
27
+ const at::Tensor &level_start_index,
28
+ const at::Tensor &sampling_loc,
29
+ const at::Tensor &attn_weight,
30
+ const at::Tensor &grad_output,
31
+ const int im2col_step);
32
+
33
+
perception_models/apps/detection/DETA_pe/models/ops/src/cuda/ms_deform_attn_cuda.cu ADDED
@@ -0,0 +1,153 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /*!
2
+ **************************************************************************************************
3
+ * Deformable DETR
4
+ * Copyright (c) 2020 SenseTime. All Rights Reserved.
5
+ * Licensed under the Apache License, Version 2.0 [see LICENSE for details]
6
+ **************************************************************************************************
7
+ * Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
8
+ **************************************************************************************************
9
+ */
10
+
11
+ #include <vector>
12
+ #include "cuda/ms_deform_im2col_cuda.cuh"
13
+
14
+ #include <ATen/ATen.h>
15
+ #include <ATen/cuda/CUDAContext.h>
16
+ #include <cuda.h>
17
+ #include <cuda_runtime.h>
18
+
19
+
20
+ at::Tensor ms_deform_attn_cuda_forward(
21
+ const at::Tensor &value,
22
+ const at::Tensor &spatial_shapes,
23
+ const at::Tensor &level_start_index,
24
+ const at::Tensor &sampling_loc,
25
+ const at::Tensor &attn_weight,
26
+ const int im2col_step)
27
+ {
28
+ AT_ASSERTM(value.is_contiguous(), "value tensor has to be contiguous");
29
+ AT_ASSERTM(spatial_shapes.is_contiguous(), "spatial_shapes tensor has to be contiguous");
30
+ AT_ASSERTM(level_start_index.is_contiguous(), "level_start_index tensor has to be contiguous");
31
+ AT_ASSERTM(sampling_loc.is_contiguous(), "sampling_loc tensor has to be contiguous");
32
+ AT_ASSERTM(attn_weight.is_contiguous(), "attn_weight tensor has to be contiguous");
33
+
34
+ AT_ASSERTM(value.type().is_cuda(), "value must be a CUDA tensor");
35
+ AT_ASSERTM(spatial_shapes.type().is_cuda(), "spatial_shapes must be a CUDA tensor");
36
+ AT_ASSERTM(level_start_index.type().is_cuda(), "level_start_index must be a CUDA tensor");
37
+ AT_ASSERTM(sampling_loc.type().is_cuda(), "sampling_loc must be a CUDA tensor");
38
+ AT_ASSERTM(attn_weight.type().is_cuda(), "attn_weight must be a CUDA tensor");
39
+
40
+ const int batch = value.size(0);
41
+ const int spatial_size = value.size(1);
42
+ const int num_heads = value.size(2);
43
+ const int channels = value.size(3);
44
+
45
+ const int num_levels = spatial_shapes.size(0);
46
+
47
+ const int num_query = sampling_loc.size(1);
48
+ const int num_point = sampling_loc.size(4);
49
+
50
+ const int im2col_step_ = std::min(batch, im2col_step);
51
+
52
+ AT_ASSERTM(batch % im2col_step_ == 0, "batch(%d) must divide im2col_step(%d)", batch, im2col_step_);
53
+
54
+ auto output = at::zeros({batch, num_query, num_heads, channels}, value.options());
55
+
56
+ const int batch_n = im2col_step_;
57
+ auto output_n = output.view({batch/im2col_step_, batch_n, num_query, num_heads, channels});
58
+ auto per_value_size = spatial_size * num_heads * channels;
59
+ auto per_sample_loc_size = num_query * num_heads * num_levels * num_point * 2;
60
+ auto per_attn_weight_size = num_query * num_heads * num_levels * num_point;
61
+ for (int n = 0; n < batch/im2col_step_; ++n)
62
+ {
63
+ auto columns = output_n.select(0, n);
64
+ AT_DISPATCH_FLOATING_TYPES(value.type(), "ms_deform_attn_forward_cuda", ([&] {
65
+ ms_deformable_im2col_cuda(at::cuda::getCurrentCUDAStream(),
66
+ value.data<scalar_t>() + n * im2col_step_ * per_value_size,
67
+ spatial_shapes.data<int64_t>(),
68
+ level_start_index.data<int64_t>(),
69
+ sampling_loc.data<scalar_t>() + n * im2col_step_ * per_sample_loc_size,
70
+ attn_weight.data<scalar_t>() + n * im2col_step_ * per_attn_weight_size,
71
+ batch_n, spatial_size, num_heads, channels, num_levels, num_query, num_point,
72
+ columns.data<scalar_t>());
73
+
74
+ }));
75
+ }
76
+
77
+ output = output.view({batch, num_query, num_heads*channels});
78
+
79
+ return output;
80
+ }
81
+
82
+
83
+ std::vector<at::Tensor> ms_deform_attn_cuda_backward(
84
+ const at::Tensor &value,
85
+ const at::Tensor &spatial_shapes,
86
+ const at::Tensor &level_start_index,
87
+ const at::Tensor &sampling_loc,
88
+ const at::Tensor &attn_weight,
89
+ const at::Tensor &grad_output,
90
+ const int im2col_step)
91
+ {
92
+
93
+ AT_ASSERTM(value.is_contiguous(), "value tensor has to be contiguous");
94
+ AT_ASSERTM(spatial_shapes.is_contiguous(), "spatial_shapes tensor has to be contiguous");
95
+ AT_ASSERTM(level_start_index.is_contiguous(), "level_start_index tensor has to be contiguous");
96
+ AT_ASSERTM(sampling_loc.is_contiguous(), "sampling_loc tensor has to be contiguous");
97
+ AT_ASSERTM(attn_weight.is_contiguous(), "attn_weight tensor has to be contiguous");
98
+ AT_ASSERTM(grad_output.is_contiguous(), "grad_output tensor has to be contiguous");
99
+
100
+ AT_ASSERTM(value.type().is_cuda(), "value must be a CUDA tensor");
101
+ AT_ASSERTM(spatial_shapes.type().is_cuda(), "spatial_shapes must be a CUDA tensor");
102
+ AT_ASSERTM(level_start_index.type().is_cuda(), "level_start_index must be a CUDA tensor");
103
+ AT_ASSERTM(sampling_loc.type().is_cuda(), "sampling_loc must be a CUDA tensor");
104
+ AT_ASSERTM(attn_weight.type().is_cuda(), "attn_weight must be a CUDA tensor");
105
+ AT_ASSERTM(grad_output.type().is_cuda(), "grad_output must be a CUDA tensor");
106
+
107
+ const int batch = value.size(0);
108
+ const int spatial_size = value.size(1);
109
+ const int num_heads = value.size(2);
110
+ const int channels = value.size(3);
111
+
112
+ const int num_levels = spatial_shapes.size(0);
113
+
114
+ const int num_query = sampling_loc.size(1);
115
+ const int num_point = sampling_loc.size(4);
116
+
117
+ const int im2col_step_ = std::min(batch, im2col_step);
118
+
119
+ AT_ASSERTM(batch % im2col_step_ == 0, "batch(%d) must divide im2col_step(%d)", batch, im2col_step_);
120
+
121
+ auto grad_value = at::zeros_like(value);
122
+ auto grad_sampling_loc = at::zeros_like(sampling_loc);
123
+ auto grad_attn_weight = at::zeros_like(attn_weight);
124
+
125
+ const int batch_n = im2col_step_;
126
+ auto per_value_size = spatial_size * num_heads * channels;
127
+ auto per_sample_loc_size = num_query * num_heads * num_levels * num_point * 2;
128
+ auto per_attn_weight_size = num_query * num_heads * num_levels * num_point;
129
+ auto grad_output_n = grad_output.view({batch/im2col_step_, batch_n, num_query, num_heads, channels});
130
+
131
+ for (int n = 0; n < batch/im2col_step_; ++n)
132
+ {
133
+ auto grad_output_g = grad_output_n.select(0, n);
134
+ AT_DISPATCH_FLOATING_TYPES(value.type(), "ms_deform_attn_backward_cuda", ([&] {
135
+ ms_deformable_col2im_cuda(at::cuda::getCurrentCUDAStream(),
136
+ grad_output_g.data<scalar_t>(),
137
+ value.data<scalar_t>() + n * im2col_step_ * per_value_size,
138
+ spatial_shapes.data<int64_t>(),
139
+ level_start_index.data<int64_t>(),
140
+ sampling_loc.data<scalar_t>() + n * im2col_step_ * per_sample_loc_size,
141
+ attn_weight.data<scalar_t>() + n * im2col_step_ * per_attn_weight_size,
142
+ batch_n, spatial_size, num_heads, channels, num_levels, num_query, num_point,
143
+ grad_value.data<scalar_t>() + n * im2col_step_ * per_value_size,
144
+ grad_sampling_loc.data<scalar_t>() + n * im2col_step_ * per_sample_loc_size,
145
+ grad_attn_weight.data<scalar_t>() + n * im2col_step_ * per_attn_weight_size);
146
+
147
+ }));
148
+ }
149
+
150
+ return {
151
+ grad_value, grad_sampling_loc, grad_attn_weight
152
+ };
153
+ }
perception_models/apps/detection/DETA_pe/models/ops/src/cuda/ms_deform_attn_cuda.h ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /*!
2
+ **************************************************************************************************
3
+ * Deformable DETR
4
+ * Copyright (c) 2020 SenseTime. All Rights Reserved.
5
+ * Licensed under the Apache License, Version 2.0 [see LICENSE for details]
6
+ **************************************************************************************************
7
+ * Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
8
+ **************************************************************************************************
9
+ */
10
+
11
+ #pragma once
12
+ #include <torch/extension.h>
13
+
14
+ at::Tensor ms_deform_attn_cuda_forward(
15
+ const at::Tensor &value,
16
+ const at::Tensor &spatial_shapes,
17
+ const at::Tensor &level_start_index,
18
+ const at::Tensor &sampling_loc,
19
+ const at::Tensor &attn_weight,
20
+ const int im2col_step);
21
+
22
+ std::vector<at::Tensor> ms_deform_attn_cuda_backward(
23
+ const at::Tensor &value,
24
+ const at::Tensor &spatial_shapes,
25
+ const at::Tensor &level_start_index,
26
+ const at::Tensor &sampling_loc,
27
+ const at::Tensor &attn_weight,
28
+ const at::Tensor &grad_output,
29
+ const int im2col_step);
30
+
perception_models/apps/detection/DETA_pe/models/ops/src/cuda/ms_deform_im2col_cuda.cuh ADDED
@@ -0,0 +1,1327 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /*!
2
+ **************************************************************************
3
+ * Deformable DETR
4
+ * Copyright (c) 2020 SenseTime. All Rights Reserved.
5
+ * Licensed under the Apache License, Version 2.0 [see LICENSE for details]
6
+ **************************************************************************
7
+ * Modified from DCN (https://github.com/msracver/Deformable-ConvNets)
8
+ * Copyright (c) 2018 Microsoft
9
+ **************************************************************************
10
+ */
11
+
12
+ #include <cstdio>
13
+ #include <algorithm>
14
+ #include <cstring>
15
+
16
+ #include <ATen/ATen.h>
17
+ #include <ATen/cuda/CUDAContext.h>
18
+
19
+ #include <THC/THCAtomics.cuh>
20
+
21
+ #define CUDA_KERNEL_LOOP(i, n) \
22
+ for (int i = blockIdx.x * blockDim.x + threadIdx.x; \
23
+ i < (n); \
24
+ i += blockDim.x * gridDim.x)
25
+
26
+ const int CUDA_NUM_THREADS = 1024;
27
+ inline int GET_BLOCKS(const int N, const int num_threads)
28
+ {
29
+ return (N + num_threads - 1) / num_threads;
30
+ }
31
+
32
+
33
+ template <typename scalar_t>
34
+ __device__ scalar_t ms_deform_attn_im2col_bilinear(const scalar_t* &bottom_data,
35
+ const int &height, const int &width, const int &nheads, const int &channels,
36
+ const scalar_t &h, const scalar_t &w, const int &m, const int &c)
37
+ {
38
+ const int h_low = floor(h);
39
+ const int w_low = floor(w);
40
+ const int h_high = h_low + 1;
41
+ const int w_high = w_low + 1;
42
+
43
+ const scalar_t lh = h - h_low;
44
+ const scalar_t lw = w - w_low;
45
+ const scalar_t hh = 1 - lh, hw = 1 - lw;
46
+
47
+ const int w_stride = nheads * channels;
48
+ const int h_stride = width * w_stride;
49
+ const int h_low_ptr_offset = h_low * h_stride;
50
+ const int h_high_ptr_offset = h_low_ptr_offset + h_stride;
51
+ const int w_low_ptr_offset = w_low * w_stride;
52
+ const int w_high_ptr_offset = w_low_ptr_offset + w_stride;
53
+ const int base_ptr = m * channels + c;
54
+
55
+ scalar_t v1 = 0;
56
+ if (h_low >= 0 && w_low >= 0)
57
+ {
58
+ const int ptr1 = h_low_ptr_offset + w_low_ptr_offset + base_ptr;
59
+ v1 = bottom_data[ptr1];
60
+ }
61
+ scalar_t v2 = 0;
62
+ if (h_low >= 0 && w_high <= width - 1)
63
+ {
64
+ const int ptr2 = h_low_ptr_offset + w_high_ptr_offset + base_ptr;
65
+ v2 = bottom_data[ptr2];
66
+ }
67
+ scalar_t v3 = 0;
68
+ if (h_high <= height - 1 && w_low >= 0)
69
+ {
70
+ const int ptr3 = h_high_ptr_offset + w_low_ptr_offset + base_ptr;
71
+ v3 = bottom_data[ptr3];
72
+ }
73
+ scalar_t v4 = 0;
74
+ if (h_high <= height - 1 && w_high <= width - 1)
75
+ {
76
+ const int ptr4 = h_high_ptr_offset + w_high_ptr_offset + base_ptr;
77
+ v4 = bottom_data[ptr4];
78
+ }
79
+
80
+ const scalar_t w1 = hh * hw, w2 = hh * lw, w3 = lh * hw, w4 = lh * lw;
81
+
82
+ const scalar_t val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4);
83
+ return val;
84
+ }
85
+
86
+
87
+ template <typename scalar_t>
88
+ __device__ void ms_deform_attn_col2im_bilinear(const scalar_t* &bottom_data,
89
+ const int &height, const int &width, const int &nheads, const int &channels,
90
+ const scalar_t &h, const scalar_t &w, const int &m, const int &c,
91
+ const scalar_t &top_grad,
92
+ const scalar_t &attn_weight,
93
+ scalar_t* &grad_value,
94
+ scalar_t* grad_sampling_loc,
95
+ scalar_t* grad_attn_weight)
96
+ {
97
+ const int h_low = floor(h);
98
+ const int w_low = floor(w);
99
+ const int h_high = h_low + 1;
100
+ const int w_high = w_low + 1;
101
+
102
+ const scalar_t lh = h - h_low;
103
+ const scalar_t lw = w - w_low;
104
+ const scalar_t hh = 1 - lh, hw = 1 - lw;
105
+
106
+ const int w_stride = nheads * channels;
107
+ const int h_stride = width * w_stride;
108
+ const int h_low_ptr_offset = h_low * h_stride;
109
+ const int h_high_ptr_offset = h_low_ptr_offset + h_stride;
110
+ const int w_low_ptr_offset = w_low * w_stride;
111
+ const int w_high_ptr_offset = w_low_ptr_offset + w_stride;
112
+ const int base_ptr = m * channels + c;
113
+
114
+ const scalar_t w1 = hh * hw, w2 = hh * lw, w3 = lh * hw, w4 = lh * lw;
115
+ const scalar_t top_grad_value = top_grad * attn_weight;
116
+ scalar_t grad_h_weight = 0, grad_w_weight = 0;
117
+
118
+ scalar_t v1 = 0;
119
+ if (h_low >= 0 && w_low >= 0)
120
+ {
121
+ const int ptr1 = h_low_ptr_offset + w_low_ptr_offset + base_ptr;
122
+ v1 = bottom_data[ptr1];
123
+ grad_h_weight -= hw * v1;
124
+ grad_w_weight -= hh * v1;
125
+ atomicAdd(grad_value+ptr1, w1*top_grad_value);
126
+ }
127
+ scalar_t v2 = 0;
128
+ if (h_low >= 0 && w_high <= width - 1)
129
+ {
130
+ const int ptr2 = h_low_ptr_offset + w_high_ptr_offset + base_ptr;
131
+ v2 = bottom_data[ptr2];
132
+ grad_h_weight -= lw * v2;
133
+ grad_w_weight += hh * v2;
134
+ atomicAdd(grad_value+ptr2, w2*top_grad_value);
135
+ }
136
+ scalar_t v3 = 0;
137
+ if (h_high <= height - 1 && w_low >= 0)
138
+ {
139
+ const int ptr3 = h_high_ptr_offset + w_low_ptr_offset + base_ptr;
140
+ v3 = bottom_data[ptr3];
141
+ grad_h_weight += hw * v3;
142
+ grad_w_weight -= lh * v3;
143
+ atomicAdd(grad_value+ptr3, w3*top_grad_value);
144
+ }
145
+ scalar_t v4 = 0;
146
+ if (h_high <= height - 1 && w_high <= width - 1)
147
+ {
148
+ const int ptr4 = h_high_ptr_offset + w_high_ptr_offset + base_ptr;
149
+ v4 = bottom_data[ptr4];
150
+ grad_h_weight += lw * v4;
151
+ grad_w_weight += lh * v4;
152
+ atomicAdd(grad_value+ptr4, w4*top_grad_value);
153
+ }
154
+
155
+ const scalar_t val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4);
156
+ *grad_attn_weight = top_grad * val;
157
+ *grad_sampling_loc = width * grad_w_weight * top_grad_value;
158
+ *(grad_sampling_loc + 1) = height * grad_h_weight * top_grad_value;
159
+ }
160
+
161
+
162
+ template <typename scalar_t>
163
+ __device__ void ms_deform_attn_col2im_bilinear_gm(const scalar_t* &bottom_data,
164
+ const int &height, const int &width, const int &nheads, const int &channels,
165
+ const scalar_t &h, const scalar_t &w, const int &m, const int &c,
166
+ const scalar_t &top_grad,
167
+ const scalar_t &attn_weight,
168
+ scalar_t* &grad_value,
169
+ scalar_t* grad_sampling_loc,
170
+ scalar_t* grad_attn_weight)
171
+ {
172
+ const int h_low = floor(h);
173
+ const int w_low = floor(w);
174
+ const int h_high = h_low + 1;
175
+ const int w_high = w_low + 1;
176
+
177
+ const scalar_t lh = h - h_low;
178
+ const scalar_t lw = w - w_low;
179
+ const scalar_t hh = 1 - lh, hw = 1 - lw;
180
+
181
+ const int w_stride = nheads * channels;
182
+ const int h_stride = width * w_stride;
183
+ const int h_low_ptr_offset = h_low * h_stride;
184
+ const int h_high_ptr_offset = h_low_ptr_offset + h_stride;
185
+ const int w_low_ptr_offset = w_low * w_stride;
186
+ const int w_high_ptr_offset = w_low_ptr_offset + w_stride;
187
+ const int base_ptr = m * channels + c;
188
+
189
+ const scalar_t w1 = hh * hw, w2 = hh * lw, w3 = lh * hw, w4 = lh * lw;
190
+ const scalar_t top_grad_value = top_grad * attn_weight;
191
+ scalar_t grad_h_weight = 0, grad_w_weight = 0;
192
+
193
+ scalar_t v1 = 0;
194
+ if (h_low >= 0 && w_low >= 0)
195
+ {
196
+ const int ptr1 = h_low_ptr_offset + w_low_ptr_offset + base_ptr;
197
+ v1 = bottom_data[ptr1];
198
+ grad_h_weight -= hw * v1;
199
+ grad_w_weight -= hh * v1;
200
+ atomicAdd(grad_value+ptr1, w1*top_grad_value);
201
+ }
202
+ scalar_t v2 = 0;
203
+ if (h_low >= 0 && w_high <= width - 1)
204
+ {
205
+ const int ptr2 = h_low_ptr_offset + w_high_ptr_offset + base_ptr;
206
+ v2 = bottom_data[ptr2];
207
+ grad_h_weight -= lw * v2;
208
+ grad_w_weight += hh * v2;
209
+ atomicAdd(grad_value+ptr2, w2*top_grad_value);
210
+ }
211
+ scalar_t v3 = 0;
212
+ if (h_high <= height - 1 && w_low >= 0)
213
+ {
214
+ const int ptr3 = h_high_ptr_offset + w_low_ptr_offset + base_ptr;
215
+ v3 = bottom_data[ptr3];
216
+ grad_h_weight += hw * v3;
217
+ grad_w_weight -= lh * v3;
218
+ atomicAdd(grad_value+ptr3, w3*top_grad_value);
219
+ }
220
+ scalar_t v4 = 0;
221
+ if (h_high <= height - 1 && w_high <= width - 1)
222
+ {
223
+ const int ptr4 = h_high_ptr_offset + w_high_ptr_offset + base_ptr;
224
+ v4 = bottom_data[ptr4];
225
+ grad_h_weight += lw * v4;
226
+ grad_w_weight += lh * v4;
227
+ atomicAdd(grad_value+ptr4, w4*top_grad_value);
228
+ }
229
+
230
+ const scalar_t val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4);
231
+ atomicAdd(grad_attn_weight, top_grad * val);
232
+ atomicAdd(grad_sampling_loc, width * grad_w_weight * top_grad_value);
233
+ atomicAdd(grad_sampling_loc + 1, height * grad_h_weight * top_grad_value);
234
+ }
235
+
236
+
237
+ template <typename scalar_t>
238
+ __global__ void ms_deformable_im2col_gpu_kernel(const int n,
239
+ const scalar_t *data_value,
240
+ const int64_t *data_spatial_shapes,
241
+ const int64_t *data_level_start_index,
242
+ const scalar_t *data_sampling_loc,
243
+ const scalar_t *data_attn_weight,
244
+ const int batch_size,
245
+ const int spatial_size,
246
+ const int num_heads,
247
+ const int channels,
248
+ const int num_levels,
249
+ const int num_query,
250
+ const int num_point,
251
+ scalar_t *data_col)
252
+ {
253
+ CUDA_KERNEL_LOOP(index, n)
254
+ {
255
+ int _temp = index;
256
+ const int c_col = _temp % channels;
257
+ _temp /= channels;
258
+ const int sampling_index = _temp;
259
+ const int m_col = _temp % num_heads;
260
+ _temp /= num_heads;
261
+ const int q_col = _temp % num_query;
262
+ _temp /= num_query;
263
+ const int b_col = _temp;
264
+
265
+ scalar_t *data_col_ptr = data_col + index;
266
+ int data_weight_ptr = sampling_index * num_levels * num_point;
267
+ int data_loc_w_ptr = data_weight_ptr << 1;
268
+ const int qid_stride = num_heads * channels;
269
+ const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride;
270
+ scalar_t col = 0;
271
+
272
+ for (int l_col=0; l_col < num_levels; ++l_col)
273
+ {
274
+ const int level_start_id = data_level_start_index[l_col];
275
+ const int spatial_h_ptr = l_col << 1;
276
+ const int spatial_h = data_spatial_shapes[spatial_h_ptr];
277
+ const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1];
278
+ const scalar_t *data_value_ptr = data_value + (data_value_ptr_init_offset + level_start_id * qid_stride);
279
+ for (int p_col=0; p_col < num_point; ++p_col)
280
+ {
281
+ const scalar_t loc_w = data_sampling_loc[data_loc_w_ptr];
282
+ const scalar_t loc_h = data_sampling_loc[data_loc_w_ptr + 1];
283
+ const scalar_t weight = data_attn_weight[data_weight_ptr];
284
+
285
+ const scalar_t h_im = loc_h * spatial_h - 0.5;
286
+ const scalar_t w_im = loc_w * spatial_w - 0.5;
287
+
288
+ if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w)
289
+ {
290
+ col += ms_deform_attn_im2col_bilinear(data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im, w_im, m_col, c_col) * weight;
291
+ }
292
+
293
+ data_weight_ptr += 1;
294
+ data_loc_w_ptr += 2;
295
+ }
296
+ }
297
+ *data_col_ptr = col;
298
+ }
299
+ }
300
+
301
+ template <typename scalar_t, unsigned int blockSize>
302
+ __global__ void ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1(const int n,
303
+ const scalar_t *grad_col,
304
+ const scalar_t *data_value,
305
+ const int64_t *data_spatial_shapes,
306
+ const int64_t *data_level_start_index,
307
+ const scalar_t *data_sampling_loc,
308
+ const scalar_t *data_attn_weight,
309
+ const int batch_size,
310
+ const int spatial_size,
311
+ const int num_heads,
312
+ const int channels,
313
+ const int num_levels,
314
+ const int num_query,
315
+ const int num_point,
316
+ scalar_t *grad_value,
317
+ scalar_t *grad_sampling_loc,
318
+ scalar_t *grad_attn_weight)
319
+ {
320
+ CUDA_KERNEL_LOOP(index, n)
321
+ {
322
+ __shared__ scalar_t cache_grad_sampling_loc[blockSize * 2];
323
+ __shared__ scalar_t cache_grad_attn_weight[blockSize];
324
+ unsigned int tid = threadIdx.x;
325
+ int _temp = index;
326
+ const int c_col = _temp % channels;
327
+ _temp /= channels;
328
+ const int sampling_index = _temp;
329
+ const int m_col = _temp % num_heads;
330
+ _temp /= num_heads;
331
+ const int q_col = _temp % num_query;
332
+ _temp /= num_query;
333
+ const int b_col = _temp;
334
+
335
+ const scalar_t top_grad = grad_col[index];
336
+
337
+ int data_weight_ptr = sampling_index * num_levels * num_point;
338
+ int data_loc_w_ptr = data_weight_ptr << 1;
339
+ const int grad_sampling_ptr = data_weight_ptr;
340
+ grad_sampling_loc += grad_sampling_ptr << 1;
341
+ grad_attn_weight += grad_sampling_ptr;
342
+ const int grad_weight_stride = 1;
343
+ const int grad_loc_stride = 2;
344
+ const int qid_stride = num_heads * channels;
345
+ const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride;
346
+
347
+ for (int l_col=0; l_col < num_levels; ++l_col)
348
+ {
349
+ const int level_start_id = data_level_start_index[l_col];
350
+ const int spatial_h_ptr = l_col << 1;
351
+ const int spatial_h = data_spatial_shapes[spatial_h_ptr];
352
+ const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1];
353
+ const int value_ptr_offset = data_value_ptr_init_offset + level_start_id * qid_stride;
354
+ const scalar_t *data_value_ptr = data_value + value_ptr_offset;
355
+ scalar_t *grad_value_ptr = grad_value + value_ptr_offset;
356
+
357
+ for (int p_col=0; p_col < num_point; ++p_col)
358
+ {
359
+ const scalar_t loc_w = data_sampling_loc[data_loc_w_ptr];
360
+ const scalar_t loc_h = data_sampling_loc[data_loc_w_ptr + 1];
361
+ const scalar_t weight = data_attn_weight[data_weight_ptr];
362
+
363
+ const scalar_t h_im = loc_h * spatial_h - 0.5;
364
+ const scalar_t w_im = loc_w * spatial_w - 0.5;
365
+ *(cache_grad_sampling_loc+(threadIdx.x << 1)) = 0;
366
+ *(cache_grad_sampling_loc+((threadIdx.x << 1) + 1)) = 0;
367
+ *(cache_grad_attn_weight+threadIdx.x)=0;
368
+ if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w)
369
+ {
370
+ ms_deform_attn_col2im_bilinear(
371
+ data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im, w_im, m_col, c_col,
372
+ top_grad, weight, grad_value_ptr,
373
+ cache_grad_sampling_loc+(threadIdx.x << 1), cache_grad_attn_weight+threadIdx.x);
374
+ }
375
+
376
+ __syncthreads();
377
+ if (tid == 0)
378
+ {
379
+ scalar_t _grad_w=cache_grad_sampling_loc[0], _grad_h=cache_grad_sampling_loc[1], _grad_a=cache_grad_attn_weight[0];
380
+ int sid=2;
381
+ for (unsigned int tid = 1; tid < blockSize; ++tid)
382
+ {
383
+ _grad_w += cache_grad_sampling_loc[sid];
384
+ _grad_h += cache_grad_sampling_loc[sid + 1];
385
+ _grad_a += cache_grad_attn_weight[tid];
386
+ sid += 2;
387
+ }
388
+
389
+
390
+ *grad_sampling_loc = _grad_w;
391
+ *(grad_sampling_loc + 1) = _grad_h;
392
+ *grad_attn_weight = _grad_a;
393
+ }
394
+ __syncthreads();
395
+
396
+ data_weight_ptr += 1;
397
+ data_loc_w_ptr += 2;
398
+ grad_attn_weight += grad_weight_stride;
399
+ grad_sampling_loc += grad_loc_stride;
400
+ }
401
+ }
402
+ }
403
+ }
404
+
405
+
406
+ template <typename scalar_t, unsigned int blockSize>
407
+ __global__ void ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2(const int n,
408
+ const scalar_t *grad_col,
409
+ const scalar_t *data_value,
410
+ const int64_t *data_spatial_shapes,
411
+ const int64_t *data_level_start_index,
412
+ const scalar_t *data_sampling_loc,
413
+ const scalar_t *data_attn_weight,
414
+ const int batch_size,
415
+ const int spatial_size,
416
+ const int num_heads,
417
+ const int channels,
418
+ const int num_levels,
419
+ const int num_query,
420
+ const int num_point,
421
+ scalar_t *grad_value,
422
+ scalar_t *grad_sampling_loc,
423
+ scalar_t *grad_attn_weight)
424
+ {
425
+ CUDA_KERNEL_LOOP(index, n)
426
+ {
427
+ __shared__ scalar_t cache_grad_sampling_loc[blockSize * 2];
428
+ __shared__ scalar_t cache_grad_attn_weight[blockSize];
429
+ unsigned int tid = threadIdx.x;
430
+ int _temp = index;
431
+ const int c_col = _temp % channels;
432
+ _temp /= channels;
433
+ const int sampling_index = _temp;
434
+ const int m_col = _temp % num_heads;
435
+ _temp /= num_heads;
436
+ const int q_col = _temp % num_query;
437
+ _temp /= num_query;
438
+ const int b_col = _temp;
439
+
440
+ const scalar_t top_grad = grad_col[index];
441
+
442
+ int data_weight_ptr = sampling_index * num_levels * num_point;
443
+ int data_loc_w_ptr = data_weight_ptr << 1;
444
+ const int grad_sampling_ptr = data_weight_ptr;
445
+ grad_sampling_loc += grad_sampling_ptr << 1;
446
+ grad_attn_weight += grad_sampling_ptr;
447
+ const int grad_weight_stride = 1;
448
+ const int grad_loc_stride = 2;
449
+ const int qid_stride = num_heads * channels;
450
+ const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride;
451
+
452
+ for (int l_col=0; l_col < num_levels; ++l_col)
453
+ {
454
+ const int level_start_id = data_level_start_index[l_col];
455
+ const int spatial_h_ptr = l_col << 1;
456
+ const int spatial_h = data_spatial_shapes[spatial_h_ptr];
457
+ const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1];
458
+ const int value_ptr_offset = data_value_ptr_init_offset + level_start_id * qid_stride;
459
+ const scalar_t *data_value_ptr = data_value + value_ptr_offset;
460
+ scalar_t *grad_value_ptr = grad_value + value_ptr_offset;
461
+
462
+ for (int p_col=0; p_col < num_point; ++p_col)
463
+ {
464
+ const scalar_t loc_w = data_sampling_loc[data_loc_w_ptr];
465
+ const scalar_t loc_h = data_sampling_loc[data_loc_w_ptr + 1];
466
+ const scalar_t weight = data_attn_weight[data_weight_ptr];
467
+
468
+ const scalar_t h_im = loc_h * spatial_h - 0.5;
469
+ const scalar_t w_im = loc_w * spatial_w - 0.5;
470
+ *(cache_grad_sampling_loc+(threadIdx.x << 1)) = 0;
471
+ *(cache_grad_sampling_loc+((threadIdx.x << 1) + 1)) = 0;
472
+ *(cache_grad_attn_weight+threadIdx.x)=0;
473
+ if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w)
474
+ {
475
+ ms_deform_attn_col2im_bilinear(
476
+ data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im, w_im, m_col, c_col,
477
+ top_grad, weight, grad_value_ptr,
478
+ cache_grad_sampling_loc+(threadIdx.x << 1), cache_grad_attn_weight+threadIdx.x);
479
+ }
480
+
481
+ __syncthreads();
482
+
483
+ for (unsigned int s=blockSize/2; s>0; s>>=1)
484
+ {
485
+ if (tid < s) {
486
+ const unsigned int xid1 = tid << 1;
487
+ const unsigned int xid2 = (tid + s) << 1;
488
+ cache_grad_attn_weight[tid] += cache_grad_attn_weight[tid + s];
489
+ cache_grad_sampling_loc[xid1] += cache_grad_sampling_loc[xid2];
490
+ cache_grad_sampling_loc[xid1 + 1] += cache_grad_sampling_loc[xid2 + 1];
491
+ }
492
+ __syncthreads();
493
+ }
494
+
495
+ if (tid == 0)
496
+ {
497
+ *grad_sampling_loc = cache_grad_sampling_loc[0];
498
+ *(grad_sampling_loc + 1) = cache_grad_sampling_loc[1];
499
+ *grad_attn_weight = cache_grad_attn_weight[0];
500
+ }
501
+ __syncthreads();
502
+
503
+ data_weight_ptr += 1;
504
+ data_loc_w_ptr += 2;
505
+ grad_attn_weight += grad_weight_stride;
506
+ grad_sampling_loc += grad_loc_stride;
507
+ }
508
+ }
509
+ }
510
+ }
511
+
512
+
513
+ template <typename scalar_t>
514
+ __global__ void ms_deformable_col2im_gpu_kernel_shm_reduce_v1(const int n,
515
+ const scalar_t *grad_col,
516
+ const scalar_t *data_value,
517
+ const int64_t *data_spatial_shapes,
518
+ const int64_t *data_level_start_index,
519
+ const scalar_t *data_sampling_loc,
520
+ const scalar_t *data_attn_weight,
521
+ const int batch_size,
522
+ const int spatial_size,
523
+ const int num_heads,
524
+ const int channels,
525
+ const int num_levels,
526
+ const int num_query,
527
+ const int num_point,
528
+ scalar_t *grad_value,
529
+ scalar_t *grad_sampling_loc,
530
+ scalar_t *grad_attn_weight)
531
+ {
532
+ CUDA_KERNEL_LOOP(index, n)
533
+ {
534
+ extern __shared__ int _s[];
535
+ scalar_t* cache_grad_sampling_loc = (scalar_t*)_s;
536
+ scalar_t* cache_grad_attn_weight = cache_grad_sampling_loc + 2 * blockDim.x;
537
+ unsigned int tid = threadIdx.x;
538
+ int _temp = index;
539
+ const int c_col = _temp % channels;
540
+ _temp /= channels;
541
+ const int sampling_index = _temp;
542
+ const int m_col = _temp % num_heads;
543
+ _temp /= num_heads;
544
+ const int q_col = _temp % num_query;
545
+ _temp /= num_query;
546
+ const int b_col = _temp;
547
+
548
+ const scalar_t top_grad = grad_col[index];
549
+
550
+ int data_weight_ptr = sampling_index * num_levels * num_point;
551
+ int data_loc_w_ptr = data_weight_ptr << 1;
552
+ const int grad_sampling_ptr = data_weight_ptr;
553
+ grad_sampling_loc += grad_sampling_ptr << 1;
554
+ grad_attn_weight += grad_sampling_ptr;
555
+ const int grad_weight_stride = 1;
556
+ const int grad_loc_stride = 2;
557
+ const int qid_stride = num_heads * channels;
558
+ const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride;
559
+
560
+ for (int l_col=0; l_col < num_levels; ++l_col)
561
+ {
562
+ const int level_start_id = data_level_start_index[l_col];
563
+ const int spatial_h_ptr = l_col << 1;
564
+ const int spatial_h = data_spatial_shapes[spatial_h_ptr];
565
+ const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1];
566
+ const int value_ptr_offset = data_value_ptr_init_offset + level_start_id * qid_stride;
567
+ const scalar_t *data_value_ptr = data_value + value_ptr_offset;
568
+ scalar_t *grad_value_ptr = grad_value + value_ptr_offset;
569
+
570
+ for (int p_col=0; p_col < num_point; ++p_col)
571
+ {
572
+ const scalar_t loc_w = data_sampling_loc[data_loc_w_ptr];
573
+ const scalar_t loc_h = data_sampling_loc[data_loc_w_ptr + 1];
574
+ const scalar_t weight = data_attn_weight[data_weight_ptr];
575
+
576
+ const scalar_t h_im = loc_h * spatial_h - 0.5;
577
+ const scalar_t w_im = loc_w * spatial_w - 0.5;
578
+ *(cache_grad_sampling_loc+(threadIdx.x << 1)) = 0;
579
+ *(cache_grad_sampling_loc+((threadIdx.x << 1) + 1)) = 0;
580
+ *(cache_grad_attn_weight+threadIdx.x)=0;
581
+ if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w)
582
+ {
583
+ ms_deform_attn_col2im_bilinear(
584
+ data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im, w_im, m_col, c_col,
585
+ top_grad, weight, grad_value_ptr,
586
+ cache_grad_sampling_loc+(threadIdx.x << 1), cache_grad_attn_weight+threadIdx.x);
587
+ }
588
+
589
+ __syncthreads();
590
+ if (tid == 0)
591
+ {
592
+ scalar_t _grad_w=cache_grad_sampling_loc[0], _grad_h=cache_grad_sampling_loc[1], _grad_a=cache_grad_attn_weight[0];
593
+ int sid=2;
594
+ for (unsigned int tid = 1; tid < blockDim.x; ++tid)
595
+ {
596
+ _grad_w += cache_grad_sampling_loc[sid];
597
+ _grad_h += cache_grad_sampling_loc[sid + 1];
598
+ _grad_a += cache_grad_attn_weight[tid];
599
+ sid += 2;
600
+ }
601
+
602
+
603
+ *grad_sampling_loc = _grad_w;
604
+ *(grad_sampling_loc + 1) = _grad_h;
605
+ *grad_attn_weight = _grad_a;
606
+ }
607
+ __syncthreads();
608
+
609
+ data_weight_ptr += 1;
610
+ data_loc_w_ptr += 2;
611
+ grad_attn_weight += grad_weight_stride;
612
+ grad_sampling_loc += grad_loc_stride;
613
+ }
614
+ }
615
+ }
616
+ }
617
+
618
+ template <typename scalar_t>
619
+ __global__ void ms_deformable_col2im_gpu_kernel_shm_reduce_v2(const int n,
620
+ const scalar_t *grad_col,
621
+ const scalar_t *data_value,
622
+ const int64_t *data_spatial_shapes,
623
+ const int64_t *data_level_start_index,
624
+ const scalar_t *data_sampling_loc,
625
+ const scalar_t *data_attn_weight,
626
+ const int batch_size,
627
+ const int spatial_size,
628
+ const int num_heads,
629
+ const int channels,
630
+ const int num_levels,
631
+ const int num_query,
632
+ const int num_point,
633
+ scalar_t *grad_value,
634
+ scalar_t *grad_sampling_loc,
635
+ scalar_t *grad_attn_weight)
636
+ {
637
+ CUDA_KERNEL_LOOP(index, n)
638
+ {
639
+ extern __shared__ int _s[];
640
+ scalar_t* cache_grad_sampling_loc = (scalar_t*)_s;
641
+ scalar_t* cache_grad_attn_weight = cache_grad_sampling_loc + 2 * blockDim.x;
642
+ unsigned int tid = threadIdx.x;
643
+ int _temp = index;
644
+ const int c_col = _temp % channels;
645
+ _temp /= channels;
646
+ const int sampling_index = _temp;
647
+ const int m_col = _temp % num_heads;
648
+ _temp /= num_heads;
649
+ const int q_col = _temp % num_query;
650
+ _temp /= num_query;
651
+ const int b_col = _temp;
652
+
653
+ const scalar_t top_grad = grad_col[index];
654
+
655
+ int data_weight_ptr = sampling_index * num_levels * num_point;
656
+ int data_loc_w_ptr = data_weight_ptr << 1;
657
+ const int grad_sampling_ptr = data_weight_ptr;
658
+ grad_sampling_loc += grad_sampling_ptr << 1;
659
+ grad_attn_weight += grad_sampling_ptr;
660
+ const int grad_weight_stride = 1;
661
+ const int grad_loc_stride = 2;
662
+ const int qid_stride = num_heads * channels;
663
+ const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride;
664
+
665
+ for (int l_col=0; l_col < num_levels; ++l_col)
666
+ {
667
+ const int level_start_id = data_level_start_index[l_col];
668
+ const int spatial_h_ptr = l_col << 1;
669
+ const int spatial_h = data_spatial_shapes[spatial_h_ptr];
670
+ const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1];
671
+ const int value_ptr_offset = data_value_ptr_init_offset + level_start_id * qid_stride;
672
+ const scalar_t *data_value_ptr = data_value + value_ptr_offset;
673
+ scalar_t *grad_value_ptr = grad_value + value_ptr_offset;
674
+
675
+ for (int p_col=0; p_col < num_point; ++p_col)
676
+ {
677
+ const scalar_t loc_w = data_sampling_loc[data_loc_w_ptr];
678
+ const scalar_t loc_h = data_sampling_loc[data_loc_w_ptr + 1];
679
+ const scalar_t weight = data_attn_weight[data_weight_ptr];
680
+
681
+ const scalar_t h_im = loc_h * spatial_h - 0.5;
682
+ const scalar_t w_im = loc_w * spatial_w - 0.5;
683
+ *(cache_grad_sampling_loc+(threadIdx.x << 1)) = 0;
684
+ *(cache_grad_sampling_loc+((threadIdx.x << 1) + 1)) = 0;
685
+ *(cache_grad_attn_weight+threadIdx.x)=0;
686
+ if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w)
687
+ {
688
+ ms_deform_attn_col2im_bilinear(
689
+ data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im, w_im, m_col, c_col,
690
+ top_grad, weight, grad_value_ptr,
691
+ cache_grad_sampling_loc+(threadIdx.x << 1), cache_grad_attn_weight+threadIdx.x);
692
+ }
693
+
694
+ __syncthreads();
695
+
696
+ for (unsigned int s=blockDim.x/2, spre=blockDim.x; s>0; s>>=1, spre>>=1)
697
+ {
698
+ if (tid < s) {
699
+ const unsigned int xid1 = tid << 1;
700
+ const unsigned int xid2 = (tid + s) << 1;
701
+ cache_grad_attn_weight[tid] += cache_grad_attn_weight[tid + s];
702
+ cache_grad_sampling_loc[xid1] += cache_grad_sampling_loc[xid2];
703
+ cache_grad_sampling_loc[xid1 + 1] += cache_grad_sampling_loc[xid2 + 1];
704
+ if (tid + (s << 1) < spre)
705
+ {
706
+ cache_grad_attn_weight[tid] += cache_grad_attn_weight[tid + (s << 1)];
707
+ cache_grad_sampling_loc[xid1] += cache_grad_sampling_loc[xid2 + (s << 1)];
708
+ cache_grad_sampling_loc[xid1 + 1] += cache_grad_sampling_loc[xid2 + 1 + (s << 1)];
709
+ }
710
+ }
711
+ __syncthreads();
712
+ }
713
+
714
+ if (tid == 0)
715
+ {
716
+ *grad_sampling_loc = cache_grad_sampling_loc[0];
717
+ *(grad_sampling_loc + 1) = cache_grad_sampling_loc[1];
718
+ *grad_attn_weight = cache_grad_attn_weight[0];
719
+ }
720
+ __syncthreads();
721
+
722
+ data_weight_ptr += 1;
723
+ data_loc_w_ptr += 2;
724
+ grad_attn_weight += grad_weight_stride;
725
+ grad_sampling_loc += grad_loc_stride;
726
+ }
727
+ }
728
+ }
729
+ }
730
+
731
+ template <typename scalar_t>
732
+ __global__ void ms_deformable_col2im_gpu_kernel_shm_reduce_v2_multi_blocks(const int n,
733
+ const scalar_t *grad_col,
734
+ const scalar_t *data_value,
735
+ const int64_t *data_spatial_shapes,
736
+ const int64_t *data_level_start_index,
737
+ const scalar_t *data_sampling_loc,
738
+ const scalar_t *data_attn_weight,
739
+ const int batch_size,
740
+ const int spatial_size,
741
+ const int num_heads,
742
+ const int channels,
743
+ const int num_levels,
744
+ const int num_query,
745
+ const int num_point,
746
+ scalar_t *grad_value,
747
+ scalar_t *grad_sampling_loc,
748
+ scalar_t *grad_attn_weight)
749
+ {
750
+ CUDA_KERNEL_LOOP(index, n)
751
+ {
752
+ extern __shared__ int _s[];
753
+ scalar_t* cache_grad_sampling_loc = (scalar_t*)_s;
754
+ scalar_t* cache_grad_attn_weight = cache_grad_sampling_loc + 2 * blockDim.x;
755
+ unsigned int tid = threadIdx.x;
756
+ int _temp = index;
757
+ const int c_col = _temp % channels;
758
+ _temp /= channels;
759
+ const int sampling_index = _temp;
760
+ const int m_col = _temp % num_heads;
761
+ _temp /= num_heads;
762
+ const int q_col = _temp % num_query;
763
+ _temp /= num_query;
764
+ const int b_col = _temp;
765
+
766
+ const scalar_t top_grad = grad_col[index];
767
+
768
+ int data_weight_ptr = sampling_index * num_levels * num_point;
769
+ int data_loc_w_ptr = data_weight_ptr << 1;
770
+ const int grad_sampling_ptr = data_weight_ptr;
771
+ grad_sampling_loc += grad_sampling_ptr << 1;
772
+ grad_attn_weight += grad_sampling_ptr;
773
+ const int grad_weight_stride = 1;
774
+ const int grad_loc_stride = 2;
775
+ const int qid_stride = num_heads * channels;
776
+ const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride;
777
+
778
+ for (int l_col=0; l_col < num_levels; ++l_col)
779
+ {
780
+ const int level_start_id = data_level_start_index[l_col];
781
+ const int spatial_h_ptr = l_col << 1;
782
+ const int spatial_h = data_spatial_shapes[spatial_h_ptr];
783
+ const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1];
784
+ const int value_ptr_offset = data_value_ptr_init_offset + level_start_id * qid_stride;
785
+ const scalar_t *data_value_ptr = data_value + value_ptr_offset;
786
+ scalar_t *grad_value_ptr = grad_value + value_ptr_offset;
787
+
788
+ for (int p_col=0; p_col < num_point; ++p_col)
789
+ {
790
+ const scalar_t loc_w = data_sampling_loc[data_loc_w_ptr];
791
+ const scalar_t loc_h = data_sampling_loc[data_loc_w_ptr + 1];
792
+ const scalar_t weight = data_attn_weight[data_weight_ptr];
793
+
794
+ const scalar_t h_im = loc_h * spatial_h - 0.5;
795
+ const scalar_t w_im = loc_w * spatial_w - 0.5;
796
+ *(cache_grad_sampling_loc+(threadIdx.x << 1)) = 0;
797
+ *(cache_grad_sampling_loc+((threadIdx.x << 1) + 1)) = 0;
798
+ *(cache_grad_attn_weight+threadIdx.x)=0;
799
+ if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w)
800
+ {
801
+ ms_deform_attn_col2im_bilinear(
802
+ data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im, w_im, m_col, c_col,
803
+ top_grad, weight, grad_value_ptr,
804
+ cache_grad_sampling_loc+(threadIdx.x << 1), cache_grad_attn_weight+threadIdx.x);
805
+ }
806
+
807
+ __syncthreads();
808
+
809
+ for (unsigned int s=blockDim.x/2, spre=blockDim.x; s>0; s>>=1, spre>>=1)
810
+ {
811
+ if (tid < s) {
812
+ const unsigned int xid1 = tid << 1;
813
+ const unsigned int xid2 = (tid + s) << 1;
814
+ cache_grad_attn_weight[tid] += cache_grad_attn_weight[tid + s];
815
+ cache_grad_sampling_loc[xid1] += cache_grad_sampling_loc[xid2];
816
+ cache_grad_sampling_loc[xid1 + 1] += cache_grad_sampling_loc[xid2 + 1];
817
+ if (tid + (s << 1) < spre)
818
+ {
819
+ cache_grad_attn_weight[tid] += cache_grad_attn_weight[tid + (s << 1)];
820
+ cache_grad_sampling_loc[xid1] += cache_grad_sampling_loc[xid2 + (s << 1)];
821
+ cache_grad_sampling_loc[xid1 + 1] += cache_grad_sampling_loc[xid2 + 1 + (s << 1)];
822
+ }
823
+ }
824
+ __syncthreads();
825
+ }
826
+
827
+ if (tid == 0)
828
+ {
829
+ atomicAdd(grad_sampling_loc, cache_grad_sampling_loc[0]);
830
+ atomicAdd(grad_sampling_loc + 1, cache_grad_sampling_loc[1]);
831
+ atomicAdd(grad_attn_weight, cache_grad_attn_weight[0]);
832
+ }
833
+ __syncthreads();
834
+
835
+ data_weight_ptr += 1;
836
+ data_loc_w_ptr += 2;
837
+ grad_attn_weight += grad_weight_stride;
838
+ grad_sampling_loc += grad_loc_stride;
839
+ }
840
+ }
841
+ }
842
+ }
843
+
844
+
845
+ template <typename scalar_t>
846
+ __global__ void ms_deformable_col2im_gpu_kernel_gm(const int n,
847
+ const scalar_t *grad_col,
848
+ const scalar_t *data_value,
849
+ const int64_t *data_spatial_shapes,
850
+ const int64_t *data_level_start_index,
851
+ const scalar_t *data_sampling_loc,
852
+ const scalar_t *data_attn_weight,
853
+ const int batch_size,
854
+ const int spatial_size,
855
+ const int num_heads,
856
+ const int channels,
857
+ const int num_levels,
858
+ const int num_query,
859
+ const int num_point,
860
+ scalar_t *grad_value,
861
+ scalar_t *grad_sampling_loc,
862
+ scalar_t *grad_attn_weight)
863
+ {
864
+ CUDA_KERNEL_LOOP(index, n)
865
+ {
866
+ int _temp = index;
867
+ const int c_col = _temp % channels;
868
+ _temp /= channels;
869
+ const int sampling_index = _temp;
870
+ const int m_col = _temp % num_heads;
871
+ _temp /= num_heads;
872
+ const int q_col = _temp % num_query;
873
+ _temp /= num_query;
874
+ const int b_col = _temp;
875
+
876
+ const scalar_t top_grad = grad_col[index];
877
+
878
+ int data_weight_ptr = sampling_index * num_levels * num_point;
879
+ int data_loc_w_ptr = data_weight_ptr << 1;
880
+ const int grad_sampling_ptr = data_weight_ptr;
881
+ grad_sampling_loc += grad_sampling_ptr << 1;
882
+ grad_attn_weight += grad_sampling_ptr;
883
+ const int grad_weight_stride = 1;
884
+ const int grad_loc_stride = 2;
885
+ const int qid_stride = num_heads * channels;
886
+ const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride;
887
+
888
+ for (int l_col=0; l_col < num_levels; ++l_col)
889
+ {
890
+ const int level_start_id = data_level_start_index[l_col];
891
+ const int spatial_h_ptr = l_col << 1;
892
+ const int spatial_h = data_spatial_shapes[spatial_h_ptr];
893
+ const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1];
894
+ const int value_ptr_offset = data_value_ptr_init_offset + level_start_id * qid_stride;
895
+ const scalar_t *data_value_ptr = data_value + value_ptr_offset;
896
+ scalar_t *grad_value_ptr = grad_value + value_ptr_offset;
897
+
898
+ for (int p_col=0; p_col < num_point; ++p_col)
899
+ {
900
+ const scalar_t loc_w = data_sampling_loc[data_loc_w_ptr];
901
+ const scalar_t loc_h = data_sampling_loc[data_loc_w_ptr + 1];
902
+ const scalar_t weight = data_attn_weight[data_weight_ptr];
903
+
904
+ const scalar_t h_im = loc_h * spatial_h - 0.5;
905
+ const scalar_t w_im = loc_w * spatial_w - 0.5;
906
+ if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w)
907
+ {
908
+ ms_deform_attn_col2im_bilinear_gm(
909
+ data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im, w_im, m_col, c_col,
910
+ top_grad, weight, grad_value_ptr,
911
+ grad_sampling_loc, grad_attn_weight);
912
+ }
913
+ data_weight_ptr += 1;
914
+ data_loc_w_ptr += 2;
915
+ grad_attn_weight += grad_weight_stride;
916
+ grad_sampling_loc += grad_loc_stride;
917
+ }
918
+ }
919
+ }
920
+ }
921
+
922
+
923
+ template <typename scalar_t>
924
+ void ms_deformable_im2col_cuda(cudaStream_t stream,
925
+ const scalar_t* data_value,
926
+ const int64_t* data_spatial_shapes,
927
+ const int64_t* data_level_start_index,
928
+ const scalar_t* data_sampling_loc,
929
+ const scalar_t* data_attn_weight,
930
+ const int batch_size,
931
+ const int spatial_size,
932
+ const int num_heads,
933
+ const int channels,
934
+ const int num_levels,
935
+ const int num_query,
936
+ const int num_point,
937
+ scalar_t* data_col)
938
+ {
939
+ const int num_kernels = batch_size * num_query * num_heads * channels;
940
+ const int num_actual_kernels = batch_size * num_query * num_heads * channels;
941
+ const int num_threads = CUDA_NUM_THREADS;
942
+ ms_deformable_im2col_gpu_kernel<scalar_t>
943
+ <<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,
944
+ 0, stream>>>(
945
+ num_kernels, data_value, data_spatial_shapes, data_level_start_index, data_sampling_loc, data_attn_weight,
946
+ batch_size, spatial_size, num_heads, channels, num_levels, num_query, num_point, data_col);
947
+
948
+ cudaError_t err = cudaGetLastError();
949
+ if (err != cudaSuccess)
950
+ {
951
+ printf("error in ms_deformable_im2col_cuda: %s\n", cudaGetErrorString(err));
952
+ }
953
+
954
+ }
955
+
956
+ template <typename scalar_t>
957
+ void ms_deformable_col2im_cuda(cudaStream_t stream,
958
+ const scalar_t* grad_col,
959
+ const scalar_t* data_value,
960
+ const int64_t * data_spatial_shapes,
961
+ const int64_t * data_level_start_index,
962
+ const scalar_t * data_sampling_loc,
963
+ const scalar_t * data_attn_weight,
964
+ const int batch_size,
965
+ const int spatial_size,
966
+ const int num_heads,
967
+ const int channels,
968
+ const int num_levels,
969
+ const int num_query,
970
+ const int num_point,
971
+ scalar_t* grad_value,
972
+ scalar_t* grad_sampling_loc,
973
+ scalar_t* grad_attn_weight)
974
+ {
975
+ const int num_threads = (channels > CUDA_NUM_THREADS)?CUDA_NUM_THREADS:channels;
976
+ const int num_kernels = batch_size * num_query * num_heads * channels;
977
+ const int num_actual_kernels = batch_size * num_query * num_heads * channels;
978
+ if (channels > 1024)
979
+ {
980
+ if ((channels & 1023) == 0)
981
+ {
982
+ ms_deformable_col2im_gpu_kernel_shm_reduce_v2_multi_blocks<scalar_t>
983
+ <<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,
984
+ num_threads*3*sizeof(scalar_t), stream>>>(
985
+ num_kernels,
986
+ grad_col,
987
+ data_value,
988
+ data_spatial_shapes,
989
+ data_level_start_index,
990
+ data_sampling_loc,
991
+ data_attn_weight,
992
+ batch_size,
993
+ spatial_size,
994
+ num_heads,
995
+ channels,
996
+ num_levels,
997
+ num_query,
998
+ num_point,
999
+ grad_value,
1000
+ grad_sampling_loc,
1001
+ grad_attn_weight);
1002
+ }
1003
+ else
1004
+ {
1005
+ ms_deformable_col2im_gpu_kernel_gm<scalar_t>
1006
+ <<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,
1007
+ 0, stream>>>(
1008
+ num_kernels,
1009
+ grad_col,
1010
+ data_value,
1011
+ data_spatial_shapes,
1012
+ data_level_start_index,
1013
+ data_sampling_loc,
1014
+ data_attn_weight,
1015
+ batch_size,
1016
+ spatial_size,
1017
+ num_heads,
1018
+ channels,
1019
+ num_levels,
1020
+ num_query,
1021
+ num_point,
1022
+ grad_value,
1023
+ grad_sampling_loc,
1024
+ grad_attn_weight);
1025
+ }
1026
+ }
1027
+ else{
1028
+ switch(channels)
1029
+ {
1030
+ case 1:
1031
+ ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1<scalar_t, 1>
1032
+ <<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,
1033
+ 0, stream>>>(
1034
+ num_kernels,
1035
+ grad_col,
1036
+ data_value,
1037
+ data_spatial_shapes,
1038
+ data_level_start_index,
1039
+ data_sampling_loc,
1040
+ data_attn_weight,
1041
+ batch_size,
1042
+ spatial_size,
1043
+ num_heads,
1044
+ channels,
1045
+ num_levels,
1046
+ num_query,
1047
+ num_point,
1048
+ grad_value,
1049
+ grad_sampling_loc,
1050
+ grad_attn_weight);
1051
+ break;
1052
+ case 2:
1053
+ ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1<scalar_t, 2>
1054
+ <<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,
1055
+ 0, stream>>>(
1056
+ num_kernels,
1057
+ grad_col,
1058
+ data_value,
1059
+ data_spatial_shapes,
1060
+ data_level_start_index,
1061
+ data_sampling_loc,
1062
+ data_attn_weight,
1063
+ batch_size,
1064
+ spatial_size,
1065
+ num_heads,
1066
+ channels,
1067
+ num_levels,
1068
+ num_query,
1069
+ num_point,
1070
+ grad_value,
1071
+ grad_sampling_loc,
1072
+ grad_attn_weight);
1073
+ break;
1074
+ case 4:
1075
+ ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1<scalar_t, 4>
1076
+ <<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,
1077
+ 0, stream>>>(
1078
+ num_kernels,
1079
+ grad_col,
1080
+ data_value,
1081
+ data_spatial_shapes,
1082
+ data_level_start_index,
1083
+ data_sampling_loc,
1084
+ data_attn_weight,
1085
+ batch_size,
1086
+ spatial_size,
1087
+ num_heads,
1088
+ channels,
1089
+ num_levels,
1090
+ num_query,
1091
+ num_point,
1092
+ grad_value,
1093
+ grad_sampling_loc,
1094
+ grad_attn_weight);
1095
+ break;
1096
+ case 8:
1097
+ ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1<scalar_t, 8>
1098
+ <<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,
1099
+ 0, stream>>>(
1100
+ num_kernels,
1101
+ grad_col,
1102
+ data_value,
1103
+ data_spatial_shapes,
1104
+ data_level_start_index,
1105
+ data_sampling_loc,
1106
+ data_attn_weight,
1107
+ batch_size,
1108
+ spatial_size,
1109
+ num_heads,
1110
+ channels,
1111
+ num_levels,
1112
+ num_query,
1113
+ num_point,
1114
+ grad_value,
1115
+ grad_sampling_loc,
1116
+ grad_attn_weight);
1117
+ break;
1118
+ case 16:
1119
+ ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1<scalar_t, 16>
1120
+ <<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,
1121
+ 0, stream>>>(
1122
+ num_kernels,
1123
+ grad_col,
1124
+ data_value,
1125
+ data_spatial_shapes,
1126
+ data_level_start_index,
1127
+ data_sampling_loc,
1128
+ data_attn_weight,
1129
+ batch_size,
1130
+ spatial_size,
1131
+ num_heads,
1132
+ channels,
1133
+ num_levels,
1134
+ num_query,
1135
+ num_point,
1136
+ grad_value,
1137
+ grad_sampling_loc,
1138
+ grad_attn_weight);
1139
+ break;
1140
+ case 32:
1141
+ ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1<scalar_t, 32>
1142
+ <<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,
1143
+ 0, stream>>>(
1144
+ num_kernels,
1145
+ grad_col,
1146
+ data_value,
1147
+ data_spatial_shapes,
1148
+ data_level_start_index,
1149
+ data_sampling_loc,
1150
+ data_attn_weight,
1151
+ batch_size,
1152
+ spatial_size,
1153
+ num_heads,
1154
+ channels,
1155
+ num_levels,
1156
+ num_query,
1157
+ num_point,
1158
+ grad_value,
1159
+ grad_sampling_loc,
1160
+ grad_attn_weight);
1161
+ break;
1162
+ case 64:
1163
+ ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2<scalar_t, 64>
1164
+ <<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,
1165
+ 0, stream>>>(
1166
+ num_kernels,
1167
+ grad_col,
1168
+ data_value,
1169
+ data_spatial_shapes,
1170
+ data_level_start_index,
1171
+ data_sampling_loc,
1172
+ data_attn_weight,
1173
+ batch_size,
1174
+ spatial_size,
1175
+ num_heads,
1176
+ channels,
1177
+ num_levels,
1178
+ num_query,
1179
+ num_point,
1180
+ grad_value,
1181
+ grad_sampling_loc,
1182
+ grad_attn_weight);
1183
+ break;
1184
+ case 128:
1185
+ ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2<scalar_t, 128>
1186
+ <<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,
1187
+ 0, stream>>>(
1188
+ num_kernels,
1189
+ grad_col,
1190
+ data_value,
1191
+ data_spatial_shapes,
1192
+ data_level_start_index,
1193
+ data_sampling_loc,
1194
+ data_attn_weight,
1195
+ batch_size,
1196
+ spatial_size,
1197
+ num_heads,
1198
+ channels,
1199
+ num_levels,
1200
+ num_query,
1201
+ num_point,
1202
+ grad_value,
1203
+ grad_sampling_loc,
1204
+ grad_attn_weight);
1205
+ break;
1206
+ case 256:
1207
+ ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2<scalar_t, 256>
1208
+ <<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,
1209
+ 0, stream>>>(
1210
+ num_kernels,
1211
+ grad_col,
1212
+ data_value,
1213
+ data_spatial_shapes,
1214
+ data_level_start_index,
1215
+ data_sampling_loc,
1216
+ data_attn_weight,
1217
+ batch_size,
1218
+ spatial_size,
1219
+ num_heads,
1220
+ channels,
1221
+ num_levels,
1222
+ num_query,
1223
+ num_point,
1224
+ grad_value,
1225
+ grad_sampling_loc,
1226
+ grad_attn_weight);
1227
+ break;
1228
+ case 512:
1229
+ ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2<scalar_t, 512>
1230
+ <<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,
1231
+ 0, stream>>>(
1232
+ num_kernels,
1233
+ grad_col,
1234
+ data_value,
1235
+ data_spatial_shapes,
1236
+ data_level_start_index,
1237
+ data_sampling_loc,
1238
+ data_attn_weight,
1239
+ batch_size,
1240
+ spatial_size,
1241
+ num_heads,
1242
+ channels,
1243
+ num_levels,
1244
+ num_query,
1245
+ num_point,
1246
+ grad_value,
1247
+ grad_sampling_loc,
1248
+ grad_attn_weight);
1249
+ break;
1250
+ case 1024:
1251
+ ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2<scalar_t, 1024>
1252
+ <<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,
1253
+ 0, stream>>>(
1254
+ num_kernels,
1255
+ grad_col,
1256
+ data_value,
1257
+ data_spatial_shapes,
1258
+ data_level_start_index,
1259
+ data_sampling_loc,
1260
+ data_attn_weight,
1261
+ batch_size,
1262
+ spatial_size,
1263
+ num_heads,
1264
+ channels,
1265
+ num_levels,
1266
+ num_query,
1267
+ num_point,
1268
+ grad_value,
1269
+ grad_sampling_loc,
1270
+ grad_attn_weight);
1271
+ break;
1272
+ default:
1273
+ if (channels < 64)
1274
+ {
1275
+ ms_deformable_col2im_gpu_kernel_shm_reduce_v1<scalar_t>
1276
+ <<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,
1277
+ num_threads*3*sizeof(scalar_t), stream>>>(
1278
+ num_kernels,
1279
+ grad_col,
1280
+ data_value,
1281
+ data_spatial_shapes,
1282
+ data_level_start_index,
1283
+ data_sampling_loc,
1284
+ data_attn_weight,
1285
+ batch_size,
1286
+ spatial_size,
1287
+ num_heads,
1288
+ channels,
1289
+ num_levels,
1290
+ num_query,
1291
+ num_point,
1292
+ grad_value,
1293
+ grad_sampling_loc,
1294
+ grad_attn_weight);
1295
+ }
1296
+ else
1297
+ {
1298
+ ms_deformable_col2im_gpu_kernel_shm_reduce_v2<scalar_t>
1299
+ <<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,
1300
+ num_threads*3*sizeof(scalar_t), stream>>>(
1301
+ num_kernels,
1302
+ grad_col,
1303
+ data_value,
1304
+ data_spatial_shapes,
1305
+ data_level_start_index,
1306
+ data_sampling_loc,
1307
+ data_attn_weight,
1308
+ batch_size,
1309
+ spatial_size,
1310
+ num_heads,
1311
+ channels,
1312
+ num_levels,
1313
+ num_query,
1314
+ num_point,
1315
+ grad_value,
1316
+ grad_sampling_loc,
1317
+ grad_attn_weight);
1318
+ }
1319
+ }
1320
+ }
1321
+ cudaError_t err = cudaGetLastError();
1322
+ if (err != cudaSuccess)
1323
+ {
1324
+ printf("error in ms_deformable_col2im_cuda: %s\n", cudaGetErrorString(err));
1325
+ }
1326
+
1327
+ }
perception_models/apps/detection/DETA_pe/models/ops/src/ms_deform_attn.h ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /*!
2
+ **************************************************************************************************
3
+ * Deformable DETR
4
+ * Copyright (c) 2020 SenseTime. All Rights Reserved.
5
+ * Licensed under the Apache License, Version 2.0 [see LICENSE for details]
6
+ **************************************************************************************************
7
+ * Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
8
+ **************************************************************************************************
9
+ */
10
+
11
+ #pragma once
12
+
13
+ #include "cpu/ms_deform_attn_cpu.h"
14
+
15
+ #ifdef WITH_CUDA
16
+ #include "cuda/ms_deform_attn_cuda.h"
17
+ #endif
18
+
19
+
20
+ at::Tensor
21
+ ms_deform_attn_forward(
22
+ const at::Tensor &value,
23
+ const at::Tensor &spatial_shapes,
24
+ const at::Tensor &level_start_index,
25
+ const at::Tensor &sampling_loc,
26
+ const at::Tensor &attn_weight,
27
+ const int im2col_step)
28
+ {
29
+ if (value.type().is_cuda())
30
+ {
31
+ #ifdef WITH_CUDA
32
+ return ms_deform_attn_cuda_forward(
33
+ value, spatial_shapes, level_start_index, sampling_loc, attn_weight, im2col_step);
34
+ #else
35
+ AT_ERROR("Not compiled with GPU support");
36
+ #endif
37
+ }
38
+ AT_ERROR("Not implemented on the CPU");
39
+ }
40
+
41
+ std::vector<at::Tensor>
42
+ ms_deform_attn_backward(
43
+ const at::Tensor &value,
44
+ const at::Tensor &spatial_shapes,
45
+ const at::Tensor &level_start_index,
46
+ const at::Tensor &sampling_loc,
47
+ const at::Tensor &attn_weight,
48
+ const at::Tensor &grad_output,
49
+ const int im2col_step)
50
+ {
51
+ if (value.type().is_cuda())
52
+ {
53
+ #ifdef WITH_CUDA
54
+ return ms_deform_attn_cuda_backward(
55
+ value, spatial_shapes, level_start_index, sampling_loc, attn_weight, grad_output, im2col_step);
56
+ #else
57
+ AT_ERROR("Not compiled with GPU support");
58
+ #endif
59
+ }
60
+ AT_ERROR("Not implemented on the CPU");
61
+ }
62
+
perception_models/apps/detection/DETA_pe/models/ops/src/vision.cpp ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /*!
2
+ **************************************************************************************************
3
+ * Deformable DETR
4
+ * Copyright (c) 2020 SenseTime. All Rights Reserved.
5
+ * Licensed under the Apache License, Version 2.0 [see LICENSE for details]
6
+ **************************************************************************************************
7
+ * Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
8
+ **************************************************************************************************
9
+ */
10
+
11
+ #include "ms_deform_attn.h"
12
+
13
+ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
14
+ m.def("ms_deform_attn_forward", &ms_deform_attn_forward, "ms_deform_attn_forward");
15
+ m.def("ms_deform_attn_backward", &ms_deform_attn_backward, "ms_deform_attn_backward");
16
+ }
perception_models/apps/detection/DETA_pe/models/ops/test.py ADDED
@@ -0,0 +1,89 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ------------------------------------------------------------------------------------------------
2
+ # Deformable DETR
3
+ # Copyright (c) 2020 SenseTime. All Rights Reserved.
4
+ # Licensed under the Apache License, Version 2.0 [see LICENSE for details]
5
+ # ------------------------------------------------------------------------------------------------
6
+ # Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
7
+ # ------------------------------------------------------------------------------------------------
8
+
9
+ from __future__ import absolute_import
10
+ from __future__ import print_function
11
+ from __future__ import division
12
+
13
+ import time
14
+ import torch
15
+ import torch.nn as nn
16
+ from torch.autograd import gradcheck
17
+
18
+ from functions.ms_deform_attn_func import MSDeformAttnFunction, ms_deform_attn_core_pytorch
19
+
20
+
21
+ N, M, D = 1, 2, 2
22
+ Lq, L, P = 2, 2, 2
23
+ shapes = torch.as_tensor([(6, 4), (3, 2)], dtype=torch.long).cuda()
24
+ level_start_index = torch.cat((shapes.new_zeros((1, )), shapes.prod(1).cumsum(0)[:-1]))
25
+ S = sum([(H*W).item() for H, W in shapes])
26
+
27
+
28
+ torch.manual_seed(3)
29
+
30
+
31
+ @torch.no_grad()
32
+ def check_forward_equal_with_pytorch_double():
33
+ value = torch.rand(N, S, M, D).cuda() * 0.01
34
+ sampling_locations = torch.rand(N, Lq, M, L, P, 2).cuda()
35
+ attention_weights = torch.rand(N, Lq, M, L, P).cuda() + 1e-5
36
+ attention_weights /= attention_weights.sum(-1, keepdim=True).sum(-2, keepdim=True)
37
+ im2col_step = 2
38
+ output_pytorch = ms_deform_attn_core_pytorch(value.double(), shapes, sampling_locations.double(), attention_weights.double()).detach().cpu()
39
+ output_cuda = MSDeformAttnFunction.apply(value.double(), shapes, level_start_index, sampling_locations.double(), attention_weights.double(), im2col_step).detach().cpu()
40
+ fwdok = torch.allclose(output_cuda, output_pytorch)
41
+ max_abs_err = (output_cuda - output_pytorch).abs().max()
42
+ max_rel_err = ((output_cuda - output_pytorch).abs() / output_pytorch.abs()).max()
43
+
44
+ print(f'* {fwdok} check_forward_equal_with_pytorch_double: max_abs_err {max_abs_err:.2e} max_rel_err {max_rel_err:.2e}')
45
+
46
+
47
+ @torch.no_grad()
48
+ def check_forward_equal_with_pytorch_float():
49
+ value = torch.rand(N, S, M, D).cuda() * 0.01
50
+ sampling_locations = torch.rand(N, Lq, M, L, P, 2).cuda()
51
+ attention_weights = torch.rand(N, Lq, M, L, P).cuda() + 1e-5
52
+ attention_weights /= attention_weights.sum(-1, keepdim=True).sum(-2, keepdim=True)
53
+ im2col_step = 2
54
+ output_pytorch = ms_deform_attn_core_pytorch(value, shapes, sampling_locations, attention_weights).detach().cpu()
55
+ output_cuda = MSDeformAttnFunction.apply(value, shapes, level_start_index, sampling_locations, attention_weights, im2col_step).detach().cpu()
56
+ fwdok = torch.allclose(output_cuda, output_pytorch, rtol=1e-2, atol=1e-3)
57
+ max_abs_err = (output_cuda - output_pytorch).abs().max()
58
+ max_rel_err = ((output_cuda - output_pytorch).abs() / output_pytorch.abs()).max()
59
+
60
+ print(f'* {fwdok} check_forward_equal_with_pytorch_float: max_abs_err {max_abs_err:.2e} max_rel_err {max_rel_err:.2e}')
61
+
62
+
63
+ def check_gradient_numerical(channels=4, grad_value=True, grad_sampling_loc=True, grad_attn_weight=True):
64
+
65
+ value = torch.rand(N, S, M, channels).cuda() * 0.01
66
+ sampling_locations = torch.rand(N, Lq, M, L, P, 2).cuda()
67
+ attention_weights = torch.rand(N, Lq, M, L, P).cuda() + 1e-5
68
+ attention_weights /= attention_weights.sum(-1, keepdim=True).sum(-2, keepdim=True)
69
+ im2col_step = 2
70
+ func = MSDeformAttnFunction.apply
71
+
72
+ value.requires_grad = grad_value
73
+ sampling_locations.requires_grad = grad_sampling_loc
74
+ attention_weights.requires_grad = grad_attn_weight
75
+
76
+ gradok = gradcheck(func, (value.double(), shapes, level_start_index, sampling_locations.double(), attention_weights.double(), im2col_step))
77
+
78
+ print(f'* {gradok} check_gradient_numerical(D={channels})')
79
+
80
+
81
+ if __name__ == '__main__':
82
+ check_forward_equal_with_pytorch_double()
83
+ check_forward_equal_with_pytorch_float()
84
+
85
+ for channels in [30, 32, 64, 71, 1025, 2048, 3096]:
86
+ check_gradient_numerical(channels, True, True, True)
87
+
88
+
89
+
perception_models/apps/detection/DETA_pe/models/pev1.py ADDED
@@ -0,0 +1,686 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from collections import OrderedDict
3
+ from functools import partial
4
+ from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union
5
+
6
+ import torch
7
+ import torch.nn.functional as F
8
+ from einops import rearrange, repeat
9
+ from torch import broadcast_tensors, einsum, nn
10
+ from torch.nn.parameter import Parameter
11
+ from torch.utils.checkpoint import checkpoint
12
+
13
+ from .utils_d2 import (
14
+ add_decomposed_rel_pos,
15
+ PatchEmbed,
16
+ window_partition,
17
+ window_unpartition,
18
+ )
19
+
20
+
21
+ def get_abs_pos(abs_pos, has_cls_token, hw, tile=False):
22
+ h, w = hw
23
+ if has_cls_token:
24
+ abs_pos = abs_pos[:, 1:]
25
+ xy_num = abs_pos.shape[1]
26
+ size = int(math.sqrt(xy_num))
27
+ assert size * size == xy_num
28
+
29
+ if size != h or size != w:
30
+ if tile == True:
31
+ new_abs_pos = abs_pos.reshape(1, size, size, -1).tile(
32
+ [1, h // size + 1, w // size + 1, 1]
33
+ )[:, :h, :w, :]
34
+
35
+ return new_abs_pos
36
+ else:
37
+ new_abs_pos = F.interpolate(
38
+ abs_pos.reshape(1, size, size, -1).permute(0, 3, 1, 2),
39
+ size=(h, w),
40
+ mode="bicubic",
41
+ align_corners=False,
42
+ )
43
+ return new_abs_pos.permute(0, 2, 3, 1)
44
+ else:
45
+ return abs_pos.reshape(1, h, w, -1)
46
+
47
+
48
+ # broadcat, as tortoise-tts was using it
49
+ def broadcat(tensors, dim=-1):
50
+ broadcasted_tensors = broadcast_tensors(*tensors)
51
+ return torch.cat(broadcasted_tensors, dim=dim)
52
+
53
+
54
+ # rotary embedding helper functions
55
+ def rotate_half(x):
56
+ x = rearrange(x, "... (d r) -> ... d r", r=2)
57
+ x1, x2 = x.unbind(dim=-1)
58
+ x = torch.stack((-x2, x1), dim=-1)
59
+ return rearrange(x, "... d r -> ... (d r)")
60
+
61
+
62
+ class VisionRotaryEmbeddingFast(nn.Module):
63
+ def __init__(
64
+ self,
65
+ dim,
66
+ pt_seq_len=16,
67
+ ft_seq_len=None,
68
+ custom_freqs=None,
69
+ freqs_for="lang",
70
+ theta=10000,
71
+ max_freq=10,
72
+ num_freqs=1,
73
+ ):
74
+ super().__init__()
75
+ if custom_freqs:
76
+ freqs = custom_freqs
77
+ elif freqs_for == "lang":
78
+ freqs = 1.0 / (
79
+ theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)
80
+ )
81
+ elif freqs_for == "pixel":
82
+ freqs = torch.linspace(1.0, max_freq / 2, dim // 2) * pi
83
+ elif freqs_for == "constant":
84
+ freqs = torch.ones(num_freqs).float()
85
+ else:
86
+ raise ValueError(f"unknown modality {freqs_for}")
87
+
88
+ if ft_seq_len is None:
89
+ ft_seq_len = pt_seq_len
90
+ t = (
91
+ torch.arange(ft_seq_len) / ft_seq_len * pt_seq_len + 1
92
+ ) # + 1 is hacking vev0 pt code
93
+
94
+ freqs = torch.einsum("..., f -> ... f", t, freqs)
95
+ freqs = repeat(freqs, "... n -> ... (n r)", r=2)
96
+ # freqs = broadcat((freqs[:, None, :], freqs[None, :, :]), dim = -1)
97
+ freqs = broadcat(
98
+ (freqs[None, :, :], freqs[:, None, :]), dim=-1
99
+ ) # follow vev0 pt code
100
+
101
+ freqs_cos = freqs.cos().view(-1, freqs.shape[-1])
102
+ freqs_sin = freqs.sin().view(-1, freqs.shape[-1])
103
+
104
+ self.register_buffer("freqs_cos", freqs_cos)
105
+ self.register_buffer("freqs_sin", freqs_sin)
106
+
107
+ print("======== shape of rope freq", self.freqs_cos.shape, "========")
108
+
109
+ def forward(self, tt):
110
+ return tt * self.freqs_cos + rotate_half(tt) * self.freqs_sin
111
+
112
+
113
+ class LayerNorm(nn.LayerNorm):
114
+ """Subclass torch's LayerNorm to handle fp16."""
115
+
116
+ def forward(self, x: torch.Tensor):
117
+ orig_type = x.dtype
118
+ # ret = super().forward(x.type(torch.float32))
119
+ ret = F.layer_norm(
120
+ x.type(torch.float32),
121
+ self.normalized_shape,
122
+ self.weight.type(torch.float32),
123
+ self.bias.type(torch.float32),
124
+ self.eps,
125
+ )
126
+ return ret.type(orig_type)
127
+
128
+
129
+ class QuickGELU(nn.Module):
130
+ def forward(self, x: torch.Tensor):
131
+ return x * torch.sigmoid(1.702 * x)
132
+
133
+
134
+ def drop_path(
135
+ x, drop_prob: float = 0.0, training: bool = False, scale_by_keep: bool = True
136
+ ):
137
+ """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
138
+
139
+ This is the same as the DropConnect impl I created for EfficientNet, etc networks, however,
140
+ the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
141
+ See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for
142
+ changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use
143
+ 'survival rate' as the argument.
144
+
145
+ """
146
+ if drop_prob == 0.0 or not training:
147
+ return x
148
+ keep_prob = 1 - drop_prob
149
+ shape = (x.shape[0],) + (1,) * (
150
+ x.ndim - 1
151
+ ) # work with diff dim tensors, not just 2D ConvNets
152
+ random_tensor = x.new_empty(shape).bernoulli_(keep_prob)
153
+ if keep_prob > 0.0 and scale_by_keep:
154
+ random_tensor.div_(keep_prob)
155
+ return x * random_tensor
156
+
157
+
158
+ class DropPath(nn.Module):
159
+ """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks)."""
160
+
161
+ def __init__(self, drop_prob: float = 0.0, scale_by_keep: bool = True):
162
+ super(DropPath, self).__init__()
163
+ self.drop_prob = drop_prob
164
+ self.scale_by_keep = scale_by_keep
165
+
166
+ def forward(self, x):
167
+ return drop_path(x, self.drop_prob, self.training, self.scale_by_keep)
168
+
169
+ def extra_repr(self):
170
+ return f"drop_prob={round(self.drop_prob,3):0.3f}"
171
+
172
+
173
+ class Attention(nn.Module):
174
+ r"""
175
+ Implements attention based on Rope
176
+ """
177
+
178
+ def __init__(
179
+ self,
180
+ embed_dim: int,
181
+ num_heads: int,
182
+ dropout: float = 0.0,
183
+ bias: bool = True,
184
+ add_bias_kv: bool = False,
185
+ kdim: Optional[bool] = None,
186
+ vdim: Optional[bool] = None,
187
+ rope=None,
188
+ ):
189
+ super(Attention, self).__init__()
190
+ self.embed_dim = embed_dim
191
+ self.kdim = kdim if kdim is not None else embed_dim
192
+ self.vdim = vdim if vdim is not None else embed_dim
193
+ self._qkv_same_embed_dim = self.kdim == embed_dim and self.vdim == embed_dim
194
+
195
+ self.num_heads = num_heads
196
+ self.dropout = dropout
197
+ self.head_dim = embed_dim // num_heads
198
+ assert (
199
+ self.head_dim * num_heads == self.embed_dim
200
+ ), "embed_dim must be divisible by num_heads"
201
+
202
+ if self._qkv_same_embed_dim is False:
203
+ self.q_proj_weight = Parameter(torch.Tensor(embed_dim, embed_dim))
204
+ self.k_proj_weight = Parameter(torch.Tensor(embed_dim, self.kdim))
205
+ self.v_proj_weight = Parameter(torch.Tensor(embed_dim, self.vdim))
206
+ else:
207
+ self.in_proj_weight = Parameter(torch.empty(3 * embed_dim, embed_dim))
208
+
209
+ if bias:
210
+ self.in_proj_bias = Parameter(torch.empty(3 * embed_dim))
211
+ else:
212
+ self.register_parameter("in_proj_bias", None)
213
+ self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
214
+
215
+ if add_bias_kv:
216
+ self.bias_k = Parameter(torch.empty(1, 1, embed_dim))
217
+ self.bias_v = Parameter(torch.empty(1, 1, embed_dim))
218
+ else:
219
+ self.bias_k = self.bias_v = None
220
+
221
+ self.rope = rope
222
+
223
+ self.scale = self.head_dim ** (-0.5)
224
+
225
+ def forward(self, query, attn_mask: Optional[torch.Tensor] = None):
226
+ batch, seq, embed_dim = query.shape
227
+
228
+ proj = torch._C._nn.linear(query, self.in_proj_weight, self.in_proj_bias)
229
+ # reshape to 3, E and not E, 3 is deliberate for better memory coalescing and keeping same order as chunk()
230
+ proj = (
231
+ proj.unflatten(-1, (3, embed_dim))
232
+ .unsqueeze(0)
233
+ .transpose(0, -2)
234
+ .squeeze(-2)
235
+ .contiguous()
236
+ )
237
+ q_, k_, v_ = proj[0], proj[1], proj[2]
238
+
239
+ # Use "q_" so that we don't accidentally quit in pdb :)
240
+ q_ = rearrange(q_, "b s (h d) -> b h s d", h=self.num_heads)
241
+ k_ = rearrange(k_, "b s (h d) -> b h s d", h=self.num_heads)
242
+ v_ = rearrange(v_, "b s (h d) -> b h s d", h=self.num_heads)
243
+
244
+ ## rope
245
+ q_ = self.rope(q_).type_as(v_)
246
+ k_ = self.rope(k_).type_as(v_)
247
+
248
+ attn = (q_ * self.scale) @ k_.transpose(-2, -1)
249
+ attn = attn.softmax(dim=-1)
250
+ x_ = attn @ v_
251
+
252
+ x_ = rearrange(x_, "b h s d -> b s (h d)")
253
+
254
+ return torch._C._nn.linear(x_, self.out_proj.weight, self.out_proj.bias)
255
+
256
+
257
+ class LayerScale(nn.Module):
258
+ def __init__(
259
+ self,
260
+ dim: int,
261
+ init_values: float = 1e-5,
262
+ inplace: bool = False,
263
+ ) -> None:
264
+ super().__init__()
265
+ self.inplace = inplace
266
+ self.gamma = nn.Parameter(init_values * torch.ones(dim))
267
+
268
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
269
+ return x.mul_(self.gamma) if self.inplace else x * self.gamma
270
+
271
+
272
+ class ResidualAttentionBlock(nn.Module):
273
+ def __init__(
274
+ self,
275
+ d_model: int,
276
+ n_head: int,
277
+ mlp_ratio=4.0,
278
+ act_layer=nn.GELU,
279
+ norm_layer=LayerNorm,
280
+ drop_path=0.0,
281
+ use_rel_pos=False,
282
+ rel_pos_zero_init=True,
283
+ window_size=0,
284
+ rope=None,
285
+ input_size=None,
286
+ attn_mask=None,
287
+ init_values=0.0,
288
+ ):
289
+ super().__init__()
290
+
291
+ self.attn = Attention(embed_dim=d_model, num_heads=n_head, rope=rope)
292
+ self.ls_1 = (
293
+ LayerScale(d_model, init_values=init_values)
294
+ if init_values > 0.0
295
+ else nn.Identity()
296
+ )
297
+ self.ln_1 = LayerNorm(d_model)
298
+ self.mlp = nn.Sequential(
299
+ OrderedDict(
300
+ [
301
+ ("c_fc", nn.Linear(d_model, int(d_model * mlp_ratio))),
302
+ ("gelu", act_layer()),
303
+ ("c_proj", nn.Linear(int(d_model * mlp_ratio), d_model)),
304
+ ]
305
+ )
306
+ )
307
+
308
+ self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
309
+ self.ln_2 = LayerNorm(d_model)
310
+ self.attn_mask = attn_mask
311
+ self.ls_2 = (
312
+ LayerScale(d_model, init_values=init_values)
313
+ if init_values > 0.0
314
+ else nn.Identity()
315
+ )
316
+ self.window_size = window_size
317
+
318
+ def attention_nhwc(self, x: torch.Tensor):
319
+ self.attn_mask = (
320
+ self.attn_mask.to(dtype=x.dtype, device=x.device)
321
+ if self.attn_mask is not None
322
+ else None
323
+ )
324
+ B, H, W, _ = x.shape
325
+ x = x.reshape(B, H * W, -1)
326
+ x = self.attn(x, attn_mask=self.attn_mask)
327
+ x = x.reshape(B, H, W, -1)
328
+ return x
329
+
330
+ def forward(self, x: torch.Tensor):
331
+ shortcut = x
332
+
333
+ x = self.ln_1(x)
334
+ # Window partition
335
+ if self.window_size > 0:
336
+ H, W = x.shape[1], x.shape[2]
337
+ x, pad_hw = window_partition(x, self.window_size)
338
+
339
+ x = self.attention_nhwc(x)
340
+ # Reverse window partition
341
+ if self.window_size > 0:
342
+ x = window_unpartition(x, self.window_size, pad_hw, (H, W))
343
+
344
+ x = shortcut + self.drop_path(self.ls_1(x))
345
+ x = x + self.drop_path(self.ls_2(self.mlp(self.ln_2(x))))
346
+ return x
347
+
348
+
349
+ class Transformer(nn.Module):
350
+ def __init__(
351
+ self,
352
+ embed_dim: int,
353
+ depth: int,
354
+ num_heads: int,
355
+ mlp_ratio=4.0,
356
+ act_layer=nn.GELU,
357
+ norm_layer=LayerNorm,
358
+ drop_path_rate=0.0,
359
+ use_rel_pos=False,
360
+ rel_pos_zero_init=True,
361
+ window_size=0,
362
+ window_block_indexes=(),
363
+ img_size=1024,
364
+ patch_size=16,
365
+ rope_win=None,
366
+ rope_glb=None,
367
+ use_act_checkpoint=False,
368
+ act_checkpoint_ratio=1.0,
369
+ attn_mask=None,
370
+ init_values=0.0,
371
+ return_layer=[-1],
372
+ ):
373
+ super().__init__()
374
+ self.use_act_checkpoint = use_act_checkpoint
375
+ self.act_checkpoint_ratio = act_checkpoint_ratio
376
+
377
+ # stochastic depth decay rule
378
+ dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)]
379
+
380
+ self.resblocks = nn.ModuleList()
381
+ for i in range(depth):
382
+ block = ResidualAttentionBlock(
383
+ embed_dim,
384
+ num_heads,
385
+ attn_mask=attn_mask,
386
+ drop_path=dpr[i],
387
+ mlp_ratio=mlp_ratio,
388
+ act_layer=act_layer,
389
+ norm_layer=norm_layer,
390
+ use_rel_pos=use_rel_pos,
391
+ rel_pos_zero_init=rel_pos_zero_init,
392
+ window_size=window_size if i in window_block_indexes else 0,
393
+ rope=rope_win if i in window_block_indexes else rope_glb,
394
+ input_size=(img_size // patch_size, img_size // patch_size),
395
+ init_values=init_values,
396
+ )
397
+ self.resblocks.append(block)
398
+
399
+ self.return_layer = return_layer
400
+
401
+ def forward(self, x: torch.Tensor):
402
+ x_list = []
403
+ for idx, blk in enumerate(self.resblocks):
404
+ if (
405
+ self.use_act_checkpoint
406
+ and (idx / len(self.resblocks)) <= self.act_checkpoint_ratio
407
+ ):
408
+ x = checkpoint(blk, x)
409
+ else:
410
+ x = blk(x)
411
+
412
+ if idx in self.return_layer or idx == len(self.resblocks) - 1:
413
+ x_list.append(x)
414
+
415
+ return x, x_list
416
+
417
+
418
+ class PEv1_simpleFPN(nn.Module):
419
+ def __init__(
420
+ self,
421
+ img_size=1024,
422
+ patch_size=16,
423
+ in_chans=3,
424
+ embed_dim=768,
425
+ depth=12,
426
+ num_heads=12,
427
+ mlp_ratio=4.0,
428
+ qkv_bias=True,
429
+ drop_path_rate=0.0,
430
+ norm_layer=nn.LayerNorm,
431
+ act_layer=nn.GELU,
432
+ use_abs_pos=True,
433
+ use_rel_pos=False,
434
+ rel_pos_zero_init=True,
435
+ rope=True,
436
+ pt_hw_seq_len=16,
437
+ intp_freq=True,
438
+ window_size=0,
439
+ window_block_indexes=(),
440
+ residual_block_indexes=(),
441
+ use_act_checkpoint=False,
442
+ act_checkpoint_ratio=1.0,
443
+ pretrain_img_size=336,
444
+ pretrain_use_cls_token=True,
445
+ out_feature="last_feat",
446
+ tile_posemb=False,
447
+ init_values=0.0,
448
+ tta_rope=False,
449
+ return_layer=[-1],
450
+ ):
451
+ super().__init__()
452
+ self.pretrain_use_cls_token = pretrain_use_cls_token
453
+
454
+ self.conv1 = nn.Conv2d(
455
+ in_channels=in_chans,
456
+ out_channels=embed_dim,
457
+ kernel_size=patch_size,
458
+ stride=patch_size,
459
+ bias=False,
460
+ )
461
+
462
+ if use_abs_pos:
463
+ # Initialize absolute positional embedding with pretrain image size.
464
+ num_patches = (pretrain_img_size // patch_size) * (
465
+ pretrain_img_size // patch_size
466
+ )
467
+ num_positions = (num_patches + 1) if pretrain_use_cls_token else num_patches
468
+ self.positional_embedding = nn.Parameter(
469
+ torch.zeros(1, num_positions, embed_dim)
470
+ )
471
+ print("positional_embedding:", self.positional_embedding.shape)
472
+ print("positional_embedding:", self.positional_embedding.shape)
473
+ print("positional_embedding:", self.positional_embedding.shape)
474
+
475
+ else:
476
+ self.positional_embedding = None
477
+
478
+ self.tile_posemb = tile_posemb
479
+
480
+ self.ln_pre = LayerNorm(embed_dim)
481
+
482
+ half_head_dim = embed_dim // num_heads // 2
483
+ hw_seq_len = img_size // patch_size
484
+
485
+ self.rope_win = VisionRotaryEmbeddingFast(
486
+ dim=half_head_dim,
487
+ pt_seq_len=pt_hw_seq_len,
488
+ ft_seq_len=window_size if intp_freq else None,
489
+ )
490
+ self.rope_glb = VisionRotaryEmbeddingFast(
491
+ dim=half_head_dim,
492
+ pt_seq_len=pt_hw_seq_len,
493
+ ft_seq_len=hw_seq_len if intp_freq else None,
494
+ )
495
+
496
+ self.transformer = Transformer(
497
+ embed_dim=embed_dim,
498
+ depth=depth,
499
+ num_heads=num_heads,
500
+ mlp_ratio=mlp_ratio,
501
+ act_layer=act_layer,
502
+ norm_layer=norm_layer,
503
+ drop_path_rate=drop_path_rate,
504
+ use_rel_pos=use_rel_pos,
505
+ rel_pos_zero_init=rel_pos_zero_init,
506
+ window_size=window_size,
507
+ window_block_indexes=window_block_indexes,
508
+ rope_win=self.rope_win,
509
+ rope_glb=self.rope_glb,
510
+ img_size=img_size,
511
+ patch_size=patch_size,
512
+ use_act_checkpoint=use_act_checkpoint,
513
+ act_checkpoint_ratio=act_checkpoint_ratio,
514
+ init_values=init_values,
515
+ return_layer=return_layer,
516
+ )
517
+
518
+ self._out_feature_channels = {out_feature: embed_dim}
519
+ self._out_feature_strides = {out_feature: patch_size}
520
+ self._out_features = [out_feature]
521
+
522
+ if self.positional_embedding is not None:
523
+ nn.init.trunc_normal_(self.positional_embedding, std=0.02)
524
+
525
+ self.return_layer = return_layer
526
+ # In our method, we don't use backbone feature with stride 4
527
+ self.fpn1 = nn.Sequential(
528
+ nn.ConvTranspose2d(embed_dim, embed_dim // 2, kernel_size=2, stride=2),
529
+ )
530
+ self.fpn2 = nn.Identity()
531
+ self.fpn3 = nn.MaxPool2d(kernel_size=2, stride=2)
532
+
533
+ self.apply(self._init_weights)
534
+
535
+ strides = [patch_size // 2, patch_size, patch_size * 2]
536
+ self._out_features = ["p{}".format(int(math.log2(s))) for s in strides]
537
+ self._out_feature_strides = {
538
+ "p3": 8,
539
+ "p4": 16,
540
+ "p5": 32,
541
+ }
542
+ self._out_feature_channels = {
543
+ "p3": embed_dim // 2,
544
+ "p4": embed_dim,
545
+ "p5": embed_dim,
546
+ }
547
+ self._size_divisibility = strides[-1]
548
+ self._square_pad = img_size
549
+
550
+ def _init_weights(self, m):
551
+ if isinstance(m, nn.Linear):
552
+ nn.init.trunc_normal_(m.weight, std=0.02)
553
+ if isinstance(m, nn.Linear) and m.bias is not None:
554
+ nn.init.constant_(m.bias, 0)
555
+ elif isinstance(m, nn.LayerNorm):
556
+ nn.init.constant_(m.bias, 0)
557
+ nn.init.constant_(m.weight, 1.0)
558
+
559
+ def forward(self, x):
560
+ x = self.conv1(x)
561
+ x = x.permute(0, 2, 3, 1)
562
+
563
+ if self.positional_embedding is not None:
564
+ x = x + get_abs_pos(
565
+ self.positional_embedding,
566
+ self.pretrain_use_cls_token,
567
+ (x.shape[1], x.shape[2]),
568
+ self.tile_posemb,
569
+ )
570
+ x = self.ln_pre(x)
571
+
572
+ x, x_list = self.transformer(x)
573
+
574
+ xp = x.permute(0, 3, 1, 2) # (b, h, w, c) --> (b, c, h, w)
575
+
576
+ features = []
577
+ ops = [self.fpn1, self.fpn2, self.fpn3]
578
+ for i in range(len(ops)):
579
+ features.append(ops[i](xp))
580
+ rets = {"p{}".format(u + 3): v for (u, v) in enumerate(features)}
581
+
582
+ return rets
583
+
584
+
585
+ def get_pev1_and_fpn_backbone(args):
586
+ if args.lsj_img_size_max > 0:
587
+ img_size = args.lsj_img_size_max
588
+ else:
589
+ img_size = args.lsj_img_size
590
+ use_act_checkpoint = args.backbone_use_act_checkpoint
591
+ act_checkpoint_ratio = args.backbone_act_checkpoint_ratio
592
+ init_values = args.backbone_init_values
593
+ tile_posemb = args.backbone_tile_posemb
594
+ tta_rope = args.backbone_tta_rope
595
+ multi_layer = args.backbone_multi_layer
596
+ backbone_dp = args.backbone_dp
597
+
598
+ if args.backbone_size == "G":
599
+ embed_dim, depth, num_heads, mlp_ratio, dp = 1536, 50, 16, 8960 / 1536, 0.5
600
+ pretrain_img_size, patch_size, window_size = 224, 16, 14
601
+ window_block_indexes = (
602
+ list(range(0, 12))
603
+ + list(range(13, 24))
604
+ + list(range(25, 36))
605
+ + list(range(37, 49))
606
+ )
607
+ pretrain_use_cls_token = False
608
+ if multi_layer:
609
+ return_layer = [12, 24, 36, 49]
610
+ else:
611
+ return_layer = [-1]
612
+
613
+ elif args.backbone_size == "Gwin384":
614
+ embed_dim, depth, num_heads, mlp_ratio, dp = 1536, 50, 16, 8960 / 1536, 0.5
615
+ pretrain_img_size, patch_size, window_size = 384, 16, 24
616
+ window_block_indexes = (
617
+ list(range(0, 12))
618
+ + list(range(13, 24))
619
+ + list(range(25, 36))
620
+ + list(range(37, 49))
621
+ )
622
+ pretrain_use_cls_token = False
623
+ if multi_layer:
624
+ return_layer = [12, 24, 36, 49]
625
+ else:
626
+ return_layer = [-1]
627
+
628
+ elif args.backbone_size == "Gwin512":
629
+ embed_dim, depth, num_heads, mlp_ratio, dp = 1536, 50, 16, 8960 / 1536, 0.5
630
+ pretrain_img_size, patch_size, window_size = 512, 16, 32
631
+ window_block_indexes = (
632
+ list(range(0, 12))
633
+ + list(range(13, 24))
634
+ + list(range(25, 36))
635
+ + list(range(37, 49))
636
+ )
637
+ pretrain_use_cls_token = False
638
+ if multi_layer:
639
+ return_layer = [12, 24, 36, 49]
640
+ else:
641
+ return_layer = [-1]
642
+ else:
643
+ raise ValueError("Unsupported backbone size")
644
+
645
+ if backbone_dp >= 0:
646
+ dp = backbone_dp
647
+
648
+ assert (
649
+ depth == args.backbone_layers
650
+ ), f"backbone depth {depth} and layers {args.backbone_layers}(from config) must be the same"
651
+
652
+ model = PEv1_simpleFPN(
653
+ use_act_checkpoint=use_act_checkpoint,
654
+ act_checkpoint_ratio=act_checkpoint_ratio,
655
+ pretrain_img_size=pretrain_img_size,
656
+ pretrain_use_cls_token=pretrain_use_cls_token,
657
+ img_size=img_size,
658
+ patch_size=patch_size,
659
+ embed_dim=embed_dim,
660
+ depth=depth,
661
+ num_heads=num_heads,
662
+ drop_path_rate=dp,
663
+ window_size=window_size,
664
+ pt_hw_seq_len=16, # Maybe a bug ?
665
+ mlp_ratio=mlp_ratio,
666
+ qkv_bias=True,
667
+ norm_layer=partial(nn.LayerNorm, eps=1e-6),
668
+ window_block_indexes=window_block_indexes,
669
+ residual_block_indexes=[],
670
+ use_rel_pos=True,
671
+ out_feature="last_feat",
672
+ tile_posemb=tile_posemb,
673
+ init_values=init_values,
674
+ tta_rope=tta_rope,
675
+ return_layer=return_layer,
676
+ )
677
+
678
+ pretrained_backbone_path = args.backbone_path
679
+ if pretrained_backbone_path:
680
+ state_dict = torch.load(pretrained_backbone_path, map_location="cpu")
681
+ load_info = model.load_state_dict(state_dict["model"], strict=False)
682
+ print("Missing keys", load_info.missing_keys)
683
+ print("Unexpected keys", load_info.unexpected_keys)
684
+ else:
685
+ print("Skip pretrained backbone loading")
686
+ return model
perception_models/apps/detection/DETA_pe/models/position_encoding.py ADDED
@@ -0,0 +1,97 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ------------------------------------------------------------------------
2
+ # Deformable DETR
3
+ # Copyright (c) 2020 SenseTime. All Rights Reserved.
4
+ # Licensed under the Apache License, Version 2.0 [see LICENSE for details]
5
+ # ------------------------------------------------------------------------
6
+ # Modified from DETR (https://github.com/facebookresearch/detr)
7
+ # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
8
+ # ------------------------------------------------------------------------
9
+
10
+ """
11
+ Various positional encodings for the transformer.
12
+ """
13
+ import math
14
+ import torch
15
+ from torch import nn
16
+
17
+ from util.misc import NestedTensor
18
+
19
+
20
+ class PositionEmbeddingSine(nn.Module):
21
+ """
22
+ This is a more standard version of the position embedding, very similar to the one
23
+ used by the Attention is all you need paper, generalized to work on images.
24
+ """
25
+ def __init__(self, num_pos_feats=64, temperature=10000, normalize=False, scale=None):
26
+ super().__init__()
27
+ self.num_pos_feats = num_pos_feats
28
+ self.temperature = temperature
29
+ self.normalize = normalize
30
+ if scale is not None and normalize is False:
31
+ raise ValueError("normalize should be True if scale is passed")
32
+ if scale is None:
33
+ scale = 2 * math.pi
34
+ self.scale = scale
35
+
36
+ def forward(self, tensor_list: NestedTensor):
37
+ x = tensor_list.tensors
38
+ mask = tensor_list.mask
39
+ assert mask is not None
40
+ not_mask = ~mask
41
+ y_embed = not_mask.cumsum(1, dtype=torch.float32)
42
+ x_embed = not_mask.cumsum(2, dtype=torch.float32)
43
+ if self.normalize:
44
+ eps = 1e-6
45
+ y_embed = (y_embed - 0.5) / (y_embed[:, -1:, :] + eps) * self.scale
46
+ x_embed = (x_embed - 0.5) / (x_embed[:, :, -1:] + eps) * self.scale
47
+
48
+ dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device)
49
+ dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats)
50
+
51
+ pos_x = x_embed[:, :, :, None] / dim_t
52
+ pos_y = y_embed[:, :, :, None] / dim_t
53
+ pos_x = torch.stack((pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4).flatten(3)
54
+ pos_y = torch.stack((pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4).flatten(3)
55
+ pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2)
56
+ return pos
57
+
58
+
59
+ class PositionEmbeddingLearned(nn.Module):
60
+ """
61
+ Absolute pos embedding, learned.
62
+ """
63
+ def __init__(self, num_pos_feats=256):
64
+ super().__init__()
65
+ self.row_embed = nn.Embedding(50, num_pos_feats)
66
+ self.col_embed = nn.Embedding(50, num_pos_feats)
67
+ self.reset_parameters()
68
+
69
+ def reset_parameters(self):
70
+ nn.init.uniform_(self.row_embed.weight)
71
+ nn.init.uniform_(self.col_embed.weight)
72
+
73
+ def forward(self, tensor_list: NestedTensor):
74
+ x = tensor_list.tensors
75
+ h, w = x.shape[-2:]
76
+ i = torch.arange(w, device=x.device)
77
+ j = torch.arange(h, device=x.device)
78
+ x_emb = self.col_embed(i)
79
+ y_emb = self.row_embed(j)
80
+ pos = torch.cat([
81
+ x_emb.unsqueeze(0).repeat(h, 1, 1),
82
+ y_emb.unsqueeze(1).repeat(1, w, 1),
83
+ ], dim=-1).permute(2, 0, 1).unsqueeze(0).repeat(x.shape[0], 1, 1, 1)
84
+ return pos
85
+
86
+
87
+ def build_position_encoding(args):
88
+ N_steps = args.hidden_dim // 2
89
+ if args.position_embedding in ('v2', 'sine'):
90
+ # TODO find a better way of exposing other arguments
91
+ position_embedding = PositionEmbeddingSine(N_steps, normalize=True)
92
+ elif args.position_embedding in ('v3', 'learned'):
93
+ position_embedding = PositionEmbeddingLearned(N_steps)
94
+ else:
95
+ raise ValueError(f"not supported {args.position_embedding}")
96
+
97
+ return position_embedding
perception_models/apps/detection/DETA_pe/models/segmentation.py ADDED
@@ -0,0 +1,369 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ------------------------------------------------------------------------
2
+ # Deformable DETR
3
+ # Copyright (c) 2020 SenseTime. All Rights Reserved.
4
+ # Licensed under the Apache License, Version 2.0 [see LICENSE for details]
5
+ # ------------------------------------------------------------------------
6
+ # Modified from DETR (https://github.com/facebookresearch/detr)
7
+ # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
8
+ # ------------------------------------------------------------------------
9
+
10
+ """
11
+ This file provides the definition of the convolutional heads used to predict masks, as well as the losses
12
+ """
13
+ import io
14
+ from collections import defaultdict
15
+
16
+ import torch
17
+ import torch.nn as nn
18
+ import torch.nn.functional as F
19
+ from PIL import Image
20
+
21
+ import util.box_ops as box_ops
22
+ from util.misc import NestedTensor, interpolate, nested_tensor_from_tensor_list
23
+
24
+ try:
25
+ from panopticapi.utils import id2rgb, rgb2id
26
+ except ImportError:
27
+ pass
28
+
29
+
30
+ class DETRsegm(nn.Module):
31
+ def __init__(self, detr, freeze_detr=False):
32
+ super().__init__()
33
+ self.detr = detr
34
+
35
+ if freeze_detr:
36
+ for p in self.parameters():
37
+ p.requires_grad_(False)
38
+
39
+ hidden_dim, nheads = detr.transformer.d_model, detr.transformer.nhead
40
+ self.bbox_attention = MHAttentionMap(hidden_dim, hidden_dim, nheads, dropout=0)
41
+ self.mask_head = MaskHeadSmallConv(hidden_dim + nheads, [1024, 512, 256], hidden_dim)
42
+
43
+ def forward(self, samples: NestedTensor):
44
+ if not isinstance(samples, NestedTensor):
45
+ samples = nested_tensor_from_tensor_list(samples)
46
+ features, pos = self.detr.backbone(samples)
47
+
48
+ bs = features[-1].tensors.shape[0]
49
+
50
+ src, mask = features[-1].decompose()
51
+ src_proj = self.detr.input_proj(src)
52
+ hs, memory = self.detr.transformer(src_proj, mask, self.detr.query_embed.weight, pos[-1])
53
+
54
+ outputs_class = self.detr.class_embed(hs)
55
+ outputs_coord = self.detr.bbox_embed(hs).sigmoid()
56
+ out = {"pred_logits": outputs_class[-1], "pred_boxes": outputs_coord[-1]}
57
+ if self.detr.aux_loss:
58
+ out["aux_outputs"] = [
59
+ {"pred_logits": a, "pred_boxes": b} for a, b in zip(outputs_class[:-1], outputs_coord[:-1])
60
+ ]
61
+
62
+ # FIXME h_boxes takes the last one computed, keep this in mind
63
+ bbox_mask = self.bbox_attention(hs[-1], memory, mask=mask)
64
+
65
+ seg_masks = self.mask_head(src_proj, bbox_mask, [features[2].tensors, features[1].tensors, features[0].tensors])
66
+ outputs_seg_masks = seg_masks.view(bs, self.detr.num_queries, seg_masks.shape[-2], seg_masks.shape[-1])
67
+
68
+ out["pred_masks"] = outputs_seg_masks
69
+ return out
70
+
71
+
72
+ class MaskHeadSmallConv(nn.Module):
73
+ """
74
+ Simple convolutional head, using group norm.
75
+ Upsampling is done using a FPN approach
76
+ """
77
+
78
+ def __init__(self, dim, fpn_dims, context_dim):
79
+ super().__init__()
80
+
81
+ inter_dims = [dim, context_dim // 2, context_dim // 4, context_dim // 8, context_dim // 16, context_dim // 64]
82
+ self.lay1 = torch.nn.Conv2d(dim, dim, 3, padding=1)
83
+ self.gn1 = torch.nn.GroupNorm(8, dim)
84
+ self.lay2 = torch.nn.Conv2d(dim, inter_dims[1], 3, padding=1)
85
+ self.gn2 = torch.nn.GroupNorm(8, inter_dims[1])
86
+ self.lay3 = torch.nn.Conv2d(inter_dims[1], inter_dims[2], 3, padding=1)
87
+ self.gn3 = torch.nn.GroupNorm(8, inter_dims[2])
88
+ self.lay4 = torch.nn.Conv2d(inter_dims[2], inter_dims[3], 3, padding=1)
89
+ self.gn4 = torch.nn.GroupNorm(8, inter_dims[3])
90
+ self.lay5 = torch.nn.Conv2d(inter_dims[3], inter_dims[4], 3, padding=1)
91
+ self.gn5 = torch.nn.GroupNorm(8, inter_dims[4])
92
+ self.out_lay = torch.nn.Conv2d(inter_dims[4], 1, 3, padding=1)
93
+
94
+ self.dim = dim
95
+
96
+ self.adapter1 = torch.nn.Conv2d(fpn_dims[0], inter_dims[1], 1)
97
+ self.adapter2 = torch.nn.Conv2d(fpn_dims[1], inter_dims[2], 1)
98
+ self.adapter3 = torch.nn.Conv2d(fpn_dims[2], inter_dims[3], 1)
99
+
100
+ for m in self.modules():
101
+ if isinstance(m, nn.Conv2d):
102
+ nn.init.kaiming_uniform_(m.weight, a=1)
103
+ nn.init.constant_(m.bias, 0)
104
+
105
+ def forward(self, x, bbox_mask, fpns):
106
+ def expand(tensor, length):
107
+ return tensor.unsqueeze(1).repeat(1, int(length), 1, 1, 1).flatten(0, 1)
108
+
109
+ x = torch.cat([expand(x, bbox_mask.shape[1]), bbox_mask.flatten(0, 1)], 1)
110
+
111
+ x = self.lay1(x)
112
+ x = self.gn1(x)
113
+ x = F.relu(x)
114
+ x = self.lay2(x)
115
+ x = self.gn2(x)
116
+ x = F.relu(x)
117
+
118
+ cur_fpn = self.adapter1(fpns[0])
119
+ if cur_fpn.size(0) != x.size(0):
120
+ cur_fpn = expand(cur_fpn, x.size(0) / cur_fpn.size(0))
121
+ x = cur_fpn + F.interpolate(x, size=cur_fpn.shape[-2:], mode="nearest")
122
+ x = self.lay3(x)
123
+ x = self.gn3(x)
124
+ x = F.relu(x)
125
+
126
+ cur_fpn = self.adapter2(fpns[1])
127
+ if cur_fpn.size(0) != x.size(0):
128
+ cur_fpn = expand(cur_fpn, x.size(0) / cur_fpn.size(0))
129
+ x = cur_fpn + F.interpolate(x, size=cur_fpn.shape[-2:], mode="nearest")
130
+ x = self.lay4(x)
131
+ x = self.gn4(x)
132
+ x = F.relu(x)
133
+
134
+ cur_fpn = self.adapter3(fpns[2])
135
+ if cur_fpn.size(0) != x.size(0):
136
+ cur_fpn = expand(cur_fpn, x.size(0) / cur_fpn.size(0))
137
+ x = cur_fpn + F.interpolate(x, size=cur_fpn.shape[-2:], mode="nearest")
138
+ x = self.lay5(x)
139
+ x = self.gn5(x)
140
+ x = F.relu(x)
141
+
142
+ x = self.out_lay(x)
143
+ return x
144
+
145
+
146
+ class MHAttentionMap(nn.Module):
147
+ """This is a 2D attention module, which only returns the attention softmax (no multiplication by value)"""
148
+
149
+ def __init__(self, query_dim, hidden_dim, num_heads, dropout=0, bias=True):
150
+ super().__init__()
151
+ self.num_heads = num_heads
152
+ self.hidden_dim = hidden_dim
153
+ self.dropout = nn.Dropout(dropout)
154
+
155
+ self.q_linear = nn.Linear(query_dim, hidden_dim, bias=bias)
156
+ self.k_linear = nn.Linear(query_dim, hidden_dim, bias=bias)
157
+
158
+ nn.init.zeros_(self.k_linear.bias)
159
+ nn.init.zeros_(self.q_linear.bias)
160
+ nn.init.xavier_uniform_(self.k_linear.weight)
161
+ nn.init.xavier_uniform_(self.q_linear.weight)
162
+ self.normalize_fact = float(hidden_dim / self.num_heads) ** -0.5
163
+
164
+ def forward(self, q, k, mask=None):
165
+ q = self.q_linear(q)
166
+ k = F.conv2d(k, self.k_linear.weight.unsqueeze(-1).unsqueeze(-1), self.k_linear.bias)
167
+ qh = q.view(q.shape[0], q.shape[1], self.num_heads, self.hidden_dim // self.num_heads)
168
+ kh = k.view(k.shape[0], self.num_heads, self.hidden_dim // self.num_heads, k.shape[-2], k.shape[-1])
169
+ weights = torch.einsum("bqnc,bnchw->bqnhw", qh * self.normalize_fact, kh)
170
+
171
+ if mask is not None:
172
+ weights.masked_fill_(mask.unsqueeze(1).unsqueeze(1), float("-inf"))
173
+ weights = F.softmax(weights.flatten(2), dim=-1).view_as(weights)
174
+ weights = self.dropout(weights)
175
+ return weights
176
+
177
+
178
+ def dice_loss(inputs, targets, num_boxes):
179
+ """
180
+ Compute the DICE loss, similar to generalized IOU for masks
181
+ Args:
182
+ inputs: A float tensor of arbitrary shape.
183
+ The predictions for each example.
184
+ targets: A float tensor with the same shape as inputs. Stores the binary
185
+ classification label for each element in inputs
186
+ (0 for the negative class and 1 for the positive class).
187
+ """
188
+ inputs = inputs.sigmoid()
189
+ inputs = inputs.flatten(1)
190
+ numerator = 2 * (inputs * targets).sum(1)
191
+ denominator = inputs.sum(-1) + targets.sum(-1)
192
+ loss = 1 - (numerator + 1) / (denominator + 1)
193
+ return loss.sum() / num_boxes
194
+
195
+
196
+ def sigmoid_focal_loss(inputs, targets, num_boxes, alpha: float = 0.25, gamma: float = 2):
197
+ """
198
+ Loss used in RetinaNet for dense detection: https://arxiv.org/abs/1708.02002.
199
+ Args:
200
+ inputs: A float tensor of arbitrary shape.
201
+ The predictions for each example.
202
+ targets: A float tensor with the same shape as inputs. Stores the binary
203
+ classification label for each element in inputs
204
+ (0 for the negative class and 1 for the positive class).
205
+ alpha: (optional) Weighting factor in range (0,1) to balance
206
+ positive vs negative examples. Default = -1 (no weighting).
207
+ gamma: Exponent of the modulating factor (1 - p_t) to
208
+ balance easy vs hard examples.
209
+ Returns:
210
+ Loss tensor
211
+ """
212
+ prob = inputs.sigmoid()
213
+ ce_loss = F.binary_cross_entropy_with_logits(inputs, targets, reduction="none")
214
+ p_t = prob * targets + (1 - prob) * (1 - targets)
215
+ loss = ce_loss * ((1 - p_t) ** gamma)
216
+
217
+ if alpha >= 0:
218
+ alpha_t = alpha * targets + (1 - alpha) * (1 - targets)
219
+ loss = alpha_t * loss
220
+
221
+ return loss.mean(1).sum() / num_boxes
222
+
223
+
224
+ class PostProcessSegm(nn.Module):
225
+ def __init__(self, threshold=0.5):
226
+ super().__init__()
227
+ self.threshold = threshold
228
+
229
+ @torch.no_grad()
230
+ def forward(self, results, outputs, orig_target_sizes, max_target_sizes):
231
+ assert len(orig_target_sizes) == len(max_target_sizes)
232
+ max_h, max_w = max_target_sizes.max(0)[0].tolist()
233
+ outputs_masks = outputs["pred_masks"].squeeze(2)
234
+ outputs_masks = F.interpolate(outputs_masks, size=(max_h, max_w), mode="bilinear", align_corners=False)
235
+ outputs_masks = (outputs_masks.sigmoid() > self.threshold).cpu()
236
+
237
+ for i, (cur_mask, t, tt) in enumerate(zip(outputs_masks, max_target_sizes, orig_target_sizes)):
238
+ img_h, img_w = t[0], t[1]
239
+ results[i]["masks"] = cur_mask[:, :img_h, :img_w].unsqueeze(1)
240
+ results[i]["masks"] = F.interpolate(
241
+ results[i]["masks"].float(), size=tuple(tt.tolist()), mode="nearest"
242
+ ).byte()
243
+
244
+ return results
245
+
246
+
247
+ class PostProcessPanoptic(nn.Module):
248
+ """This class converts the output of the model to the final panoptic result, in the format expected by the
249
+ coco panoptic API """
250
+
251
+ def __init__(self, is_thing_map, threshold=0.85):
252
+ """
253
+ Parameters:
254
+ is_thing_map: This is a whose keys are the class ids, and the values a boolean indicating whether
255
+ the class is a thing (True) or a stuff (False) class
256
+ threshold: confidence threshold: segments with confidence lower than this will be deleted
257
+ """
258
+ super().__init__()
259
+ self.threshold = threshold
260
+ self.is_thing_map = is_thing_map
261
+
262
+ def forward(self, outputs, processed_sizes, target_sizes=None):
263
+ """ This function computes the panoptic prediction from the model's predictions.
264
+ Parameters:
265
+ outputs: This is a dict coming directly from the model. See the model doc for the content.
266
+ processed_sizes: This is a list of tuples (or torch tensors) of sizes of the images that were passed to the
267
+ model, ie the size after data augmentation but before batching.
268
+ target_sizes: This is a list of tuples (or torch tensors) corresponding to the requested final size
269
+ of each prediction. If left to None, it will default to the processed_sizes
270
+ """
271
+ if target_sizes is None:
272
+ target_sizes = processed_sizes
273
+ assert len(processed_sizes) == len(target_sizes)
274
+ out_logits, raw_masks, raw_boxes = outputs["pred_logits"], outputs["pred_masks"], outputs["pred_boxes"]
275
+ assert len(out_logits) == len(raw_masks) == len(target_sizes)
276
+ preds = []
277
+
278
+ def to_tuple(tup):
279
+ if isinstance(tup, tuple):
280
+ return tup
281
+ return tuple(tup.cpu().tolist())
282
+
283
+ for cur_logits, cur_masks, cur_boxes, size, target_size in zip(
284
+ out_logits, raw_masks, raw_boxes, processed_sizes, target_sizes
285
+ ):
286
+ # we filter empty queries and detection below threshold
287
+ scores, labels = cur_logits.softmax(-1).max(-1)
288
+ keep = labels.ne(outputs["pred_logits"].shape[-1] - 1) & (scores > self.threshold)
289
+ cur_scores, cur_classes = cur_logits.softmax(-1).max(-1)
290
+ cur_scores = cur_scores[keep]
291
+ cur_classes = cur_classes[keep]
292
+ cur_masks = cur_masks[keep]
293
+ cur_masks = interpolate(cur_masks[None], to_tuple(size), mode="bilinear").squeeze(0)
294
+ cur_boxes = box_ops.box_cxcywh_to_xyxy(cur_boxes[keep])
295
+
296
+ h, w = cur_masks.shape[-2:]
297
+ assert len(cur_boxes) == len(cur_classes)
298
+
299
+ # It may be that we have several predicted masks for the same stuff class.
300
+ # In the following, we track the list of masks ids for each stuff class (they are merged later on)
301
+ cur_masks = cur_masks.flatten(1)
302
+ stuff_equiv_classes = defaultdict(lambda: [])
303
+ for k, label in enumerate(cur_classes):
304
+ if not self.is_thing_map[label.item()]:
305
+ stuff_equiv_classes[label.item()].append(k)
306
+
307
+ def get_ids_area(masks, scores, dedup=False):
308
+ # This helper function creates the final panoptic segmentation image
309
+ # It also returns the area of the masks that appears on the image
310
+
311
+ m_id = masks.transpose(0, 1).softmax(-1)
312
+
313
+ if m_id.shape[-1] == 0:
314
+ # We didn't detect any mask :(
315
+ m_id = torch.zeros((h, w), dtype=torch.long, device=m_id.device)
316
+ else:
317
+ m_id = m_id.argmax(-1).view(h, w)
318
+
319
+ if dedup:
320
+ # Merge the masks corresponding to the same stuff class
321
+ for equiv in stuff_equiv_classes.values():
322
+ if len(equiv) > 1:
323
+ for eq_id in equiv:
324
+ m_id.masked_fill_(m_id.eq(eq_id), equiv[0])
325
+
326
+ final_h, final_w = to_tuple(target_size)
327
+
328
+ seg_img = Image.fromarray(id2rgb(m_id.view(h, w).cpu().numpy()))
329
+ seg_img = seg_img.resize(size=(final_w, final_h), resample=Image.NEAREST)
330
+
331
+ np_seg_img = (
332
+ torch.ByteTensor(torch.ByteStorage.from_buffer(seg_img.tobytes())).view(final_h, final_w, 3).numpy()
333
+ )
334
+ m_id = torch.from_numpy(rgb2id(np_seg_img))
335
+
336
+ area = []
337
+ for i in range(len(scores)):
338
+ area.append(m_id.eq(i).sum().item())
339
+ return area, seg_img
340
+
341
+ area, seg_img = get_ids_area(cur_masks, cur_scores, dedup=True)
342
+ if cur_classes.numel() > 0:
343
+ # We know filter empty masks as long as we find some
344
+ while True:
345
+ filtered_small = torch.as_tensor(
346
+ [area[i] <= 4 for i, c in enumerate(cur_classes)], dtype=torch.bool, device=keep.device
347
+ )
348
+ if filtered_small.any().item():
349
+ cur_scores = cur_scores[~filtered_small]
350
+ cur_classes = cur_classes[~filtered_small]
351
+ cur_masks = cur_masks[~filtered_small]
352
+ area, seg_img = get_ids_area(cur_masks, cur_scores)
353
+ else:
354
+ break
355
+
356
+ else:
357
+ cur_classes = torch.ones(1, dtype=torch.long, device=cur_classes.device)
358
+
359
+ segments_info = []
360
+ for i, a in enumerate(area):
361
+ cat = cur_classes[i].item()
362
+ segments_info.append({"id": i, "isthing": self.is_thing_map[cat], "category_id": cat, "area": a})
363
+ del cur_classes
364
+
365
+ with io.BytesIO() as out:
366
+ seg_img.save(out, format="PNG")
367
+ predictions = {"png_string": out.getvalue(), "segments_info": segments_info}
368
+ preds.append(predictions)
369
+ return preds